feat: Add basic text generation support with native models, initially supporting Gemma3 (#12392)
This commit is contained in:
@@ -3,6 +3,8 @@ import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any, Tuple
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
import comfy.utils
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
@@ -313,6 +315,13 @@ class Gemma3_4B_Config:
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
|
||||
vision_config = GEMMA3_VISION_CONFIG
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
@dataclass
|
||||
class Gemma3_12B_Config:
|
||||
vocab_size: int = 262208
|
||||
@@ -336,7 +345,7 @@ class Gemma3_12B_Config:
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
vision_config = GEMMA3_VISION_CONFIG
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@@ -441,8 +450,10 @@ class Attention(nn.Module):
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
xq = self.q_proj(hidden_states)
|
||||
xk = self.k_proj(hidden_states)
|
||||
xv = self.v_proj(hidden_states)
|
||||
@@ -477,6 +488,11 @@ class Attention(nn.Module):
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
||||
xk = xk[:, :, -sliding_window:]
|
||||
xv = xv[:, :, -sliding_window:]
|
||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
@@ -559,10 +575,12 @@ class TransformerBlockGemma2(nn.Module):
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
sliding_window = None
|
||||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
sliding_window = self.sliding_attention
|
||||
if x.shape[1] > self.sliding_attention:
|
||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
|
||||
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask + sliding_mask
|
||||
@@ -581,6 +599,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
@@ -765,6 +784,104 @@ class BaseLlama:
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
class BaseGenerate:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
if hasattr(self.model, "lm_head"):
|
||||
module = self.model.lm_head
|
||||
else:
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
if module.comfy_cast_weights:
|
||||
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
||||
else:
|
||||
weight = self.model.embed_tokens.weight.to(x)
|
||||
|
||||
x = torch.nn.functional.linear(input, weight, None)
|
||||
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
|
||||
device = embeds.device
|
||||
model_config = self.model.config
|
||||
|
||||
if execution_dtype is None:
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
execution_dtype = torch.bfloat16
|
||||
else:
|
||||
execution_dtype = torch.float32
|
||||
embeds = embeds.to(execution_dtype)
|
||||
|
||||
if embeds.ndim == 2:
|
||||
embeds = embeds.unsqueeze(0)
|
||||
|
||||
past_key_values = [] #kv_cache init
|
||||
max_cache_len = embeds.shape[1] + max_length
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||
|
||||
generated_token_ids = []
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# Generation loop
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
break
|
||||
|
||||
return generated_token_ids
|
||||
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
||||
|
||||
if not do_sample or temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
|
||||
# Sampling mode
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
if top_k > 0:
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
if min_p > 0.0:
|
||||
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
|
||||
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
|
||||
min_threshold = min_p * top_probs
|
||||
indices_to_remove = probs_before_filter < min_threshold
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 0] = False
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
return torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
@@ -871,7 +988,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen25_7BVLI_Config(**config_dict)
|
||||
@@ -881,6 +998,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
# todo: should this be tied or not?
|
||||
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
||||
@@ -923,7 +1043,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma3_4B(BaseLlama, torch.nn.Module):
|
||||
class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_4B_Config(**config_dict)
|
||||
@@ -932,7 +1052,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma3_12B(BaseLlama, torch.nn.Module):
|
||||
class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_4B_Vision_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
|
||||
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
|
||||
self.image_size = config.vision_config["image_size"]
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
|
||||
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
|
||||
return None, None
|
||||
|
||||
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_12B_Config(**config_dict)
|
||||
|
||||
Reference in New Issue
Block a user