feat: support compile multiple batch graph (#1085)

### What this PR does / why we need it?

support compile multiple batch graph with different code object to avoid
cache invalidation

### How was this patch tested?

```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1

source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh

nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
    --quantization ascend \
    --served-model-name auto \
    --trust-remote-code \
    --distributed-executor-backend=mp \
    --port 8006 \
    -tp=8 \
    -dp=2 \
    --no-enforce-eager \
    --max-num-seqs 24 \
    --max-model-len 32768 \
    --max-num-batched-tokens 32768 \
    --block-size 128 \
    --no-enable-prefix-caching \
    --additional-config '{"torchair_graph_config": {"enabled": true,"use_cached_graph": true,"graph_batch_sizes": [8,16,24]},"ascend_scheduler_config": {"enabled":true,"chunked_prefill_enabled":false},"expert_tensor_parallel_size":16}' \
    --gpu-memory-utilization 0.95 &> run.log &
disown
```

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-06-06 20:17:51 +08:00
committed by GitHub
parent c46632439a
commit c7f1c59911

View File

@@ -20,6 +20,7 @@
import gc import gc
import os import os
import time import time
import types
import weakref import weakref
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
@@ -321,6 +322,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.sampler = Sampler() self.sampler = Sampler()
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
@@ -713,7 +716,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
model_kwargs["kv_caches"] = self.kv_caches model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled and not with_prefill: if self.torchair_graph_enabled and not with_prefill:
hidden_states = self.compile_model( compiled_model = self._get_torchair_lazy_compiled_model(
padded_batch_size)
hidden_states = compiled_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
@@ -1190,7 +1195,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv, tuple), "kv_cache must be a tuple" kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0]) torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1]) torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model( compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
hidden_states = compiled_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
@@ -1264,8 +1271,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30)) m.consumed_memory / float(2**30))
# adapter torch compile with npu_backend def _get_torchair_lazy_compiled_model(self, batch_size: int):
if self.torchair_graph_enabled: if batch_size < 0 or batch_size > self.max_num_reqs:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
import torchair # type: ignore import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore from torchair import patch_for_hcom # type: ignore
@@ -1276,18 +1294,36 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch.npu.set_compile_mode(jit_compile=False) torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph: if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config) npu_backend = torchair.get_npu_backend(compiler_config=config)
self.compile_model = torch.compile( self.torchair_compiled_model = torch.compile(
self.model, self.model,
dynamic=True, dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend) backend=npu_backend)
return self.torchair_compiled_model
else: else:
self.compile_model = torchair.inference.cache_compile( # Generate a new forward proxy code object to prevent the invalidation of
self.model.forward, # compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True, dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config, config=config,
ge_cache=False) ge_cache=False)
return self.torchair_compiled_models[batch_size]
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """