[Feature] Support moe multi-stream for aclgraph. (#2946)
This PR puts the calculation of shared experts into a separate stream,
overlaping with routing experts.
- vLLM version: v0.10.2
- vLLM main:
fbd6523ac0
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import atexit
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
@@ -321,7 +321,9 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV':
|
||||
# TODO: Find out whether we need to take into account the pp_size
|
||||
parallel_factor = 1 + num_comm_groups + int(
|
||||
parallel_config.enable_expert_parallel)
|
||||
parallel_config.enable_expert_parallel) + int(
|
||||
vllm_config.additional_config.get(
|
||||
"multistream_overlap_shared_expert", False))
|
||||
if is_moe_model(vllm_config):
|
||||
parallel_factor += (parallel_config.data_parallel_size > 1)
|
||||
# Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
|
||||
@@ -617,3 +619,16 @@ def weak_ref_tensors(
|
||||
if isinstance(tensors, tuple):
|
||||
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def npu_stream_switch(target_stream: torch.npu.Stream,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
"""
|
||||
Switch to the target stream if enabled is True.
|
||||
Otherwise, do nothing.
|
||||
"""
|
||||
if not enabled:
|
||||
return nullcontext()
|
||||
assert target_stream is not None
|
||||
return torch.npu.stream(target_stream)
|
||||
|
||||
Reference in New Issue
Block a user