diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_post_update.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_post_update.py new file mode 100644 index 00000000..ca1e9413 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_post_update.py @@ -0,0 +1,97 @@ +from typing import Dict, Any + +import torch +import pytest +from vllm.v1.worker.gpu.input_batch import post_update as post_update_gpu +from vllm_ascend.worker.v2.input_batch import post_update as post_update_npu + + +def generate_test_data(num_reqs: int, max_num_reqs: int, vocab_size: int, num_speculative_steps: int, device: str) -> \ + Dict[str, Any]: + """ + Generate random test data. + Return a dictionary containing all input tensors and the additional field 'expected_query_lens' for validation. + """ + num_cols = num_speculative_steps + 1 + + if num_reqs > max_num_reqs: + raise ValueError("num_reqs cannot be larger than max_num_reqs") + + idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) + num_computed_tokens = torch.randint(0, 100, (max_num_reqs,), dtype=torch.int32, device=device) + last_sampled_tokens = torch.randint(0, vocab_size, (max_num_reqs,), dtype=torch.int32, device=device) + output_bin_counts = torch.randint(0, 10, (max_num_reqs, vocab_size), dtype=torch.int32, device=device) + sampled_tokens = torch.randint(0, vocab_size, (num_reqs, num_speculative_steps + 1), dtype=torch.int32, + device=device) + num_sampled = torch.randint(1, num_speculative_steps + 2, (num_reqs,), dtype=torch.int32, device=device) + num_rejected = torch.randint(0, num_speculative_steps + 1, (num_reqs,), dtype=torch.int32, device=device) + num_rejected = torch.min(num_rejected, num_sampled - 1) + + query_lengths = torch.randint(1, 20, (num_reqs,), dtype=torch.int32, device=device) + query_start_loc = torch.cat([ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(query_lengths, dim=0) + ]) + total_len = torch.randint(50, 200, (max_num_reqs,), dtype=torch.int32, device=device) + + max_model_len = 3000 # 或者可以从total_len的最大值获取 + all_token_ids = torch.randint(0, vocab_size, (max_num_reqs, max_model_len), dtype=torch.int32, device=device) + + return { + "idx_mapping": idx_mapping, + "num_computed_tokens": num_computed_tokens, + "last_sampled_tokens": last_sampled_tokens, + "output_bin_counts": output_bin_counts, + "sampled_tokens": sampled_tokens, + "num_sampled": num_sampled, + "num_rejected": num_rejected, + "query_start_loc": query_start_loc, + "all_token_ids": all_token_ids, + "total_len": total_len + } + + +@pytest.mark.parametrize("num_reqs,max_num_reqs,vocab_size,num_speculative_steps", [ + (36, 36, 200, 2), + (48, 48, 32000, 5), + (128, 128, 32000, 5), +]) +def test_post_update(num_reqs: int, max_num_reqs: int, vocab_size: int, num_speculative_steps: int): + """Test _topk_log_softmax_kernel for computing log probabilities + Args: + batch_size: Number of sequences in the batch + vocab_size: Size of the vocabulary + num_logprobs: Number of tokens to compute log probabilities for + """ + torch.manual_seed(42) + + post_update_params = ["idx_mapping", + "num_computed_tokens", + "last_sampled_tokens", + "output_bin_counts", + "sampled_tokens", + "num_sampled", + "num_rejected", + "query_start_loc", + "all_token_ids", + "total_len" + ] + + data = generate_test_data(num_reqs, max_num_reqs, vocab_size, num_speculative_steps, device="npu") + kernel_inputs_gpu = {k: data[k].clone() for k in post_update_params} + kernel_inputs_npu = {k: data[k].clone() for k in post_update_params} + + # Invoke Triton kernel + post_update_gpu(**kernel_inputs_gpu) + torch.npu.synchronize() + + post_update_npu(**kernel_inputs_npu) + torch.npu.synchronize() + + # ========== Verify results ========== + assert torch.allclose(kernel_inputs_gpu["output_bin_counts"], kernel_inputs_npu["output_bin_counts"], rtol=1e-3, + atol=1e-3), \ + f"Triton output differs from PyTorch reference.\n" \ + f"Max diff: {torch.max(torch.abs(kernel_inputs_npu['output_bin_counts'] - kernel_inputs_npu['output_bin_counts']))}\n" \ + f"Mean diff: {torch.mean(torch.abs(kernel_inputs_npu['output_bin_counts'] - kernel_inputs_npu['output_bin_counts']))}" + diff --git a/vllm_ascend/worker/v2/input_batch.py b/vllm_ascend/worker/v2/input_batch.py index 9a1ccea0..1c8e78d2 100644 --- a/vllm_ascend/worker/v2/input_batch.py +++ b/vllm_ascend/worker/v2/input_batch.py @@ -20,9 +20,11 @@ from dataclasses import asdict, dataclass import numpy as np import torch +from vllm.triton_utils import tl, triton from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num class AscendInputBuffers(InputBuffers): @@ -101,3 +103,107 @@ class AscendInputBatch(InputBatch): # we can also set attn_state to AscendAttentionState.DecodeOnly. input_batch.attn_state = AscendAttentionState.DecodeOnly return cls(**asdict(input_batch), seq_lens_np=seq_lens_np) + + +@triton.jit +def _post_update_kernel( + idx_mapping_ptr, + idx_mapping_stride, + num_computed_tokens_ptr, + last_sampled_tokens_ptr, + output_bin_counts_ptr, + output_bin_counts_stride, + sampled_tokens_ptr, + sampled_tokens_stride, + num_rows, + num_sampled_ptr, + num_rejected_ptr, + query_start_loc_ptr, + all_token_ids_ptr, + all_token_ids_stride, + total_len_ptr, +): + pid = tl.program_id(0) + n_programs = tl.num_programs(0) + + rows_per_program = (num_rows + n_programs - 1) // n_programs + start_row = pid * rows_per_program + end_row = tl.minimum(start_row + rows_per_program, num_rows) + + for row_idx in range(start_row, end_row): + req_state_idx = tl.load(idx_mapping_ptr + row_idx * idx_mapping_stride) + total_len = tl.load(total_len_ptr + req_state_idx) + num_sampled = tl.load(num_sampled_ptr + row_idx) + + if num_sampled > 0: + token_id = tl.load(sampled_tokens_ptr + row_idx * sampled_tokens_stride + num_sampled - 1) + tl.store(last_sampled_tokens_ptr + req_state_idx, token_id) + tl.store(total_len_ptr + req_state_idx, total_len + num_sampled) + + for i in range(num_sampled): + token_id = tl.load(sampled_tokens_ptr + row_idx * sampled_tokens_stride + i) + + token_ptr = output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id + count = tl.load(token_ptr) + count += 1 + tl.store(token_ptr, count) + + tl.store( + all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i, + token_id, + ) + + query_start = tl.load(query_start_loc_ptr + row_idx) + query_end = tl.load(query_start_loc_ptr + row_idx + 1) + query_len = query_end - query_start + num_rejected = tl.load(num_rejected_ptr + row_idx) + + num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) + num_computed += query_len - num_rejected + tl.store(num_computed_tokens_ptr + req_state_idx, num_computed) + + +def post_update( + # [num_reqs] + idx_mapping: torch.Tensor, + # [max_num_reqs] + num_computed_tokens: torch.Tensor, + # [max_num_reqs] + last_sampled_tokens: torch.Tensor, + # [max_num_reqs, vocab_size] + output_bin_counts: torch.Tensor, + # [num_reqs, num_speculative_steps + 1] + sampled_tokens: torch.Tensor, + # [num_reqs] + num_sampled: torch.Tensor, + # [num_reqs] + num_rejected: torch.Tensor, + # [num_reqs + 1] + query_start_loc: torch.Tensor, + # [max_num_reqs, max_model_len] + all_token_ids: torch.Tensor, + # [max_num_reqs] + total_len: torch.Tensor, +) -> None: + num_rows = idx_mapping.shape[0] + + core_num = get_vectorcore_num() + + grid = (min(num_rows, core_num),) + _post_update_kernel[grid]( + idx_mapping, + idx_mapping.stride(0), + num_computed_tokens, + last_sampled_tokens, + output_bin_counts, + output_bin_counts.stride(0), + sampled_tokens, + sampled_tokens.stride(0), + num_rows, + num_sampled, + num_rejected, + query_start_loc, + all_token_ids, + all_token_ids.stride(0), + total_len, + )