[Feature] implement eagle spec decoding for model runner v2 (#5840)
### What this PR does / why we need it? this pr implement eagle spec decoding for model runner v2, please see RFC https://github.com/vllm-project/vllm-ascend/issues/5208 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: v0.13.0 --------- Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
#
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -50,13 +50,11 @@ def build_attn_metadata(
|
||||
query_start_loc_gpu: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
num_computed_tokens_cpu: torch.Tensor,
|
||||
seq_lens_np: np.ndarray,
|
||||
num_computed_tokens_cpu: torch.Tensor | None,
|
||||
block_tables: Sequence[torch.Tensor],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
decode_token_per_req: int,
|
||||
actual_seq_lengths_q: list[int],
|
||||
positions: torch.Tensor | None = None,
|
||||
attn_state: Any | None = None,
|
||||
graph_pad_size: int = -1,
|
||||
@@ -67,7 +65,11 @@ def build_attn_metadata(
|
||||
"""Build attention metadata for Ascend NPUs."""
|
||||
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
|
||||
max_query_len = int(query_start_loc_cpu.max())
|
||||
seq_lens_cpu = torch.from_numpy(seq_lens_np)
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
||||
# be torch.int32.
|
||||
slot_mappings = slot_mappings.to(torch.int32)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
@@ -80,14 +82,11 @@ def build_attn_metadata(
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens_cpu=seq_lens_cpu[:num_reqs],
|
||||
seq_lens=seq_lens[:num_reqs],
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
decode_token_per_req=decode_token_per_req,
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
positions=positions,
|
||||
attn_state=attn_state,
|
||||
graph_pad_size=graph_pad_size,
|
||||
|
||||
Reference in New Issue
Block a user