Support DP MLA (#1970)
This commit is contained in:
@@ -22,7 +22,9 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@@ -338,6 +340,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
use_dp=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -351,29 +354,80 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
if use_dp:
|
||||
# For data parallel attention
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ReplicatedLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.q_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.kv_b_proj = ReplicatedLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
# O projection.
|
||||
self.o_proj = ReplicatedLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
# For tensor parallel attention
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
@@ -385,19 +439,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
@@ -491,6 +532,36 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def all_gather(
|
||||
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
||||
):
|
||||
if world_size == 1:
|
||||
return input_tensor
|
||||
|
||||
all_lens = forward_batch.global_num_tokens
|
||||
max_len = max(forward_batch.global_num_tokens)
|
||||
|
||||
padded_tensor = torch.nn.functional.pad(
|
||||
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
||||
)
|
||||
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
forward_batch.gathered_buffer, padded_tensor, group=group
|
||||
)
|
||||
|
||||
gathered_tensors = torch.concat(
|
||||
[
|
||||
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
||||
for i in range(world_size)
|
||||
]
|
||||
)
|
||||
|
||||
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
||||
end_index = start_index + all_lens[rank]
|
||||
|
||||
return gathered_tensors, start_index, end_index
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -505,6 +576,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.enable_dp_attention = (
|
||||
not global_server_args_dict["disable_mla"]
|
||||
and global_server_args_dict["enable_dp_attention"]
|
||||
)
|
||||
if self.enable_dp_attention:
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group().device_group
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
config=config,
|
||||
@@ -523,6 +602,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
use_dp=self.enable_dp_attention,
|
||||
)
|
||||
else:
|
||||
self.self_attn = DeepseekV2Attention(
|
||||
@@ -569,20 +649,32 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if self.enable_dp_attention:
|
||||
hidden_states, start_idx, end_idx = all_gather(
|
||||
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = hidden_states[start_idx:end_idx]
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -603,6 +695,7 @@ class DeepseekV2Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
@@ -630,7 +723,8 @@ class DeepseekV2Model(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -646,10 +740,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekV2Model(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
self.lm_head = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -659,9 +761,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
Reference in New Issue
Block a user