[Feat] Multi-stream for eplb heat collection and aggregation (#4214)
### What this PR does / why we need it? This PR optimizes multistream for eplb heat collection and aggregation - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: daishixun <dsxsteven@sina.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -23,6 +23,8 @@ from vllm.logger import logger
|
|||||||
|
|
||||||
from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
|
from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
|
||||||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||||
|
from vllm_ascend.eplb.utils import moe_load_async_stream
|
||||||
|
from vllm_ascend.utils import npu_stream_switch
|
||||||
|
|
||||||
|
|
||||||
class EplbUpdator:
|
class EplbUpdator:
|
||||||
@@ -153,21 +155,22 @@ class EplbUpdator:
|
|||||||
|
|
||||||
self._gather_buffer = None
|
self._gather_buffer = None
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
self.world_size = dist.get_world_size()
|
with npu_stream_switch(moe_load_async_stream()):
|
||||||
self.device = local_load.device
|
self.world_size = dist.get_world_size()
|
||||||
if self._gather_buffer is None:
|
self.device = local_load.device
|
||||||
shape = (self.world_size, *local_load.shape)
|
if self._gather_buffer is None:
|
||||||
self._gather_buffer = torch.empty(shape,
|
shape = (self.world_size, *local_load.shape)
|
||||||
dtype=local_load.dtype,
|
self._gather_buffer = torch.empty(shape,
|
||||||
device=self.device)
|
dtype=local_load.dtype,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
dist.all_gather_into_tensor(self._gather_buffer, local_load)
|
dist.all_gather_into_tensor(self._gather_buffer, local_load)
|
||||||
|
|
||||||
moe_load = self._gather_buffer.permute(1, 0, 2)
|
moe_load = self._gather_buffer.permute(1, 0, 2)
|
||||||
self.shared_dict["moe_load"] = moe_load.cpu()
|
self.shared_dict["moe_load"] = moe_load.cpu()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
|
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
moe_load = local_load.unsqueeze(1)
|
moe_load = local_load.unsqueeze(1)
|
||||||
self.shared_dict["moe_load"] = moe_load.cpu()
|
self.shared_dict["moe_load"] = moe_load.cpu()
|
||||||
|
|||||||
@@ -18,6 +18,9 @@
|
|||||||
import types
|
import types
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
_MOE_LOAD_ASYNC_STREAM = None
|
||||||
|
|
||||||
|
|
||||||
def get_expert_map(self, layer_id):
|
def get_expert_map(self, layer_id):
|
||||||
@@ -75,3 +78,12 @@ def model_register(model, model_config):
|
|||||||
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
|
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("EPLB is not supported.")
|
raise NotImplementedError("EPLB is not supported.")
|
||||||
|
|
||||||
|
|
||||||
|
def moe_load_async_stream() -> torch_npu.npu.Stream:
|
||||||
|
global _MOE_LOAD_ASYNC_STREAM
|
||||||
|
if _MOE_LOAD_ASYNC_STREAM is None:
|
||||||
|
# when this function is called before any stream is set,
|
||||||
|
# we return the default stream.
|
||||||
|
_MOE_LOAD_ASYNC_STREAM = torch_npu.npu.Stream()
|
||||||
|
return _MOE_LOAD_ASYNC_STREAM
|
||||||
@@ -36,6 +36,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
|||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
|
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
|
||||||
|
from vllm_ascend.eplb.utils import moe_load_async_stream
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
|
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
|
||||||
@@ -368,8 +369,15 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if isinstance(final_hidden_states, tuple):
|
if isinstance(final_hidden_states, tuple):
|
||||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.moe_load += expert_tokens if group_list_type == 1 else \
|
|
||||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
moe_load_stream = moe_load_async_stream()
|
||||||
|
cur_stream = torch.npu.current_stream()
|
||||||
|
|
||||||
|
moe_load_stream.wait_stream(cur_stream)
|
||||||
|
with npu_stream_switch(moe_load_stream):
|
||||||
|
self.moe_load += expert_tokens if group_list_type == 1 else \
|
||||||
|
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||||
|
cur_stream.wait_stream(moe_load_stream)
|
||||||
|
|
||||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||||
hidden_states=final_hidden_states,
|
hidden_states=final_hidden_states,
|
||||||
|
|||||||
Reference in New Issue
Block a user