[Feature] add the magicmtp speculative decoding acceleration algorithm (#5542)

### What this PR does / why we need it?

1. MagicMTP (paper: "Block Verification Accelerates Speculative
Decoding") was introduced to consider the influence among multiple draft
tokens, improving the acceptance rate without compromising accuracy.
2. Added Triton and PyTorch implementations, and added E2E test cases.

### Does this PR introduce _any_ user-facing change?
MagicMTP will automatically take effect when the parameter
"num_speculative_tokens" >= 3.
- vLLM version: v0.13.0
- vLLM main:
7157596103

Signed-off-by: chenaoxuan <cax1165@163.com>
This commit is contained in:
Aoxuan Chen
2026-01-08 09:15:55 +08:00
committed by GitHub
parent 481138e1d2
commit 8763953f56
3 changed files with 372 additions and 37 deletions

View File

@@ -378,3 +378,84 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation.
BLOCK_SIZE=block_size,
)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_block_verify_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
vec_len,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr):
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < vec_len
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
not_greedy_mask = is_greedy == 0
start_idxs = tl.where(
offsets == 0, 0,
tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
if not_greedy:
rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
start_idx = tl.get_element(start_idxs, (req_i, ))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))
for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
uniform_prob = uniform_prob * tmp_uniform_prob
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
pi = min(pi * target_prob / draft_prob, 1.0)
if draft_prob > 0 and pi >= uniform_prob:
last_accepted_token_pos = pos
rejected = False
else:
rejected = True
if last_accepted_token_pos > -1:
for pos in range(last_accepted_token_pos + 1):
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
pos, token_id)
if rejected:
recovered_token_id = tl.load(recovered_token_ids_ptr +
start_idx +
last_accepted_token_pos + 1)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
last_accepted_token_pos + 1, recovered_token_id)
else:
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)