Mixed Precision Quantization System (#10498)
* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. * Updated design using Tensor Subclasses * Fix FP8 MM * An actually functional POC * Remove CK reference and ensure correct compute dtype * Update unit tests * ruff lint * Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. * Updated design using Tensor Subclasses * Fix FP8 MM * An actually functional POC * Remove CK reference and ensure correct compute dtype * Update unit tests * ruff lint * Fix missing keys * Rename quant dtype parameter * Rename quant dtype parameter * Fix unittests for CPU build
This commit is contained in:
@@ -6,6 +6,20 @@ import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
|
||||
def detect_layer_quantization(metadata):
|
||||
quant_key = "_quantization_metadata"
|
||||
if metadata is not None and quant_key in metadata:
|
||||
quant_metadata = metadata.pop(quant_key)
|
||||
quant_metadata = json.loads(quant_metadata)
|
||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError("Invalid quantization metadata format")
|
||||
return None
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
@@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
layer_quant_config = detect_layer_quantization(metadata)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
|
||||
Reference in New Issue
Block a user