Lower cosmos diffusion model memory usage.

This commit is contained in:
comfyanonymous
2025-01-15 23:46:42 -05:00
parent 4758fb64b9
commit 25683b5b02
3 changed files with 26 additions and 17 deletions

View File

@@ -173,6 +173,7 @@ class LearnablePosEmbAxis(VideoPositionEmb):
len_w: int,
len_t: int,
device=None,
dtype=None,
**kwargs,
):
"""
@@ -184,9 +185,9 @@ class LearnablePosEmbAxis(VideoPositionEmb):
self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: