Add support for Chroma Radiance (#9682)
* Initial Chroma Radiance support * Minor Chroma Radiance cleanups * Update Radiance nodes to ensure latents/images are on the intermediate device * Fix Chroma Radiance memory estimation. * Increase Chroma Radiance memory usage factor * Increase Chroma Radiance memory usage factor once again * Ensure images are multiples of 16 for Chroma Radiance Add batch dimension and fix channels when necessary in ChromaRadianceImageToLatent node * Tile Chroma Radiance NeRF to reduce memory consumption, update memory usage factor * Update Radiance to support conv nerf final head type. * Allow setting NeRF embedder dtype for Radiance Bump Radiance nerf tile size to 32 Support EasyCache/LazyCache on Radiance (maybe) * Add ChromaRadianceStubVAE node * Crop Radiance image inputs to multiples of 16 instead of erroring to be in line with existing VAE behavior * Convert Chroma Radiance nodes to V3 schema. * Add ChromaRadianceOptions node and backend support. Cleanups/refactoring to reduce code duplication with Chroma. * Fix overriding the NeRF embedder dtype for Chroma Radiance * Minor Chroma Radiance cleanups * Move Chroma Radiance to its own directory in ldm Minor code cleanups and tooltip improvements * Fix Chroma Radiance embedder dtype overriding * Remove Radiance dynamic nerf_embedder dtype override feature * Unbork Radiance NeRF embedder init * Remove Chroma Radiance image conversion and stub VAE nodes Add a chroma_radiance option to the VAELoader builtin node which uses comfy.sd.PixelspaceConversionVAE Add a PixelspaceConversionVAE to comfy.sd for converting BHWC 0..1 <-> BCHW -1..1
This commit is contained in:
60
comfy/sd.py
60
comfy/sd.py
@@ -785,6 +785,66 @@ class VAE:
|
||||
except:
|
||||
return None
|
||||
|
||||
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
|
||||
# to LATENT B, C, H, W and values on the scale of -1..1.
|
||||
class PixelspaceConversionVAE:
|
||||
def __init__(self, size_increment: int=16):
|
||||
self.intermediate_device = comfy.model_management.intermediate_device()
|
||||
self.size_increment = size_increment
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor:
|
||||
if self.size_increment == 1:
|
||||
return pixels
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
d_adj = (dims[d] // self.size_increment) * self.size_increment
|
||||
if d_adj == d:
|
||||
continue
|
||||
d_offset = (dims[d] % self.size_increment) // 2
|
||||
pixels = pixels.narrow(d + 1, d_offset, d_adj)
|
||||
return pixels
|
||||
|
||||
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||
if pixels.ndim == 3:
|
||||
pixels = pixels.unsqueeze(0)
|
||||
elif pixels.ndim != 4:
|
||||
raise ValueError("Unexpected input image shape")
|
||||
# Ensure the image has spatial dimensions that are multiples of 16.
|
||||
pixels = self.vae_encode_crop_pixels(pixels)
|
||||
h, w, c = pixels.shape[1:]
|
||||
if h < self.size_increment or w < self.size_increment:
|
||||
raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).")
|
||||
pixels= pixels[..., :3]
|
||||
if c == 1:
|
||||
pixels = pixels.expand(-1, -1, -1, 3)
|
||||
elif c != 3:
|
||||
raise ValueError("Unexpected number of channels in input image")
|
||||
# Rescale to -1..1 and move the channel dimension to position 1.
|
||||
latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
|
||||
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
|
||||
latent -= 0.5
|
||||
latent *= 2
|
||||
return latent.clamp_(-1, 1)
|
||||
|
||||
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||
# Rescale to 0..1 and move the channel dimension to the end.
|
||||
img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
|
||||
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
|
||||
img += 1.0
|
||||
img *= 0.5
|
||||
return img.clamp_(0, 1)
|
||||
|
||||
encode_tiled = encode
|
||||
decode_tiled = decode
|
||||
|
||||
@classmethod
|
||||
def spacial_compression_decode(cls) -> int:
|
||||
# This just exists so the tiled VAE nodes don't crash.
|
||||
return 1
|
||||
|
||||
spacial_compression_encode = spacial_compression_decode
|
||||
temporal_compression_decode = spacial_compression_decode
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
self.model = model
|
||||
|
||||
Reference in New Issue
Block a user