feat: support compile multiple batch graph (#1085)

### What this PR does / why we need it?

support compile multiple batch graph with different code object to avoid
cache invalidation

### How was this patch tested?

```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1

source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh

nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
    --quantization ascend \
    --served-model-name auto \
    --trust-remote-code \
    --distributed-executor-backend=mp \
    --port 8006 \
    -tp=8 \
    -dp=2 \
    --no-enforce-eager \
    --max-num-seqs 24 \
    --max-model-len 32768 \
    --max-num-batched-tokens 32768 \
    --block-size 128 \
    --no-enable-prefix-caching \
    --additional-config '{"torchair_graph_config": {"enabled": true,"use_cached_graph": true,"graph_batch_sizes": [8,16,24]},"ascend_scheduler_config": {"enabled":true,"chunked_prefill_enabled":false},"expert_tensor_parallel_size":16}' \
    --gpu-memory-utilization 0.95 &> run.log &
disown
```

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-06-06 20:17:51 +08:00
committed by GitHub
parent c46632439a
commit c7f1c59911

View File

@@ -20,6 +20,7 @@
import gc import gc
import os import os
import time import time
import types
import weakref import weakref
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
@@ -321,6 +322,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.sampler = Sampler() self.sampler = Sampler()
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
@@ -713,7 +716,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
model_kwargs["kv_caches"] = self.kv_caches model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled and not with_prefill: if self.torchair_graph_enabled and not with_prefill:
hidden_states = self.compile_model( compiled_model = self._get_torchair_lazy_compiled_model(
padded_batch_size)
hidden_states = compiled_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
@@ -1190,7 +1195,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv, tuple), "kv_cache must be a tuple" kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0]) torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1]) torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model( compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
hidden_states = compiled_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
@@ -1264,30 +1271,59 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30)) m.consumed_memory / float(2**30))
# adapter torch compile with npu_backend def _get_torchair_lazy_compiled_model(self, batch_size: int):
if self.torchair_graph_enabled: if batch_size < 0 or batch_size > self.max_num_reqs:
import torchair # type: ignore raise ValueError(
from torchair import patch_for_hcom # type: ignore f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
)
patch_for_hcom() compiled_model = self.torchair_compiled_models.get(
config = torchair.CompilerConfig() batch_size
config.experimental_config.frozen_parameter = True ) if self.use_cached_npu_graph else self.torchair_compiled_model
config.experimental_config.tiling_schedule_optimize = True
torch.npu.set_compile_mode(jit_compile=False) if compiled_model:
if not self.use_cached_npu_graph: return compiled_model
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.compile_model = torch.compile( import torchair # type: ignore
self.model, from torchair import patch_for_hcom # type: ignore
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, patch_for_hcom()
backend=npu_backend) config = torchair.CompilerConfig()
else: config.experimental_config.frozen_parameter = True
self.compile_model = torchair.inference.cache_compile( config.experimental_config.tiling_schedule_optimize = True
self.model.forward, torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True, dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config, config=config,
ge_cache=False) ge_cache=False)
return self.torchair_compiled_models[batch_size]
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """