[FEAT] Refactor spec decode to support efficient padded speculation (#3528)

### What this PR does / why we need it?
1. Refactor the file `mtp_proposer.py`, splits torchair related codes
into `mtp_torchair_proposer.py`
2. According to https://github.com/vllm-project/vllm/pull/24539,
implements padded speculative decoding as described in
https://github.com/vllm-project/vllm/issues/21984.
### Does this PR introduce _any_ user-facing change?
User can use `disable_padded_drafter_batch` to disable/enable padded
speculation, default is `False`.
offline example:
```
speculative_config={"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False}
```

### How was this patch tested?

- [x] egaer with pad/unpad:
- [x] aclgraph with pad/unpad
- [x] torchair with pad/unpad

performance test of deepseek-r1 with tp16、dp1
aclgraph with pad ITL: 168ms
aclgraph with unpad ITL: 169ms
original: 178ms


- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19

---------

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-10-30 16:53:05 +08:00
committed by GitHub
parent 10772d94e3
commit eff3e5fc6f
7 changed files with 1203 additions and 440 deletions

View File

@@ -1,11 +1,15 @@
from __future__ import annotations from __future__ import annotations
import os
import pytest import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode from vllm.config import CompilationConfig, CUDAGraphMode
from tests.e2e.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@pytest.fixture @pytest.fixture
def sampling_config(): def sampling_config():
@@ -17,12 +21,12 @@ def model_name():
return "wemaster/deepseek_mtp_main_random_bf16" return "wemaster/deepseek_mtp_main_random_bf16"
def mtp_correctness( def mtp_correctness(sampling_config: SamplingParams,
sampling_config: SamplingParams,
model_name: str, model_name: str,
num_speculative_tokens: int, num_speculative_tokens: int,
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE, graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
): enforce_eager=False,
disable_padded_drafter_batch=True):
example_prompts = [ example_prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
@@ -37,7 +41,7 @@ def mtp_correctness(
tensor_parallel_size=1, tensor_parallel_size=1,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
max_model_len=256, max_model_len=256,
enforce_eager=False) as ref_llm: enforce_eager=enforce_eager) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config) ref_outputs = ref_llm.generate(example_prompts, sampling_config)
graph_mode_str = "PIECEWISE" graph_mode_str = "PIECEWISE"
@@ -54,8 +58,9 @@ def mtp_correctness(
speculative_config={ speculative_config={
"method": "deepseek_mtp", "method": "deepseek_mtp",
"num_speculative_tokens": num_speculative_tokens, "num_speculative_tokens": num_speculative_tokens,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
}, },
enforce_eager=False, enforce_eager=enforce_eager,
max_model_len=2000, max_model_len=2000,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
cudagraph_mode=graph_mode_str), cudagraph_mode=graph_mode_str),
@@ -82,6 +87,20 @@ def mtp_correctness(
del spec_llm del spec_llm
def test_mtp1_correctness_eager(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1, enforce_eager=True)
def test_mtp2_correctness_eager(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2, enforce_eager=True)
@pytest.mark.skip("TODO(cmq): Revert me when mtp aclgraph is fixed") @pytest.mark.skip("TODO(cmq): Revert me when mtp aclgraph is fixed")
def test_mtp1_correctness_piecewise_graph( def test_mtp1_correctness_piecewise_graph(
sampling_config: SamplingParams, sampling_config: SamplingParams,
@@ -110,3 +129,47 @@ def test_mtp2_correctness_full_graph(
model_name: str, model_name: str,
): ):
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
def test_mtp1_correctness_eager_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
1,
enforce_eager=True,
disable_padded_drafter_batch=False)
def test_mtp2_correctness_eager_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
2,
enforce_eager=True,
disable_padded_drafter_batch=False)
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
def test_mtp1_correctness_piecewise_graph_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
1,
disable_padded_drafter_batch=False)
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
def test_mtp2_correctness_piecewise_graph_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
2,
disable_padded_drafter_batch=False)

View File

@@ -19,14 +19,21 @@
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
def get_spec_decode_method(method, vllm_config, device, runner): def get_spec_decode_method(method,
vllm_config,
device,
runner,
is_torchair_graph=False):
if method == "ngram": if method == "ngram":
return NgramProposer(vllm_config, device, runner) return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]: elif method in ["eagle", "eagle3"]:
return EagleProposer(vllm_config, device, runner) return EagleProposer(vllm_config, device, runner)
elif method == 'deepseek_mtp': elif method == 'deepseek_mtp':
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner) return MtpProposer(vllm_config, device, runner)
else: else:
raise ValueError("Unknown speculative decoding method: " raise ValueError("Unknown speculative decoding method: "

View File

@@ -1,37 +1,41 @@
import types from typing import Optional
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchair
from torchair import patch_for_hcom
from vllm.config import (CUDAGraphMode, VllmConfig, from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config) get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import \ from vllm.model_executor.model_loader.utils import \
process_weights_after_loading process_weights_after_loading
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
TorchairDeepSeekMTP
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
TorchairCommonAttentionMetadata)
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
vllm_version_is) vllm_version_is)
if vllm_version_is("0.11.0"): if vllm_version_is("0.11.0"):
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.utils import is_pin_memory_available
else: else:
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import set_default_torch_dtype from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
PADDING_SLOT_ID = -1 PADDING_SLOT_ID = -1
@@ -45,34 +49,77 @@ class MtpProposer(Proposer):
): ):
self.name = SpecDcodeType.MTP self.name = SpecDcodeType.MTP
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.device = device self.speculative_config = vllm_config.speculative_config
self.runner = runner assert self.speculative_config is not None
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
# persistent buffers for graph self.runner = runner
self.input_ids = torch.zeros(self.runner.max_num_tokens, self.device = device
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.draft_indexer_metadata_builder: Optional[
AttentionMetadataBuilder] = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.use_aclgraph = self.runner._use_aclgraph()
self.cudagraph_batch_sizes = (list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
if self.use_aclgraph else [])
# persistent buffers for aclgraph graph
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=device)
self.positions = torch.zeros(self.runner.max_num_tokens, self.uses_mrope = self.vllm_config.model_config.uses_mrope
if self.uses_mrope:
# M-RoPE need (3, max_num_tokens)
self.mrope_positions = torch.zeros((3, self.max_num_tokens),
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=device)
else:
# RoPE need (max_num_tokens,)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros( self.hidden_states = torch.zeros(
(self.runner.max_num_tokens, (self.max_num_tokens, self.hidden_size),
vllm_config.model_config.get_hidden_size()), dtype=self.dtype,
dtype=self.runner.dtype, device=device)
device=self.device)
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled
self.enable_shared_expert_dp = get_ascend_config(
).enable_shared_expert_dp
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + max_batch_size = vllm_config.scheduler_config.max_num_seqs
1, max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
device=self.runner.device, self.arange = torch.arange(max_num_slots_for_arange,
device=device,
dtype=torch.int32) dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
self.use_sparse = hasattr(vllm_config.model_config.hf_config, self.use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk") "index_topk")
@@ -89,12 +136,6 @@ class MtpProposer(Proposer):
with set_default_torch_dtype( with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config( draft_model_config.dtype), set_current_vllm_config(
self.vllm_config): self.vllm_config):
if self.torchair_graph_enabled or (
self.enable_shared_expert_dp
and self.vllm_config.model_config.use_mla):
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
else:
self.model = DeepSeekMTP( self.model = DeepSeekMTP(
vllm_config=self.vllm_config).to(target_device) vllm_config=self.vllm_config).to(target_device)
@@ -121,7 +162,7 @@ class MtpProposer(Proposer):
num_tokens_across_dp=None, num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None: batch_descriptor=None) -> None:
if not self.torchair_graph_enabled:
( (
num_tokens, num_tokens,
num_tokens_across_dp, num_tokens_across_dp,
@@ -131,24 +172,7 @@ class MtpProposer(Proposer):
moe_comm_type = self.runner._select_moe_comm_method( moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill) num_tokens, with_prefill)
is_running_torchair = self.torchair_graph_enabled and \
not with_prefill
if is_running_torchair:
skip_attn = False
if skip_attn:
attn_metadata = None attn_metadata = None
else:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,
num_actual_tokens=1,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
decode_token_per_req=self.runner.decode_token_per_req,
)
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
common_attn_metadata)
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
@@ -166,32 +190,6 @@ class MtpProposer(Proposer):
num_actual_tokens=0, num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor): batch_descriptor=batch_descriptor):
if is_running_torchair:
assert attn_metadata is not None
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(previous_hidden_states)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
torchair_compiled_model(
input_ids=input_ids,
positions=positions,
hidden_states=previous_hidden_states,
inputs_embeds=None,
intermediate_tensors=None,
attn_metadata=attn_metadata,
kv_caches=self.runner.kv_caches[-1:],
spec_step_idx=0)
else:
self.model(input_ids=input_ids, self.model(input_ids=input_ids,
positions=positions, positions=positions,
hidden_states=previous_hidden_states) hidden_states=previous_hidden_states)
@@ -199,7 +197,7 @@ class MtpProposer(Proposer):
break break
def generate_token_ids(self, def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]], sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata = None, sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None, scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None, spec_decode_metadata: SpecDecodeMetadata = None,
@@ -208,235 +206,240 @@ class MtpProposer(Proposer):
hidden_states: torch.Tensor = None, hidden_states: torch.Tensor = None,
attn_metadata=None, attn_metadata=None,
aux_hidden_states: torch.Tensor = None): aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
if attn_metadata is not None and isinstance(attn_metadata, dict): if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids): if self.speculative_config.disable_padded_drafter_batch:
if token_ids: # When padded-batch is disabled, the sampled_token_ids should be
# Common case. # the cpu-side list[list[int]] of valid sampled tokens for each
next_token_id = token_ids[-1] # request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
next_token_ids = self.prepare_next_token_ids_cpu(
sampled_token_ids, self.runner.requests,
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
else: else:
# Partial prefill (rare case). # When using padded-batch, the sampled_token_ids should be
# Get the next token id from the request state. # the gpu tensor of sampled tokens for each request, of shape
req_id = self.runner.input_batch.req_ids[i] # (num_reqs, num_spec_tokens + 1) with rejected tokens having
req_state = self.runner.requests[req_id] # value -1.
seq_len = (req_state.num_computed_tokens + assert isinstance(sampled_token_ids, torch.Tensor), \
scheduler_output.num_scheduled_tokens[req_id]) "sampled_token_ids should be a torch.Tensor when" \
next_token_id = req_state.get_token_id(seq_len) "padded-batch is enabled."
next_token_ids.append(next_token_id) next_token_ids, valid_sampled_tokens_count = \
next_token_ids = torch.tensor(next_token_ids, self.prepare_next_token_ids_padded(
dtype=torch.int32, common_attn_metadata,
device=self.device) sampled_token_ids,
accepted_token_indices = None self.runner.requests,
self.runner.input_batch,
self.runner.discard_request_indices.gpu,
self.runner.num_discarded_requests
)
if spec_decode_metadata is None: if spec_decode_metadata is None:
token_indices_to_sample = None
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids[:num_scheduled_tokens] target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else: else:
# TODO(woosuk): Refactor this. if self.speculative_config.disable_padded_drafter_batch:
num_draft_tokens = spec_decode_metadata.num_draft_tokens token_indices_to_sample = None
num_rejected_tokens = [ common_attn_metadata, token_indices =\
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 self._prepare_inputs(
for i, n in enumerate(num_draft_tokens) common_attn_metadata,
] sampled_token_ids,
num_rejected_tokens = torch.tensor( spec_decode_metadata.num_draft_tokens)
num_rejected_tokens, else:
dtype=torch.int32, common_attn_metadata, token_indices, \
device=self.device, token_indices_to_sample =\
) self.prepare_inputs_padded(
cu_num_tokens, accepted_token_indices, target_token_ids, \ common_attn_metadata,
target_positions, target_hidden_states, target_slot_mapping = self._prepare_inputs( spec_decode_metadata,
attn_metadata.query_start_loc, valid_sampled_tokens_count)
num_rejected_tokens, target_token_ids = self.runner.input_ids[token_indices]
self.runner.input_ids[:num_scheduled_tokens], target_positions = positions[token_indices]
positions[:num_scheduled_tokens], target_hidden_states = hidden_states[token_indices]
hidden_states[:num_scheduled_tokens],
attn_metadata.slot_mapping[:num_scheduled_tokens],
is_torchair_graph=self.runner._build_drafter_prepare_inputs_torchair_param(),
)
draft_token_ids = self._propose( draft_token_ids = self._propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens, last_token_indices=token_indices_to_sample,
block_table=attn_metadata.block_tables, common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
token_indices=accepted_token_indices) )
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids return draft_token_ids
def _prepare_inputs( def _prepare_inputs(
self, self,
# [batch_size + 1] common_attn_metadata: CommonAttentionMetadata,
cu_target_query_lens: torch.Tensor, sampled_token_ids: list[list[int]],
# [batch_size] num_draft_tokens: list[int],
num_rejected_tokens: torch.Tensor, ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
token_ids: torch.Tensor, """
positions: torch.Tensor, This function is used to prepare the inputs for speculative decoding.
hidden_states: torch.Tensor, It updates to the common_attn_metadata to account for the rejected
slot_mapping: torch.Tensor, tokens (and newly sampled tokens). It also returns the token indices
is_torchair_graph: bool = False of the tokens that should be fed to the speculator.
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, """
torch.Tensor, torch.Tensor]: # E.g.
# cu_target_query_lens: [0, a, a + b, a + b + c] # common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3] # num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3] # This function computes the intermediate values:
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# token_indices: [0, 1, ..., a - n1 - 1, # And returns:
# a, a + 1, ..., a + b - n2 - 1, # common_attn_metadata.query_start_loc{_cpu}:
# a + b, a + b + 1, ..., a + b + c - n3 - 1] # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# [0, a, a + b, a + b + c] -> [a, b, c] # common_attn_metadata.seq_lens{_cpu}:
query_len_per_req = (cu_target_query_lens[1:] - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
cu_target_query_lens[:-1]) # token_indices: [0, 1, ..., q1 - n1 - 1,
# [a, b, c] -> [a - n1, b - n2, c - n3] # q1, q1 + 1, ..., q1 + q2 - n2 - 1,
num_tokens_per_req = query_len_per_req - num_rejected_tokens # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
if is_torchair_graph:
cu_num_tokens = cu_target_query_lens
relative_index = query_len_per_req - num_rejected_tokens - 1
token_indices = cu_num_tokens[:-1] + relative_index
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
target_token_ids = token_ids
target_positions = positions
target_hidden_states = hidden_states
target_slot_mapping = slot_mapping
else:
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization. num_rejected_tokens = [
num_tokens = cu_num_tokens[-1].item() n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
token_indices = torch.zeros( for i, n in enumerate(num_draft_tokens)
num_tokens, ]
num_rejected_tokens = torch.tensor(num_rejected_tokens,
dtype=torch.int32)
device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = query_start_loc_cpu[
1:] - query_start_loc_cpu[:-1]
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(
query_start_loc_cpu.shape,
dtype=torch.int32, dtype=torch.int32,
device=cu_num_tokens.device, pin_memory=is_pin_memory_available(),
) )
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
BLOCK_SIZE = 1024 total_num_tokens = new_query_start_loc_np[-1]
self._prepare_input_kernel( # Example assuming num_tokens_per_req_np = [2, 4, 3]
token_indices, # this implies that `new_query_start_locs` is:
cu_target_query_lens, # [0, 2, 6, 9] ->
cu_num_tokens, # [0, 0, 2, 2, 2, 2, 6, 6, 6]
block_size=BLOCK_SIZE, # _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = (self.token_arange_np[:total_num_tokens] -
new_query_start_locs_expanded)
# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True)
spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu.to(device,
non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
seq_lens_cpu=new_seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
positions=common_attn_metadata.positions[token_indices],
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
graph_pad_size=self.runner.graph_pad_size,
decode_token_per_req=self.runner.decode_token_per_req,
) )
target_token_ids = token_ids[token_indices] return spec_common_attn_metadata, token_indices
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = slot_mapping[token_indices]
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
def _propose( def _propose(
self, self,
# [num_tokens] # [num_tokens]
target_token_ids: torch.Tensor, target_token_ids: torch.Tensor,
# [num_tokens] # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
target_positions: torch.Tensor, target_positions: torch.Tensor,
# [num_tokens, hidden_size] # [num_tokens, hidden_size]
target_hidden_states: torch.Tensor, target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size] # [batch_size]
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0 last_token_indices: Optional[torch.Tensor],
cu_num_tokens: torch.Tensor, common_attn_metadata: CommonAttentionMetadata,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
token_indices=None) -> torch.Tensor: mm_embed_inputs: Optional[tuple[list[torch.Tensor],
torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token. # Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:] self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token. # Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
if token_indices is not None and self.torchair_graph_enabled:
last_token_indices = token_indices
self.input_ids[last_token_indices] = next_token_ids self.input_ids[last_token_indices] = next_token_ids
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] assert self.runner is not None
max_query_len = query_lens.max().item()
# FIXME: reorder_batch() needs to be called before build() builder = self.runner.attn_groups[0][0].get_metadata_builder()
# because fields of attn_metadata_builder needs to be updated. attn_metadata_mtp = builder.build(0, common_attn_metadata,
# However, currently reorder_batch() takes input_batch and self.runner.get_model())
# scheduler_output as arguments, we should probably refactor attn_metadata = {}
# the method to use new data structures which are independent for layer_name in self.attn_layer_name:
# from input_batch and scheduler_output. attn_metadata[layer_name] = attn_metadata_mtp
# self.runner.attn_metadata_builder.reorder_batch(
# input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output,
# )
is_running_torchair = self.torchair_graph_enabled and \
not self.runner.with_prefill
if is_running_torchair: if self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
# Torchair graph mode, padding is same as the main model
num_input_tokens = self.runner.graph_pad_size
elif (self.runner.use_aclgraph
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
# Acl graph mode, add padding to the batch size # Acl graph mode, add padding to the batch size
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else: else:
# Eager mode, no padding needed # Eager mode, no padding needed
num_input_tokens = num_tokens num_input_tokens = num_tokens
seq_lens = target_positions[last_token_indices] + 1 # copy inputs to buffer for cudagraph
seq_lens = seq_lens.int()
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=cu_num_tokens[:batch_size + 1],
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
seq_lens_cpu=seq_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
graph_pad_size=self.runner.graph_pad_size,
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=None,
seq_lens=None)
if not self.torchair_graph_enabled:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_mtp = builder.build(0, common_attn_metadata,
self.runner.get_model())
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
else:
attn_metadata = self.runner.attn_metadata_builder.build(
0, common_attn_metadata, self.runner.get_model())
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
# eager/acl piecewise mode need to update num_tokens_across_dp
if not self.torchair_graph_enabled:
# torch mode need to update num_tokens_across_dp
(num_input_tokens, num_tokens_across_dp, (num_input_tokens, num_tokens_across_dp,
with_prefill) = self.runner._sync_metadata_across_dp( with_prefill) = self.runner._sync_metadata_across_dp(
num_input_tokens, self.runner.with_prefill) num_input_tokens, self.runner.with_prefill)
else:
# torchair mode can reuse self.runner.num_tokens_across_dp
num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill
moe_comm_type = self.runner._select_moe_comm_method( moe_comm_type = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill) num_input_tokens, with_prefill)
@@ -444,6 +447,15 @@ class MtpProposer(Proposer):
uniform_decode=False) uniform_decode=False)
aclgraph_runtime_mode, batch_descriptor = \ aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor) self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
if aclgraph_runtime_mode not in [
CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE
]:
# Fallback to piecewise graph, when acl full graph is enabled
logger.debug(
"Currently the eagle proposer only supports cudagraph_mode "
f"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
"to CUDAGraphMode.PIECEWISE")
aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE
for step in range(self.num_speculative_tokens): for step in range(self.num_speculative_tokens):
with set_ascend_forward_context( with set_ascend_forward_context(
@@ -461,26 +473,11 @@ class MtpProposer(Proposer):
with ProfileExecuteDuration().capture_async('mtp_forward'): with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {} model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if is_running_torchair:
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_input_tokens)
hidden_states = torchair_compiled_model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.
hidden_states[:num_input_tokens],
inputs_embeds=None,
intermediate_tensors=None,
spec_step_idx=0,
**model_kwargs)
else:
hidden_states = self.model( hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens], input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens] hidden_states=self.hidden_states[:num_input_tokens])
)
num_indices = last_token_indices.shape[0] num_indices = last_token_indices.shape[0]
if lmhead_tp_enable(): if lmhead_tp_enable():
@@ -515,10 +512,7 @@ class MtpProposer(Proposer):
if step == self.num_speculative_tokens - 1 or with_prefill: if step == self.num_speculative_tokens - 1 or with_prefill:
break break
if not self.torchair_graph_enabled:
attn_metadata_i = attn_metadata[self.attn_layer_name[0]] attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
else:
attn_metadata_i = attn_metadata
if step == 0: if step == 0:
positions = target_positions[last_token_indices] positions = target_positions[last_token_indices]
@@ -529,21 +523,16 @@ class MtpProposer(Proposer):
last_token_indices = self.arange[:batch_size] last_token_indices = self.arange[:batch_size]
if attn_metadata_i.num_decode_tokens != 0: if attn_metadata_i.num_decode_tokens != 0:
attn_metadata_i.num_decode_tokens = batch_size attn_metadata_i.num_decode_tokens = batch_size
if is_running_torchair:
attn_metadata_i.num_actual_tokens = batch_size
attn_metadata_i.query_lens = [1] * batch_size
input_ids = draft_token_ids_list[-1].int() input_ids = draft_token_ids_list[-1].int()
positions += 1 positions += 1
if not self.torchair_graph_enabled:
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
1:batch_size + 1].tolist() 1:batch_size + 1].tolist()
attn_metadata_i.decode.cos = builder.cos_cache[ attn_metadata_i.decode.cos = builder.cos_cache[
positions].unsqueeze(1).unsqueeze(2) positions].unsqueeze(1).unsqueeze(2)
attn_metadata_i.decode.sin = builder.sin_cache[ attn_metadata_i.decode.sin = builder.sin_cache[
positions].unsqueeze(1).unsqueeze(2) positions].unsqueeze(1).unsqueeze(2)
# NOTE(woosuk): We should handle the case where the draft model # NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex # generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch # to remove such requests from the batch, we keep them in the batch
@@ -601,61 +590,6 @@ class MtpProposer(Proposer):
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids return draft_token_ids
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
patch_for_hcom()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.runner.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=not self.use_sparse,
fullgraph=True,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=not self.use_sparse,
fullgraph=True,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
# TODO Using torch instead of triton may result in poor performance # TODO Using torch instead of triton may result in poor performance
def _prepare_input_kernel(self, out_ptr: torch.Tensor, def _prepare_input_kernel(self, out_ptr: torch.Tensor,
cu_query_lens: torch.Tensor, cu_query_lens: torch.Tensor,
@@ -676,3 +610,160 @@ class MtpProposer(Proposer):
global_indices_flat = global_indices[mask] global_indices_flat = global_indices[mask]
values_flat = values[mask] values_flat = values[mask]
out_ptr[global_indices_flat] = values_flat out_ptr[global_indices_flat] = values_flat
def prepare_next_token_ids_cpu(
self,
sampled_token_ids: list[list[int]],
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
num_scheduled_tokens: dict[str, int],
) -> torch.Tensor:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids for each request based on the sampled
token ids from the CPU. If a request has no sampled token ids (e.g.,
during the initial decoding steps), it falls back to using the request
state to get the next token id.
"""
req_ids = gpu_input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = req_ids[i]
req_state = requests[req_id]
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
req_id]
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.input_ids.device)
return next_token_ids
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array([
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item())
for i in range(num_reqs)
])
self.backup_next_token_ids.copy_to_gpu(num_reqs)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = discard_request_indices[:
num_discarded_requests]
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
valid_sampled_token_ids_gpu.index_fill_(
0, discard_sampled_tokens_req_indices, -1)
# Generate a mask for all valid tokens within those requests
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1,
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1,
selected_tokens,
self.backup_next_token_ids.gpu[:batch_size],
)
return next_token_ids, valid_sampled_tokens_count
def prepare_inputs_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_draft_tokens_gpu = torch.cat([
spec_decode_metadata.cu_num_draft_tokens[0:1],
spec_decode_metadata.cu_num_draft_tokens[1:] -
spec_decode_metadata.cu_num_draft_tokens[:-1],
])
num_rejected_tokens_gpu = torch.where(
num_draft_tokens_gpu > 0,
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
torch.zeros_like(num_draft_tokens_gpu),
)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_query_len_per_req = query_start_loc_cpu[
1:] - query_start_loc_cpu[:-1]
total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens]
spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=common_attn_metadata.seq_lens,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
positions=common_attn_metadata.positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
graph_pad_size=self.runner.graph_pad_size,
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
seq_lens=common_attn_metadata.seq_lens)
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
1 - num_rejected_tokens_gpu)
return spec_common_attn_metadata, token_indices, token_indices_to_sample

View File

@@ -34,6 +34,7 @@ from vllm.logger import logger
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.torchair.utils import ( from vllm_ascend.torchair.utils import (
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata, TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
check_torchair_cache_exist, converting_weight_acl_format, check_torchair_cache_exist, converting_weight_acl_format,
@@ -83,6 +84,20 @@ class NPUTorchairModelRunner(NPUModelRunner):
self._check_batch_sizes_consistency() self._check_batch_sizes_consistency()
def _set_up_drafter(self):
super()._set_up_drafter()
if self.speculative_config:
# Torchair do not support disable_padded_drafter_batch
# Enforce to disable this feature
self.speculative_config.disable_padded_drafter_batch = True
def _get_drafter(self):
return get_spec_decode_method(self.speculative_config.method,
self.vllm_config,
self.device,
self,
is_torchair_graph=True)
def _may_pad_kv_consumer_num_seq(self): def _may_pad_kv_consumer_num_seq(self):
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens # pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
# self.max_num_reqs here is greater than the actual maximum request number # self.max_num_reqs here is greater than the actual maximum request number

View File

@@ -0,0 +1,554 @@
import types
import torch
import torch.nn as nn
import torchair
from torchair import patch_for_hcom
from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode import MtpProposer
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
TorchairDeepSeekMTP
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
TorchairCommonAttentionMetadata)
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
vllm_version_is)
if vllm_version_is("0.11.0"):
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
else:
from vllm.utils.torch_utils import set_default_torch_dtype
PADDING_SLOT_ID = -1
class TorchairMtpProposer(MtpProposer):
def __init__(
self,
vllm_config: VllmConfig,
device,
runner,
):
super().__init__(vllm_config, device, runner)
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase).keys())
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
target_device = self.vllm_config.device_config.device
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(draft_attn_layer_names)
self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)
@torch.inference_mode()
def dummy_run(self,
num_tokens: int,
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
if not with_prefill:
skip_attn = False
if skip_attn:
attn_metadata = None
else:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,
num_actual_tokens=1,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
decode_token_per_req=self.runner.decode_token_per_req,
)
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
common_attn_metadata)
input_ids = self.input_ids[:num_tokens]
positions = self.positions[:num_tokens]
previous_hidden_states = self.hidden_states[:num_tokens]
for _ in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if not with_prefill:
assert attn_metadata is not None
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(previous_hidden_states)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
torchair_compiled_model(
input_ids=input_ids,
positions=positions,
hidden_states=previous_hidden_states,
inputs_embeds=None,
intermediate_tensors=None,
attn_metadata=attn_metadata,
kv_caches=self.runner.kv_caches[-1:],
spec_step_idx=0)
else:
self.model(input_ids=input_ids,
positions=positions,
hidden_states=previous_hidden_states)
if with_prefill:
break
def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.runner.input_batch.req_ids[i]
req_state = self.runner.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
accepted_token_indices = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, accepted_token_indices, target_token_ids, \
target_positions, target_hidden_states, target_slot_mapping = self._torchair_prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
self.runner.input_ids[:num_scheduled_tokens],
positions[:num_scheduled_tokens],
hidden_states[:num_scheduled_tokens],
attn_metadata.slot_mapping[:num_scheduled_tokens],
)
draft_token_ids = self._propose_torchair(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
token_indices=accepted_token_indices)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def _torchair_prepare_inputs(
self,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
token_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
slot_mapping: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
cu_num_tokens = cu_target_query_lens
relative_index = query_len_per_req - num_rejected_tokens - 1
token_indices = cu_num_tokens[:-1] + relative_index
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
target_token_ids = token_ids
target_positions = positions
target_hidden_states = hidden_states
target_slot_mapping = slot_mapping
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
def _propose_torchair(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
token_indices=None) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
if token_indices is not None:
last_token_indices = token_indices
self.input_ids[last_token_indices] = next_token_ids
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
# FIXME: reorder_batch() needs to be called before build()
# because fields of attn_metadata_builder needs to be updated.
# However, currently reorder_batch() takes input_batch and
# scheduler_output as arguments, we should probably refactor
# the method to use new data structures which are independent
# from input_batch and scheduler_output.
# self.runner.attn_metadata_builder.reorder_batch(
# input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output,
# )
if not self.runner.with_prefill:
# Torchair graph mode, padding is same as the main model
num_input_tokens = self.runner.graph_pad_size
elif (self.runner.use_aclgraph
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
# Acl graph mode, add padding to the batch size
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
# Eager mode, no padding needed
num_input_tokens = num_tokens
seq_lens = target_positions[last_token_indices] + 1
seq_lens = seq_lens.int()
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=cu_num_tokens[:batch_size + 1],
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
seq_lens_cpu=seq_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
graph_pad_size=self.runner.graph_pad_size,
decode_token_per_req=self.runner.decode_token_per_req,
num_computed_tokens_cpu=None,
seq_lens=None)
attn_metadata = self.runner.attn_metadata_builder.build(
0, common_attn_metadata, self.runner.get_model())
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
# torchair mode can reuse self.runner.num_tokens_across_dp
num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill
moe_comm_type = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
for step in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if not self.runner.with_prefill:
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_input_tokens)
hidden_states = torchair_compiled_model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.
hidden_states[:num_input_tokens],
inputs_embeds=None,
intermediate_tensors=None,
spec_step_idx=0,
**model_kwargs)
else:
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens]
)
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
if not self.runner.with_prefill:
max_num_reqs_across_dp = num_input_tokens
else:
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
last_token_indices = nn.functional.pad(
last_token_indices,
(0, max_num_reqs_across_dp - num_indices))
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
if step == 0:
draft_token_ids_list = [draft_token_ids]
else:
draft_token_ids_list.append(draft_token_ids)
# prepare next mtp inputs
# mtp>1: prefill skip or decode skip last loop
if with_prefill:
for _ in range(self.num_speculative_tokens - 1):
draft_token_ids_list.append(draft_token_ids)
if step == self.num_speculative_tokens - 1 or with_prefill:
break
attn_metadata_i = attn_metadata
if step == 0:
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
slot_mapping = attn_metadata_i.slot_mapping[last_token_indices]
attn_metadata_i.slot_mapping.fill_(-1)
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
last_token_indices = self.arange[:batch_size]
if attn_metadata_i.num_decode_tokens != 0:
attn_metadata_i.num_decode_tokens = batch_size
if not self.runner.with_prefill:
attn_metadata_i.num_actual_tokens = batch_size
attn_metadata_i.query_lens = [1] * batch_size
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
attn_metadata_i.seq_lens[:batch_size] += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
attn_metadata_i.seq_lens.device, non_blocking=True)
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len_cpu, 1)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping += 1
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:hidden_states.shape[0]] = hidden_states
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
if attn_metadata_i.prefill is not None:
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
)
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.prefill.max_seq_lens += 1
attn_metadata_i.prefill.max_seq_lens = min(
attn_metadata_i.prefill.max_seq_lens,
self.runner.model_config.max_model_len)
if attn_metadata_i.decode is not None:
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
)
attn_metadata_i.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.decode.max_seq_lens += 1
attn_metadata_i.decode.max_seq_lens = min(
attn_metadata_i.decode.max_seq_lens,
self.runner.model_config.max_model_len)
# mtp>1: [batch_size, k]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
patch_for_hcom()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.runner.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=not self.use_sparse,
fullgraph=True,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=not self.use_sparse,
fullgraph=True,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]

View File

@@ -133,6 +133,7 @@ from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration, AscendSocVersion, ProfileExecuteDuration,
enable_sp, get_ascend_soc_version, is_310p, enable_sp, get_ascend_soc_version, is_310p,
@@ -369,32 +370,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask_builder = AttentionMaskBuilder( self.attn_mask_builder = AttentionMaskBuilder(
self.model_config.max_model_len, self.dtype) self.model_config.max_model_len, self.dtype)
# Set up speculative decoding. self._set_up_drafter()
self.spec_attn_mask = None
self.drafter: Optional[Union[NgramProposer, EagleProposer,
MtpProposer]] = None
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
diagonal=1).to(self.device)
if get_pp_group().is_last_rank:
self.drafter = get_spec_decode_method(
self.speculative_config.method, self.vllm_config,
self.device, self)
if vllm_version_is("0.11.0"):
self.rejection_sampler = AscendRejectionSampler()
else:
self.rejection_sampler = AscendRejectionSampler(
self.sampler)
self.actual_seq_lengths_q = list(
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))
# kv role # kv role
self.is_kv_producer = False self.is_kv_producer = False
@@ -590,6 +566,39 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# TODO: EVS Support (Video tokens pruning) (see vllm#22980) # TODO: EVS Support (Video tokens pruning) (see vllm#22980)
self.is_multimodal_pruning_enabled = False self.is_multimodal_pruning_enabled = False
def _set_up_drafter(self):
# Set up speculative decoding.
self.spec_attn_mask = None
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
TorchairMtpProposer]] = None
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
diagonal=1).to(self.device)
if get_pp_group().is_last_rank:
self.drafter = self._get_drafter()
if vllm_version_is("0.11.0"):
self.rejection_sampler = AscendRejectionSampler()
else:
self.rejection_sampler = AscendRejectionSampler(
self.sampler)
self.actual_seq_lengths_q = list(
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))
self.discard_request_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
self.num_discarded_requests = 0
def _get_drafter(self):
return get_spec_decode_method(self.speculative_config.method,
self.vllm_config, self.device, self)
def _may_pad_kv_consumer_num_seq(self): def _may_pad_kv_consumer_num_seq(self):
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario, # For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
# we may want to pad self.max_num_seqs in kv_consumer nodes to avoid # we may want to pad self.max_num_seqs in kv_consumer nodes to avoid
@@ -609,7 +618,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
tp_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division. # Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size self.mc2_tokens_capacity: int = num_tokens_per_tp_rank * tp_size
def _make_buffer(self, def _make_buffer(self,
*size: Union[int, torch.SymInt], *size: Union[int, torch.SymInt],
@@ -1522,6 +1531,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
attn_metadata: dict[str, Any] = {} attn_metadata: dict[str, Any] = {}
# Record the index of requests that should not be sampled,
# so that we could clear the sampled tokens before returning
num_tokens = [
self.requests[r].num_tokens for r in self.input_batch.req_ids
]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
num_reqs = self.input_batch.num_reqs
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
discard_request_indices = np.nonzero(discard_requests_mask)[0]
self.num_discarded_requests = len(discard_request_indices)
self.discard_request_indices.np[:self.num_discarded_requests] = (
discard_request_indices)
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
# _prepare_inputs may reorder the batch, so we must gather # _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order # multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model: if self.is_multimodal_model:
@@ -1615,7 +1638,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
num_computed_tokens_cpu = ( num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
spec_decode_common_attn_metadata = None self.spec_decode_common_attn_metadata = None
if use_spec_decode and self.need_accepted_tokens: if use_spec_decode and self.need_accepted_tokens:
self.num_accepted_tokens.np[:num_reqs] = ( self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]) self.input_batch.num_accepted_tokens_cpu[:num_reqs])
@@ -1676,7 +1699,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata = AscendCommonAttentionMetadata( common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu, seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
seq_lens=self.seq_lens_cpu[:num_reqs], seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=slot_mapping_size, num_actual_tokens=slot_mapping_size,
@@ -1700,8 +1723,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
) )
if self.speculative_config and \ if self.speculative_config and \
spec_decode_common_attn_metadata is None: self.spec_decode_common_attn_metadata is None:
spec_decode_common_attn_metadata = common_attn_metadata self.spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]: for attn_group in self.attn_groups[kv_cache_group_id]:
common_prefix_len = 0 common_prefix_len = 0
@@ -1998,7 +2021,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
valid_sampled_token_ids: list[list[int]], valid_sampled_token_ids: Union[torch.Tensor, list[list[int]]],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
spec_decode_metadata: SpecDecodeMetadata, spec_decode_metadata: SpecDecodeMetadata,
@@ -2255,6 +2278,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logits = self.apply_grammar_bitmask( logits = self.apply_grammar_bitmask(
scheduler_output, logits) scheduler_output, logits)
with ProfileExecuteDuration().capture_async("Sample"):
# Sample the next token and get logprobs if needed. # Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None: if spec_decode_metadata is None:
@@ -2296,21 +2320,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.need_accepted_tokens: if self.need_accepted_tokens:
self._update_states_after_model_execute(output_token_ids) self._update_states_after_model_execute(output_token_ids)
discard_sampled_tokens_req_indices: list[int] = [] discard_sampled_tokens_req_indices = \
# TODO(woosuk): The following loop can be slow since it iterates over self.discard_request_indices.np[:self.num_discarded_requests]
# the requests one by one. Optimize. for i in discard_sampled_tokens_req_indices:
discard_sampled_tokens_req_indices = [] generator = self.input_batch.generators.get(int(i))
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None: if generator is not None:
generator.set_offset(generator.get_offset() - 4) generator.set_offset(generator.get_offset() - 4)
discard_sampled_tokens_req_indices.append(i)
# Copy some objects so they don't get modified after returning. # Copy some objects so they don't get modified after returning.
# This is important when using async scheduling. # This is important when using async scheduling.
@@ -2346,10 +2361,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
) )
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[int(i)].clear()
else: else:
valid_sampled_token_ids = [] valid_sampled_token_ids = []
invalid_req_indices = list(discard_sampled_tokens_req_indices) invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
)
invalid_req_indices_set = set(invalid_req_indices) invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids.shape[-1] == 1 assert sampled_token_ids.shape[-1] == 1
@@ -2394,9 +2410,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_state = self.requests[req_id] req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
if self.speculative_config: def propose_draft_token_ids(sampled_token_ids):
assert self.spec_decode_common_attn_metadata is not None
self._draft_token_ids = self.propose_draft_token_ids( self._draft_token_ids = self.propose_draft_token_ids(
valid_sampled_token_ids, sampled_token_ids,
sampling_metadata, sampling_metadata,
scheduler_output, scheduler_output,
spec_decode_metadata, spec_decode_metadata,
@@ -2407,6 +2424,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states, aux_hidden_states,
) )
with ProfileExecuteDuration().capture_async("Draft"):
if self.speculative_config:
use_padded_batch_for_eagle = self.speculative_config and \
self.speculative_config.method == "deepseek_mtp" and \
not self.speculative_config.disable_padded_drafter_batch
if use_padded_batch_for_eagle:
# EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids)
if self.speculative_config and not use_padded_batch_for_eagle:
# ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()

View File

@@ -92,8 +92,10 @@ class CachedRequestState:
def get_token_id(self, idx: int) -> int: def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens: if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx] return self.prompt_token_ids[idx]
else: elif idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens] return self.output_token_ids[idx - self.num_prompt_tokens]
else:
return -1
class InputBatch: class InputBatch: