[Trainer] FP4, 8, 16 training by native dtype support and quant linear autograd function (#12681)

This commit is contained in:
Kohaku-Blueleaf
2026-03-17 09:31:50 +08:00
committed by GitHub
parent 7a16e8aa4e
commit 20561aa919
3 changed files with 150 additions and 23 deletions
+4
View File
@@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
return prev
def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):