[Feature] add the magicmtp speculative decoding acceleration algorithm (#5542)

### What this PR does / why we need it?

1. MagicMTP (paper: "Block Verification Accelerates Speculative
Decoding") was introduced to consider the influence among multiple draft
tokens, improving the acceptance rate without compromising accuracy.
2. Added Triton and PyTorch implementations, and added E2E test cases.

### Does this PR introduce _any_ user-facing change?
MagicMTP will automatically take effect when the parameter
"num_speculative_tokens" >= 3.
- vLLM version: v0.13.0
- vLLM main:
7157596103

Signed-off-by: chenaoxuan <cax1165@163.com>
This commit is contained in:
Aoxuan Chen
2026-01-08 09:15:55 +08:00
committed by GitHub
parent 481138e1d2
commit 8763953f56
3 changed files with 372 additions and 37 deletions

View File

@@ -10,8 +10,9 @@ from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE, MAX_SPEC_LEN,
from vllm_ascend.ops.triton.reject_sample import (
cal_grid_and_block_size, expand_triton,
rejection_greedy_sample_with_triton, rejection_random_sample_kernel,
sample_recovered_tokens_kernel)
rejection_greedy_sample_with_triton,
rejection_random_sample_block_verify_kernel,
rejection_random_sample_kernel, sample_recovered_tokens_kernel)
from vllm_ascend.sample.sampler import apply_top_k_top_p
@@ -104,6 +105,9 @@ def rejection_sample(
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# When num_speculative_tokens>=3, using block verify.
using_block_verify = max_spec_len >= 3
# Create output buffer.
output_token_ids = torch.empty(
(batch_size, max_spec_len + 1),
@@ -172,41 +176,74 @@ def rejection_sample(
sampling_metadata,
device,
)
# Rejection sampling for random sampling requests.
if HAS_TRITON:
rejection_random_sample_kernel[(grid, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs.to(torch.float32),
is_greedy,
max_spec_len,
vocab_size,
batch_size,
NO_DRAFT_PROBS=draft_probs is None,
BLOCK_SIZE=block_size,
)
if not using_block_verify:
# Rejection sampling for random sampling requests.
if HAS_TRITON:
rejection_random_sample_kernel[(grid, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs.to(torch.float32),
is_greedy,
max_spec_len,
vocab_size,
batch_size,
NO_DRAFT_PROBS=draft_probs is None,
BLOCK_SIZE=block_size,
)
else:
rejection_random_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
# num_warps=1,
)
else:
rejection_random_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
# num_warps=1,
)
# MagicMTP: Improving acceptance rate with Block Verify.
if HAS_TRITON:
rejection_random_sample_block_verify_kernel[(grid, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs.to(torch.float32),
is_greedy,
max_spec_len,
vocab_size,
batch_size,
NO_DRAFT_PROBS=draft_probs is None,
BLOCK_SIZE=block_size,
)
else:
rejection_random_sample_block_verify_pytorch(output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs
is None)
return output_token_ids
@@ -676,3 +713,86 @@ def sample_recovered_tokens_pytorch(
recovered_ids = torch.argmax(prob_over_q, dim=1)
output_token_ids[:] = recovered_ids
def rejection_random_sample_block_verify_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
draft_probs, # [num_tokens, vocab_size] or None
target_probs, # [num_tokens, vocab_size]
bonus_token_ids, # [batch_size]
recovered_token_ids, # [num_tokens]
uniform_probs, # [num_tokens]
is_greedy, # [batch_size]
max_spec_len,
vocab_size,
IS_NGRAM=False,
):
batch_size = output_token_ids.shape[0]
device = output_token_ids.device
zero_cpu = torch.tensor([0], pin_memory=True)
zero_device = zero_cpu.to(device, non_blocking=True)
cu_start = torch.cat([zero_device, cu_num_draft_tokens[:-1]])
cu_end = cu_num_draft_tokens
num_draft_per_batch = (cu_end - cu_start)[:, None]
pos_indices_cpu = torch.arange(max_spec_len, pin_memory=True)
pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :]
valid_mask = pos_indices < num_draft_per_batch
global_token_indices = cu_start[:, None] + pos_indices
global_token_indices = global_token_indices.clamp(
0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[global_token_indices]
if IS_NGRAM:
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
draft_token_probs = ones_cpu.to(
device, non_blocking=True).expand_as(draft_tokens)
else:
flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
flat_draft_probs = draft_probs[flat_indices, flat_draft_tokens]
draft_token_probs = flat_draft_probs.view(batch_size, max_spec_len)
flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
flat_target_probs = target_probs[flat_indices, flat_draft_tokens]
target_token_probs = flat_target_probs.view(batch_size, max_spec_len)
uniform_token_probs = uniform_probs[global_token_indices]
recovered_tokens = recovered_token_ids[global_token_indices]
pi = target_token_probs / draft_token_probs
pi = pi.clamp(max=1.0)
pi = torch.cumprod(pi, dim=-1)
uniform_token_probs = torch.cumprod(uniform_token_probs, dim=-1)
legal_mask = (draft_token_probs > 0) & (pi >= uniform_token_probs)
legal_mask = legal_mask & valid_mask
last_accept_pos = torch.where(
legal_mask.any(dim=-1, keepdim=True),
(max_spec_len -
legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1),
-1)
non_greedy_mask = (~is_greedy)[:, None]
accept_mask = (pos_indices
<= last_accept_pos) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(
accept_mask, draft_tokens, output_token_ids[:, :max_spec_len])
reject_mask = (pos_indices
== last_accept_pos + 1) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(
reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len])
bonus_mask = (last_accept_pos + 1 >= num_draft_per_batch) & non_greedy_mask
all_positions_cpu = torch.arange(max_spec_len + 1, pin_memory=True)
all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :]
bonus_pos_match = (all_positions == num_draft_per_batch)
bonus_mask = bonus_mask & bonus_pos_match
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(
-1, max_spec_len + 1)
output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded,
output_token_ids)