Improve: Tiny fix Olmo2 (#3348)
This commit is contained in:
@@ -64,24 +64,24 @@ class Olmo2Attention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.total_num_heads = config.num_attention_heads
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
|
||||||
assert self.hidden_size % self.total_num_heads == 0
|
assert self.hidden_size % self.total_num_heads == 0
|
||||||
assert self.total_num_heads % tp_size == 0
|
assert self.total_num_heads % self.tp_size == 0
|
||||||
|
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
self.num_heads = self.total_num_heads // self.tp_size
|
||||||
self.total_num_kv_heads = self.config.num_key_value_heads
|
self.total_num_kv_heads = self.config.num_key_value_heads
|
||||||
|
|
||||||
if self.total_num_kv_heads >= tp_size:
|
if self.total_num_kv_heads >= self.tp_size:
|
||||||
# Number of KV heads is greater than TP size, so we partition
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert self.total_num_kv_heads % tp_size == 0
|
assert self.total_num_kv_heads % self.tp_size == 0
|
||||||
else:
|
else:
|
||||||
# Number of KV heads is less than TP size, so we replicate
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert tp_size % self.total_num_kv_heads == 0
|
assert self.tp_size % self.total_num_kv_heads == 0
|
||||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
||||||
|
|
||||||
self.head_dim = self.hidden_size // self.total_num_heads
|
self.head_dim = self.hidden_size // self.total_num_heads
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
@@ -343,7 +343,7 @@ class Olmo2ForCausalLM(nn.Module):
|
|||||||
input_embeds=input_embeds,
|
input_embeds=input_embeds,
|
||||||
)
|
)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
Reference in New Issue
Block a user