[Feature] add the magicmtp speculative decoding acceleration algorithm (#5542)
### What this PR does / why we need it?
1. MagicMTP (paper: "Block Verification Accelerates Speculative
Decoding") was introduced to consider the influence among multiple draft
tokens, improving the acceptance rate without compromising accuracy.
2. Added Triton and PyTorch implementations, and added E2E test cases.
### Does this PR introduce _any_ user-facing change?
MagicMTP will automatically take effect when the parameter
"num_speculative_tokens" >= 3.
- vLLM version: v0.13.0
- vLLM main:
7157596103
Signed-off-by: chenaoxuan <cax1165@163.com>
This commit is contained in:
@@ -4,8 +4,11 @@ from vllm.v1.sample.rejection_sampler import \
|
||||
rejection_random_sample_kernel as original_rejection_random_sample_kernel
|
||||
|
||||
from vllm_ascend.ops.triton.reject_sample import (
|
||||
cal_grid_and_block_size, rejection_random_sample_kernel)
|
||||
cal_grid_and_block_size, rejection_random_sample_block_verify_kernel,
|
||||
rejection_random_sample_kernel)
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
from vllm_ascend.sample.rejection_sampler import \
|
||||
rejection_random_sample_block_verify_pytorch
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
@@ -93,3 +96,134 @@ def test_rejection_random_sample(max_spec_len, vocab_size, batch_size):
|
||||
BLOCK_SIZE=block_size)
|
||||
torch.npu.synchronize()
|
||||
assert torch.equal(original_output_token_ids, output_token_ids)
|
||||
|
||||
|
||||
DEVICE = "npu"
|
||||
BATCH_SIZE = 7
|
||||
MAX_SPEC_LEN = 3
|
||||
VOCAB_SIZE = 5
|
||||
CU_NUM_DRAFT_TOKENS = torch.tensor([2, 2, 5, 8, 11, 14, 15],
|
||||
dtype=torch.int32,
|
||||
device=DEVICE)
|
||||
DRAFT_TOKEN_IDS = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0],
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
NUM_TOKENS = DRAFT_TOKEN_IDS.shape[0]
|
||||
DRAFT_PROBS = None
|
||||
TARGET_PROBS = torch.tensor(
|
||||
[
|
||||
[0.4, 0.3, 0.1, 0.1, 0.1], # 0
|
||||
[0.1, 0.9, 0.0, 0.0, 0.0], # 1
|
||||
[0.2, 0.1, 0.2, 0.4, 0.1], # 0
|
||||
[0.1, 0.4, 0.1, 0.1, 0.3], # 0
|
||||
[0.2, 0.1, 0.4, 0.1, 0.2], # 0
|
||||
[0.4, 0.2, 0.1, 0.2, 0.1], # 0
|
||||
[0.1, 0.6, 0.1, 0.1, 0.1], # 1
|
||||
[0.2, 0.2, 0.2, 0.3, 0.1], # 0
|
||||
[0.4, 0.2, 0.1, 0.2, 0.1], # 0
|
||||
[0.1, 0.6, 0.1, 0.1, 0.1], # 1
|
||||
[0.2, 0.2, 0.2, 0.3, 0.1], # 0
|
||||
[0.4, 0.4, 0.1, 0.0, 0.1], # 1
|
||||
[0.4, 0.3, 0.1, 0.1, 0.1], # 0
|
||||
[0.4, 0.0, 0.5, 0.0, 0.1], # 1
|
||||
[0.4, 0.1, 0.3, 0.1, 0.1], # 1
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
UNIFORM_PROBS = torch.tensor([
|
||||
0.9,
|
||||
0.0,
|
||||
0.9,
|
||||
0.7,
|
||||
0.8,
|
||||
0.5,
|
||||
0.45,
|
||||
1.0,
|
||||
0.5,
|
||||
0.45,
|
||||
1.0,
|
||||
0.39,
|
||||
0.4,
|
||||
0.1,
|
||||
0.3,
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ),
|
||||
MAX_SPEC_LEN + 1,
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
RECOVERED_TOKEN_IDS = torch.full((NUM_TOKENS, ),
|
||||
MAX_SPEC_LEN,
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
IS_GREEDY = torch.zeros(BATCH_SIZE, dtype=torch.bool, device=DEVICE)
|
||||
IS_GREEDY[4] = True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS])
|
||||
@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS])
|
||||
@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS])
|
||||
@pytest.mark.parametrize("target_probs", [TARGET_PROBS])
|
||||
@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS])
|
||||
@pytest.mark.parametrize("recovered_token_ids", [RECOVERED_TOKEN_IDS])
|
||||
@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS])
|
||||
@pytest.mark.parametrize("is_greedy", [IS_GREEDY])
|
||||
@pytest.mark.parametrize("batch_size", [BATCH_SIZE])
|
||||
@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN])
|
||||
@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE])
|
||||
@torch.inference_mode()
|
||||
def test_rejection_sampler_block_verify_triton_kernel(
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
draft_probs, # [num_tokens, vocab_size] or None
|
||||
target_probs, # [num_tokens, vocab_size]
|
||||
bonus_token_ids, # [batch_size]
|
||||
recovered_token_ids, # [num_tokens]
|
||||
uniform_probs, # [num_tokens]
|
||||
is_greedy, # [batch_size]
|
||||
batch_size, # int
|
||||
max_spec_len, # int
|
||||
vocab_size, # int
|
||||
) -> None:
|
||||
|
||||
grid, block_size = cal_grid_and_block_size(batch_size)
|
||||
|
||||
output_token_ids_ref = torch.full((batch_size, max_spec_len + 1),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
|
||||
output_token_ids_triton = output_token_ids_ref.clone()
|
||||
|
||||
rejection_random_sample_block_verify_pytorch(
|
||||
output_token_ids=output_token_ids_ref,
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
draft_token_ids=draft_token_ids,
|
||||
draft_probs=draft_probs,
|
||||
target_probs=target_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
recovered_token_ids=recovered_token_ids,
|
||||
uniform_probs=uniform_probs,
|
||||
is_greedy=is_greedy,
|
||||
max_spec_len=max_spec_len,
|
||||
vocab_size=vocab_size,
|
||||
IS_NGRAM=draft_probs is None)
|
||||
|
||||
rejection_random_sample_block_verify_kernel[(grid, )](
|
||||
output_token_ids_ptr=output_token_ids_triton,
|
||||
cu_num_draft_tokens_ptr=cu_num_draft_tokens,
|
||||
draft_token_ids_ptr=draft_token_ids,
|
||||
draft_probs_ptr=draft_probs,
|
||||
target_probs_ptr=target_probs,
|
||||
bonus_token_ids_ptr=bonus_token_ids,
|
||||
recovered_token_ids_ptr=recovered_token_ids,
|
||||
uniform_probs_ptr=uniform_probs,
|
||||
is_greedy_ptr=is_greedy,
|
||||
max_spec_len=max_spec_len,
|
||||
vocab_size=vocab_size,
|
||||
vec_len=batch_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
BLOCK_SIZE=block_size)
|
||||
torch.npu.synchronize()
|
||||
assert torch.equal(output_token_ids_ref, output_token_ids_triton)
|
||||
|
||||
Reference in New Issue
Block a user