feat: support compile multiple batch graph (#1085)

### What this PR does / why we need it?

support compile multiple batch graph with different code object to avoid
cache invalidation

### How was this patch tested?

```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1

source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh

nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
    --quantization ascend \
    --served-model-name auto \
    --trust-remote-code \
    --distributed-executor-backend=mp \
    --port 8006 \
    -tp=8 \
    -dp=2 \
    --no-enforce-eager \
    --max-num-seqs 24 \
    --max-model-len 32768 \
    --max-num-batched-tokens 32768 \
    --block-size 128 \
    --no-enable-prefix-caching \
    --additional-config '{"torchair_graph_config": {"enabled": true,"use_cached_graph": true,"graph_batch_sizes": [8,16,24]},"ascend_scheduler_config": {"enabled":true,"chunked_prefill_enabled":false},"expert_tensor_parallel_size":16}' \
    --gpu-memory-utilization 0.95 &> run.log &
disown
```

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-06-06 20:17:51 +08:00
committed by GitHub
parent c46632439a
commit c7f1c59911

View File

@@ -20,6 +20,7 @@
import gc
import os
import time
import types
import weakref
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
@@ -321,6 +322,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.sampler = Sampler()
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
@@ -713,7 +716,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled and not with_prefill:
hidden_states = self.compile_model(
compiled_model = self._get_torchair_lazy_compiled_model(
padded_batch_size)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
@@ -1190,7 +1195,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model(
compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
@@ -1264,30 +1271,59 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30))
# adapter torch compile with npu_backend
if self.torchair_graph_enabled:
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.max_num_reqs:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
)
patch_for_hcom()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.compile_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
else:
self.compile_model = torchair.inference.cache_compile(
self.model.forward,
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""