Fix small performance regression with fp8 fast and scaled fp8. (#10537)

This commit is contained in:
comfyanonymous
2025-10-29 16:29:01 -07:00
committed by GitHub
parent 25de7b1bfa
commit 906c089957
2 changed files with 8 additions and 3 deletions

View File

@@ -357,9 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
lp_amax = torch.finfo(dtype).max
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = {