Fix torch compile regression on fp8 ops. (#10580)

This commit is contained in:
comfyanonymous
2025-10-31 21:25:17 -07:00
committed by GitHub
parent 7f374e42c8
commit c58c13b2ba
4 changed files with 43 additions and 36 deletions

View File

@@ -14,7 +14,7 @@ if not has_gpu():
args.cpu = True
from comfy import ops
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
from comfy.quant_ops import QuantizedTensor
class SimpleModel(torch.nn.Module):
@@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
# Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
# Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
@@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)