[Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -23,8 +23,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
@@ -43,7 +43,7 @@ def get_attn_mask_builder(device: torch.device):
|
||||
|
||||
def build_attn_metadata(
|
||||
*,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
query_start_loc_gpu: torch.Tensor,
|
||||
@@ -54,6 +54,7 @@ def build_attn_metadata(
|
||||
block_tables: Sequence[torch.Tensor],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
dcp_local_seq_lens: torch.Tensor | None = None,
|
||||
# extra attributes for ascend npus.
|
||||
seq_lens_np: np.ndarray | None = None,
|
||||
num_computed_tokens_cpu: torch.Tensor | None = None,
|
||||
@@ -72,9 +73,6 @@ def build_attn_metadata(
|
||||
if seq_lens_np is None:
|
||||
seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32)
|
||||
seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs]
|
||||
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
||||
# be torch.int32.
|
||||
slot_mappings = slot_mappings.to(torch.int32)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
@@ -100,13 +98,14 @@ def build_attn_metadata(
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata, # type: ignore
|
||||
)
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
for attn_group in attn_groups[i]:
|
||||
attn_metadata_builder = attn_group.get_metadata_builder(0)
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user