feat: support torchair graph mode in v1 engine (#789)

### What this PR does / why we need it?
support torchair graph mode with v1 engine

---------

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-05-12 19:14:07 +08:00
committed by GitHub
parent 4a2505f81f
commit efabd722eb
5 changed files with 585 additions and 82 deletions

View File

@@ -63,6 +63,8 @@ if TYPE_CHECKING:
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
import vllm.envs as envs
@dataclass
class GraphCaptureContext:
@@ -117,6 +119,12 @@ class NPUModelRunner:
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.graph_block_tables = np.zeros(
(self.vllm_config.scheduler_config.max_num_seqs,
(self.model_config.max_model_len + self.block_size - 1) //
self.block_size),
dtype=np.int32)
# Model-related.
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
vllm_config.parallel_config, LayerBlockType.attention)
@@ -307,6 +315,15 @@ class NPUModelRunner:
self.attn_mask_len, self.dtype)
self.sampler = Sampler()
self.enable_torchair_graph_mode = False
self.use_cached_npu_graph = False
additional_config = vllm_config.additional_config
if additional_config:
self.enable_torchair_graph_mode = additional_config.get(
"enable_graph_mode",
False) and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = additional_config.get(
"use_cached_npu_graph", False)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
@@ -563,11 +580,19 @@ class NPUModelRunner:
self.attn_mask = attn_mask
self.attn_state = attn_state # type: ignore
extra_builder_kwargs = {}
# Add graph_pad_size here
if self.enable_torchair_graph_mode:
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=None,
**extra_builder_kwargs,
)
# Prepare input_ids
@@ -582,15 +607,45 @@ class NPUModelRunner:
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
input_ids = self.input_ids[:total_num_scheduled_tokens]
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
padding = torch.zeros(graph_pad_size,
dtype=input_ids.dtype,
device=input_ids.device)
input_ids = torch.cat([input_ids, padding])
positions = torch.cat([positions, padding])
# Run forward pass
with set_forward_context(attn_metadata, self.vllm_config):
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
)
model_kwargs = {}
if self.enable_torchair_graph_mode:
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
else:
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
return hidden_states[sample_indices]
@@ -879,6 +934,31 @@ class NPUModelRunner:
logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30))
# adapter torch compile with npu_backend
if self.enable_torchair_graph_mode:
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.compile_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
else:
self.compile_model = torchair.inference.cache_compile(
self.model.forward,
dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@@ -909,10 +989,29 @@ class NPUModelRunner:
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
if self.enable_torchair_graph_mode:
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ),
dtype=self.dtype,
pin_memory=True,
device=self.device)
layer_kv_cache_pe = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.qk_rope_head_dim,
),
dtype=self.dtype,
pin_memory=True,
device=self.device)
kv_caches[layer_name] = (layer_kv_cache_nope,
layer_kv_cache_pe)
torch_npu.npu_format_cast(kv_caches[layer_name][0], 2)
torch_npu.npu_format_cast(kv_caches[layer_name][1], 2)
else:
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.