Revert "Add MagicMTP(block verify) and Triton optimization (#4443)" (#5380)

### What this PR does / why we need it?
#4443 introduces a precision issue in scenarios where MTP >= 3 + deepseek v3.1, and this pr reverts it

- vLLM version: release/v0.13.0
- vLLM main:
bc0a5a0c08

Signed-off-by: GDzhu01 <809721801@qq.com>
This commit is contained in:
Zhu Yi Lin
2025-12-26 15:06:13 +08:00
committed by GitHub
parent 45c5bcd962
commit 18302c8467
2 changed files with 40 additions and 337 deletions

View File

@@ -114,9 +114,6 @@ def rejection_sample(
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# When num_speculative_tokens>=3, using block verify.
using_block_verify = max_spec_len >= 3
# Create output buffer.
output_token_ids = torch.empty(
(batch_size, max_spec_len + 1),
@@ -194,81 +191,52 @@ def rejection_sample(
sampling_metadata.generators,
device,
)
if not using_block_verify:
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
# Rejection sampling for random sampling requests.
if HAS_TRITON:
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
bonus_token_ids,
recovered_token_ids,
uniform_probs.to(torch.float32),
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
)
# Rejection sampling for random sampling requests.
if HAS_TRITON:
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs.to(torch.float32),
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
)
else:
rejection_random_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
)
else:
# MagicMTP: Improving acceptance rate with Block Verify.
if HAS_TRITON:
rejection_random_sample_block_verify_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
uniform_probs.to(torch.float32),
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
multibuffer=True,
)
else:
rejection_random_sample_block_verify_pytorch(output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs
is None)
rejection_random_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
# num_warps=1,
)
return output_token_ids
@@ -532,71 +500,6 @@ def rejection_random_sample_pytorch(
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
def rejection_random_sample_block_verify_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
draft_probs, # [num_tokens, vocab_size] or None
target_probs, # [num_tokens, vocab_size]
bonus_token_ids, # [batch_size]
uniform_probs, # [num_tokens]
is_greedy, # [batch_size]
max_spec_len,
vocab_size,
IS_NGRAM=False,
):
batch_size = output_token_ids.shape[0]
for req_idx in range(batch_size):
if is_greedy[req_idx]:
continue
if req_idx == 0:
start_idx = 0
else:
start_idx = cu_num_draft_tokens[req_idx - 1].item()
end_idx = cu_num_draft_tokens[req_idx].item()
num_draft_tokens = end_idx - start_idx
rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
for pos in range(num_draft_tokens):
draft_token_id = draft_token_ids[start_idx + pos].item()
target_prob = target_probs[start_idx + pos, draft_token_id].item()
uniform_prob = uniform_prob * uniform_probs[start_idx + pos].item()
if IS_NGRAM:
draft_prob = 1.0
else:
draft_prob = draft_probs[start_idx + pos,
draft_token_id].item()
pi = min(pi * target_prob / draft_prob, 1.0)
if draft_prob > 0 and pi >= uniform_prob:
last_accepted_token_pos = pos
rejected = False
else:
rejected = True
if last_accepted_token_pos > -1:
for pos in range(last_accepted_token_pos + 1):
draft_token_id = draft_token_ids[start_idx + pos].item()
output_token_ids[req_idx, pos] = draft_token_id
if rejected:
recovered_token_id = torch.argmax(
target_probs[start_idx + last_accepted_token_pos + 1]).item()
output_token_ids[req_idx,
last_accepted_token_pos + 1] = recovered_token_id
else:
bonus_token_id = bonus_token_ids[req_idx].item()
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
def expand_pytorch(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
@@ -834,92 +737,6 @@ def rejection_random_sample_kernel(
)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_block_verify_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
SUB_BLOCK: tl.constexpr = 1500,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exit for greedy sampling requests.
return
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size + draft_token_id)
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
uniform_prob = uniform_prob * tmp_uniform_prob
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
pi = min(pi * target_prob / draft_prob, 1.0)
if draft_prob > 0 and pi >= uniform_prob:
last_accepted_token_pos = pos
rejected = False
else:
rejected = True
if last_accepted_token_pos > -1:
for pos in range(last_accepted_token_pos + 1):
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
if rejected:
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
global_recovered_id = -1
global_max_p = -1.0
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
tmp_target_prob = tl.load(
target_probs_ptr +
(start_idx + last_accepted_token_pos + 1) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
recovered_id = tl.argmax(tmp_target_prob, axis=-1)
max_p = tl.get_element(tmp_target_prob, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
last_accepted_token_pos + 1, global_recovered_id)
else:
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
output_ptr, # [num_tokens]