Fix dtype issue in embeddings connector. (#12570)
This commit is contained in:
@@ -234,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
|
|
||||||
return indices
|
return indices
|
||||||
|
|
||||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
||||||
dim = self.inner_dim
|
dim = self.inner_dim
|
||||||
n_elem = 2 # 2 because of cos and sin
|
n_elem = 2 # 2 because of cos and sin
|
||||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||||
@@ -247,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -288,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
indices_grid = indices_grid[None, None, :]
|
indices_grid = indices_grid[None, None, :]
|
||||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
||||||
|
|
||||||
# 2. Blocks
|
# 2. Blocks
|
||||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||||
|
|||||||
Reference in New Issue
Block a user