Switch mochi and wan modes to use pytorch RMSNorm. (#7925)

* Switch genmo model to native RMSNorm.

* Switch WAN to native RMSNorm.
This commit is contained in:
comfyanonymous
2025-05-03 16:07:55 -07:00
committed by GitHub
parent 7689917113
commit 3041e5c354
3 changed files with 7 additions and 20 deletions

View File

@@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
x = self.norm(x)
return x
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
self.register_parameter("bias", None)
def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)