[Fix] Fix update_aclgraph_sizes when running MoE models (#913)
### What this PR does / why we need it? Fix update_aclgraph_sizes when running MoE models. --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -22,18 +22,17 @@ def get_etp_group() -> GroupCoordinator:
|
|||||||
|
|
||||||
|
|
||||||
def init_ascend_model_parallel(
|
def init_ascend_model_parallel(
|
||||||
tensor_model_parallel_size: int = 1,
|
expert_parallel_size: int = 1,
|
||||||
pipeline_model_parallel_size: int = 1,
|
|
||||||
expert_tensor_parallel_size: int = 1,
|
expert_tensor_parallel_size: int = 1,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
backend: Optional[str] = None,
|
backend: Optional[str] = None,
|
||||||
):
|
):
|
||||||
assert torch.distributed.is_initialized()
|
assert torch.distributed.is_initialized()
|
||||||
world_size: int = torch.distributed.get_world_size()
|
world_size = world_size or torch.distributed.get_world_size()
|
||||||
backend = backend or torch.distributed.get_backend(
|
backend = backend or torch.distributed.get_backend(
|
||||||
get_world_group().device_group)
|
get_world_group().device_group)
|
||||||
num_expert_parallel_groups: int = expert_tensor_parallel_size
|
num_expert_parallel_groups = expert_tensor_parallel_size
|
||||||
num_expert_tensor_parallel_groups: int = (world_size //
|
num_expert_tensor_parallel_groups = expert_parallel_size
|
||||||
expert_tensor_parallel_size)
|
|
||||||
|
|
||||||
global _EP
|
global _EP
|
||||||
group_ranks = []
|
group_ranks = []
|
||||||
|
|||||||
@@ -119,6 +119,26 @@ class NPUPlatform(Platform):
|
|||||||
from vllm.config import CompilationLevel # noqa: E402
|
from vllm.config import CompilationLevel # noqa: E402
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
additional_config = vllm_config.additional_config
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
|
||||||
|
if parallel_config:
|
||||||
|
# Default value for expert tensor parallel size
|
||||||
|
parallel_config.expert_tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
# NOTE: When enable_expert_parallel is True, we follow vLLM convention:
|
||||||
|
# ep_size = world_size, which means expert_tensor_parallel_size must be 1
|
||||||
|
if (additional_config
|
||||||
|
and "expert_tensor_parallel_size" in additional_config
|
||||||
|
and not parallel_config.enable_expert_parallel):
|
||||||
|
parallel_config.expert_tensor_parallel_size = int(
|
||||||
|
additional_config["expert_tensor_parallel_size"])
|
||||||
|
|
||||||
|
# Calculate expert parallel size based on world size
|
||||||
|
parallel_config.expert_parallel_size = (
|
||||||
|
parallel_config.world_size //
|
||||||
|
parallel_config.expert_tensor_parallel_size)
|
||||||
|
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
logger.warning("Model config is missing. This may indicate "
|
logger.warning("Model config is missing. This may indicate "
|
||||||
@@ -127,9 +147,9 @@ class NPUPlatform(Platform):
|
|||||||
else:
|
else:
|
||||||
enforce_eager = getattr(model_config, "enforce_eager", False)
|
enforce_eager = getattr(model_config, "enforce_eager", False)
|
||||||
|
|
||||||
if vllm_config.additional_config is not None:
|
if additional_config is not None:
|
||||||
enable_graph_mode = vllm_config.additional_config.get(
|
enable_graph_mode = additional_config.get("enable_graph_mode",
|
||||||
"enable_graph_mode", False)
|
False)
|
||||||
if enable_graph_mode:
|
if enable_graph_mode:
|
||||||
if enforce_eager:
|
if enforce_eager:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -139,7 +159,7 @@ class NPUPlatform(Platform):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
|
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
|
||||||
"it has been disabled automatically.")
|
"it has been disabled automatically.")
|
||||||
vllm_config.additional_config["enable_graph_mode"] = False
|
additional_config["enable_graph_mode"] = False
|
||||||
if model_config:
|
if model_config:
|
||||||
model_type = model_config.hf_config.model_type
|
model_type = model_config.hf_config.model_type
|
||||||
if "deepseek" not in model_type:
|
if "deepseek" not in model_type:
|
||||||
@@ -178,7 +198,6 @@ class NPUPlatform(Platform):
|
|||||||
["vllm.unified_ascend_attention_with_output"])
|
["vllm.unified_ascend_attention_with_output"])
|
||||||
update_aclgraph_sizes(vllm_config)
|
update_aclgraph_sizes(vllm_config)
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
|
||||||
if parallel_config and parallel_config.worker_cls == "auto":
|
if parallel_config and parallel_config.worker_cls == "auto":
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
||||||
@@ -190,7 +209,6 @@ class NPUPlatform(Platform):
|
|||||||
else:
|
else:
|
||||||
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
if cache_config:
|
if cache_config:
|
||||||
if cache_config.block_size is None:
|
if cache_config.block_size is None:
|
||||||
cache_config.block_size = 128
|
cache_config.block_size = 128
|
||||||
@@ -202,11 +220,10 @@ class NPUPlatform(Platform):
|
|||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
# Activate custom ops for v1.
|
# Activate custom ops for v1.
|
||||||
vllm_config.compilation_config.custom_ops = ["all"]
|
compilation_config.custom_ops = ["all"]
|
||||||
# If ascend_scheduler_config exists in additional_config,
|
# If ascend_scheduler_config exists in additional_config,
|
||||||
# extents original scheduler_config to use AscendScheduler.
|
# extents original scheduler_config to use AscendScheduler.
|
||||||
|
|
||||||
additional_config = vllm_config.additional_config
|
|
||||||
if additional_config and additional_config.get(
|
if additional_config and additional_config.get(
|
||||||
"ascend_scheduler_config", None) is not None:
|
"ascend_scheduler_config", None) is not None:
|
||||||
additional_scheduler_config = additional_config.get(
|
additional_scheduler_config = additional_config.get(
|
||||||
|
|||||||
@@ -126,14 +126,16 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
|||||||
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||||
compilation_config.cudagraph_capture_sizes, None
|
compilation_config.cudagraph_capture_sizes, None
|
||||||
|
|
||||||
# Calculate parallel configuration factor (increases with DP or TP)
|
# Calculate parallel configuration factor
|
||||||
# TODO(Yizhou): This is a temporary solution, need to be improved
|
|
||||||
# in the future, taking into account the other parallel configurations.
|
|
||||||
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
|
|
||||||
|
# TODO: Find out whether we need to take into account the pp_size
|
||||||
parallel_factor = 1 + sum(size > 1 for size in [
|
parallel_factor = 1 + sum(size > 1 for size in [
|
||||||
parallel_config.data_parallel_size,
|
parallel_config.data_parallel_size_local,
|
||||||
parallel_config.tensor_parallel_size
|
parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.expert_parallel_size,
|
||||||
|
parallel_config.expert_tensor_parallel_size,
|
||||||
])
|
])
|
||||||
|
|
||||||
# Calculate maximum supported batch sizes considering model architecture
|
# Calculate maximum supported batch sizes considering model architecture
|
||||||
|
|||||||
@@ -534,7 +534,6 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
backend: str = "hccl") -> None:
|
backend: str = "hccl") -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
parallel_config = self.parallel_config
|
parallel_config = self.parallel_config
|
||||||
additional_config = self.vllm_config.additional_config
|
|
||||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||||
init_distributed_environment(parallel_config.world_size, rank,
|
init_distributed_environment(parallel_config.world_size, rank,
|
||||||
distributed_init_method, local_rank,
|
distributed_init_method, local_rank,
|
||||||
@@ -542,13 +541,11 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
parallel_config.tensor_parallel_size,
|
parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
expert_tensor_parallel_size = 1
|
init_ascend_model_parallel(
|
||||||
if additional_config:
|
parallel_config.expert_parallel_size,
|
||||||
expert_tensor_parallel_size = additional_config.get(
|
parallel_config.expert_tensor_parallel_size,
|
||||||
"expert_tensor_parallel_size", 1)
|
parallel_config.world_size,
|
||||||
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
|
)
|
||||||
parallel_config.pipeline_parallel_size,
|
|
||||||
expert_tensor_parallel_size)
|
|
||||||
ensure_kv_transfer_initialized(vllm_config)
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -234,7 +234,6 @@ class NPUWorker(WorkerBase):
|
|||||||
|
|
||||||
def _init_worker_distributed_environment(self) -> None:
|
def _init_worker_distributed_environment(self) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
additional_config = self.vllm_config.additional_config
|
|
||||||
parallel_config = self.vllm_config.parallel_config
|
parallel_config = self.vllm_config.parallel_config
|
||||||
set_custom_all_reduce(
|
set_custom_all_reduce(
|
||||||
not self.parallel_config.disable_custom_all_reduce)
|
not self.parallel_config.disable_custom_all_reduce)
|
||||||
@@ -244,13 +243,11 @@ class NPUWorker(WorkerBase):
|
|||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
self.parallel_config.tensor_parallel_size,
|
self.parallel_config.tensor_parallel_size,
|
||||||
self.parallel_config.pipeline_parallel_size)
|
self.parallel_config.pipeline_parallel_size)
|
||||||
expert_tensor_parallel_size = 1
|
init_ascend_model_parallel(
|
||||||
if additional_config is not None and "expert_tensor_parallel_size" in additional_config:
|
parallel_config.expert_parallel_size,
|
||||||
expert_tensor_parallel_size = int(
|
parallel_config.expert_tensor_parallel_size,
|
||||||
additional_config["expert_tensor_parallel_size"])
|
parallel_config.world_size,
|
||||||
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
|
)
|
||||||
parallel_config.pipeline_parallel_size,
|
|
||||||
expert_tensor_parallel_size)
|
|
||||||
ensure_kv_transfer_initialized(self.vllm_config)
|
ensure_kv_transfer_initialized(self.vllm_config)
|
||||||
|
|
||||||
def _init_profiler(self):
|
def _init_profiler(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user