From 78e5b22f29e756667fb60a98c67dc142d3fe95e3 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 18 Jan 2025 02:57:18 +0800 Subject: [PATCH] feat: use get_rope for gemma2 (#2954) --- python/sglang/srt/models/gemma2.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index af51ba41b..ee0a762aa 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -20,6 +20,7 @@ from typing import Iterable, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import GeluAndMul @@ -48,19 +49,6 @@ def get_attention_sliding_window_size(config): from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -class GemmaRotaryEmbedding(RotaryEmbedding): - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 - inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() - / self.rotary_dim - ) - ) - return inv_freq - - class Gemma2MLP(nn.Module): def __init__( self, @@ -143,14 +131,12 @@ class Gemma2Attention(nn.Module): bias=config.attention_bias, quant_config=quant_config, ) - # from vLLM: TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, - self.head_dim, - max_position_embeddings, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")