Add default device argument. (#9023)
This commit is contained in:
@@ -880,6 +880,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
return d
|
||||
|
||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||
return d
|
||||
|
||||
|
||||
Reference in New Issue
Block a user