Use faster manual cast for fp8 in unet.

This commit is contained in:
comfyanonymous
2023-12-11 18:24:44 -05:00
parent ab93abd4b2
commit ba07cb748e
5 changed files with 48 additions and 12 deletions

View File

@@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import comfy.conds
import comfy.ops
from enum import Enum
import contextlib
from . import utils
@@ -41,9 +42,14 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@@ -63,11 +69,8 @@ class BaseModel(torch.nn.Module):
context = c_crossattn
dtype = self.get_dtype()
if comfy.model_management.supports_dtype(xc.device, dtype):
precision_scope = lambda a: contextlib.nullcontext(a)
else:
precision_scope = torch.autocast
dtype = torch.float32
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
@@ -79,9 +82,7 @@ class BaseModel(torch.nn.Module):
extra = extra.to(dtype)
extra_conds[o] = extra
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def get_dtype(self):