From 8763953f567f12291e956370c77d349686ee7fc7 Mon Sep 17 00:00:00 2001 From: Aoxuan Chen <43376869+chenaoxuan@users.noreply.github.com> Date: Thu, 8 Jan 2026 09:15:55 +0800 Subject: [PATCH] [Feature] add the magicmtp speculative decoding acceleration algorithm (#5542) ### What this PR does / why we need it? 1. MagicMTP (paper: "Block Verification Accelerates Speculative Decoding") was introduced to consider the influence among multiple draft tokens, improving the acceptance rate without compromising accuracy. 2. Added Triton and PyTorch implementations, and added E2E test cases. ### Does this PR introduce _any_ user-facing change? MagicMTP will automatically take effect when the parameter "num_speculative_tokens" >= 3. - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/7157596103666ee7ccb7008acee8bff8a8ff1731 Signed-off-by: chenaoxuan --- .../triton/test_rejection_sample.py | 136 ++++++++++++- vllm_ascend/ops/triton/reject_sample.py | 81 ++++++++ vllm_ascend/sample/rejection_sampler.py | 192 ++++++++++++++---- 3 files changed, 372 insertions(+), 37 deletions(-) diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rejection_sample.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rejection_sample.py index c8a85749..95c1157a 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rejection_sample.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rejection_sample.py @@ -4,8 +4,11 @@ from vllm.v1.sample.rejection_sampler import \ rejection_random_sample_kernel as original_rejection_random_sample_kernel from vllm_ascend.ops.triton.reject_sample import ( - cal_grid_and_block_size, rejection_random_sample_kernel) + cal_grid_and_block_size, rejection_random_sample_block_verify_kernel, + rejection_random_sample_kernel) from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton +from vllm_ascend.sample.rejection_sampler import \ + rejection_random_sample_block_verify_pytorch @pytest.fixture(scope="function", autouse=True) @@ -93,3 +96,134 @@ def test_rejection_random_sample(max_spec_len, vocab_size, batch_size): BLOCK_SIZE=block_size) torch.npu.synchronize() assert torch.equal(original_output_token_ids, output_token_ids) + + +DEVICE = "npu" +BATCH_SIZE = 7 +MAX_SPEC_LEN = 3 +VOCAB_SIZE = 5 +CU_NUM_DRAFT_TOKENS = torch.tensor([2, 2, 5, 8, 11, 14, 15], + dtype=torch.int32, + device=DEVICE) +DRAFT_TOKEN_IDS = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0], + dtype=torch.int64, + device=DEVICE) +NUM_TOKENS = DRAFT_TOKEN_IDS.shape[0] +DRAFT_PROBS = None +TARGET_PROBS = torch.tensor( + [ + [0.4, 0.3, 0.1, 0.1, 0.1], # 0 + [0.1, 0.9, 0.0, 0.0, 0.0], # 1 + [0.2, 0.1, 0.2, 0.4, 0.1], # 0 + [0.1, 0.4, 0.1, 0.1, 0.3], # 0 + [0.2, 0.1, 0.4, 0.1, 0.2], # 0 + [0.4, 0.2, 0.1, 0.2, 0.1], # 0 + [0.1, 0.6, 0.1, 0.1, 0.1], # 1 + [0.2, 0.2, 0.2, 0.3, 0.1], # 0 + [0.4, 0.2, 0.1, 0.2, 0.1], # 0 + [0.1, 0.6, 0.1, 0.1, 0.1], # 1 + [0.2, 0.2, 0.2, 0.3, 0.1], # 0 + [0.4, 0.4, 0.1, 0.0, 0.1], # 1 + [0.4, 0.3, 0.1, 0.1, 0.1], # 0 + [0.4, 0.0, 0.5, 0.0, 0.1], # 1 + [0.4, 0.1, 0.3, 0.1, 0.1], # 1 + ], + dtype=torch.float32, + device=DEVICE) +UNIFORM_PROBS = torch.tensor([ + 0.9, + 0.0, + 0.9, + 0.7, + 0.8, + 0.5, + 0.45, + 1.0, + 0.5, + 0.45, + 1.0, + 0.39, + 0.4, + 0.1, + 0.3, +], + dtype=torch.float32, + device=DEVICE) +BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ), + MAX_SPEC_LEN + 1, + dtype=torch.int64, + device=DEVICE) +RECOVERED_TOKEN_IDS = torch.full((NUM_TOKENS, ), + MAX_SPEC_LEN, + dtype=torch.int64, + device=DEVICE) +IS_GREEDY = torch.zeros(BATCH_SIZE, dtype=torch.bool, device=DEVICE) +IS_GREEDY[4] = True + + +@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS]) +@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS]) +@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS]) +@pytest.mark.parametrize("target_probs", [TARGET_PROBS]) +@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS]) +@pytest.mark.parametrize("recovered_token_ids", [RECOVERED_TOKEN_IDS]) +@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS]) +@pytest.mark.parametrize("is_greedy", [IS_GREEDY]) +@pytest.mark.parametrize("batch_size", [BATCH_SIZE]) +@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN]) +@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE]) +@torch.inference_mode() +def test_rejection_sampler_block_verify_triton_kernel( + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + recovered_token_ids, # [num_tokens] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + batch_size, # int + max_spec_len, # int + vocab_size, # int +) -> None: + + grid, block_size = cal_grid_and_block_size(batch_size) + + output_token_ids_ref = torch.full((batch_size, max_spec_len + 1), + -1, + dtype=torch.int64, + device=DEVICE) + + output_token_ids_triton = output_token_ids_ref.clone() + + rejection_random_sample_block_verify_pytorch( + output_token_ids=output_token_ids_ref, + cu_num_draft_tokens=cu_num_draft_tokens, + draft_token_ids=draft_token_ids, + draft_probs=draft_probs, + target_probs=target_probs, + bonus_token_ids=bonus_token_ids, + recovered_token_ids=recovered_token_ids, + uniform_probs=uniform_probs, + is_greedy=is_greedy, + max_spec_len=max_spec_len, + vocab_size=vocab_size, + IS_NGRAM=draft_probs is None) + + rejection_random_sample_block_verify_kernel[(grid, )]( + output_token_ids_ptr=output_token_ids_triton, + cu_num_draft_tokens_ptr=cu_num_draft_tokens, + draft_token_ids_ptr=draft_token_ids, + draft_probs_ptr=draft_probs, + target_probs_ptr=target_probs, + bonus_token_ids_ptr=bonus_token_ids, + recovered_token_ids_ptr=recovered_token_ids, + uniform_probs_ptr=uniform_probs, + is_greedy_ptr=is_greedy, + max_spec_len=max_spec_len, + vocab_size=vocab_size, + vec_len=batch_size, + NO_DRAFT_PROBS=draft_probs is None, + BLOCK_SIZE=block_size) + torch.npu.synchronize() + assert torch.equal(output_token_ids_ref, output_token_ids_triton) diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 6de1ae64..14281557 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -378,3 +378,84 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation. BLOCK_SIZE=block_size, ) + + +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_random_sample_block_verify_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + vec_len, + NO_DRAFT_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr): + block_idx = tl.program_id(0) + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < vec_len + is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1) + not_greedy_mask = is_greedy == 0 + start_idxs = tl.where( + offsets == 0, 0, + tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) + end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) + n_num_draft_tokens = end_idxs - start_idxs + for req_i in range(BLOCK_SIZE): + not_greedy = tl.get_element(not_greedy_mask, (req_i, )) + if not_greedy: + + rejected = False + pi = 1.0 + uniform_prob = 1.0 + last_accepted_token_pos = -1 + start_idx = tl.get_element(start_idxs, (req_i, )) + req_idx = block_idx * BLOCK_SIZE + req_i + num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, )) + + for pos in range(num_draft_tokens): + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + uniform_prob = uniform_prob * tmp_uniform_prob + + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + + pi = min(pi * target_prob / draft_prob, 1.0) + if draft_prob > 0 and pi >= uniform_prob: + last_accepted_token_pos = pos + rejected = False + else: + rejected = True + + if last_accepted_token_pos > -1: + for pos in range(last_accepted_token_pos + 1): + token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + pos, token_id) + + if rejected: + recovered_token_id = tl.load(recovered_token_ids_ptr + + start_idx + + last_accepted_token_pos + 1) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + last_accepted_token_pos + 1, recovered_token_id) + else: + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 16f89541..7d8b8078 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -10,8 +10,9 @@ from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE, MAX_SPEC_LEN, from vllm_ascend.ops.triton.reject_sample import ( cal_grid_and_block_size, expand_triton, - rejection_greedy_sample_with_triton, rejection_random_sample_kernel, - sample_recovered_tokens_kernel) + rejection_greedy_sample_with_triton, + rejection_random_sample_block_verify_kernel, + rejection_random_sample_kernel, sample_recovered_tokens_kernel) from vllm_ascend.sample.sampler import apply_top_k_top_p @@ -104,6 +105,9 @@ def rejection_sample( assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) + # When num_speculative_tokens>=3, using block verify. + using_block_verify = max_spec_len >= 3 + # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), @@ -172,41 +176,74 @@ def rejection_sample( sampling_metadata, device, ) - - # Rejection sampling for random sampling requests. - if HAS_TRITON: - rejection_random_sample_kernel[(grid, )]( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs.to(torch.float32), - is_greedy, - max_spec_len, - vocab_size, - batch_size, - NO_DRAFT_PROBS=draft_probs is None, - BLOCK_SIZE=block_size, - ) + if not using_block_verify: + # Rejection sampling for random sampling requests. + if HAS_TRITON: + rejection_random_sample_kernel[(grid, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs.to(torch.float32), + is_greedy, + max_spec_len, + vocab_size, + batch_size, + NO_DRAFT_PROBS=draft_probs is None, + BLOCK_SIZE=block_size, + ) + else: + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + # num_warps=1, + ) else: - rejection_random_sample_pytorch( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - IS_NGRAM=draft_probs is None, - # num_warps=1, - ) + # MagicMTP: Improving acceptance rate with Block Verify. + if HAS_TRITON: + rejection_random_sample_block_verify_kernel[(grid, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs.to(torch.float32), + is_greedy, + max_spec_len, + vocab_size, + batch_size, + NO_DRAFT_PROBS=draft_probs is None, + BLOCK_SIZE=block_size, + ) + else: + rejection_random_sample_block_verify_pytorch(output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs + is None) return output_token_ids @@ -676,3 +713,86 @@ def sample_recovered_tokens_pytorch( recovered_ids = torch.argmax(prob_over_q, dim=1) output_token_ids[:] = recovered_ids + + +def rejection_random_sample_block_verify_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + recovered_token_ids, # [num_tokens] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM=False, +): + batch_size = output_token_ids.shape[0] + device = output_token_ids.device + + zero_cpu = torch.tensor([0], pin_memory=True) + zero_device = zero_cpu.to(device, non_blocking=True) + + cu_start = torch.cat([zero_device, cu_num_draft_tokens[:-1]]) + cu_end = cu_num_draft_tokens + num_draft_per_batch = (cu_end - cu_start)[:, None] + pos_indices_cpu = torch.arange(max_spec_len, pin_memory=True) + pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :] + valid_mask = pos_indices < num_draft_per_batch + global_token_indices = cu_start[:, None] + pos_indices + global_token_indices = global_token_indices.clamp( + 0, draft_token_ids.shape[0] - 1) + draft_tokens = draft_token_ids[global_token_indices] + + if IS_NGRAM: + ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32) + draft_token_probs = ones_cpu.to( + device, non_blocking=True).expand_as(draft_tokens) + else: + flat_indices = global_token_indices.flatten() + flat_draft_tokens = draft_tokens.flatten() + flat_draft_probs = draft_probs[flat_indices, flat_draft_tokens] + draft_token_probs = flat_draft_probs.view(batch_size, max_spec_len) + + flat_indices = global_token_indices.flatten() + flat_draft_tokens = draft_tokens.flatten() + flat_target_probs = target_probs[flat_indices, flat_draft_tokens] + target_token_probs = flat_target_probs.view(batch_size, max_spec_len) + uniform_token_probs = uniform_probs[global_token_indices] + recovered_tokens = recovered_token_ids[global_token_indices] + + pi = target_token_probs / draft_token_probs + pi = pi.clamp(max=1.0) + pi = torch.cumprod(pi, dim=-1) + uniform_token_probs = torch.cumprod(uniform_token_probs, dim=-1) + legal_mask = (draft_token_probs > 0) & (pi >= uniform_token_probs) + legal_mask = legal_mask & valid_mask + + last_accept_pos = torch.where( + legal_mask.any(dim=-1, keepdim=True), + (max_spec_len - + legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1), + -1) + non_greedy_mask = (~is_greedy)[:, None] + + accept_mask = (pos_indices + <= last_accept_pos) & valid_mask & non_greedy_mask + output_token_ids[:, :max_spec_len] = torch.where( + accept_mask, draft_tokens, output_token_ids[:, :max_spec_len]) + + reject_mask = (pos_indices + == last_accept_pos + 1) & valid_mask & non_greedy_mask + output_token_ids[:, :max_spec_len] = torch.where( + reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len]) + + bonus_mask = (last_accept_pos + 1 >= num_draft_per_batch) & non_greedy_mask + all_positions_cpu = torch.arange(max_spec_len + 1, pin_memory=True) + all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :] + bonus_pos_match = (all_positions == num_draft_per_batch) + bonus_mask = bonus_mask & bonus_pos_match + bonus_values_expanded = bonus_token_ids.view(-1, 1).expand( + -1, max_spec_len + 1) + output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded, + output_token_ids)