Add support for Chroma Radiance (#9682)

* Initial Chroma Radiance support

* Minor Chroma Radiance cleanups

* Update Radiance nodes to ensure latents/images are on the intermediate device

* Fix Chroma Radiance memory estimation.

* Increase Chroma Radiance memory usage factor

* Increase Chroma Radiance memory usage factor once again

* Ensure images are multiples of 16 for Chroma Radiance
Add batch dimension and fix channels when necessary in ChromaRadianceImageToLatent node

* Tile Chroma Radiance NeRF to reduce memory consumption, update memory usage factor

* Update Radiance to support conv nerf final head type.

* Allow setting NeRF embedder dtype for Radiance
Bump Radiance nerf tile size to 32
Support EasyCache/LazyCache on Radiance (maybe)

* Add ChromaRadianceStubVAE node

* Crop Radiance image inputs to multiples of 16 instead of erroring to be in line with existing VAE behavior

* Convert Chroma Radiance nodes to V3 schema.

* Add ChromaRadianceOptions node and backend support.
Cleanups/refactoring to reduce code duplication with Chroma.

* Fix overriding the NeRF embedder dtype for Chroma Radiance

* Minor Chroma Radiance cleanups

* Move Chroma Radiance to its own directory in ldm
Minor code cleanups and tooltip improvements

* Fix Chroma Radiance embedder dtype overriding

* Remove Radiance dynamic nerf_embedder dtype override feature

* Unbork Radiance NeRF embedder init

* Remove Chroma Radiance image conversion and stub VAE nodes
Add a chroma_radiance option to the VAELoader builtin node which uses comfy.sd.PixelspaceConversionVAE
Add a PixelspaceConversionVAE to comfy.sd for converting BHWC 0..1 <-> BCHW -1..1
This commit is contained in:
blepping
2025-09-13 15:58:43 -06:00
committed by GitHub
parent e5e70636e7
commit c1297f4eb3
10 changed files with 770 additions and 9 deletions

View File

@@ -785,6 +785,66 @@ class VAE:
except:
return None
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
# to LATENT B, C, H, W and values on the scale of -1..1.
class PixelspaceConversionVAE:
def __init__(self, size_increment: int=16):
self.intermediate_device = comfy.model_management.intermediate_device()
self.size_increment = size_increment
def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor:
if self.size_increment == 1:
return pixels
dims = pixels.shape[1:-1]
for d in range(len(dims)):
d_adj = (dims[d] // self.size_increment) * self.size_increment
if d_adj == d:
continue
d_offset = (dims[d] % self.size_increment) // 2
pixels = pixels.narrow(d + 1, d_offset, d_adj)
return pixels
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
if pixels.ndim == 3:
pixels = pixels.unsqueeze(0)
elif pixels.ndim != 4:
raise ValueError("Unexpected input image shape")
# Ensure the image has spatial dimensions that are multiples of 16.
pixels = self.vae_encode_crop_pixels(pixels)
h, w, c = pixels.shape[1:]
if h < self.size_increment or w < self.size_increment:
raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).")
pixels= pixels[..., :3]
if c == 1:
pixels = pixels.expand(-1, -1, -1, 3)
elif c != 3:
raise ValueError("Unexpected number of channels in input image")
# Rescale to -1..1 and move the channel dimension to position 1.
latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
latent -= 0.5
latent *= 2
return latent.clamp_(-1, 1)
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
# Rescale to 0..1 and move the channel dimension to the end.
img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
img += 1.0
img *= 0.5
return img.clamp_(0, 1)
encode_tiled = encode
decode_tiled = decode
@classmethod
def spacial_compression_decode(cls) -> int:
# This just exists so the tiled VAE nodes don't crash.
return 1
spacial_compression_encode = spacial_compression_decode
temporal_compression_decode = spacial_compression_decode
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model