perf: use multicast to avoid padding decode request to prefill size (#1555)
### What this PR does / why we need it?
perf: use multicast to avoid padding decode request to prefill size
### How was this patch tested?
- vLLM version: v0.9.1
- vLLM main:
1fd471e957
Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -1048,7 +1048,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
shared_experts=shared_experts)
|
||||
elif fused_moe_state == FusedMoEState.AllGather:
|
||||
elif fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -1225,6 +1227,22 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
return buffer
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -1250,9 +1268,10 @@ class AscendFusedMoE(FusedMoE):
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP
|
||||
and not replace_allreduce):
|
||||
if (tp_size > 1 and fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce):
|
||||
if num_tokens < tp_size:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
||||
@@ -1267,21 +1286,31 @@ class AscendFusedMoE(FusedMoE):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
hidden_states = chunk_hidden_states[tp_rank]
|
||||
router_logits = chunk_router_logits[tp_rank]
|
||||
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
|
||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||
if not self.torchair_graph_enabled or is_prefill:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None:
|
||||
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
||||
if num_tokens < max_num_tokens_across_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
if self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.AllGather:
|
||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||
if not self.torchair_graph_enabled:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None:
|
||||
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
||||
if num_tokens < max_num_tokens_across_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0,
|
||||
max_num_tokens_across_dp - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0,
|
||||
max_num_tokens_across_dp - num_tokens))
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_dp_cpu)
|
||||
|
||||
# Matrix multiply.
|
||||
e_hidden_states = self.quant_method.apply(
|
||||
@@ -1310,28 +1339,40 @@ class AscendFusedMoE(FusedMoE):
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP
|
||||
and not replace_allreduce):
|
||||
if (tp_size > 1 and fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce):
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_tokens < tp_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
|
||||
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
e_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
final_hidden_states = get_dp_group().all_reduce(
|
||||
e_hidden_states)
|
||||
final_hidden_states = final_hidden_states[start:end, :]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif fused_moe_state == FusedMoEState.AllGather:
|
||||
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
e_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
|
||||
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
|
||||
or fused_moe_state == FusedMoEState.AllGatherEP):
|
||||
if tp_size > 1 and fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
|
||||
@@ -780,7 +780,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts)
|
||||
elif fused_moe_state == FusedMoEState.AllGather:
|
||||
elif fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
|
||||
@@ -419,6 +419,7 @@ class FusedMoEState(Enum):
|
||||
All2All = 1
|
||||
MC2 = 2
|
||||
AllGatherEP = 3
|
||||
NaiveMulticast = 4
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||
@@ -430,7 +431,10 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
and is_deepseek_v3_r1):
|
||||
return FusedMoEState.AllGatherEP
|
||||
elif ep_size == 1:
|
||||
return FusedMoEState.AllGather
|
||||
if with_prefill:
|
||||
return FusedMoEState.NaiveMulticast
|
||||
else:
|
||||
return FusedMoEState.AllGather
|
||||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
||||
elif ep_size < 16 or with_prefill:
|
||||
return FusedMoEState.All2All
|
||||
|
||||
Reference in New Issue
Block a user