[Feature]refactor the npugraph_ex config, support online-infer with static kernel (#5775)

### What this PR does / why we need it?
This is a part of
https://github.com/vllm-project/vllm-ascend/issues/4715#issue-3694310762
1. refactor the npugraph_ex config,modified the default configuration of
the static kernel, new default value of static kernel is false
2. support online-infer with static kernel
3. fixed the issue where manually modifying FX graphs caused an abnormal
model return type, and removed the related redundant code.

### Does this PR introduce _any_ user-facing change?
yes,the new config of npugraph_ex is as follow:
```
additional_config={
            "npugraph_ex_config": {
                "enable": True,
                "enable_static_kernel": False
            }
        }
```
### How was this patch tested?
```
vllm serve /data/DeepSeek-V3.1-Terminus-w4a8 \
    --host 0.0.0.0 \
    --port 8004 \
    --data-parallel-size 4 \
    --tensor-parallel-size 4 \
    --quantization ascend \
    --seed 1024 \
    --served-model-name deepseek_v3 \
    --enable-expert-parallel \
    --max-num-seqs 48 \
    --max-model-len 40000 \
    --async-scheduling \
    --max-num-batched-tokens 9000 \
    --trust-remote-code \
    --no-enable-prefix-caching \
    --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp","disable_padded_drafter_batch": false}' \
    --gpu-memory-utilization 0.9 \
    --compilation-config '{"cudagraph_capture_sizes":[4,32,64,112,160,176,192], "cudagraph_mode": "FULL_DECODE_ONLY"}' \
    --additional-config \
    '{"enable_shared_expert_dp": true,"multistream_overlap_shared_expert": true,"npugraph_ex_config":{"enable":true}}'
```

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
This commit is contained in:
ChenCangtao
2026-01-20 21:31:38 +08:00
committed by GitHub
parent 0c0514579f
commit 6c30f8bf87
6 changed files with 91 additions and 17 deletions

View File

@@ -26,9 +26,10 @@ from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tu
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_config import NpugraphExConfig, get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY
@@ -68,6 +69,8 @@ def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
vllm_config: VllmConfig,
npugraph_ex_config: NpugraphExConfig,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
@@ -85,7 +88,6 @@ def npugraph_ex_compile(
tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
output_node.args = (tuple_node,)
graph.recompile()
import torchair
# TODO: use a better way to lazy register replacement, instead of import one by one
@@ -98,10 +100,24 @@ def npugraph_ex_compile(
config.mode = "reduce-overhead"
# execute FX graph in eager mode before graph mode to optimize FX graph.
config.debug.run_eagerly = True
# static kernel switch, suitable for static shapes or scenes with less shape changes.
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
if npugraph_ex_config.enable_static_kernel:
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
# According to the cudagraph_capture_size configuration, set the shapes
# that can trigger the compilation of static kernel. If this configuration is
# not applied, new shapes will trigger the compilation of static kernels,
# affecting program execution.
num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0
uniform_decode_query_len = num_spec_tokens + 1
max_num_tokens = vllm_config.scheduler_config.max_num_seq * uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x
for x in vllm_config.compilation_config.cudagraph_capture_size
if max_num_tokens >= x >= uniform_decode_query_len
]
config.experimental_config.aclgraph._aclnn_static_shape_kernel_sym_value_range = decode_cudagraph_batch_sizes
npugraph_ex = torchair.get_npu_backend(compiler_config=config)
compile_graph = npugraph_ex(graph, example_inputs)
return compile_graph, None
@@ -115,6 +131,12 @@ class AscendCompiler(CompilerInterface):
name = "AscendCompiler"
def compute_hash(self, vllm_config: VllmConfig) -> str:
npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
self.vllm_config = vllm_config
return vllm_config.compute_hash()
def compile(
self,
graph: fx.GraphModule,
@@ -123,8 +145,11 @@ class AscendCompiler(CompilerInterface):
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config, compile_range, key)
npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
assert hasattr(self, "vllm_config")
return npugraph_ex_compile(
graph, example_inputs, compiler_config, self.vllm_config, npugraph_ex_config, compile_range, key
)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)