block info (#10843)

This commit is contained in:
Haoming
2025-11-25 02:30:40 +08:00
committed by GitHub
parent 1286fcfe40
commit 3d1fdaf9f4

View File

@@ -179,7 +179,10 @@ class Chroma(nn.Module):
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if i not in self.skip_mmdit: if i not in self.skip_mmdit:
double_mod = ( double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_img", idx=i),
@@ -222,7 +225,10 @@ class Chroma(nn.Module):
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit: if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i) single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace: