Fix loras not working on mixed fp8. (#10899)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import logging
|
||||
from typing import Tuple, Dict
|
||||
import comfy.float
|
||||
|
||||
_LAYOUT_REGISTRY = {}
|
||||
_GENERIC_UTILS = {}
|
||||
@@ -393,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if scale is None:
|
||||
@@ -403,17 +404,23 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
||||
lp_amax = torch.finfo(dtype).max
|
||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||
else:
|
||||
lp_amax = torch.finfo(dtype).max
|
||||
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
|
||||
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': orig_dtype
|
||||
}
|
||||
return qdata, layout_params
|
||||
return tensor, layout_params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user