diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 40bc5f8f4..ed933a190 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -35,6 +35,7 @@ class LlamaMLP(nn.Module): intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -42,12 +43,14 @@ class LlamaMLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": raise ValueError( @@ -76,6 +79,7 @@ class LlamaAttention(nn.Module): rope_is_neox_style: bool = True, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -110,12 +114,14 @@ class LlamaAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( @@ -154,6 +160,7 @@ class LlamaDecoderLayer(nn.Module): config: LlamaConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -178,12 +185,14 @@ class LlamaDecoderLayer(nn.Module): rope_is_neox_style=rope_is_neox_style, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -231,7 +240,9 @@ class LlamaModel(nn.Module): ) self.layers = nn.ModuleList( [ - LlamaDecoderLayer(config, i, quant_config=quant_config) + LlamaDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) for i in range(config.num_hidden_layers) ] )