perf: use multicast to avoid padding decode request to prefill size (#1555)

### What this PR does / why we need it?
perf: use multicast to avoid padding decode request to prefill size

### How was this patch tested?

- vLLM version: v0.9.1
- vLLM main:
1fd471e957

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-07-07 22:36:03 +08:00
committed by GitHub
parent f08c4f15a2
commit df84cceca8
3 changed files with 81 additions and 34 deletions

View File

@@ -1048,7 +1048,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts)
elif fused_moe_state == FusedMoEState.AllGather:
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -1225,6 +1227,22 @@ class AscendFusedMoE(FusedMoE):
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
buffer[start:end, :].copy_(x)
for idx in range(self.dp_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -1250,9 +1268,10 @@ class AscendFusedMoE(FusedMoE):
shared_hidden_states = shared_experts(hidden_states)
tp_size = get_tensor_model_parallel_world_size()
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP
and not replace_allreduce):
if (tp_size > 1 and fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
if num_tokens < tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
@@ -1267,21 +1286,31 @@ class AscendFusedMoE(FusedMoE):
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
if not self.torchair_graph_enabled or is_prefill:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
if num_tokens < max_num_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
if not self.torchair_graph_enabled:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
if num_tokens < max_num_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
elif fused_moe_state == FusedMoEState.NaiveMulticast:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
# Matrix multiply.
e_hidden_states = self.quant_method.apply(
@@ -1310,28 +1339,40 @@ class AscendFusedMoE(FusedMoE):
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP
and not replace_allreduce):
if (tp_size > 1 and fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_tokens < tp_size:
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
e_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
elif self.dp_size > 1:
if fused_moe_state == FusedMoEState.NaiveMulticast:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
final_hidden_states = get_dp_group().all_reduce(
e_hidden_states)
final_hidden_states = final_hidden_states[start:end, :]
dispose_tensor(e_hidden_states)
elif fused_moe_state == FusedMoEState.AllGather:
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
e_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
else:
final_hidden_states = e_hidden_states
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
or fused_moe_state == FusedMoEState.AllGatherEP):
if tp_size > 1 and fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
]:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

View File

@@ -780,7 +780,9 @@ class AscendW8A8DynamicFusedMoEMethod:
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts)
elif fused_moe_state == FusedMoEState.AllGather:
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,

View File

@@ -419,6 +419,7 @@ class FusedMoEState(Enum):
All2All = 1
MC2 = 2
AllGatherEP = 3
NaiveMulticast = 4
# TODO(zzzzwwjj): add soc_version to choose branch
@@ -430,7 +431,10 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool,
and is_deepseek_v3_r1):
return FusedMoEState.AllGatherEP
elif ep_size == 1:
return FusedMoEState.AllGather
if with_prefill:
return FusedMoEState.NaiveMulticast
else:
return FusedMoEState.AllGather
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
elif ep_size < 16 or with_prefill:
return FusedMoEState.All2All