Support Lumina 2 model.
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
@@ -21,15 +20,41 @@ class Llama2Config:
|
||||
max_position_embeddings: int = 8192
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
vocab_size: int = 256000
|
||||
hidden_size: int = 2304
|
||||
intermediate_size: int = 9216
|
||||
num_hidden_layers: int = 26
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 4
|
||||
max_position_embeddings: int = 8192
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
transformer_type: str = "gemma2"
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
self.add = add
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
w = self.weight
|
||||
if self.add:
|
||||
w = w + 1.0
|
||||
|
||||
return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@@ -68,13 +93,15 @@ class Attention(nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
self.head_dim = config.head_dim
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
|
||||
ops = ops or nn
|
||||
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -84,7 +111,6 @@ class Attention(nn.Module):
|
||||
optimized_attention=None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
xq = self.q_proj(hidden_states)
|
||||
xk = self.k_proj(hidden_states)
|
||||
xv = self.v_proj(hidden_states)
|
||||
@@ -108,9 +134,13 @@ class MLP(nn.Module):
|
||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
if config.mlp_activation == "silu":
|
||||
self.activation = torch.nn.functional.silu
|
||||
elif config.mlp_activation == "gelu_pytorch_tanh":
|
||||
self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh")
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
@@ -146,6 +176,45 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
class TransformerBlockGemma2(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
# MLP
|
||||
residual = x
|
||||
x = self.pre_feedforward_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = self.post_feedforward_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
@@ -158,17 +227,27 @@ class Llama2_(nn.Module):
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if self.config.transformer_type == "gemma2":
|
||||
transformer = TransformerBlockGemma2
|
||||
self.normalize_in = True
|
||||
else:
|
||||
transformer = TransformerBlock
|
||||
self.normalize_in = False
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
|
||||
transformer(config, device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
x.shape[1],
|
||||
self.config.rope_theta,
|
||||
device=x.device)
|
||||
@@ -206,16 +285,7 @@ class Llama2_(nn.Module):
|
||||
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Llama2(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Llama2Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class BaseLlama:
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
@@ -224,3 +294,23 @@ class Llama2(torch.nn.Module):
|
||||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
|
||||
class Llama2(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Llama2Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma2_2B_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
44
comfy/text_encoders/lumina2.py
Normal file
44
comfy/text_encoders/lumina2.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
|
||||
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False})
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
|
||||
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
||||
|
||||
|
||||
class Gemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return LuminaTEModel_
|
||||
@@ -1,21 +1,21 @@
|
||||
import torch
|
||||
|
||||
class SPieceTokenizer:
|
||||
add_eos = True
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path):
|
||||
return SPieceTokenizer(path)
|
||||
def from_pretrained(path, **kwargs):
|
||||
return SPieceTokenizer(path, **kwargs)
|
||||
|
||||
def __init__(self, tokenizer_path):
|
||||
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
|
||||
self.add_bos = add_bos
|
||||
self.add_eos = add_eos
|
||||
import sentencepiece
|
||||
if torch.is_tensor(tokenizer_path):
|
||||
tokenizer_path = tokenizer_path.numpy().tobytes()
|
||||
|
||||
if isinstance(tokenizer_path, bytes):
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||
else:
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
||||
|
||||
def get_vocab(self):
|
||||
out = {}
|
||||
|
||||
Reference in New Issue
Block a user