perf: use multicast to avoid padding decode request to prefill size (#1555)
### What this PR does / why we need it?
perf: use multicast to avoid padding decode request to prefill size
### How was this patch tested?
- vLLM version: v0.9.1
- vLLM main:
1fd471e957
Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -1048,7 +1048,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||||
shared_experts=shared_experts)
|
shared_experts=shared_experts)
|
||||||
elif fused_moe_state == FusedMoEState.AllGather:
|
elif fused_moe_state in [
|
||||||
|
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||||
|
]:
|
||||||
return fused_experts(hidden_states=x,
|
return fused_experts(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
@@ -1225,6 +1227,22 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
self.tp_group = get_tp_group().device_group
|
self.tp_group = get_tp_group().device_group
|
||||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||||
|
|
||||||
|
def naive_multicast(self, x: torch.Tensor,
|
||||||
|
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||||
|
assert (len(x.shape) == 2)
|
||||||
|
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype)
|
||||||
|
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||||
|
self.dp_rank - 1]
|
||||||
|
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||||
|
buffer[start:end, :].copy_(x)
|
||||||
|
for idx in range(self.dp_size):
|
||||||
|
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||||
|
end = cu_tokens_across_dp_cpu[idx]
|
||||||
|
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||||
|
return buffer
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
@@ -1250,9 +1268,10 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
shared_hidden_states = shared_experts(hidden_states)
|
shared_hidden_states = shared_experts(hidden_states)
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
if (tp_size > 1 and fused_moe_state not in [
|
||||||
and fused_moe_state != FusedMoEState.AllGatherEP
|
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||||
and not replace_allreduce):
|
FusedMoEState.NaiveMulticast
|
||||||
|
] and not replace_allreduce):
|
||||||
if num_tokens < tp_size:
|
if num_tokens < tp_size:
|
||||||
hidden_states = nn.functional.pad(
|
hidden_states = nn.functional.pad(
|
||||||
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
||||||
@@ -1267,21 +1286,31 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
hidden_states = chunk_hidden_states[tp_rank]
|
hidden_states = chunk_hidden_states[tp_rank]
|
||||||
router_logits = chunk_router_logits[tp_rank]
|
router_logits = chunk_router_logits[tp_rank]
|
||||||
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
|
if self.dp_size > 1:
|
||||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
if fused_moe_state == FusedMoEState.AllGather:
|
||||||
if not self.torchair_graph_enabled or is_prefill:
|
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
if not self.torchair_graph_enabled:
|
||||||
if attn_metadata is not None:
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
if attn_metadata is not None:
|
||||||
if num_tokens < max_num_tokens_across_dp:
|
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
||||||
hidden_states = nn.functional.pad(
|
if num_tokens < max_num_tokens_across_dp:
|
||||||
hidden_states,
|
hidden_states = nn.functional.pad(
|
||||||
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
|
hidden_states,
|
||||||
router_logits = nn.functional.pad(
|
(0, 0, 0,
|
||||||
router_logits,
|
max_num_tokens_across_dp - num_tokens))
|
||||||
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
|
router_logits = nn.functional.pad(
|
||||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
router_logits,
|
||||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
(0, 0, 0,
|
||||||
|
max_num_tokens_across_dp - num_tokens))
|
||||||
|
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||||
|
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||||
|
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||||
|
cu_tokens_across_dp_cpu = get_forward_context(
|
||||||
|
).dp_metadata.cu_tokens_across_dp_cpu
|
||||||
|
hidden_states = self.naive_multicast(hidden_states,
|
||||||
|
cu_tokens_across_dp_cpu)
|
||||||
|
router_logits = self.naive_multicast(router_logits,
|
||||||
|
cu_tokens_across_dp_cpu)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
e_hidden_states = self.quant_method.apply(
|
e_hidden_states = self.quant_method.apply(
|
||||||
@@ -1310,28 +1339,40 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if isinstance(e_hidden_states, tuple):
|
if isinstance(e_hidden_states, tuple):
|
||||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||||
|
|
||||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
if (tp_size > 1 and fused_moe_state not in [
|
||||||
and fused_moe_state != FusedMoEState.AllGatherEP
|
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||||
and not replace_allreduce):
|
FusedMoEState.NaiveMulticast
|
||||||
|
] and not replace_allreduce):
|
||||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||||
self.tp_group)
|
self.tp_group)
|
||||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||||
if num_tokens < tp_size:
|
if num_tokens < tp_size:
|
||||||
final_hidden_states = final_hidden_states[:num_tokens]
|
final_hidden_states = final_hidden_states[:num_tokens]
|
||||||
dispose_tensor(e_hidden_states)
|
dispose_tensor(e_hidden_states)
|
||||||
elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
|
elif self.dp_size > 1:
|
||||||
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||||
e_hidden_states,
|
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||||
"sum",
|
self.dp_rank - 1]
|
||||||
scatter_dim=0,
|
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||||
group=get_dp_group().device_group)
|
final_hidden_states = get_dp_group().all_reduce(
|
||||||
final_hidden_states = final_hidden_states[:num_tokens]
|
e_hidden_states)
|
||||||
dispose_tensor(e_hidden_states)
|
final_hidden_states = final_hidden_states[start:end, :]
|
||||||
|
dispose_tensor(e_hidden_states)
|
||||||
|
elif fused_moe_state == FusedMoEState.AllGather:
|
||||||
|
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||||
|
e_hidden_states,
|
||||||
|
"sum",
|
||||||
|
scatter_dim=0,
|
||||||
|
group=get_dp_group().device_group)
|
||||||
|
final_hidden_states = final_hidden_states[:num_tokens]
|
||||||
|
dispose_tensor(e_hidden_states)
|
||||||
else:
|
else:
|
||||||
final_hidden_states = e_hidden_states
|
final_hidden_states = e_hidden_states
|
||||||
|
|
||||||
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
|
if tp_size > 1 and fused_moe_state in [
|
||||||
or fused_moe_state == FusedMoEState.AllGatherEP):
|
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||||
|
FusedMoEState.NaiveMulticast
|
||||||
|
]:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
|
|||||||
@@ -780,7 +780,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
log2phy=log2phy,
|
log2phy=log2phy,
|
||||||
global_redundant_expert_num=global_redundant_expert_num,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
shared_experts=shared_experts)
|
shared_experts=shared_experts)
|
||||||
elif fused_moe_state == FusedMoEState.AllGather:
|
elif fused_moe_state in [
|
||||||
|
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||||
|
]:
|
||||||
return fused_experts(hidden_states=x,
|
return fused_experts(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
|||||||
@@ -419,6 +419,7 @@ class FusedMoEState(Enum):
|
|||||||
All2All = 1
|
All2All = 1
|
||||||
MC2 = 2
|
MC2 = 2
|
||||||
AllGatherEP = 3
|
AllGatherEP = 3
|
||||||
|
NaiveMulticast = 4
|
||||||
|
|
||||||
|
|
||||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||||
@@ -430,7 +431,10 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
|||||||
and is_deepseek_v3_r1):
|
and is_deepseek_v3_r1):
|
||||||
return FusedMoEState.AllGatherEP
|
return FusedMoEState.AllGatherEP
|
||||||
elif ep_size == 1:
|
elif ep_size == 1:
|
||||||
return FusedMoEState.AllGather
|
if with_prefill:
|
||||||
|
return FusedMoEState.NaiveMulticast
|
||||||
|
else:
|
||||||
|
return FusedMoEState.AllGather
|
||||||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
||||||
elif ep_size < 16 or with_prefill:
|
elif ep_size < 16 or with_prefill:
|
||||||
return FusedMoEState.All2All
|
return FusedMoEState.All2All
|
||||||
|
|||||||
Reference in New Issue
Block a user