Add Gemma2 (#592)
This commit is contained in:
@@ -165,6 +165,7 @@ def decode(input_token_ids, batch, model_runner):
|
||||
return next_token_ids, output.next_token_logits
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def correctness_test(
|
||||
server_args,
|
||||
bench_args,
|
||||
@@ -178,9 +179,10 @@ def correctness_test(
|
||||
# Prepare inputs
|
||||
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
|
||||
|
||||
# Prefill
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print("prefill logits (first half)", next_token_logits)
|
||||
if bench_args.cut_len > 0:
|
||||
# Prefill
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print("prefill logits (first half)", next_token_logits)
|
||||
|
||||
# Prepare extend inputs
|
||||
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
|
||||
@@ -190,7 +192,7 @@ def correctness_test(
|
||||
rank_print("prefill logits (final)", next_token_logits)
|
||||
|
||||
# Decode
|
||||
output_ids = [list(req.input_ids) for req in reqs]
|
||||
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
||||
for _ in range(bench_args.output_len):
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
for i in range(len(reqs)):
|
||||
|
||||
@@ -191,6 +191,7 @@ def extend_attention_fwd(
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
sm_scale=None,
|
||||
logit_cap=-1,
|
||||
):
|
||||
"""
|
||||
@@ -213,7 +214,7 @@ def extend_attention_fwd(
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
||||
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||
|
||||
|
||||
@@ -108,6 +108,11 @@ class LogitsProcessor(nn.Module):
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
|
||||
if hasattr(self.config, "final_logit_softcapping"):
|
||||
last_logits /= self.config.final_logit_softcapping
|
||||
last_logits = torch.tanh(last_logits)
|
||||
last_logits *= self.config.final_logit_softcapping
|
||||
|
||||
# Return only last_logits if logprob is not requested
|
||||
if not input_metadata.return_logprob:
|
||||
return LogitProcessorOutput(
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
"""Radix attention."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||
@@ -21,10 +19,9 @@ class RadixAttention(nn.Module):
|
||||
self.tp_k_head_num = num_kv_heads
|
||||
self.tp_v_head_num = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
|
||||
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
||||
|
||||
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
||||
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
@@ -32,29 +29,17 @@ class RadixAttention(nn.Module):
|
||||
self.extend_forward = self.prefill_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
# flashinfer now accepts float logit_cap argument
|
||||
self.logit_cap = logit_cap if logit_cap > 0 else 0
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
else:
|
||||
self.prefill_forward = self.prefill_forward_triton
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
self.logit_cap = logit_cap
|
||||
self.logit_cap = logit_cap if logit_cap is not None else 0
|
||||
|
||||
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
|
||||
context_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k,
|
||||
v,
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
self.logit_cap,
|
||||
)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
return o
|
||||
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
|
||||
# See the extend_forward_xxx functions.
|
||||
raise NotImplementedError()
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
@@ -75,7 +60,8 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.max_extend_len,
|
||||
self.logit_cap,
|
||||
sm_scale=self.scaling,
|
||||
logit_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
@@ -96,7 +82,8 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.other_kv_index,
|
||||
input_metadata.total_num_tokens,
|
||||
self.logit_cap,
|
||||
sm_scale=self.scaling,
|
||||
logit_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
@@ -108,6 +95,8 @@ class RadixAttention(nn.Module):
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
@@ -118,6 +107,7 @@ class RadixAttention(nn.Module):
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
causal=False,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
@@ -135,6 +125,7 @@ class RadixAttention(nn.Module):
|
||||
o = input_metadata.flashinfer_decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
|
||||
@@ -176,6 +176,7 @@ def _token_att_m_fwd(
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
@@ -183,7 +184,6 @@ def _token_att_m_fwd(
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
assert Lq == Lk
|
||||
assert Lk in {16, 32, 64, 128, 256}
|
||||
sm_scale = 1.0 / (Lk**0.5)
|
||||
|
||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||
|
||||
@@ -317,6 +317,7 @@ def token_attention_fwd(
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
total_num_tokens,
|
||||
sm_scale=None,
|
||||
logit_cap=-1,
|
||||
att_m=None,
|
||||
):
|
||||
@@ -324,6 +325,7 @@ def token_attention_fwd(
|
||||
att_m = torch.empty(
|
||||
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
||||
)
|
||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
||||
|
||||
_token_att_m_fwd(
|
||||
q,
|
||||
@@ -334,6 +336,7 @@ def token_attention_fwd(
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_token_softmax_reducev_fwd(
|
||||
|
||||
427
python/sglang/srt/models/gemma2.py
Normal file
427
python/sglang/srt/models/gemma2.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Gemma2Config
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
# FIXME: temporary solution, remove after next vllm release
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
"""RMS normalization for Gemma.
|
||||
|
||||
Two differences from the above RMSNorm:
|
||||
1. x * (1 + w) instead of x * w.
|
||||
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
residual = x
|
||||
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
x = x * (1.0 + self.weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
# FIXME: temporary solution, remove after next vllm release
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
class GemmaRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
||||
inv_freq = 1.0 / (base**(
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
|
||||
self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
|
||||
class Gemma2MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
hidden_activation: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
||||
raise ValueError(
|
||||
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
|
||||
"function. Please set `hidden_act` and `hidden_activation` to "
|
||||
"`gelu_pytorch_tanh`.")
|
||||
self.act_fn = GeluAndMul(approximate="tanh")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Gemma2Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
layer_idx: int,
|
||||
config: Gemma2Config,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.config = config
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# from vLLM: TODO(woosuk): Use the `get_rope` interface.
|
||||
self.rotary_emb = GemmaRotaryEmbedding(
|
||||
self.head_dim,
|
||||
self.head_dim,
|
||||
max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
is_neox_style=True,
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
||||
# odd layer, vLLM currently ignores it and uses global attention for
|
||||
# all layers.
|
||||
use_sliding_window = (layer_idx % 2 == 1
|
||||
and config.sliding_window is not None)
|
||||
del use_sliding_window # Unused.
|
||||
self.attn = RadixAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_idx,
|
||||
logit_cap=self.config.attn_logit_softcapping)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Gemma2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
config: Gemma2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Gemma2Attention(
|
||||
layer_idx=layer_idx,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
rope_theta=config.rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.mlp = Gemma2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
hidden_activation=config.hidden_activation,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states, residual = self.pre_feedforward_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Gemma2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Normalize the embedding by sqrt(hidden_size)
|
||||
# The normalizer's data type should be downcasted to the model's
|
||||
# data type such as bfloat16, not float32.
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = self.config.hidden_size**0.5
|
||||
self.register_buffer("normalizer", torch.tensor(normalizer))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=torch.float16)
|
||||
hidden_states *= normalizer
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Gemma2ForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
# Gemma does not apply LoRA to the embedding layer.
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Gemma2Model(config, cache_config, quant_config)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# lm_head is not used in vllm as it is tied with embed_token.
|
||||
# To prevent errors, skip loading lm_head.weight.
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
raise RuntimeError(
|
||||
"Some weights are not initialized from checkpoints: "
|
||||
f"{unloaded_params}")
|
||||
|
||||
|
||||
EntryClass = Gemma2ForCausalLM
|
||||
Reference in New Issue
Block a user