minor changes for tiled sampler
This commit is contained in:
@@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
"""
|
||||
B, N, _ = metric.shape
|
||||
|
||||
if r <= 0:
|
||||
if r <= 0 or w == 1 or h == 1:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
Reference in New Issue
Block a user