Changes to the previous radiance commit. (#9851)
This commit is contained in:
69
comfy/sd.py
69
comfy/sd.py
@@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.pixel_space_convert
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
@@ -516,6 +517,15 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
elif "pixel_space_vae" in sd:
|
||||
self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE()
|
||||
self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.downscale_ratio = 1
|
||||
self.upscale_ratio = 1
|
||||
self.latent_channels = 3
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -785,65 +795,6 @@ class VAE:
|
||||
except:
|
||||
return None
|
||||
|
||||
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
|
||||
# to LATENT B, C, H, W and values on the scale of -1..1.
|
||||
class PixelspaceConversionVAE:
|
||||
def __init__(self, size_increment: int=16):
|
||||
self.intermediate_device = comfy.model_management.intermediate_device()
|
||||
self.size_increment = size_increment
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor:
|
||||
if self.size_increment == 1:
|
||||
return pixels
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
d_adj = (dims[d] // self.size_increment) * self.size_increment
|
||||
if d_adj == d:
|
||||
continue
|
||||
d_offset = (dims[d] % self.size_increment) // 2
|
||||
pixels = pixels.narrow(d + 1, d_offset, d_adj)
|
||||
return pixels
|
||||
|
||||
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||
if pixels.ndim == 3:
|
||||
pixels = pixels.unsqueeze(0)
|
||||
elif pixels.ndim != 4:
|
||||
raise ValueError("Unexpected input image shape")
|
||||
# Ensure the image has spatial dimensions that are multiples of 16.
|
||||
pixels = self.vae_encode_crop_pixels(pixels)
|
||||
h, w, c = pixels.shape[1:]
|
||||
if h < self.size_increment or w < self.size_increment:
|
||||
raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).")
|
||||
pixels= pixels[..., :3]
|
||||
if c == 1:
|
||||
pixels = pixels.expand(-1, -1, -1, 3)
|
||||
elif c != 3:
|
||||
raise ValueError("Unexpected number of channels in input image")
|
||||
# Rescale to -1..1 and move the channel dimension to position 1.
|
||||
latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
|
||||
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
|
||||
latent -= 0.5
|
||||
latent *= 2
|
||||
return latent.clamp_(-1, 1)
|
||||
|
||||
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||
# Rescale to 0..1 and move the channel dimension to the end.
|
||||
img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
|
||||
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
|
||||
img += 1.0
|
||||
img *= 0.5
|
||||
return img.clamp_(0, 1)
|
||||
|
||||
encode_tiled = encode
|
||||
decode_tiled = decode
|
||||
|
||||
@classmethod
|
||||
def spacial_compression_decode(cls) -> int:
|
||||
# This just exists so the tiled VAE nodes don't crash.
|
||||
return 1
|
||||
|
||||
spacial_compression_encode = spacial_compression_decode
|
||||
temporal_compression_decode = spacial_compression_decode
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
|
||||
Reference in New Issue
Block a user