diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index f5da6dcb3..fb808abed 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -70,6 +70,7 @@ class LlamaMLP(nn.Module): class LlamaAttention(nn.Module): def __init__( self, + config: LlamaConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -96,7 +97,10 @@ class LlamaAttention(nn.Module): # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -168,6 +172,7 @@ class LlamaDecoderLayer(nn.Module): rope_is_neox_style = getattr(config, "rope_is_neox_style", True) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads,