ltxv: add noise to guidance image to ensure generated motion. (#5937)

This commit is contained in:
Michael Kupchick
2024-12-06 12:46:08 +02:00
committed by GitHub
parent 1e21f4c14e
commit 005d2d3a13
2 changed files with 18 additions and 1 deletions

View File

@@ -379,6 +379,7 @@ class LTXVModel(torch.nn.Module):
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
@@ -417,6 +418,7 @@ class LTXVModel(torch.nn.Module):
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
image_noise_scale = transformer_options.get("image_noise_scale", 0.15)
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
@@ -435,6 +437,17 @@ class LTXVModel(torch.nn.Module):
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if image_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
guiding_noise = image_noise_scale * (input_ts ** 2) * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] += guiding_noise[:, :, 0]
orig_shape = list(x.shape)