Compat with latest VLLM 0.4.2 main + fork.number rename + Flashinfer 0.0.4 (#380)
Co-authored-by: ZX <zx@lbx.dev> Co-authored-by: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com>
This commit is contained in:
@@ -9,20 +9,21 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
@@ -34,7 +35,7 @@ from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
|
||||
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -44,10 +45,10 @@ class StablelmMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
[config.intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
config.intermediate_size, config.hidden_size, bias=False
|
||||
config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config,
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
@@ -63,7 +64,7 @@ class StablelmAttention(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -105,13 +106,11 @@ class StablelmAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_key_value_heads,
|
||||
self.qkv_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -146,11 +145,11 @@ class StablelmDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = StablelmAttention(config, layer_id=layer_id)
|
||||
self.mlp = StablelmMLP(config, linear_method)
|
||||
self.mlp = StablelmMLP(config, quant_config=quant_config)
|
||||
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||
@@ -182,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
|
||||
|
||||
class StableLMEpochModel(nn.Module):
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
|
||||
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@@ -191,7 +190,7 @@ class StableLMEpochModel(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
StablelmDecoderLayer(config, i, linear_method)
|
||||
StablelmDecoderLayer(config, i, quant_config=quant_config)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -224,12 +223,12 @@ class StableLmForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = StableLMEpochModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user