[aclgraph] implentment NPUPiecewiseBackend to enable aclgraph (#836)

### What this PR does / why we need it?
1. Implentment `NPUPiecewiseBackend` to enable aclgraph
2. Eable aclgraph by default in V1, but raise error when running
deepseek and raise warning when running models except for qwen

### How was this patch tested?
CI pass with the new ut

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-05-29 11:58:26 +08:00
committed by GitHub
parent cc74b97f74
commit a93bed4535
8 changed files with 380 additions and 33 deletions

View File

@@ -23,7 +23,6 @@ import torch
import vllm.envs as envs
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum
from vllm.utils import supports_dynamo
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes
@@ -119,24 +118,48 @@ class NPUPlatform(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel # noqa: E402
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
if vllm_config.model_config is None:
if model_config is None:
logger.warning("Model config is missing. This may indicate "
"that we are running a test case")
enforce_eager = False
else:
enforce_eager = getattr(vllm_config.model_config, "enforce_eager",
False)
enforce_eager = getattr(model_config, "enforce_eager", False)
# TODO(Yizhou): Override the value of enforce_eager to True before
# the CANN and torch_npu support NPU compilation.
enforce_eager = True
logger.warning(
"NPU compilation support pending. Will be available in future CANN and "
"torch_npu releases. NPU graph mode is currently experimental and disabled "
"by default. You can just adopt additional_config={'enable_graph_mode': True} "
"to serve deepseek models with NPU graph mode on vllm-ascend with V0 engine. "
)
if vllm_config.additional_config is not None:
enable_graph_mode = vllm_config.additional_config.get(
"enable_graph_mode", False)
if enable_graph_mode:
if enforce_eager:
raise RuntimeError(
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
)
elif envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE:
logger.warning(
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
"it has been disabled automatically.")
vllm_config.additional_config["enable_graph_mode"] = False
if model_config:
model_type = model_config.hf_config.model_type
if "deepseek" not in model_type:
raise NotImplementedError(
"enable_graph_mode only works with deepseek model."
)
elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager:
model_type = model_config.hf_config.model_type
if "deepseek" in model_type:
raise NotImplementedError(
"ACL Graph does not support deepseek. Please "
"adopt additional_config={'enable_graph_mode': True} "
"to serve deepseek models with NPU graph mode on vllm-ascend with V1 engine."
" Or set `enforce_eager=True` to use eager mode.")
elif "qwen" not in model_type:
logger.warning(
"ACL Graph is currently experimental. Please "
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
" if you encourage any Error")
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
logger.info("Compilation disabled, using eager mode by default")
@@ -155,20 +178,6 @@ class NPUPlatform(Platform):
["vllm.unified_ascend_attention_with_output"])
update_aclgraph_sizes(vllm_config)
if vllm_config.additional_config is not None:
enable_graph_mode = vllm_config.additional_config.get(
"enable_graph_mode", False)
if enable_graph_mode and not supports_dynamo():
logger.warning(
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
)
vllm_config.additional_config["enable_graph_mode"] = False
if enable_graph_mode and envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE:
logger.warning(
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
"it has been disabled automatically.")
vllm_config.additional_config["enable_graph_mode"] = False
parallel_config = vllm_config.parallel_config
if parallel_config and parallel_config.worker_cls == "auto":
if envs.VLLM_USE_V1:
@@ -244,3 +253,10 @@ class NPUPlatform(Platform):
model configuration.
"""
return True
@classmethod
def get_piecewise_backend_cls(cls) -> str:
"""
Get piecewise backend class for piecewise graph.
"""
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa