Fix torch compile regression on fp8 ops. (#10580)
This commit is contained in:
@@ -14,7 +14,7 @@ if not has_gpu():
|
||||
args.cpu = True
|
||||
|
||||
from comfy import ops
|
||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
@@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
# Verify weights are wrapped in QuantizedTensor
|
||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
||||
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
# Layer 2 should NOT be quantized
|
||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||
|
||||
# Layer 3 should be quantized
|
||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
||||
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
# Verify scales were loaded
|
||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||
@@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
# Verify non-quantized layers are standard tensors
|
||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||
|
||||
@@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase):
|
||||
scale = torch.tensor(2.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.shape, (256, 128))
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt._layout_params['scale'], scale)
|
||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
||||
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test explicit dequantization"""
|
||||
@@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
||||
scale = torch.tensor(3.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
dequantized = qt.dequantize()
|
||||
|
||||
self.assertEqual(dequantized.dtype, torch.float32)
|
||||
@@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
||||
|
||||
qt = QuantizedTensor.from_float(
|
||||
float_tensor,
|
||||
TensorCoreFP8Layout,
|
||||
"TensorCoreFP8Layout",
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
@@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase):
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
# Detach should return a new QuantizedTensor
|
||||
qt_detached = qt.detach()
|
||||
|
||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||
self.assertEqual(qt_detached.shape, qt.shape)
|
||||
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
||||
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
def test_clone(self):
|
||||
"""Test clone operation on quantized tensor"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
# Clone should return a new QuantizedTensor
|
||||
qt_cloned = qt.clone()
|
||||
|
||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
||||
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
# Verify it's a deep copy
|
||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||
@@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase):
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
# Moving to same device should work (CPU to CPU)
|
||||
qt_cpu = qt.to('cpu')
|
||||
@@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase):
|
||||
scale = torch.tensor(1.0)
|
||||
a_q = QuantizedTensor.from_float(
|
||||
a_fp32,
|
||||
TensorCoreFP8Layout,
|
||||
"TensorCoreFP8Layout",
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user