feat: use get_rope for gemma2 (#2954)
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import Iterable, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
@@ -48,19 +49,6 @@ def get_attention_sliding_window_size(config):
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class GemmaRotaryEmbedding(RotaryEmbedding):
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
|
||||
/ self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
|
||||
class Gemma2MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -143,14 +131,12 @@ class Gemma2Attention(nn.Module):
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# from vLLM: TODO(woosuk): Use the `get_rope` interface.
|
||||
self.rotary_emb = GemmaRotaryEmbedding(
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.head_dim,
|
||||
max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
is_neox_style=True,
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
|
||||
|
||||
Reference in New Issue
Block a user