Compat with latest VLLM 0.4.2 main + fork.number rename + Flashinfer 0.0.4 (#380)

Co-authored-by: ZX <zx@lbx.dev>
Co-authored-by: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com>
This commit is contained in:
Qubitium
2024-05-12 07:37:49 +08:00
committed by GitHub
parent a511a2d089
commit 33b242df30
20 changed files with 611 additions and 187 deletions

View File

@@ -20,7 +20,7 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
@@ -29,19 +29,20 @@ from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
@@ -92,7 +93,7 @@ class CohereMLP(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -102,13 +103,13 @@ class CohereMLP(nn.Module):
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.act_fn = SiluAndMul()
@@ -124,7 +125,7 @@ class CohereAttention(nn.Module):
self,
config: PretrainedConfig,
layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
@@ -159,13 +160,13 @@ class CohereAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -221,16 +222,16 @@ class CohereDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(
config, layer_id=layer_id, linear_method=linear_method
config, layer_id=layer_id, quant_config=quant_config
)
self.mlp = CohereMLP(config, linear_method=linear_method)
self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(
param_shape=(config.hidden_size), eps=config.layer_norm_eps
)
@@ -261,7 +262,7 @@ class CohereModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -271,7 +272,7 @@ class CohereModel(nn.Module):
)
self.layers = nn.ModuleList(
[
CohereDecoderLayer(config, i, linear_method=linear_method)
CohereDecoderLayer(config, i, quant_config=quant_config)
for i in range(config.num_hidden_layers)
]
)
@@ -303,13 +304,13 @@ class CohereForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config)
self.model = CohereModel(config, linear_method)
self.model = CohereModel(config, quant_config)
@torch.no_grad()
def forward(