Mixed Precision Quantization System (#10498)

* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.

* Updated design using Tensor Subclasses

* Fix FP8 MM

* An actually functional POC

* Remove CK reference and ensure correct compute dtype

* Update unit tests

* ruff lint

* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.

* Updated design using Tensor Subclasses

* Fix FP8 MM

* An actually functional POC

* Remove CK reference and ensure correct compute dtype

* Update unit tests

* ruff lint

* Fix missing keys

* Rename quant dtype parameter

* Rename quant dtype parameter

* Fix unittests for CPU build
This commit is contained in:
contentis
2025-10-28 21:20:53 +01:00
committed by GitHub
parent 22e40d2ace
commit 8817f8fc14
8 changed files with 1030 additions and 19 deletions

View File

@@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -333,6 +333,14 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION: