feat: support compile multiple batch graph (#1085)
### What this PR does / why we need it?
support compile multiple batch graph with different code object to avoid
cache invalidation
### How was this patch tested?
```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
--quantization ascend \
--served-model-name auto \
--trust-remote-code \
--distributed-executor-backend=mp \
--port 8006 \
-tp=8 \
-dp=2 \
--no-enforce-eager \
--max-num-seqs 24 \
--max-model-len 32768 \
--max-num-batched-tokens 32768 \
--block-size 128 \
--no-enable-prefix-caching \
--additional-config '{"torchair_graph_config": {"enabled": true,"use_cached_graph": true,"graph_batch_sizes": [8,16,24]},"ascend_scheduler_config": {"enabled":true,"chunked_prefill_enabled":false},"expert_tensor_parallel_size":16}' \
--gpu-memory-utilization 0.95 &> run.log &
disown
```
Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -20,6 +20,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import types
|
||||||
import weakref
|
import weakref
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -321,6 +322,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
self.torchair_compiled_model = None # type: ignore
|
||||||
|
self.torchair_compiled_models = {} # type: ignore
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
|
||||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||||
@@ -713,7 +716,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
model_kwargs["kv_caches"] = self.kv_caches
|
model_kwargs["kv_caches"] = self.kv_caches
|
||||||
model_kwargs["attn_metadata"] = attn_metadata
|
model_kwargs["attn_metadata"] = attn_metadata
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
if self.torchair_graph_enabled and not with_prefill:
|
||||||
hidden_states = self.compile_model(
|
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||||
|
padded_batch_size)
|
||||||
|
hidden_states = compiled_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
@@ -1190,7 +1195,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
kv, tuple), "kv_cache must be a tuple"
|
kv, tuple), "kv_cache must be a tuple"
|
||||||
torch._dynamo.mark_static(kv[0])
|
torch._dynamo.mark_static(kv[0])
|
||||||
torch._dynamo.mark_static(kv[1])
|
torch._dynamo.mark_static(kv[1])
|
||||||
hidden_states = self.compile_model(
|
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||||
|
num_tokens)
|
||||||
|
hidden_states = compiled_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
@@ -1264,30 +1271,59 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logger.info("Loading model weights took %.4f GB",
|
logger.info("Loading model weights took %.4f GB",
|
||||||
m.consumed_memory / float(2**30))
|
m.consumed_memory / float(2**30))
|
||||||
|
|
||||||
# adapter torch compile with npu_backend
|
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||||
if self.torchair_graph_enabled:
|
if batch_size < 0 or batch_size > self.max_num_reqs:
|
||||||
import torchair # type: ignore
|
raise ValueError(
|
||||||
from torchair import patch_for_hcom # type: ignore
|
f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
|
||||||
|
)
|
||||||
|
|
||||||
patch_for_hcom()
|
compiled_model = self.torchair_compiled_models.get(
|
||||||
config = torchair.CompilerConfig()
|
batch_size
|
||||||
config.experimental_config.frozen_parameter = True
|
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||||
config.experimental_config.tiling_schedule_optimize = True
|
|
||||||
torch.npu.set_compile_mode(jit_compile=False)
|
if compiled_model:
|
||||||
if not self.use_cached_npu_graph:
|
return compiled_model
|
||||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
|
||||||
self.compile_model = torch.compile(
|
import torchair # type: ignore
|
||||||
self.model,
|
from torchair import patch_for_hcom # type: ignore
|
||||||
dynamic=True,
|
|
||||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
patch_for_hcom()
|
||||||
backend=npu_backend)
|
config = torchair.CompilerConfig()
|
||||||
else:
|
config.experimental_config.frozen_parameter = True
|
||||||
self.compile_model = torchair.inference.cache_compile(
|
config.experimental_config.tiling_schedule_optimize = True
|
||||||
self.model.forward,
|
torch.npu.set_compile_mode(jit_compile=False)
|
||||||
|
if not self.use_cached_npu_graph:
|
||||||
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||||
|
self.torchair_compiled_model = torch.compile(
|
||||||
|
self.model,
|
||||||
|
dynamic=True,
|
||||||
|
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
backend=npu_backend)
|
||||||
|
return self.torchair_compiled_model
|
||||||
|
else:
|
||||||
|
# Generate a new forward proxy code object to prevent the invalidation of
|
||||||
|
# compilation cache caused by dynamo retracing
|
||||||
|
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||||
|
forward_fn = self.model.forward
|
||||||
|
code = forward_fn.__code__
|
||||||
|
# Mark code object with a new proxy name
|
||||||
|
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||||
|
|
||||||
|
modified_func = types.FunctionType(modified_code,
|
||||||
|
forward_fn.__globals__,
|
||||||
|
name=forward_proxy_name,
|
||||||
|
argdefs=forward_fn.__defaults__)
|
||||||
|
|
||||||
|
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||||
|
self.model, nn.Module)
|
||||||
|
self.torchair_compiled_models[
|
||||||
|
batch_size] = torchair.inference.cache_compile(
|
||||||
|
self.model.__dict__[forward_proxy_name],
|
||||||
dynamic=True,
|
dynamic=True,
|
||||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
config=config,
|
config=config,
|
||||||
ge_cache=False)
|
ge_cache=False)
|
||||||
|
return self.torchair_compiled_models[batch_size]
|
||||||
|
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user