diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index c06637962..e1688df01 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -100,6 +100,7 @@ class LlamaAttention(nn.Module): max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + bias: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -132,14 +133,14 @@ class LlamaAttention(nn.Module): self.head_dim, self.total_num_heads, self.total_num_kv_heads, - bias=False, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, - bias=False, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) @@ -194,6 +195,11 @@ class LlamaDecoderLayer(nn.Module): ) rope_is_neox_style = getattr(config, "rope_is_neox_style", True) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + # Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, @@ -206,6 +212,7 @@ class LlamaDecoderLayer(nn.Module): max_position_embeddings=max_position_embeddings, quant_config=quant_config, prefix=f"{prefix}.self_attn", + bias=attention_bias, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size,