From 6d25372baaa0ef018a75b427b387fab8dd2e92b4 Mon Sep 17 00:00:00 2001 From: Aoxuan Chen <43376869+chenaoxuan@users.noreply.github.com> Date: Thu, 25 Dec 2025 09:00:25 +0800 Subject: [PATCH] Add MagicMTP(block verify) and Triton optimization (#4443) ### What this PR does / why we need it? 1. MagicMTP (paper: "Block Verification Accelerates Speculative Decoding") was introduced to consider the influence among multiple draft tokens, improving the acceptance rate without compromising accuracy. 2. The rejection sampling logic in rejection_sampler.py was restructured using Triton-Ascend, enabling it to operate under high concurrency, thus resolving CPU and NPU operator bottlenecks and enhancing throughput. ### Does this PR introduce _any_ user-facing change? MagicMTP will automatically take effect when the parameter "num_speculative_tokens" >= 3. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: chenaoxuan --- .../ops/triton/test_rejection_sampler.py | 114 ++++++++ vllm_ascend/sample/rejection_sampler.py | 263 +++++++++++++++--- 2 files changed, 337 insertions(+), 40 deletions(-) create mode 100644 tests/e2e/nightly/ops/triton/test_rejection_sampler.py diff --git a/tests/e2e/nightly/ops/triton/test_rejection_sampler.py b/tests/e2e/nightly/ops/triton/test_rejection_sampler.py new file mode 100644 index 00000000..86992711 --- /dev/null +++ b/tests/e2e/nightly/ops/triton/test_rejection_sampler.py @@ -0,0 +1,114 @@ +import pytest +import torch +from torch.testing import assert_close + +from vllm_ascend.sample.rejection_sampler import ( + rejection_random_sample_block_verify_kernel, + rejection_random_sample_block_verify_pytorch) + +DEVICE = "npu" +BATCH_SIZE = 3 +MAX_SPEC_LEN = 3 +VOCAB_SIZE = 5 +NUM_TOKENS = BATCH_SIZE * MAX_SPEC_LEN +CU_NUM_DRAFT_TOKENS = torch.arange(start=MAX_SPEC_LEN, + end=NUM_TOKENS + 1, + step=MAX_SPEC_LEN, + dtype=torch.int32, + device=DEVICE) +DRAFT_TOKEN_IDS = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], + dtype=torch.int64, + device=DEVICE) +DRAFT_PROBS = None +TARGET_PROBS = torch.tensor( + [ + [0.2, 0.1, 0.2, 0.4, 0.1], # 0 + [0.1, 0.4, 0.1, 0.1, 0.3], # 0 + [0.2, 0.1, 0.4, 0.1, 0.2], # 0 + [0.4, 0.2, 0.1, 0.2, 0.1], # 0 + [0.1, 0.6, 0.1, 0.1, 0.1], # 1 + [0.2, 0.2, 0.2, 0.3, 0.1], # 0 + [0.4, 0.4, 0.1, 0.0, 0.1], # 1 + [0.4, 0.3, 0.1, 0.1, 0.1], # 0 + [0.4, 0.0, 0.5, 0.0, 0.1], # 1 + ], + dtype=torch.float32, + device=DEVICE) +UNIFORM_PROBS = torch.tensor([ + 0.9, + 0.7, + 0.8, + 0.5, + 0.45, + 1.0, + 0.39, + 0.4, + 0.1, +], + dtype=torch.float32, + device=DEVICE) +BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ), + MAX_SPEC_LEN + 1, + dtype=torch.int64, + device=DEVICE) +IS_GREEDY = torch.zeros(NUM_TOKENS, dtype=torch.bool, device=DEVICE) + + +@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS]) +@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS]) +@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS]) +@pytest.mark.parametrize("target_probs", [TARGET_PROBS]) +@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS]) +@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS]) +@pytest.mark.parametrize("is_greedy", [IS_GREEDY]) +@pytest.mark.parametrize("vocab_size", [BATCH_SIZE]) +@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN]) +@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE]) +@torch.inference_mode() +def test_rejection_sampler_block_verify_triton_kernel( + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + batch_size, # int + max_spec_len, # int + vocab_size, # int +) -> None: + output_token_ids_ref = torch.full((batch_size, max_spec_len + 1), + -1, + dtype=torch.int64, + device=DEVICE) + + output_token_ids_triton = output_token_ids_ref.clone() + + rejection_random_sample_block_verify_pytorch( + output_token_ids=output_token_ids_ref, + cu_num_draft_tokens=cu_num_draft_tokens, + draft_token_ids=draft_token_ids, + draft_probs=draft_probs, + target_probs=target_probs, + bonus_token_ids=bonus_token_ids, + uniform_probs=uniform_probs, + is_greedy=is_greedy, + max_spec_len=max_spec_len, + vocab_size=vocab_size, + IS_NGRAM=draft_probs is None) + + rejection_random_sample_block_verify_kernel[(batch_size, )]( + output_token_ids_ptr=output_token_ids_triton, + cu_num_draft_tokens_ptr=cu_num_draft_tokens, + draft_token_ids_ptr=draft_token_ids, + draft_probs_ptr=draft_probs, + target_probs_ptr=target_probs, + bonus_token_ids_ptr=bonus_token_ids, + uniform_probs_ptr=uniform_probs, + is_greedy_ptr=is_greedy, + max_spec_len=max_spec_len, + vocab_size=vocab_size, + NO_DRAFT_PROBS=draft_probs is None, + multibuffer=True) + + assert_close(output_token_ids_ref, output_token_ids_triton) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index b0e6f848..7556438f 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -114,6 +114,9 @@ def rejection_sample( assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) + # When num_speculative_tokens>=3, using block verify. + using_block_verify = max_spec_len >= 3 + # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), @@ -191,52 +194,81 @@ def rejection_sample( sampling_metadata.generators, device, ) - - # Sample recovered tokens for each position. - # [num_tokens] - recovered_token_ids = sample_recovered_tokens( - max_spec_len, - num_draft_tokens, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - sampling_metadata, - device, - ) - - # Rejection sampling for random sampling requests. - if HAS_TRITON: - rejection_random_sample_kernel[(batch_size, )]( - output_token_ids, + if not using_block_verify: + # Sample recovered tokens for each position. + # [num_tokens] + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + num_draft_tokens, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs.to(torch.float32), - is_greedy, - max_spec_len, - vocab_size, - NO_DRAFT_PROBS=draft_probs is None, + sampling_metadata, + device, ) + + # Rejection sampling for random sampling requests. + if HAS_TRITON: + rejection_random_sample_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs.to(torch.float32), + is_greedy, + max_spec_len, + vocab_size, + NO_DRAFT_PROBS=draft_probs is None, + ) + else: + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + ) else: - rejection_random_sample_pytorch( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - IS_NGRAM=draft_probs is None, - # num_warps=1, - ) + # MagicMTP: Improving acceptance rate with Block Verify. + if HAS_TRITON: + rejection_random_sample_block_verify_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + uniform_probs.to(torch.float32), + is_greedy, + max_spec_len, + vocab_size, + NO_DRAFT_PROBS=draft_probs is None, + multibuffer=True, + ) + else: + rejection_random_sample_block_verify_pytorch(output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs + is None) return output_token_ids @@ -500,6 +532,71 @@ def rejection_random_sample_pytorch( output_token_ids[req_idx, num_draft_tokens] = bonus_token_id +def rejection_random_sample_block_verify_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM=False, +): + batch_size = output_token_ids.shape[0] + + for req_idx in range(batch_size): + if is_greedy[req_idx]: + continue + + if req_idx == 0: + start_idx = 0 + else: + start_idx = cu_num_draft_tokens[req_idx - 1].item() + end_idx = cu_num_draft_tokens[req_idx].item() + num_draft_tokens = end_idx - start_idx + + rejected = False + pi = 1.0 + uniform_prob = 1.0 + last_accepted_token_pos = -1 + for pos in range(num_draft_tokens): + draft_token_id = draft_token_ids[start_idx + pos].item() + + target_prob = target_probs[start_idx + pos, draft_token_id].item() + uniform_prob = uniform_prob * uniform_probs[start_idx + pos].item() + + if IS_NGRAM: + draft_prob = 1.0 + else: + draft_prob = draft_probs[start_idx + pos, + draft_token_id].item() + + pi = min(pi * target_prob / draft_prob, 1.0) + + if draft_prob > 0 and pi >= uniform_prob: + last_accepted_token_pos = pos + rejected = False + else: + rejected = True + + if last_accepted_token_pos > -1: + for pos in range(last_accepted_token_pos + 1): + draft_token_id = draft_token_ids[start_idx + pos].item() + output_token_ids[req_idx, pos] = draft_token_id + + if rejected: + recovered_token_id = torch.argmax( + target_probs[start_idx + last_accepted_token_pos + 1]).item() + output_token_ids[req_idx, + last_accepted_token_pos + 1] = recovered_token_id + else: + bonus_token_id = bonus_token_ids[req_idx].item() + output_token_ids[req_idx, num_draft_tokens] = bonus_token_id + + def expand_pytorch( output_ptr, # [num_tokens] input_ptr, # [batch_size] @@ -737,6 +834,92 @@ def rejection_random_sample_kernel( ) +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_random_sample_block_verify_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + NO_DRAFT_PROBS: tl.constexpr, + SUB_BLOCK: tl.constexpr = 1500, +): + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exit for greedy sampling requests. + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + pi = 1.0 + uniform_prob = 1.0 + last_accepted_token_pos = -1 + + for pos in range(num_draft_tokens): + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + draft_token_id) + tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + uniform_prob = uniform_prob * tmp_uniform_prob + + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + + pi = min(pi * target_prob / draft_prob, 1.0) + if draft_prob > 0 and pi >= uniform_prob: + last_accepted_token_pos = pos + rejected = False + else: + rejected = True + + if last_accepted_token_pos > -1: + for pos in range(last_accepted_token_pos + 1): + token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + token_id) + + if rejected: + loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK + global_recovered_id = -1 + global_max_p = -1.0 + for loop_i in range(loop): + vocab_start = loop_i * SUB_BLOCK + vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) + tmp_target_prob = tl.load( + target_probs_ptr + + (start_idx + last_accepted_token_pos + 1) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + recovered_id = tl.argmax(tmp_target_prob, axis=-1) + max_p = tl.get_element(tmp_target_prob, (recovered_id, )) + if max_p > global_max_p: + global_max_p = max_p + global_recovered_id = vocab_start + recovered_id + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + last_accepted_token_pos + 1, global_recovered_id) + else: + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) + + @triton.jit(do_not_specialize=["replace_from", "replace_to"]) def expand_kernel( output_ptr, # [num_tokens]