minor changes for tiled sampler

This commit is contained in:
BlenderNeko
2023-05-12 23:49:09 +02:00
parent 8ea165dd1e
commit d9e088ddfd
2 changed files with 10 additions and 7 deletions

View File

@@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
"""
B, N, _ = metric.shape
if r <= 0:
if r <= 0 or w == 1 or h == 1:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather