[Feature]refactor the npugraph_ex config, support online-infer with static kernel (#5775)

### What this PR does / why we need it?
This is a part of
https://github.com/vllm-project/vllm-ascend/issues/4715#issue-3694310762
1. refactor the npugraph_ex config,modified the default configuration of
the static kernel, new default value of static kernel is false
2. support online-infer with static kernel
3. fixed the issue where manually modifying FX graphs caused an abnormal
model return type, and removed the related redundant code.

### Does this PR introduce _any_ user-facing change?
yes,the new config of npugraph_ex is as follow:
```
additional_config={
            "npugraph_ex_config": {
                "enable": True,
                "enable_static_kernel": False
            }
        }
```
### How was this patch tested?
```
vllm serve /data/DeepSeek-V3.1-Terminus-w4a8 \
    --host 0.0.0.0 \
    --port 8004 \
    --data-parallel-size 4 \
    --tensor-parallel-size 4 \
    --quantization ascend \
    --seed 1024 \
    --served-model-name deepseek_v3 \
    --enable-expert-parallel \
    --max-num-seqs 48 \
    --max-model-len 40000 \
    --async-scheduling \
    --max-num-batched-tokens 9000 \
    --trust-remote-code \
    --no-enable-prefix-caching \
    --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp","disable_padded_drafter_batch": false}' \
    --gpu-memory-utilization 0.9 \
    --compilation-config '{"cudagraph_capture_sizes":[4,32,64,112,160,176,192], "cudagraph_mode": "FULL_DECODE_ONLY"}' \
    --additional-config \
    '{"enable_shared_expert_dp": true,"multistream_overlap_shared_expert": true,"npugraph_ex_config":{"enable":true}}'
```

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
This commit is contained in:
ChenCangtao
2026-01-20 21:31:38 +08:00
committed by GitHub
parent 0c0514579f
commit 6c30f8bf87
6 changed files with 91 additions and 17 deletions

View File

@@ -31,6 +31,7 @@ The following table lists additional configuration options available in vLLM Asc
| `finegrained_tp_config` | dict | `{}` | Configuration options for module tensor parallelism |
| `ascend_compilation_config` | dict | `{}` | Configuration options for ascend compilation |
| `eplb_config` | dict | `{}` | Configuration options for ascend compilation |
| `npugraph_ex_config` | dict | `{}` | Configuration options for npugraph_ex backend |
| `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. |
| `dump_config_path` | str | `None` | Configuration file path for msprobe dump(eager mode). |
| `enable_async_exponential` | bool | `False` | Whether to enable async exponential overlap. To enable async exponential, set this config to True. |
@@ -88,6 +89,13 @@ The details of each configuration option are as follows:
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory.|
| `num_redundant_experts` | int | `0` | Specify redundant experts during initialization. |
**npugraph_ex_config**
| Name | Type | Default | Description |
|------------------------| ---- |---------|----------------------------------------------------------------------------------------|
| `enable` | bool | `False` | Whether to enable npugraph_ex backend. |
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
### Example
An example of additional configuration is as follows:

View File

@@ -126,7 +126,9 @@ def test_npugraph_ex_res_consistency(cur_case: LLMTestCase, monkeypatch):
"cudagraph_mode": "FULL_DECODE_ONLY"
},
"additional_config": {
"enable_npugraph_ex": True
"npugraph_ex_config": {
"enable": True
}
},
}
gen_and_valid(runner_kwargs=runner_kwargs,

View File

@@ -65,7 +65,10 @@ class TestAscendConfig(TestBase):
ascend_config = init_ascend_config(test_vllm_config)
self.assertEqual(ascend_config.eplb_config.num_redundant_experts, 2)
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
self.assertFalse(ascend_config.enable_npugraph_ex)
npugraph_ex_config = ascend_config.npugraph_ex_config
self.assertFalse(npugraph_ex_config.enable)
self.assertFalse(npugraph_ex_config.enable_static_kernel)
ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertFalse(ascend_compilation_config.fuse_norm_quant)
@@ -79,11 +82,16 @@ class TestAscendConfig(TestBase):
def test_init_ascend_config_enable_npugraph_ex(self, mock_fix_incompatible_config):
test_vllm_config = VllmConfig()
test_vllm_config.additional_config = {
"enable_npugraph_ex": True,
"refresh": True,
"npugraph_ex_config": {
"enable": True,
"enable_static_kernel": True
},
"refresh": True
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertTrue(ascend_config.enable_npugraph_ex)
npugraph_ex_config = init_ascend_config(
test_vllm_config).npugraph_ex_config
self.assertTrue(npugraph_ex_config.enable)
self.assertTrue(npugraph_ex_config.enable_static_kernel)
@_clean_up_ascend_config
@patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config")

View File

@@ -102,7 +102,8 @@ class AscendConfig:
from vllm_ascend.utils import get_flashcomm2_config_and_validate
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config)
self.enable_npugraph_ex = additional_config.get("enable_npugraph_ex", False)
npugraph_ex_config = additional_config.get("npugraph_ex_config", {})
self.npugraph_ex_config = NpugraphExConfig(**npugraph_ex_config)
# We find that _npu_paged_attention still performs better than
# npu_fused_infer_attention_score in some cases. We allow to execute
# _npu_paged_attention in this cases. This should be removed once
@@ -211,6 +212,36 @@ class AscendFusionConfig:
self.fusion_ops_gmmswigluquant = fusion_ops_gmmswigluquant
class NpugraphExConfig:
"""
Configuration for controlling the behavior of npugraph_ex backend.
This class provides a way to configure whether to use the npugraph_ex backend and static kernel.
These configurations can directly impact the performance and behavior of models deployed on Ascend platforms.
"""
def __init__(self, enable: bool = False, enable_static_kernel: bool = False, **kwargs):
"""
Initialize the configuration.
Args:
enable (bool): Whether to enable npugraph_ex backend.
When set to True, the Fx graph generated by Dymano will be
optimized and compiled by the npugraph_ex backend.
Default: False
enable_static_kernel (bool): Whether to enable static kernel.
Static kernel is suitable for scenarios with purely static shapes
or minimal shape changes, and can improve network performance.
When set to True, when during graph capture, it will compile operator
binary files with the corresponding shapes based on the current batch_size,
which usually takes some time.
Default: False
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.enable = enable
self.enable_static_kernel = enable_static_kernel
class XliteGraphConfig:
"""
Configuration Object for xlite_graph_config from additional_config

View File

@@ -26,9 +26,10 @@ from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tu
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_config import NpugraphExConfig, get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY
@@ -68,6 +69,8 @@ def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
vllm_config: VllmConfig,
npugraph_ex_config: NpugraphExConfig,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
@@ -85,7 +88,6 @@ def npugraph_ex_compile(
tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
output_node.args = (tuple_node,)
graph.recompile()
import torchair
# TODO: use a better way to lazy register replacement, instead of import one by one
@@ -98,10 +100,24 @@ def npugraph_ex_compile(
config.mode = "reduce-overhead"
# execute FX graph in eager mode before graph mode to optimize FX graph.
config.debug.run_eagerly = True
# static kernel switch, suitable for static shapes or scenes with less shape changes.
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
if npugraph_ex_config.enable_static_kernel:
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
# According to the cudagraph_capture_size configuration, set the shapes
# that can trigger the compilation of static kernel. If this configuration is
# not applied, new shapes will trigger the compilation of static kernels,
# affecting program execution.
num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0
uniform_decode_query_len = num_spec_tokens + 1
max_num_tokens = vllm_config.scheduler_config.max_num_seq * uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x
for x in vllm_config.compilation_config.cudagraph_capture_size
if max_num_tokens >= x >= uniform_decode_query_len
]
config.experimental_config.aclgraph._aclnn_static_shape_kernel_sym_value_range = decode_cudagraph_batch_sizes
npugraph_ex = torchair.get_npu_backend(compiler_config=config)
compile_graph = npugraph_ex(graph, example_inputs)
return compile_graph, None
@@ -115,6 +131,12 @@ class AscendCompiler(CompilerInterface):
name = "AscendCompiler"
def compute_hash(self, vllm_config: VllmConfig) -> str:
npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
self.vllm_config = vllm_config
return vllm_config.compute_hash()
def compile(
self,
graph: fx.GraphModule,
@@ -123,8 +145,11 @@ class AscendCompiler(CompilerInterface):
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config, compile_range, key)
npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
assert hasattr(self, "vllm_config")
return npugraph_ex_compile(
graph, example_inputs, compiler_config, self.vllm_config, npugraph_ex_config, compile_range, key
)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)

View File

@@ -275,7 +275,7 @@ class NPUPlatform(Platform):
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.mode = CompilationMode.NONE
ascend_config.enable_npugraph_ex = False
ascend_config.npugraph_ex_config.enable = False
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode")
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, (
@@ -295,7 +295,7 @@ class NPUPlatform(Platform):
# not be detected in advance assert.
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
update_aclgraph_sizes(vllm_config)
ascend_config.enable_npugraph_ex = False
ascend_config.npugraph_ex_config.enable = False
elif (
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
or compilation_config.cudagraph_mode == CUDAGraphMode.FULL
@@ -324,7 +324,7 @@ class NPUPlatform(Platform):
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
compilation_config.mode = CompilationMode.NONE
ascend_config.enable_npugraph_ex = False
ascend_config.npugraph_ex_config.enable = False
# TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1
# Then, we will have to discuss the error handling strategy and user experience