diff --git a/docs/model_support.md b/docs/model_support.md index 3c67a89f5..a77a3c288 100644 --- a/docs/model_support.md +++ b/docs/model_support.md @@ -6,5 +6,13 @@ You can learn from existing model implementations and create new files for the n Another valuable resource is the vLLM model implementations. vLLM has extensive coverage of models, and SGLang has reused vLLM for most parts of the model implementations. This similarity makes it easy to port many models from vLLM to SGLang. -1. Compare these two files [SGLang LLaMA Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py) and [vLLM LLaMA Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of PagedAttention with RadixAttention. The other parts are almost identical. +1. Compare these two files [SGLang LLaMA Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py) and [vLLM LLaMA Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of PagedAttention with RadixAttention. The other parts are almost identical. Specifically, + - Replace `Attention` with `RadixAttention`. + - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. + - Remove `Sample`. + - Change `forward()` functions, and add `input_metadata`. + - Add `EntryClass` at the end. + - Test correctness by comparing the final logits and outputs of two following commands: + - `python3 playground/reference_hf.py --model [new model]` + - `python3 -m sglang.bench_latency --model [new model] --correct --output-len 16` 2. Convert models from vLLM to SGLang by visiting the [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models). diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index ca09028f4..14a2f1824 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -165,6 +165,7 @@ def decode(input_token_ids, batch, model_runner): return next_token_ids, output.next_token_logits +@torch.inference_mode() def correctness_test( server_args, bench_args, @@ -178,9 +179,10 @@ def correctness_test( # Prepare inputs input_ids, reqs = prepare_inputs(bench_args, tokenizer) - # Prefill - next_token_ids, next_token_logits, batch = extend(reqs, model_runner) - rank_print("prefill logits (first half)", next_token_logits) + if bench_args.cut_len > 0: + # Prefill + next_token_ids, next_token_logits, batch = extend(reqs, model_runner) + rank_print("prefill logits (first half)", next_token_logits) # Prepare extend inputs reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner) @@ -190,7 +192,7 @@ def correctness_test( rank_print("prefill logits (final)", next_token_logits) # Decode - output_ids = [list(req.input_ids) for req in reqs] + output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] for _ in range(bench_args.output_len): next_token_ids, _ = decode(next_token_ids, batch, model_runner) for i in range(len(reqs)): diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 5b1dba40f..41c2ca7d1 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -191,6 +191,7 @@ def extend_attention_fwd( b_seq_len_extend, max_len_in_batch, max_len_extend, + sm_scale=None, logit_cap=-1, ): """ @@ -213,7 +214,7 @@ def extend_attention_fwd( else: BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) - sm_scale = 1.0 / (Lq**0.5) + sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index ca932bef0..ce9415337 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -108,6 +108,11 @@ class LogitsProcessor(nn.Module): last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = last_logits[:, : self.config.vocab_size] + if hasattr(self.config, "final_logit_softcapping"): + last_logits /= self.config.final_logit_softcapping + last_logits = torch.tanh(last_logits) + last_logits *= self.config.final_logit_softcapping + # Return only last_logits if logprob is not requested if not input_metadata.return_logprob: return LogitProcessorOutput( diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index c46c11237..25493eef5 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -1,11 +1,9 @@ """Radix attention.""" - import numpy as np import torch from torch import nn from sglang.global_config import global_config -from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata @@ -21,10 +19,9 @@ class RadixAttention(nn.Module): self.tp_k_head_num = num_kv_heads self.tp_v_head_num = num_kv_heads self.head_dim = head_dim + self.scaling = scaling self.layer_id = layer_id - assert np.allclose(scaling, 1.0 / (head_dim**0.5)) - from sglang.srt.managers.controller.model_runner import global_server_args_dict if not global_server_args_dict.get("disable_flashinfer", False): @@ -32,29 +29,17 @@ class RadixAttention(nn.Module): self.extend_forward = self.prefill_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer # flashinfer now accepts float logit_cap argument - self.logit_cap = logit_cap if logit_cap > 0 else 0 + self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 else: self.prefill_forward = self.prefill_forward_triton self.extend_forward = self.extend_forward_triton self.decode_forward = self.decode_forward_triton - self.logit_cap = logit_cap + self.logit_cap = logit_cap if logit_cap is not None else 0 def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata): - o = torch.empty_like(q) - - context_attention_fwd( - q.view(-1, self.tp_q_head_num, self.head_dim), - k, - v, - o.view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.start_loc, - input_metadata.seq_lens, - input_metadata.max_seq_len, - self.logit_cap, - ) - self.store_kv_cache(k, v, input_metadata) - - return o + # In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend". + # See the extend_forward_xxx functions. + raise NotImplementedError() def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): o = torch.empty_like(q) @@ -75,7 +60,8 @@ class RadixAttention(nn.Module): input_metadata.extend_seq_lens, input_metadata.max_seq_len, input_metadata.max_extend_len, - self.logit_cap, + sm_scale=self.scaling, + logit_cap=self.logit_cap, ) return o @@ -96,7 +82,8 @@ class RadixAttention(nn.Module): input_metadata.max_seq_len, input_metadata.other_kv_index, input_metadata.total_num_tokens, - self.logit_cap, + sm_scale=self.scaling, + logit_cap=self.logit_cap, ) return o @@ -108,6 +95,8 @@ class RadixAttention(nn.Module): q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), v.contiguous().view(-1, self.tp_v_head_num, self.head_dim), + causal=True, + sm_scale=self.scaling, logits_soft_cap=self.logit_cap, ) @@ -118,6 +107,7 @@ class RadixAttention(nn.Module): q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.kv_data[self.layer_id], causal=False, + sm_scale=self.scaling, logits_soft_cap=self.logit_cap, ) @@ -135,6 +125,7 @@ class RadixAttention(nn.Module): o = input_metadata.flashinfer_decode_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.kv_data[self.layer_id], + sm_scale=self.scaling, logits_soft_cap=self.logit_cap, ) diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index 73b25aa85..f9d58ae27 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -176,6 +176,7 @@ def _token_att_m_fwd( B_Start_Loc, B_Seqlen, max_len_in_batch, + sm_scale, logit_cap, ): BLOCK = 32 @@ -183,7 +184,6 @@ def _token_att_m_fwd( Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk assert Lk in {16, 32, 64, 128, 256} - sm_scale = 1.0 / (Lk**0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -317,6 +317,7 @@ def token_attention_fwd( max_len_in_batch, other_kv_index, total_num_tokens, + sm_scale=None, logit_cap=-1, att_m=None, ): @@ -324,6 +325,7 @@ def token_attention_fwd( att_m = torch.empty( (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" ) + sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale _token_att_m_fwd( q, @@ -334,6 +336,7 @@ def token_attention_fwd( b_start_loc, b_seq_len, max_len_in_batch, + sm_scale, logit_cap, ) _token_softmax_reducev_fwd( diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py new file mode 100644 index 000000000..3a54d25c2 --- /dev/null +++ b/python/sglang/srt/models/gemma2.py @@ -0,0 +1,427 @@ +# Adapted from: +# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Gemma2Config + +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import GeluAndMul +# from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.controller.model_runner import InputMetadata + + +# FIXME: temporary solution, remove after next vllm release +from vllm.model_executor.custom_op import CustomOp +class GemmaRMSNorm(CustomOp): + """RMS normalization for Gemma. + + Two differences from the above RMSNorm: + 1. x * (1 + w) instead of x * w. + 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + x = x * (1.0 + self.weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. + return self.forward_native(x, residual) + + +# FIXME: temporary solution, remove after next vllm release +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +class GemmaRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 + inv_freq = 1.0 / (base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / + self.rotary_dim)) + return inv_freq + + +class Gemma2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): + raise ValueError( + "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma2Attention(nn.Module): + + def __init__(self, + layer_idx: int, + config: Gemma2Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + rope_theta: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + # from vLLM: TODO(woosuk): Use the `get_rope` interface. + self.rotary_emb = GemmaRotaryEmbedding( + self.head_dim, + self.head_dim, + max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), + ) + + # from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every + # odd layer, vLLM currently ignores it and uses global attention for + # all layers. + use_sliding_window = (layer_idx % 2 == 1 + and config.sliding_window is not None) + del use_sliding_window # Unused. + self.attn = RadixAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_idx, + logit_cap=self.config.attn_logit_softcapping) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Gemma2DecoderLayer(nn.Module): + + def __init__( + self, + layer_idx: int, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma2Attention( + layer_idx=layer_idx, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + cache_config=cache_config, + quant_config=quant_config, + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +class Gemma2Model(nn.Module): + + def __init__( + self, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=torch.float16) + hidden_states *= normalizer + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Gemma2ForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + del lora_config # Unused. + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Gemma2Model(config, cache_config, quant_config) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + return self.logits_processor( + input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + raise RuntimeError( + "Some weights are not initialized from checkpoints: " + f"{unloaded_params}") + + +EntryClass = Gemma2ForCausalLM \ No newline at end of file