Fix dtype issue in embeddings connector. (#12570)

This commit is contained in:
comfyanonymous
2026-02-22 00:18:20 -08:00
committed by GitHub
parent f266b8d352
commit 07ca6852e8
+3 -3
View File
@@ -234,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
return indices return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp"): def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
dim = self.inner_dim dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing) freqs = self.precompute_freqs(indices_grid, spacing)
@@ -247,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
) )
else: else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
def forward( def forward(
self, self,
@@ -288,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
) )
indices_grid = indices_grid[None, None, :] indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid) freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
# 2. Blocks # 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks): for block_idx, block in enumerate(self.transformer_1d_blocks):