From d37f95511d6694bcc4a2c82b1a486d9b8138d312 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 22 Feb 2025 08:09:35 +0800 Subject: [PATCH] Improve: Tiny fix Olmo2 (#3348) --- python/sglang/srt/models/olmo2.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py index f3e1979f8..a8af7bc1a 100644 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -64,24 +64,24 @@ class Olmo2Attention(nn.Module): super().__init__() self.config = config self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() + self.tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.hidden_size % self.total_num_heads == 0 - assert self.total_num_heads % tp_size == 0 + assert self.total_num_heads % self.tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + self.num_heads = self.total_num_heads // self.tp_size self.total_num_kv_heads = self.config.num_key_value_heads - if self.total_num_kv_heads >= tp_size: + if self.total_num_kv_heads >= self.tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % self.tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings @@ -343,7 +343,7 @@ class Olmo2ForCausalLM(nn.Module): input_embeds=input_embeds, ) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):