[CPU] support the case where num_attention_heads or intermediate_size is not divisible by the TP size (#6771)

This commit is contained in:
Chunyuan WU
2025-07-04 00:51:38 +08:00
committed by GitHub
parent 9fcc9a80e7
commit 1dce6c480f
11 changed files with 399 additions and 40 deletions

View File

@@ -100,6 +100,7 @@ class Qwen2Attention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: Optional[int] = None,
layer_id: int = 0,
rope_theta: float = 1000000,
rope_scaling: Optional[Dict[str, Any]] = None,
@@ -123,7 +124,10 @@ class Qwen2Attention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
if head_dim is not None:
self.head_dim = head_dim
else:
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@@ -191,10 +195,12 @@ class Qwen2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
head_dim = getattr(config, "head_dim", None)
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=head_dim,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,