diff --git a/tests/e2e/nightly/ops/triton/test_rejection_sampler.py b/tests/e2e/nightly/ops/triton/test_rejection_sampler.py deleted file mode 100644 index 3820fd11..00000000 --- a/tests/e2e/nightly/ops/triton/test_rejection_sampler.py +++ /dev/null @@ -1,114 +0,0 @@ -import pytest -import torch -from torch.testing import assert_close - -from vllm_ascend.sample.rejection_sampler import ( - rejection_random_sample_block_verify_kernel, - rejection_random_sample_block_verify_pytorch) - -DEVICE = "npu" -BATCH_SIZE = 3 -MAX_SPEC_LEN = 3 -VOCAB_SIZE = 5 -NUM_TOKENS = BATCH_SIZE * MAX_SPEC_LEN -CU_NUM_DRAFT_TOKENS = torch.arange(start=MAX_SPEC_LEN, - end=NUM_TOKENS + 1, - step=MAX_SPEC_LEN, - dtype=torch.int32, - device=DEVICE) -DRAFT_TOKEN_IDS = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], - dtype=torch.int64, - device=DEVICE) -DRAFT_PROBS = None -TARGET_PROBS = torch.tensor( - [ - [0.2, 0.1, 0.2, 0.4, 0.1], # 0 - [0.1, 0.4, 0.1, 0.1, 0.3], # 0 - [0.2, 0.1, 0.4, 0.1, 0.2], # 0 - [0.4, 0.2, 0.1, 0.2, 0.1], # 0 - [0.1, 0.6, 0.1, 0.1, 0.1], # 1 - [0.2, 0.2, 0.2, 0.3, 0.1], # 0 - [0.4, 0.4, 0.1, 0.0, 0.1], # 1 - [0.4, 0.3, 0.1, 0.1, 0.1], # 0 - [0.4, 0.0, 0.5, 0.0, 0.1], # 1 - ], - dtype=torch.float32, - device=DEVICE) -UNIFORM_PROBS = torch.tensor([ - 0.9, - 0.7, - 0.8, - 0.5, - 0.45, - 1.0, - 0.39, - 0.4, - 0.1, -], - dtype=torch.float32, - device=DEVICE) -BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ), - MAX_SPEC_LEN + 1, - dtype=torch.int64, - device=DEVICE) -IS_GREEDY = torch.zeros(NUM_TOKENS, dtype=torch.bool, device=DEVICE) - - -@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS]) -@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS]) -@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS]) -@pytest.mark.parametrize("target_probs", [TARGET_PROBS]) -@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS]) -@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS]) -@pytest.mark.parametrize("is_greedy", [IS_GREEDY]) -@pytest.mark.parametrize("batch_size", [BATCH_SIZE]) -@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN]) -@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE]) -@torch.inference_mode() -def test_rejection_sampler_block_verify_triton_kernel( - cu_num_draft_tokens, # [batch_size] - draft_token_ids, # [num_tokens] - draft_probs, # [num_tokens, vocab_size] or None - target_probs, # [num_tokens, vocab_size] - bonus_token_ids, # [batch_size] - uniform_probs, # [num_tokens] - is_greedy, # [batch_size] - batch_size, # int - max_spec_len, # int - vocab_size, # int -) -> None: - output_token_ids_ref = torch.full((batch_size, max_spec_len + 1), - -1, - dtype=torch.int64, - device=DEVICE) - - output_token_ids_triton = output_token_ids_ref.clone() - - rejection_random_sample_block_verify_pytorch( - output_token_ids=output_token_ids_ref, - cu_num_draft_tokens=cu_num_draft_tokens, - draft_token_ids=draft_token_ids, - draft_probs=draft_probs, - target_probs=target_probs, - bonus_token_ids=bonus_token_ids, - uniform_probs=uniform_probs, - is_greedy=is_greedy, - max_spec_len=max_spec_len, - vocab_size=vocab_size, - IS_NGRAM=draft_probs is None) - - rejection_random_sample_block_verify_kernel[(batch_size, )]( - output_token_ids_ptr=output_token_ids_triton, - cu_num_draft_tokens_ptr=cu_num_draft_tokens, - draft_token_ids_ptr=draft_token_ids, - draft_probs_ptr=draft_probs, - target_probs_ptr=target_probs, - bonus_token_ids_ptr=bonus_token_ids, - uniform_probs_ptr=uniform_probs, - is_greedy_ptr=is_greedy, - max_spec_len=max_spec_len, - vocab_size=vocab_size, - NO_DRAFT_PROBS=draft_probs is None, - multibuffer=True) - - assert_close(output_token_ids_ref, output_token_ids_triton) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 7556438f..b0e6f848 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -114,9 +114,6 @@ def rejection_sample( assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) - # When num_speculative_tokens>=3, using block verify. - using_block_verify = max_spec_len >= 3 - # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), @@ -194,81 +191,52 @@ def rejection_sample( sampling_metadata.generators, device, ) - if not using_block_verify: - # Sample recovered tokens for each position. - # [num_tokens] - recovered_token_ids = sample_recovered_tokens( - max_spec_len, - num_draft_tokens, + + # Sample recovered tokens for each position. + # [num_tokens] + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device, + ) + + # Rejection sampling for random sampling requests. + if HAS_TRITON: + rejection_random_sample_kernel[(batch_size, )]( + output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, - sampling_metadata, - device, + bonus_token_ids, + recovered_token_ids, + uniform_probs.to(torch.float32), + is_greedy, + max_spec_len, + vocab_size, + NO_DRAFT_PROBS=draft_probs is None, ) - - # Rejection sampling for random sampling requests. - if HAS_TRITON: - rejection_random_sample_kernel[(batch_size, )]( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs.to(torch.float32), - is_greedy, - max_spec_len, - vocab_size, - NO_DRAFT_PROBS=draft_probs is None, - ) - else: - rejection_random_sample_pytorch( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - IS_NGRAM=draft_probs is None, - ) else: - # MagicMTP: Improving acceptance rate with Block Verify. - if HAS_TRITON: - rejection_random_sample_block_verify_kernel[(batch_size, )]( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - uniform_probs.to(torch.float32), - is_greedy, - max_spec_len, - vocab_size, - NO_DRAFT_PROBS=draft_probs is None, - multibuffer=True, - ) - else: - rejection_random_sample_block_verify_pytorch(output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - IS_NGRAM=draft_probs - is None) + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + # num_warps=1, + ) return output_token_ids @@ -532,71 +500,6 @@ def rejection_random_sample_pytorch( output_token_ids[req_idx, num_draft_tokens] = bonus_token_id -def rejection_random_sample_block_verify_pytorch( - output_token_ids, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens, # [batch_size] - draft_token_ids, # [num_tokens] - draft_probs, # [num_tokens, vocab_size] or None - target_probs, # [num_tokens, vocab_size] - bonus_token_ids, # [batch_size] - uniform_probs, # [num_tokens] - is_greedy, # [batch_size] - max_spec_len, - vocab_size, - IS_NGRAM=False, -): - batch_size = output_token_ids.shape[0] - - for req_idx in range(batch_size): - if is_greedy[req_idx]: - continue - - if req_idx == 0: - start_idx = 0 - else: - start_idx = cu_num_draft_tokens[req_idx - 1].item() - end_idx = cu_num_draft_tokens[req_idx].item() - num_draft_tokens = end_idx - start_idx - - rejected = False - pi = 1.0 - uniform_prob = 1.0 - last_accepted_token_pos = -1 - for pos in range(num_draft_tokens): - draft_token_id = draft_token_ids[start_idx + pos].item() - - target_prob = target_probs[start_idx + pos, draft_token_id].item() - uniform_prob = uniform_prob * uniform_probs[start_idx + pos].item() - - if IS_NGRAM: - draft_prob = 1.0 - else: - draft_prob = draft_probs[start_idx + pos, - draft_token_id].item() - - pi = min(pi * target_prob / draft_prob, 1.0) - - if draft_prob > 0 and pi >= uniform_prob: - last_accepted_token_pos = pos - rejected = False - else: - rejected = True - - if last_accepted_token_pos > -1: - for pos in range(last_accepted_token_pos + 1): - draft_token_id = draft_token_ids[start_idx + pos].item() - output_token_ids[req_idx, pos] = draft_token_id - - if rejected: - recovered_token_id = torch.argmax( - target_probs[start_idx + last_accepted_token_pos + 1]).item() - output_token_ids[req_idx, - last_accepted_token_pos + 1] = recovered_token_id - else: - bonus_token_id = bonus_token_ids[req_idx].item() - output_token_ids[req_idx, num_draft_tokens] = bonus_token_id - - def expand_pytorch( output_ptr, # [num_tokens] input_ptr, # [batch_size] @@ -834,92 +737,6 @@ def rejection_random_sample_kernel( ) -@triton.jit(do_not_specialize=["max_spec_len"]) -def rejection_random_sample_block_verify_kernel( - output_token_ids_ptr, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens_ptr, # [batch_size] - draft_token_ids_ptr, # [num_tokens] - draft_probs_ptr, # [num_tokens, vocab_size] or None - target_probs_ptr, # [num_tokens, vocab_size] - bonus_token_ids_ptr, # [batch_size] - uniform_probs_ptr, # [num_tokens] - is_greedy_ptr, # [batch_size] - max_spec_len, - vocab_size, - NO_DRAFT_PROBS: tl.constexpr, - SUB_BLOCK: tl.constexpr = 1500, -): - req_idx = tl.program_id(0) - is_greedy = tl.load(is_greedy_ptr + req_idx) - if is_greedy: - # Early exit for greedy sampling requests. - return - - start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + - req_idx - 1) - end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) - num_draft_tokens = end_idx - start_idx - - rejected = False - pi = 1.0 - uniform_prob = 1.0 - last_accepted_token_pos = -1 - - for pos in range(num_draft_tokens): - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + draft_token_id) - tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) - uniform_prob = uniform_prob * tmp_uniform_prob - - if NO_DRAFT_PROBS: - draft_prob = 1 - else: - draft_prob = tl.load(draft_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) - - pi = min(pi * target_prob / draft_prob, 1.0) - if draft_prob > 0 and pi >= uniform_prob: - last_accepted_token_pos = pos - rejected = False - else: - rejected = True - - if last_accepted_token_pos > -1: - for pos in range(last_accepted_token_pos + 1): - token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, - token_id) - - if rejected: - loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK - global_recovered_id = -1 - global_max_p = -1.0 - for loop_i in range(loop): - vocab_start = loop_i * SUB_BLOCK - vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) - tmp_target_prob = tl.load( - target_probs_ptr + - (start_idx + last_accepted_token_pos + 1) * vocab_size + - vocab_offset, - mask=vocab_offset < vocab_size, - other=0) - recovered_id = tl.argmax(tmp_target_prob, axis=-1) - max_p = tl.get_element(tmp_target_prob, (recovered_id, )) - if max_p > global_max_p: - global_max_p = max_p - global_recovered_id = vocab_start + recovered_id - tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - last_accepted_token_pos + 1, global_recovered_id) - else: - bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) - tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, bonus_token_id) - - @triton.jit(do_not_specialize=["replace_from", "replace_to"]) def expand_kernel( output_ptr, # [num_tokens]