feat: use get_rope for gemma2 (#2954)
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import Iterable, Optional, Set, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.activation import GeluAndMul
|
from sglang.srt.layers.activation import GeluAndMul
|
||||||
@@ -48,19 +49,6 @@ def get_attention_sliding_window_size(config):
|
|||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
class GemmaRotaryEmbedding(RotaryEmbedding):
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
|
||||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
|
||||||
inv_freq = 1.0 / (
|
|
||||||
base
|
|
||||||
** (
|
|
||||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
|
|
||||||
/ self.rotary_dim
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return inv_freq
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma2MLP(nn.Module):
|
class Gemma2MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -143,14 +131,12 @@ class Gemma2Attention(nn.Module):
|
|||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
# from vLLM: TODO(woosuk): Use the `get_rope` interface.
|
self.rotary_emb = get_rope(
|
||||||
self.rotary_emb = GemmaRotaryEmbedding(
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
dtype=torch.get_default_dtype(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
|
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
|
||||||
|
|||||||
Reference in New Issue
Block a user