[perf] replace all_reduce for kv_consumer and support different num_tokens among all ranks (#4983)
pick from https://github.com/vllm-project/vllm-ascend/pull/4736 to fix the merge conflict ### What this PR does / why we need it? Currently, the all_reduce operation in _sync_metadata_across_dp is performed with gloo backend which is extremely time-consuming when DPEngineCores are in different nodes. This operation cannot be ignored by async scheduling in multi-node-scenarios with speculative decoding (e.g., EAGLE, mtp). This pr eliminates the all_reduce operation for D Nodes and change the input parameter of MoEDispatch & MoeCombine operators to make MC2EP support different num_tokens across all ranks. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tested with PD disaggregation (2P: DP2TP8EP16 1D: DP8TP4EP32) scenarios while enabling async scheduling. This pr can remove cross-node all_reduce with gloo backend and further reduce latency with correct accuracy. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -70,7 +70,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.assertFalse(self.dispatcher.with_quant)
|
||||
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
||||
self.assertTrue(self.dispatcher.need_extra_args)
|
||||
self.assertTrue(self.dispatcher.a3_need_extra_args)
|
||||
|
||||
def test_get_dispatch_mc2_kwargs_without_quant(self):
|
||||
hidden_states = torch.randn(10, 128)
|
||||
|
||||
@@ -116,7 +116,11 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
random_matrix = torch.rand(topk_ids.size(0),
|
||||
global_num_experts,
|
||||
device=topk_ids.device)
|
||||
topk_ids = torch.argsort(
|
||||
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
@@ -100,15 +101,31 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.need_extra_args = (
|
||||
get_ascend_device_type() == AscendDeviceType._910_93)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
self.a3_need_extra_args = \
|
||||
get_ascend_device_type() == AscendDeviceType._910_93
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||
self.with_quant = False
|
||||
|
||||
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
|
||||
# dispatch & combine operators with different input num_tokens per rank.
|
||||
vllm_config = get_current_vllm_config()
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
uniform_decode_query_len = 1 if not speculative_config else \
|
||||
1 + speculative_config.num_speculative_tokens
|
||||
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
||||
0)
|
||||
max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
||||
if compilation_config.cudagraph_capture_sizes:
|
||||
max_num_tokens = compilation_config.max_cudagraph_capture_size
|
||||
else:
|
||||
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -130,7 +147,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
"global_bs": self.global_bs,
|
||||
"expert_token_nums_type": 0,
|
||||
}
|
||||
|
||||
@@ -147,10 +164,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
if self.need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
"expert_scales":
|
||||
@@ -214,7 +227,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"mc2_mask": mc2_mask,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"tp_recv_counts": tp_recv_counts,
|
||||
@@ -243,7 +255,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
ep_recv_counts = context_metadata["ep_recv_counts"]
|
||||
tp_recv_counts = context_metadata["tp_recv_counts"]
|
||||
assist_info_for_combine = context_metadata["assist_info_for_combine"]
|
||||
mc2_mask = context_metadata["mc2_mask"]
|
||||
expand_scales = context_metadata["expand_scales"]
|
||||
|
||||
assert expert_map is not None
|
||||
@@ -256,7 +267,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
"global_bs": self.global_bs,
|
||||
}
|
||||
|
||||
if self.with_quant:
|
||||
@@ -285,9 +296,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage3_kwargs["x_active_mask"] = mc2_mask
|
||||
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
|
||||
@@ -371,8 +371,12 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(
|
||||
topk_ids, 0, global_num_experts - global_redundant_expert_num)
|
||||
random_matrix = torch.rand(topk_ids.size(0),
|
||||
global_num_experts -
|
||||
global_redundant_expert_num,
|
||||
device=topk_ids.device)
|
||||
topk_ids = torch.argsort(
|
||||
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
|
||||
@@ -215,8 +215,12 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(
|
||||
topk_ids, 0, global_num_experts - global_redundant_expert_num)
|
||||
random_matrix = torch.rand(topk_ids.size(0),
|
||||
global_num_experts -
|
||||
global_redundant_expert_num,
|
||||
device=topk_ids.device)
|
||||
topk_ids = torch.argsort(
|
||||
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
topk_weights = topk_weights.to(self.in_dtype)
|
||||
|
||||
|
||||
@@ -821,10 +821,7 @@ class MtpProposer(Proposer):
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
if not self.runner.with_prefill:
|
||||
max_num_reqs_across_dp = num_input_tokens
|
||||
else:
|
||||
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
|
||||
last_token_indices = nn.functional.pad(
|
||||
last_token_indices,
|
||||
(0, max_num_reqs_across_dp - num_indices))
|
||||
|
||||
@@ -428,6 +428,25 @@ class NPUModelRunner(GPUModelRunner):
|
||||
def _use_aclgraph(self) -> bool:
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager
|
||||
|
||||
def _skip_all_reduce_acorss_dp_group(self) -> bool:
|
||||
# NOTE: We can skip the all_reduce operation and avoid paading tokens
|
||||
# to max_tokens_acrodd_dp in D nodes. In MoE models, we must ensure that
|
||||
# num_tokens DOES NOT exceed mc2_tokens_capacity which means that moe_comm_method
|
||||
# of each rank is MC2. For dense models, skipping all_reduce is not necessary
|
||||
# since collective-communication is not time-consuming since dp_size in dense
|
||||
# model deployments is always small and can be overlapped by async scheduling.
|
||||
if not is_moe_model(self.vllm_config):
|
||||
return False
|
||||
if self.compilation_config.cudagraph_capture_sizes:
|
||||
potential_max_num_tokens = self.compilation_config.max_cudagraph_capture_size
|
||||
else:
|
||||
potential_max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||||
# To ensure skipping all_reduce across dp group is valid, we need to ensure that
|
||||
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
|
||||
# nodes. So here we check whether recompute_scheduler_enable is True.
|
||||
return self.is_kv_consumer and not self.in_profile_run and self.ascend_config.recompute_scheduler_enable and self._select_moe_comm_method(
|
||||
potential_max_num_tokens) == MoECommType.MC2
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
|
||||
@@ -439,6 +458,14 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# immediately once the other two flags are no longer needed.
|
||||
if self.dp_size == 1:
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
if self._skip_all_reduce_acorss_dp_group():
|
||||
num_tokens_after_padding = torch.tensor([num_tokens] *
|
||||
self.dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return num_tokens, num_tokens_after_padding, with_prefill
|
||||
|
||||
# Sync num_tokens, with_prefill across dp ranks
|
||||
num_tokens_tensor = torch.tensor([
|
||||
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
|
||||
@@ -1097,7 +1124,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
if lmhead_tp_enable():
|
||||
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
|
||||
max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len
|
||||
logits_indices = nn.functional.pad(
|
||||
logits_indices,
|
||||
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
|
||||
@@ -2190,7 +2217,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
need_dummy_logits = (not self.in_profile_run
|
||||
and lmhead_tp_enable())
|
||||
max_num_reqs_across_dp = num_tokens_padded if not with_prefill else max_num_reqs
|
||||
max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len
|
||||
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
||||
dtype=torch.int32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user