### What this PR does / why we need it?
#4443 introduces a precision issue in scenarios where MTP >= 3 + deepseek v3.1, and this pr reverts it
- vLLM version: release/v0.13.0
- vLLM main:
bc0a5a0c08
Signed-off-by: GDzhu01 <809721801@qq.com>
This commit is contained in:
@@ -1,114 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from torch.testing import assert_close
|
|
||||||
|
|
||||||
from vllm_ascend.sample.rejection_sampler import (
|
|
||||||
rejection_random_sample_block_verify_kernel,
|
|
||||||
rejection_random_sample_block_verify_pytorch)
|
|
||||||
|
|
||||||
DEVICE = "npu"
|
|
||||||
BATCH_SIZE = 3
|
|
||||||
MAX_SPEC_LEN = 3
|
|
||||||
VOCAB_SIZE = 5
|
|
||||||
NUM_TOKENS = BATCH_SIZE * MAX_SPEC_LEN
|
|
||||||
CU_NUM_DRAFT_TOKENS = torch.arange(start=MAX_SPEC_LEN,
|
|
||||||
end=NUM_TOKENS + 1,
|
|
||||||
step=MAX_SPEC_LEN,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=DEVICE)
|
|
||||||
DRAFT_TOKEN_IDS = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2],
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=DEVICE)
|
|
||||||
DRAFT_PROBS = None
|
|
||||||
TARGET_PROBS = torch.tensor(
|
|
||||||
[
|
|
||||||
[0.2, 0.1, 0.2, 0.4, 0.1], # 0
|
|
||||||
[0.1, 0.4, 0.1, 0.1, 0.3], # 0
|
|
||||||
[0.2, 0.1, 0.4, 0.1, 0.2], # 0
|
|
||||||
[0.4, 0.2, 0.1, 0.2, 0.1], # 0
|
|
||||||
[0.1, 0.6, 0.1, 0.1, 0.1], # 1
|
|
||||||
[0.2, 0.2, 0.2, 0.3, 0.1], # 0
|
|
||||||
[0.4, 0.4, 0.1, 0.0, 0.1], # 1
|
|
||||||
[0.4, 0.3, 0.1, 0.1, 0.1], # 0
|
|
||||||
[0.4, 0.0, 0.5, 0.0, 0.1], # 1
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=DEVICE)
|
|
||||||
UNIFORM_PROBS = torch.tensor([
|
|
||||||
0.9,
|
|
||||||
0.7,
|
|
||||||
0.8,
|
|
||||||
0.5,
|
|
||||||
0.45,
|
|
||||||
1.0,
|
|
||||||
0.39,
|
|
||||||
0.4,
|
|
||||||
0.1,
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=DEVICE)
|
|
||||||
BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ),
|
|
||||||
MAX_SPEC_LEN + 1,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=DEVICE)
|
|
||||||
IS_GREEDY = torch.zeros(NUM_TOKENS, dtype=torch.bool, device=DEVICE)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS])
|
|
||||||
@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS])
|
|
||||||
@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS])
|
|
||||||
@pytest.mark.parametrize("target_probs", [TARGET_PROBS])
|
|
||||||
@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS])
|
|
||||||
@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS])
|
|
||||||
@pytest.mark.parametrize("is_greedy", [IS_GREEDY])
|
|
||||||
@pytest.mark.parametrize("batch_size", [BATCH_SIZE])
|
|
||||||
@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN])
|
|
||||||
@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE])
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_rejection_sampler_block_verify_triton_kernel(
|
|
||||||
cu_num_draft_tokens, # [batch_size]
|
|
||||||
draft_token_ids, # [num_tokens]
|
|
||||||
draft_probs, # [num_tokens, vocab_size] or None
|
|
||||||
target_probs, # [num_tokens, vocab_size]
|
|
||||||
bonus_token_ids, # [batch_size]
|
|
||||||
uniform_probs, # [num_tokens]
|
|
||||||
is_greedy, # [batch_size]
|
|
||||||
batch_size, # int
|
|
||||||
max_spec_len, # int
|
|
||||||
vocab_size, # int
|
|
||||||
) -> None:
|
|
||||||
output_token_ids_ref = torch.full((batch_size, max_spec_len + 1),
|
|
||||||
-1,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=DEVICE)
|
|
||||||
|
|
||||||
output_token_ids_triton = output_token_ids_ref.clone()
|
|
||||||
|
|
||||||
rejection_random_sample_block_verify_pytorch(
|
|
||||||
output_token_ids=output_token_ids_ref,
|
|
||||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
|
||||||
draft_token_ids=draft_token_ids,
|
|
||||||
draft_probs=draft_probs,
|
|
||||||
target_probs=target_probs,
|
|
||||||
bonus_token_ids=bonus_token_ids,
|
|
||||||
uniform_probs=uniform_probs,
|
|
||||||
is_greedy=is_greedy,
|
|
||||||
max_spec_len=max_spec_len,
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
IS_NGRAM=draft_probs is None)
|
|
||||||
|
|
||||||
rejection_random_sample_block_verify_kernel[(batch_size, )](
|
|
||||||
output_token_ids_ptr=output_token_ids_triton,
|
|
||||||
cu_num_draft_tokens_ptr=cu_num_draft_tokens,
|
|
||||||
draft_token_ids_ptr=draft_token_ids,
|
|
||||||
draft_probs_ptr=draft_probs,
|
|
||||||
target_probs_ptr=target_probs,
|
|
||||||
bonus_token_ids_ptr=bonus_token_ids,
|
|
||||||
uniform_probs_ptr=uniform_probs,
|
|
||||||
is_greedy_ptr=is_greedy,
|
|
||||||
max_spec_len=max_spec_len,
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
NO_DRAFT_PROBS=draft_probs is None,
|
|
||||||
multibuffer=True)
|
|
||||||
|
|
||||||
assert_close(output_token_ids_ref, output_token_ids_triton)
|
|
||||||
@@ -114,9 +114,6 @@ def rejection_sample(
|
|||||||
assert bonus_token_ids.is_contiguous()
|
assert bonus_token_ids.is_contiguous()
|
||||||
assert target_probs.shape == (num_tokens, vocab_size)
|
assert target_probs.shape == (num_tokens, vocab_size)
|
||||||
|
|
||||||
# When num_speculative_tokens>=3, using block verify.
|
|
||||||
using_block_verify = max_spec_len >= 3
|
|
||||||
|
|
||||||
# Create output buffer.
|
# Create output buffer.
|
||||||
output_token_ids = torch.empty(
|
output_token_ids = torch.empty(
|
||||||
(batch_size, max_spec_len + 1),
|
(batch_size, max_spec_len + 1),
|
||||||
@@ -194,7 +191,7 @@ def rejection_sample(
|
|||||||
sampling_metadata.generators,
|
sampling_metadata.generators,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
if not using_block_verify:
|
|
||||||
# Sample recovered tokens for each position.
|
# Sample recovered tokens for each position.
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
recovered_token_ids = sample_recovered_tokens(
|
recovered_token_ids = sample_recovered_tokens(
|
||||||
@@ -238,37 +235,8 @@ def rejection_sample(
|
|||||||
max_spec_len,
|
max_spec_len,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
IS_NGRAM=draft_probs is None,
|
IS_NGRAM=draft_probs is None,
|
||||||
|
# num_warps=1,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# MagicMTP: Improving acceptance rate with Block Verify.
|
|
||||||
if HAS_TRITON:
|
|
||||||
rejection_random_sample_block_verify_kernel[(batch_size, )](
|
|
||||||
output_token_ids,
|
|
||||||
cu_num_draft_tokens,
|
|
||||||
draft_token_ids,
|
|
||||||
draft_probs,
|
|
||||||
target_probs,
|
|
||||||
bonus_token_ids,
|
|
||||||
uniform_probs.to(torch.float32),
|
|
||||||
is_greedy,
|
|
||||||
max_spec_len,
|
|
||||||
vocab_size,
|
|
||||||
NO_DRAFT_PROBS=draft_probs is None,
|
|
||||||
multibuffer=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rejection_random_sample_block_verify_pytorch(output_token_ids,
|
|
||||||
cu_num_draft_tokens,
|
|
||||||
draft_token_ids,
|
|
||||||
draft_probs,
|
|
||||||
target_probs,
|
|
||||||
bonus_token_ids,
|
|
||||||
uniform_probs,
|
|
||||||
is_greedy,
|
|
||||||
max_spec_len,
|
|
||||||
vocab_size,
|
|
||||||
IS_NGRAM=draft_probs
|
|
||||||
is None)
|
|
||||||
return output_token_ids
|
return output_token_ids
|
||||||
|
|
||||||
|
|
||||||
@@ -532,71 +500,6 @@ def rejection_random_sample_pytorch(
|
|||||||
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
|
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
|
||||||
|
|
||||||
|
|
||||||
def rejection_random_sample_block_verify_pytorch(
|
|
||||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
|
||||||
cu_num_draft_tokens, # [batch_size]
|
|
||||||
draft_token_ids, # [num_tokens]
|
|
||||||
draft_probs, # [num_tokens, vocab_size] or None
|
|
||||||
target_probs, # [num_tokens, vocab_size]
|
|
||||||
bonus_token_ids, # [batch_size]
|
|
||||||
uniform_probs, # [num_tokens]
|
|
||||||
is_greedy, # [batch_size]
|
|
||||||
max_spec_len,
|
|
||||||
vocab_size,
|
|
||||||
IS_NGRAM=False,
|
|
||||||
):
|
|
||||||
batch_size = output_token_ids.shape[0]
|
|
||||||
|
|
||||||
for req_idx in range(batch_size):
|
|
||||||
if is_greedy[req_idx]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if req_idx == 0:
|
|
||||||
start_idx = 0
|
|
||||||
else:
|
|
||||||
start_idx = cu_num_draft_tokens[req_idx - 1].item()
|
|
||||||
end_idx = cu_num_draft_tokens[req_idx].item()
|
|
||||||
num_draft_tokens = end_idx - start_idx
|
|
||||||
|
|
||||||
rejected = False
|
|
||||||
pi = 1.0
|
|
||||||
uniform_prob = 1.0
|
|
||||||
last_accepted_token_pos = -1
|
|
||||||
for pos in range(num_draft_tokens):
|
|
||||||
draft_token_id = draft_token_ids[start_idx + pos].item()
|
|
||||||
|
|
||||||
target_prob = target_probs[start_idx + pos, draft_token_id].item()
|
|
||||||
uniform_prob = uniform_prob * uniform_probs[start_idx + pos].item()
|
|
||||||
|
|
||||||
if IS_NGRAM:
|
|
||||||
draft_prob = 1.0
|
|
||||||
else:
|
|
||||||
draft_prob = draft_probs[start_idx + pos,
|
|
||||||
draft_token_id].item()
|
|
||||||
|
|
||||||
pi = min(pi * target_prob / draft_prob, 1.0)
|
|
||||||
|
|
||||||
if draft_prob > 0 and pi >= uniform_prob:
|
|
||||||
last_accepted_token_pos = pos
|
|
||||||
rejected = False
|
|
||||||
else:
|
|
||||||
rejected = True
|
|
||||||
|
|
||||||
if last_accepted_token_pos > -1:
|
|
||||||
for pos in range(last_accepted_token_pos + 1):
|
|
||||||
draft_token_id = draft_token_ids[start_idx + pos].item()
|
|
||||||
output_token_ids[req_idx, pos] = draft_token_id
|
|
||||||
|
|
||||||
if rejected:
|
|
||||||
recovered_token_id = torch.argmax(
|
|
||||||
target_probs[start_idx + last_accepted_token_pos + 1]).item()
|
|
||||||
output_token_ids[req_idx,
|
|
||||||
last_accepted_token_pos + 1] = recovered_token_id
|
|
||||||
else:
|
|
||||||
bonus_token_id = bonus_token_ids[req_idx].item()
|
|
||||||
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
|
|
||||||
|
|
||||||
|
|
||||||
def expand_pytorch(
|
def expand_pytorch(
|
||||||
output_ptr, # [num_tokens]
|
output_ptr, # [num_tokens]
|
||||||
input_ptr, # [batch_size]
|
input_ptr, # [batch_size]
|
||||||
@@ -834,92 +737,6 @@ def rejection_random_sample_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
|
||||||
def rejection_random_sample_block_verify_kernel(
|
|
||||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
|
||||||
cu_num_draft_tokens_ptr, # [batch_size]
|
|
||||||
draft_token_ids_ptr, # [num_tokens]
|
|
||||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
|
||||||
target_probs_ptr, # [num_tokens, vocab_size]
|
|
||||||
bonus_token_ids_ptr, # [batch_size]
|
|
||||||
uniform_probs_ptr, # [num_tokens]
|
|
||||||
is_greedy_ptr, # [batch_size]
|
|
||||||
max_spec_len,
|
|
||||||
vocab_size,
|
|
||||||
NO_DRAFT_PROBS: tl.constexpr,
|
|
||||||
SUB_BLOCK: tl.constexpr = 1500,
|
|
||||||
):
|
|
||||||
req_idx = tl.program_id(0)
|
|
||||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
|
||||||
if is_greedy:
|
|
||||||
# Early exit for greedy sampling requests.
|
|
||||||
return
|
|
||||||
|
|
||||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
|
|
||||||
req_idx - 1)
|
|
||||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
|
||||||
num_draft_tokens = end_idx - start_idx
|
|
||||||
|
|
||||||
rejected = False
|
|
||||||
pi = 1.0
|
|
||||||
uniform_prob = 1.0
|
|
||||||
last_accepted_token_pos = -1
|
|
||||||
|
|
||||||
for pos in range(num_draft_tokens):
|
|
||||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
|
||||||
target_prob = tl.load(target_probs_ptr +
|
|
||||||
(start_idx + pos) * vocab_size + draft_token_id)
|
|
||||||
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
|
||||||
uniform_prob = uniform_prob * tmp_uniform_prob
|
|
||||||
|
|
||||||
if NO_DRAFT_PROBS:
|
|
||||||
draft_prob = 1
|
|
||||||
else:
|
|
||||||
draft_prob = tl.load(draft_probs_ptr +
|
|
||||||
(start_idx + pos) * vocab_size +
|
|
||||||
draft_token_id)
|
|
||||||
|
|
||||||
pi = min(pi * target_prob / draft_prob, 1.0)
|
|
||||||
if draft_prob > 0 and pi >= uniform_prob:
|
|
||||||
last_accepted_token_pos = pos
|
|
||||||
rejected = False
|
|
||||||
else:
|
|
||||||
rejected = True
|
|
||||||
|
|
||||||
if last_accepted_token_pos > -1:
|
|
||||||
for pos in range(last_accepted_token_pos + 1):
|
|
||||||
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
|
||||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
|
||||||
token_id)
|
|
||||||
|
|
||||||
if rejected:
|
|
||||||
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
|
|
||||||
global_recovered_id = -1
|
|
||||||
global_max_p = -1.0
|
|
||||||
for loop_i in range(loop):
|
|
||||||
vocab_start = loop_i * SUB_BLOCK
|
|
||||||
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
|
|
||||||
tmp_target_prob = tl.load(
|
|
||||||
target_probs_ptr +
|
|
||||||
(start_idx + last_accepted_token_pos + 1) * vocab_size +
|
|
||||||
vocab_offset,
|
|
||||||
mask=vocab_offset < vocab_size,
|
|
||||||
other=0)
|
|
||||||
recovered_id = tl.argmax(tmp_target_prob, axis=-1)
|
|
||||||
max_p = tl.get_element(tmp_target_prob, (recovered_id, ))
|
|
||||||
if max_p > global_max_p:
|
|
||||||
global_max_p = max_p
|
|
||||||
global_recovered_id = vocab_start + recovered_id
|
|
||||||
tl.store(
|
|
||||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
|
||||||
last_accepted_token_pos + 1, global_recovered_id)
|
|
||||||
else:
|
|
||||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
|
||||||
tl.store(
|
|
||||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
|
||||||
num_draft_tokens, bonus_token_id)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
||||||
def expand_kernel(
|
def expand_kernel(
|
||||||
output_ptr, # [num_tokens]
|
output_ptr, # [num_tokens]
|
||||||
|
|||||||
Reference in New Issue
Block a user