Fix torch compile regression on fp8 ops. (#10580)

This commit is contained in:
comfyanonymous
2025-10-31 21:25:17 -07:00
committed by GitHub
parent 7f374e42c8
commit c58c13b2ba
4 changed files with 43 additions and 36 deletions

View File

@@ -401,15 +401,9 @@ def fp8_linear(self, input):
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight
@@ -422,24 +416,20 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream)
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return o
return None
@@ -540,12 +530,12 @@ if CUBLAS_IS_AVAILABLE:
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
from .quant_ops import QuantizedTensor
QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": {
"dtype": torch.float8_e4m3fn,
"layout_type": TensorCoreFP8Layout,
"layout_type": "TensorCoreFP8Layout",
"parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),

View File

@@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor):
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous()
@@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor):
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return self._layout_type.dequantize(self._qdata, **self._layout_params)
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout):
return qtensor._qdata, qtensor._layout_params['scale']
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
@@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs):
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
@@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)