Some fixes/cleanups to pixart code.
Commented out the masking related code because it is never used in this implementation.
This commit is contained in:
@@ -127,12 +127,8 @@ class PixArt(nn.Module):
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
x = x.to(self.dtype)
|
||||
timestep = t.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
pos_embed = self.pos_embed.to(self.dtype)
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
|
||||
t = self.t_embedder(timestep) # (N, D)
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
||||
if mask is not None:
|
||||
@@ -142,7 +138,7 @@ class PixArt(nn.Module):
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y_lens = None
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens) # (N, T, D)
|
||||
@@ -164,13 +160,12 @@ class PixArt(nn.Module):
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_raw(
|
||||
x = x.to(self.dtype),
|
||||
t = timesteps.to(self.dtype),
|
||||
y = context.to(self.dtype),
|
||||
x = x,
|
||||
t = timesteps,
|
||||
y = context,
|
||||
)
|
||||
|
||||
## only return EPS
|
||||
out = out.to(torch.float)
|
||||
eps, _ = out[:, :self.in_channels], out[:, self.in_channels:]
|
||||
return eps
|
||||
|
||||
|
||||
Reference in New Issue
Block a user