Fix LoRA Trainer bugs with FP8 models. (#9854)
* Fix adapter weight init * Fix fp8 model training * Avoid inference tensor
This commit is contained in:
@@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
return LoraDiff(
|
||||
|
||||
Reference in New Issue
Block a user