Support new flux model variants.

This commit is contained in:
comfyanonymous
2024-11-21 08:38:23 -05:00
parent 41444b5236
commit 8f0009aad0
9 changed files with 147 additions and 19 deletions

View File

@@ -20,6 +20,7 @@ import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
@@ -29,6 +30,7 @@ class FluxParams:
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: int
qkv_bias: bool
guidance_embed: bool
@@ -43,8 +45,9 @@ class Flux(nn.Module):
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
self.patch_size = params.patch_size
self.in_channels = params.in_channels * params.patch_size * params.patch_size
self.out_channels = params.out_channels * params.patch_size * params.patch_size
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -165,7 +168,7 @@ class Flux(nn.Module):
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

25
comfy/ldm/flux/redux.py Normal file
View File

@@ -0,0 +1,25 @@
import torch
import comfy.ops
ops = comfy.ops.manual_cast
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x