Fix LoRA Trainer bugs with FP8 models. (#9854)

* Fix adapter weight init

* Fix fp8 model training

* Avoid inference tensor
This commit is contained in:
Kohaku-Blueleaf
2025-09-21 09:24:48 +08:00
committed by GitHub
parent 9ed3c5cc09
commit 7be2b49b6b
6 changed files with 34 additions and 15 deletions

View File

@@ -365,12 +365,13 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
try:
out = fp8_linear(self, input)
if out is not None:
return out
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
if not self.training:
try:
out = fp8_linear(self, input)
if out is not None:
return out
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)