Some fixes/cleanups to pixart code.

Commented out the masking related code because it is never used in this
implementation.
This commit is contained in:
comfyanonymous
2024-12-20 17:10:52 -05:00
parent d7969cb070
commit e946667216
4 changed files with 42 additions and 42 deletions

View File

@@ -127,12 +127,8 @@ class PixArt(nn.Module):
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = t.to(self.dtype)
y = y.to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t = self.t_embedder(timestep) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
@@ -142,7 +138,7 @@ class PixArt(nn.Module):
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = block(x, y, t0, y_lens) # (N, T, D)
@@ -164,13 +160,12 @@ class PixArt(nn.Module):
## run original forward pass
out = self.forward_raw(
x = x.to(self.dtype),
t = timesteps.to(self.dtype),
y = context.to(self.dtype),
x = x,
t = timesteps,
y = context,
)
## only return EPS
out = out.to(torch.float)
eps, _ = out[:, :self.in_channels], out[:, self.in_channels:]
return eps