Fix torch compile regression on fp8 ops. (#10580)
This commit is contained in:
24
comfy/ops.py
24
comfy/ops.py
@@ -401,15 +401,9 @@ def fp8_linear(self, input):
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
tensor_2d = False
|
||||
if len(input.shape) == 2:
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
|
||||
input_shape = input.shape
|
||||
input_dtype = input.dtype
|
||||
|
||||
if len(input.shape) == 3:
|
||||
if input.ndim == 3 or input.ndim == 2:
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
@@ -422,24 +416,20 @@ def fp8_linear(self, input):
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input_shape[0], -1)
|
||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||
return o
|
||||
|
||||
return None
|
||||
|
||||
@@ -540,12 +530,12 @@ if CUBLAS_IS_AVAILABLE:
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||
from .quant_ops import QuantizedTensor
|
||||
|
||||
QUANT_FORMAT_MIXINS = {
|
||||
"float8_e4m3fn": {
|
||||
"dtype": torch.float8_e4m3fn,
|
||||
"layout_type": TensorCoreFP8Layout,
|
||||
"layout_type": "TensorCoreFP8Layout",
|
||||
"parameters": {
|
||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||
|
||||
Reference in New Issue
Block a user