Lower ltxv mem usage to what it was before previous pr. (#10643)

Bring back qwen behavior to what it was before previous pr.
This commit is contained in:
comfyanonymous
2025-11-04 19:47:35 -08:00
committed by GitHub
parent 4cd881866b
commit c4a6b389de
2 changed files with 12 additions and 12 deletions

View File

@@ -415,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)