Add an option --fp16-unet to force using fp16 for the unet.
This commit is contained in:
@@ -466,6 +466,8 @@ def unet_inital_load_device(parameters, dtype):
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
return torch.float16
|
||||
if args.fp8_e4m3fn_unet:
|
||||
return torch.float8_e4m3fn
|
||||
if args.fp8_e5m2_unet:
|
||||
|
||||
Reference in New Issue
Block a user