Add support for attention masking in Flux (#5942)

* fix attention OOM in xformers

* allow passing attention mask in flux attention

* allow an attn_mask in flux

* attn masks can be done using replace patches instead of a separate dict

* fix return types

* fix return order

* enumerate

* patch the right keys

* arg names

* fix a silly bug

* fix xformers masks

* replace match with if, elif, else

* mask with image_ref_size

* remove unused import

* remove unused import 2

* fix pytorch/xformers attention

This corrects a weird inconsistency with skip_reshape.
It also allows masks of various shapes to be passed, which will be
automtically expanded (in a memory-efficient way) to a size that is
compatible with xformers or pytorch sdpa respectively.

* fix mask shapes
This commit is contained in:
Raphael Walker
2024-12-17 00:21:17 +01:00
committed by GitHub
parent 0f954f34af
commit 61b50720d0
7 changed files with 182 additions and 48 deletions

View File

@@ -142,7 +142,7 @@ class DoubleStreamBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -163,7 +163,8 @@ class DoubleStreamBlock(nn.Module):
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2), pe=pe)
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
@@ -217,7 +218,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -226,7 +227,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output

View File

@@ -1,14 +1,15 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True)
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x
@@ -33,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -4,6 +4,8 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
from .layers import (
DoubleStreamBlock,
@@ -14,9 +16,6 @@ from .layers import (
timestep_embedding,
)
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
@@ -98,8 +97,9 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -124,14 +124,27 @@ class Flux(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -146,13 +159,20 @@ class Flux(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -181,5 +201,5 @@ class Flux(nn.Module):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]