Use faster manual cast for fp8 in unet.
This commit is contained in:
@@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.ops
|
||||
from enum import Enum
|
||||
import contextlib
|
||||
from . import utils
|
||||
@@ -41,9 +42,14 @@ class BaseModel(torch.nn.Module):
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
if self.manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@@ -63,11 +69,8 @@ class BaseModel(torch.nn.Module):
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if comfy.model_management.supports_dtype(xc.device, dtype):
|
||||
precision_scope = lambda a: contextlib.nullcontext(a)
|
||||
else:
|
||||
precision_scope = torch.autocast
|
||||
dtype = torch.float32
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
xc = xc.to(dtype)
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
@@ -79,9 +82,7 @@ class BaseModel(torch.nn.Module):
|
||||
extra = extra.to(dtype)
|
||||
extra_conds[o] = extra
|
||||
|
||||
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def get_dtype(self):
|
||||
|
||||
Reference in New Issue
Block a user