Compat with latest VLLM 0.4.2 main + fork.number rename + Flashinfer 0.0.4 (#380)

Co-authored-by: ZX <zx@lbx.dev>
Co-authored-by: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com>
This commit is contained in:
Qubitium
2024-05-12 07:37:49 +08:00
committed by GitHub
parent a511a2d089
commit 33b242df30
20 changed files with 611 additions and 187 deletions

View File

@@ -1,27 +1,28 @@
# Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import torch
from torch import nn
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.weight_utils import (
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
@@ -39,17 +40,17 @@ class Qwen2MLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, linear_method=linear_method
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
)
if hidden_act != "silu":
raise ValueError(
@@ -75,7 +76,7 @@ class Qwen2Attention(nn.Module):
rope_theta: float = 1000000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 32768,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -106,13 +107,13 @@ class Qwen2Attention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -149,7 +150,7 @@ class Qwen2DecoderLayer(nn.Module):
self,
config: Qwen2Config,
layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -164,13 +165,13 @@ class Qwen2DecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
@@ -206,7 +207,7 @@ class Qwen2Model(nn.Module):
def __init__(
self,
config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
@@ -218,7 +219,7 @@ class Qwen2Model(nn.Module):
)
self.layers = nn.ModuleList(
[
Qwen2DecoderLayer(config, i, linear_method)
Qwen2DecoderLayer(config, i, quant_config=quant_config)
for i in range(config.num_hidden_layers)
]
)
@@ -252,12 +253,12 @@ class Qwen2ForCausalLM(nn.Module):
def __init__(
self,
config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method)
self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)