Support for qwen edit plus model. Use the new TextEncodeQwenImageEditPlus. (#9986)

This commit is contained in:
comfyanonymous
2025-09-22 13:49:48 -07:00
committed by GitHub
parent 27bc181c49
commit 1fee8827cb
2 changed files with 65 additions and 6 deletions

View File

@@ -400,21 +400,25 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
grid = None
position_ids = None
offset = 0
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
start = e.get("index")
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
position_ids[0, start:end] = start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
if grid is None:
position_ids = None