[Feat] Multi-stream for eplb heat collection and aggregation (#4214)

### What this PR does / why we need it?
This PR optimizes multistream for eplb heat collection and aggregation

- vLLM version: v0.12.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0

---------

Signed-off-by: daishixun <dsxsteven@sina.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
dsxsteven
2025-12-09 16:16:55 +08:00
committed by GitHub
parent dda027e680
commit 9a885d08d0
3 changed files with 38 additions and 15 deletions

View File

@@ -23,6 +23,8 @@ from vllm.logger import logger
from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.utils import moe_load_async_stream
from vllm_ascend.utils import npu_stream_switch
class EplbUpdator: class EplbUpdator:
@@ -153,6 +155,7 @@ class EplbUpdator:
self._gather_buffer = None self._gather_buffer = None
if dist.is_initialized(): if dist.is_initialized():
with npu_stream_switch(moe_load_async_stream()):
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
self.device = local_load.device self.device = local_load.device
if self._gather_buffer is None: if self._gather_buffer is None:

View File

@@ -18,6 +18,9 @@
import types import types
import torch import torch
import torch_npu
_MOE_LOAD_ASYNC_STREAM = None
def get_expert_map(self, layer_id): def get_expert_map(self, layer_id):
@@ -75,3 +78,12 @@ def model_register(model, model_config):
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
else: else:
raise NotImplementedError("EPLB is not supported.") raise NotImplementedError("EPLB is not supported.")
def moe_load_async_stream() -> torch_npu.npu.Stream:
global _MOE_LOAD_ASYNC_STREAM
if _MOE_LOAD_ASYNC_STREAM is None:
# when this function is called before any stream is set,
# we return the default stream.
_MOE_LOAD_ASYNC_STREAM = torch_npu.npu.Stream()
return _MOE_LOAD_ASYNC_STREAM

View File

@@ -36,6 +36,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
from vllm_ascend.eplb.utils import moe_load_async_stream
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
@@ -368,8 +369,15 @@ class AscendFusedMoE(FusedMoE):
if isinstance(final_hidden_states, tuple): if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states final_hidden_states, group_list_type, expert_tokens = final_hidden_states
if self.dynamic_eplb: if self.dynamic_eplb:
moe_load_stream = moe_load_async_stream()
cur_stream = torch.npu.current_stream()
moe_load_stream.wait_stream(cur_stream)
with npu_stream_switch(moe_load_stream):
self.moe_load += expert_tokens if group_list_type == 1 else \ self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
cur_stream.wait_stream(moe_load_stream)
final_hidden_states = forward_context.moe_comm_method.finalize( final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states, hidden_states=final_hidden_states,