Make text generation work with ministral model. (#13395)
Needs template before it works properly.
This commit is contained in:
@@ -82,6 +82,7 @@ class Ministral3_3BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
stop_tokens = [2]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_3BConfig:
|
class Qwen25_3BConfig:
|
||||||
@@ -969,7 +970,7 @@ class Mistral3Small24B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Ministral3_3B(BaseLlama, torch.nn.Module):
|
class Ministral3_3B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Ministral3_3BConfig(**config_dict)
|
config = Ministral3_3BConfig(**config_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user