feat: support compile multiple batch graph (#1085)
### What this PR does / why we need it?
support compile multiple batch graph with different code object to avoid
cache invalidation
### How was this patch tested?
```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
--quantization ascend \
--served-model-name auto \
--trust-remote-code \
--distributed-executor-backend=mp \
--port 8006 \
-tp=8 \
-dp=2 \
--no-enforce-eager \
--max-num-seqs 24 \
--max-model-len 32768 \
--max-num-batched-tokens 32768 \
--block-size 128 \
--no-enable-prefix-caching \
--additional-config '{"torchair_graph_config": {"enabled": true,"use_cached_graph": true,"graph_batch_sizes": [8,16,24]},"ascend_scheduler_config": {"enabled":true,"chunked_prefill_enabled":false},"expert_tensor_parallel_size":16}' \
--gpu-memory-utilization 0.95 &> run.log &
disown
```
Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
import weakref
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
@@ -321,6 +322,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
|
||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||
@@ -713,7 +716,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
hidden_states = self.compile_model(
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_batch_size)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
@@ -1190,7 +1195,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv, tuple), "kv_cache must be a tuple"
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
hidden_states = self.compile_model(
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_tokens)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
@@ -1264,30 +1271,59 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
m.consumed_memory / float(2**30))
|
||||
|
||||
# adapter torch compile with npu_backend
|
||||
if self.torchair_graph_enabled:
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.max_num_reqs:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
|
||||
)
|
||||
|
||||
patch_for_hcom()
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.compile_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=npu_backend)
|
||||
else:
|
||||
self.compile_model = torchair.inference.cache_compile(
|
||||
self.model.forward,
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user