feat: support compile torchair graph while warming up (#839)

### What this PR does / why we need it?
feat: support compile torchair graph while warming up

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-05-31 06:03:03 +08:00
committed by GitHub
parent d9fb027068
commit 507ae627ca
7 changed files with 242 additions and 234 deletions

View File

@@ -28,10 +28,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
import numpy.typing as npt
import torch
import torch._dynamo.cache_size
import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
@@ -70,7 +72,9 @@ if TYPE_CHECKING:
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
import vllm.envs as envs
import vllm.envs as envs_vllm
import vllm_ascend.envs as envs_ascend
@dataclass
@@ -321,6 +325,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.sampler = Sampler()
self.enable_torchair_graph_mode = False
self.use_cached_npu_graph = False
self.torchair_graph_batch_sizes = []
additional_config = vllm_config.additional_config
if additional_config:
self.enable_torchair_graph_mode = additional_config.get(
@@ -328,6 +333,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
False) and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = additional_config.get(
"use_cached_npu_graph", False)
self.torchair_graph_batch_sizes = additional_config.get(
"torchair_graph_batch_sizes", [])
if not isinstance(self.torchair_graph_batch_sizes, list):
logger.warning("torchair_graph_batch_sizes must be list[int]")
self.torchair_graph_batch_sizes = []
if len(self.torchair_graph_batch_sizes
) == 0 and additional_config.get(
"torchair_graph_batch_sizes_init", False):
self.init_torchair_graph_batch_sizes()
if len(self.torchair_graph_batch_sizes) == 0:
#If MC2 is enabled, torchair_graph_batch_size should pad to tp_size
if envs_ascend.VLLM_ENABLE_MC2:
self.torchair_graph_batch_sizes = [
self.scheduler_config.max_num_seqs
]
else:
self.torchair_graph_batch_sizes = [
1, self.scheduler_config.max_num_seqs
]
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
@@ -618,7 +649,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
query_start_loc=query_start_loc, seq_lens=seq_lens)
# Add graph_pad_size here
if self.enable_torchair_graph_mode:
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
batchsize = len(seq_lens)
padded_batch_size = self.select_torchair_padded_batchsize(
batchsize)
graph_pad_size = padded_batch_size - batchsize
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
if self.vllm_config.model_config.use_mla:
@@ -653,11 +687,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
input_ids = self.input_ids[:num_input_tokens]
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
padding = torch.zeros(graph_pad_size,
dtype=input_ids.dtype,
device=input_ids.device)
input_ids = torch.cat([input_ids, padding])
positions = torch.cat([positions, padding])
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]
# Run forward pass
with set_forward_context(attn_metadata,
@@ -668,15 +699,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model(
input_ids=input_ids,
positions=positions,
@@ -1068,7 +1090,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
@torch.inference_mode()
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
def _dummy_run(
self,
num_tokens: int,
is_compile: bool = False,
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill,
) -> torch.Tensor:
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
@@ -1112,12 +1139,38 @@ class NPUModelRunner(LoRAModelRunnerMixin):
})
with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
# Only mark static while compiling
if is_compile:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(
attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
assert isinstance(
kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
)
else:
hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
@@ -1192,13 +1245,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.compile_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
else:
self.compile_model = torchair.inference.cache_compile(
self.model.forward,
dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
@@ -1316,25 +1369,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec
def capture_model(self) -> None:
if not self.use_aclgraph:
logger.warning(
"Skipping NPU graph capture. Please add "
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0]
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
for num_tokens in reversed(self.aclgraph_batch_sizes):
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
# torchair graph capture can cause some issues, so now we just
# temporarily split the codepath for the two different graph patterns.
if self.enable_torchair_graph_mode:
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
graph_num = len(torchair_graph_batch_sizes)
logger.info(
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
attn_state = AscendAttentionState.DecodeOnly
# Trigger torchair graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for idx, num_tokens in enumerate(
reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
is_compile=True,
attn_state=attn_state)
self._dummy_run(num_tokens,
is_compile=True,
attn_state=attn_state)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, graph_num)
elif self.use_aclgraph:
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
else:
logger.warning(
"Skipping NPU graph capture. Please add -O %s to use ACL graphs. "
"Or add --additional_config={'enable_graph_mode': True} to use torchair graphs",
CompilationLevel.PIECEWISE)
return
end_time = time.perf_counter()
end_free_npu_memory = torch.npu.mem_get_info()[0]
elapsed_time = end_time - start_time
@@ -1443,4 +1520,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
return spec_token_ids
def init_torchair_graph_batch_sizes(self):
tp_size = get_tensor_model_parallel_world_size()
batch_size_step = 8
largest_batch_size = 1
if envs_ascend.VLLM_ENABLE_MC2:
batch_size_step = max(batch_size_step, tp_size)
largest_batch_size = batch_size_step
while (largest_batch_size < 8):
self.torchair_graph_batch_sizes.append(largest_batch_size)
largest_batch_size *= 2
while (largest_batch_size <= self.scheduler_config.max_num_seqs):
self.torchair_graph_batch_sizes.append(largest_batch_size)
largest_batch_size += batch_size_step
def select_torchair_padded_batchsize(self, batchsize: int):
selected_batchsize = self.max_num_reqs
for padded_batchsize in self.torchair_graph_batch_sizes:
if batchsize <= padded_batchsize < selected_batchsize:
selected_batchsize = padded_batchsize
return selected_batchsize