Add slice_cond and per-model context window cond resizing (#12645)
* Add slice_cond and per-model context window cond resizing * Fix cond_value.size() call in context window cond resizing * Expose additional advanced inputs for ContextWindowsManualNode Necessary for WanAnimate context windows workflow, which needs cond_retain_index_list = 0 to work properly with its reference input. ---------
This commit is contained in:
@@ -93,6 +93,50 @@ class IndexListCallbacks:
|
||||
return {}
|
||||
|
||||
|
||||
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
|
||||
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
|
||||
return None
|
||||
cond_tensor = cond_value.cond
|
||||
if temporal_dim >= cond_tensor.ndim:
|
||||
return None
|
||||
|
||||
cond_size = cond_tensor.size(temporal_dim)
|
||||
|
||||
if temporal_scale == 1:
|
||||
expected_size = x_in.size(window.dim) - temporal_offset
|
||||
if cond_size != expected_size:
|
||||
return None
|
||||
|
||||
if temporal_offset == 0 and temporal_scale == 1:
|
||||
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||
if temporal_offset > 0:
|
||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||
indices = [i for i in indices if 0 <= i]
|
||||
else:
|
||||
indices = list(window.index_list)
|
||||
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
if temporal_scale > 1:
|
||||
scaled = []
|
||||
for i in indices:
|
||||
for k in range(temporal_scale):
|
||||
si = i * temporal_scale + k
|
||||
if si < cond_size:
|
||||
scaled.append(si)
|
||||
indices = scaled
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
idx = tuple([slice(None)] * temporal_dim + [indices])
|
||||
sliced = cond_tensor[idx].to(device)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if not handled and self._model is not None:
|
||||
result = self._model.resize_cond_for_context_window(
|
||||
cond_key, cond_value, window, x_in, device,
|
||||
retain_index_list=self.cond_retain_index_list)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
@@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
Reference in New Issue
Block a user