Fix LoRA Trainer bugs with FP8 models. (#9854)

* Fix adapter weight init

* Fix fp8 model training

* Avoid inference tensor
This commit is contained in:
Kohaku-Blueleaf
2025-09-21 09:24:48 +08:00
committed by GitHub
parent 9ed3c5cc09
commit 7be2b49b6b
6 changed files with 34 additions and 15 deletions

View File

@@ -68,7 +68,7 @@ class OFTAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
return OFTDiff(
(block, None, alpha, None)
)