Move text projection into the CLIP model code.

Fix issue with not loading the SSD1B clip correctly.
This commit is contained in:
comfyanonymous
2024-02-25 01:41:08 -05:00
parent 6533b172c1
commit 1cb3f6a83b
5 changed files with 33 additions and 15 deletions

View File

@@ -119,6 +119,9 @@ class CLIPTextModel(torch.nn.Module):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
@@ -128,7 +131,10 @@ class CLIPTextModel(torch.nn.Module):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
return self.text_model(*args, **kwargs)
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out)
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):