Fix torch compile regression on fp8 ops. (#10580)

This commit is contained in:
comfyanonymous
2025-10-31 21:25:17 -07:00
committed by GitHub
parent 7f374e42c8
commit c58c13b2ba
4 changed files with 43 additions and 36 deletions

View File

@@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor):
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous()
@@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor):
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return self._layout_type.dequantize(self._qdata, **self._layout_params)
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout):
return qtensor._qdata, qtensor._layout_params['scale']
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
@@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs):
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
@@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)