[FEAT] Refactor spec decode to support efficient padded speculation (#3528)

### What this PR does / why we need it?
1. Refactor the file `mtp_proposer.py`, splits torchair related codes
into `mtp_torchair_proposer.py`
2. According to https://github.com/vllm-project/vllm/pull/24539,
implements padded speculative decoding as described in
https://github.com/vllm-project/vllm/issues/21984.
### Does this PR introduce _any_ user-facing change?
User can use `disable_padded_drafter_batch` to disable/enable padded
speculation, default is `False`.
offline example:
```
speculative_config={"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False}
```

### How was this patch tested?

- [x] egaer with pad/unpad:
- [x] aclgraph with pad/unpad
- [x] torchair with pad/unpad

performance test of deepseek-r1 with tp16、dp1
aclgraph with pad ITL: 168ms
aclgraph with unpad ITL: 169ms
original: 178ms


- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19

---------

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-10-30 16:53:05 +08:00
committed by GitHub
parent 10772d94e3
commit eff3e5fc6f
7 changed files with 1203 additions and 440 deletions

View File

@@ -133,6 +133,7 @@ from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
enable_sp, get_ascend_soc_version, is_310p,
@@ -369,32 +370,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask_builder = AttentionMaskBuilder(
self.model_config.max_model_len, self.dtype)
# Set up speculative decoding.
self.spec_attn_mask = None
self.drafter: Optional[Union[NgramProposer, EagleProposer,
MtpProposer]] = None
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
diagonal=1).to(self.device)
if get_pp_group().is_last_rank:
self.drafter = get_spec_decode_method(
self.speculative_config.method, self.vllm_config,
self.device, self)
if vllm_version_is("0.11.0"):
self.rejection_sampler = AscendRejectionSampler()
else:
self.rejection_sampler = AscendRejectionSampler(
self.sampler)
self.actual_seq_lengths_q = list(
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))
self._set_up_drafter()
# kv role
self.is_kv_producer = False
@@ -590,6 +566,39 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# TODO: EVS Support (Video tokens pruning) (see vllm#22980)
self.is_multimodal_pruning_enabled = False
def _set_up_drafter(self):
# Set up speculative decoding.
self.spec_attn_mask = None
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
TorchairMtpProposer]] = None
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
diagonal=1).to(self.device)
if get_pp_group().is_last_rank:
self.drafter = self._get_drafter()
if vllm_version_is("0.11.0"):
self.rejection_sampler = AscendRejectionSampler()
else:
self.rejection_sampler = AscendRejectionSampler(
self.sampler)
self.actual_seq_lengths_q = list(
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))
self.discard_request_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
self.num_discarded_requests = 0
def _get_drafter(self):
return get_spec_decode_method(self.speculative_config.method,
self.vllm_config, self.device, self)
def _may_pad_kv_consumer_num_seq(self):
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
# we may want to pad self.max_num_seqs in kv_consumer nodes to avoid
@@ -609,7 +618,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
tp_size = self.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
self.mc2_tokens_capacity: int = num_tokens_per_tp_rank * tp_size
def _make_buffer(self,
*size: Union[int, torch.SymInt],
@@ -1522,6 +1531,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
attn_metadata: dict[str, Any] = {}
# Record the index of requests that should not be sampled,
# so that we could clear the sampled tokens before returning
num_tokens = [
self.requests[r].num_tokens for r in self.input_batch.req_ids
]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
num_reqs = self.input_batch.num_reqs
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
discard_request_indices = np.nonzero(discard_requests_mask)[0]
self.num_discarded_requests = len(discard_request_indices)
self.discard_request_indices.np[:self.num_discarded_requests] = (
discard_request_indices)
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
# _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model:
@@ -1615,7 +1638,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
spec_decode_common_attn_metadata = None
self.spec_decode_common_attn_metadata = None
if use_spec_decode and self.need_accepted_tokens:
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
@@ -1676,7 +1699,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu,
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=slot_mapping_size,
@@ -1700,8 +1723,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
if self.speculative_config and \
spec_decode_common_attn_metadata is None:
spec_decode_common_attn_metadata = common_attn_metadata
self.spec_decode_common_attn_metadata is None:
self.spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]:
common_prefix_len = 0
@@ -1998,7 +2021,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def propose_draft_token_ids(
self,
valid_sampled_token_ids: list[list[int]],
valid_sampled_token_ids: Union[torch.Tensor, list[list[int]]],
sampling_metadata: SamplingMetadata,
scheduler_output: "SchedulerOutput",
spec_decode_metadata: SpecDecodeMetadata,
@@ -2255,6 +2278,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logits = self.apply_grammar_bitmask(
scheduler_output, logits)
with ProfileExecuteDuration().capture_async("Sample"):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
@@ -2296,21 +2320,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.need_accepted_tokens:
self._update_states_after_model_execute(output_token_ids)
discard_sampled_tokens_req_indices: list[int] = []
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
discard_sampled_tokens_req_indices.append(i)
discard_sampled_tokens_req_indices = \
self.discard_request_indices.np[:self.num_discarded_requests]
for i in discard_sampled_tokens_req_indices:
generator = self.input_batch.generators.get(int(i))
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Copy some objects so they don't get modified after returning.
# This is important when using async scheduling.
@@ -2346,10 +2361,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
valid_sampled_token_ids[int(i)].clear()
else:
valid_sampled_token_ids = []
invalid_req_indices = list(discard_sampled_tokens_req_indices)
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
)
invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids.shape[-1] == 1
@@ -2394,18 +2410,33 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
def propose_draft_token_ids(sampled_token_ids):
assert self.spec_decode_common_attn_metadata is not None
self._draft_token_ids = self.propose_draft_token_ids(
sampled_token_ids,
sampling_metadata,
scheduler_output,
spec_decode_metadata,
positions,
scheduler_output.total_num_scheduled_tokens,
hidden_states,
attn_metadata,
aux_hidden_states,
)
with ProfileExecuteDuration().capture_async("Draft"):
if self.speculative_config:
self._draft_token_ids = self.propose_draft_token_ids(
valid_sampled_token_ids,
sampling_metadata,
scheduler_output,
spec_decode_metadata,
positions,
scheduler_output.total_num_scheduled_tokens,
hidden_states,
attn_metadata,
aux_hidden_states,
)
use_padded_batch_for_eagle = self.speculative_config and \
self.speculative_config.method == "deepseek_mtp" and \
not self.speculative_config.disable_padded_drafter_batch
if use_padded_batch_for_eagle:
# EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids)
if self.speculative_config and not use_padded_batch_for_eagle:
# ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()

View File

@@ -92,8 +92,10 @@ class CachedRequestState:
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
else:
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens]
else:
return -1
class InputBatch: