[model_runner_v2]optimize the performance of the post_update. (#7496)
### What this PR does / why we need it? - This PR aims to enhance the operator performance in the `post_update` phase of `model_runner_v2` on NPUs. By optimizing the relevant operations, it is expected to improve the overall efficiency and speed of the model running on NPU hardware, which is crucial for scenarios where high-performance inference is required. - when bs = 256, time cost reduce from 26us to 11 us; ### Does this PR introduce _any_ user-facing change? No, there are no changes to the API, interface, or other high-level behaviors that would directly affect the user's code or interaction with the system beyond the performance improvement. ### How was this patch tested? CI passed with new added/existing tests. In addition to the regular CI tests, specific benchmark tests were conducted on NPU hardware to measure the performance improvement of the `post_update` operators. --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
from typing import Dict, Any
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
from vllm.v1.worker.gpu.input_batch import post_update as post_update_gpu
|
||||
from vllm_ascend.worker.v2.input_batch import post_update as post_update_npu
|
||||
|
||||
|
||||
def generate_test_data(num_reqs: int, max_num_reqs: int, vocab_size: int, num_speculative_steps: int, device: str) -> \
|
||||
Dict[str, Any]:
|
||||
"""
|
||||
Generate random test data.
|
||||
Return a dictionary containing all input tensors and the additional field 'expected_query_lens' for validation.
|
||||
"""
|
||||
num_cols = num_speculative_steps + 1
|
||||
|
||||
if num_reqs > max_num_reqs:
|
||||
raise ValueError("num_reqs cannot be larger than max_num_reqs")
|
||||
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
num_computed_tokens = torch.randint(0, 100, (max_num_reqs,), dtype=torch.int32, device=device)
|
||||
last_sampled_tokens = torch.randint(0, vocab_size, (max_num_reqs,), dtype=torch.int32, device=device)
|
||||
output_bin_counts = torch.randint(0, 10, (max_num_reqs, vocab_size), dtype=torch.int32, device=device)
|
||||
sampled_tokens = torch.randint(0, vocab_size, (num_reqs, num_speculative_steps + 1), dtype=torch.int32,
|
||||
device=device)
|
||||
num_sampled = torch.randint(1, num_speculative_steps + 2, (num_reqs,), dtype=torch.int32, device=device)
|
||||
num_rejected = torch.randint(0, num_speculative_steps + 1, (num_reqs,), dtype=torch.int32, device=device)
|
||||
num_rejected = torch.min(num_rejected, num_sampled - 1)
|
||||
|
||||
query_lengths = torch.randint(1, 20, (num_reqs,), dtype=torch.int32, device=device)
|
||||
query_start_loc = torch.cat([
|
||||
torch.tensor([0], dtype=torch.int32, device=device),
|
||||
torch.cumsum(query_lengths, dim=0)
|
||||
])
|
||||
total_len = torch.randint(50, 200, (max_num_reqs,), dtype=torch.int32, device=device)
|
||||
|
||||
max_model_len = 3000 # 或者可以从total_len的最大值获取
|
||||
all_token_ids = torch.randint(0, vocab_size, (max_num_reqs, max_model_len), dtype=torch.int32, device=device)
|
||||
|
||||
return {
|
||||
"idx_mapping": idx_mapping,
|
||||
"num_computed_tokens": num_computed_tokens,
|
||||
"last_sampled_tokens": last_sampled_tokens,
|
||||
"output_bin_counts": output_bin_counts,
|
||||
"sampled_tokens": sampled_tokens,
|
||||
"num_sampled": num_sampled,
|
||||
"num_rejected": num_rejected,
|
||||
"query_start_loc": query_start_loc,
|
||||
"all_token_ids": all_token_ids,
|
||||
"total_len": total_len
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_reqs,max_num_reqs,vocab_size,num_speculative_steps", [
|
||||
(36, 36, 200, 2),
|
||||
(48, 48, 32000, 5),
|
||||
(128, 128, 32000, 5),
|
||||
])
|
||||
def test_post_update(num_reqs: int, max_num_reqs: int, vocab_size: int, num_speculative_steps: int):
|
||||
"""Test _topk_log_softmax_kernel for computing log probabilities
|
||||
Args:
|
||||
batch_size: Number of sequences in the batch
|
||||
vocab_size: Size of the vocabulary
|
||||
num_logprobs: Number of tokens to compute log probabilities for
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
|
||||
post_update_params = ["idx_mapping",
|
||||
"num_computed_tokens",
|
||||
"last_sampled_tokens",
|
||||
"output_bin_counts",
|
||||
"sampled_tokens",
|
||||
"num_sampled",
|
||||
"num_rejected",
|
||||
"query_start_loc",
|
||||
"all_token_ids",
|
||||
"total_len"
|
||||
]
|
||||
|
||||
data = generate_test_data(num_reqs, max_num_reqs, vocab_size, num_speculative_steps, device="npu")
|
||||
kernel_inputs_gpu = {k: data[k].clone() for k in post_update_params}
|
||||
kernel_inputs_npu = {k: data[k].clone() for k in post_update_params}
|
||||
|
||||
# Invoke Triton kernel
|
||||
post_update_gpu(**kernel_inputs_gpu)
|
||||
torch.npu.synchronize()
|
||||
|
||||
post_update_npu(**kernel_inputs_npu)
|
||||
torch.npu.synchronize()
|
||||
|
||||
# ========== Verify results ==========
|
||||
assert torch.allclose(kernel_inputs_gpu["output_bin_counts"], kernel_inputs_npu["output_bin_counts"], rtol=1e-3,
|
||||
atol=1e-3), \
|
||||
f"Triton output differs from PyTorch reference.\n" \
|
||||
f"Max diff: {torch.max(torch.abs(kernel_inputs_npu['output_bin_counts'] - kernel_inputs_npu['output_bin_counts']))}\n" \
|
||||
f"Mean diff: {torch.mean(torch.abs(kernel_inputs_npu['output_bin_counts'] - kernel_inputs_npu['output_bin_counts']))}"
|
||||
|
||||
@@ -20,9 +20,11 @@ from dataclasses import asdict, dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
|
||||
|
||||
class AscendInputBuffers(InputBuffers):
|
||||
@@ -101,3 +103,107 @@ class AscendInputBatch(InputBatch):
|
||||
# we can also set attn_state to AscendAttentionState.DecodeOnly.
|
||||
input_batch.attn_state = AscendAttentionState.DecodeOnly
|
||||
return cls(**asdict(input_batch), seq_lens_np=seq_lens_np)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _post_update_kernel(
|
||||
idx_mapping_ptr,
|
||||
idx_mapping_stride,
|
||||
num_computed_tokens_ptr,
|
||||
last_sampled_tokens_ptr,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
sampled_tokens_ptr,
|
||||
sampled_tokens_stride,
|
||||
num_rows,
|
||||
num_sampled_ptr,
|
||||
num_rejected_ptr,
|
||||
query_start_loc_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
total_len_ptr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
n_programs = tl.num_programs(0)
|
||||
|
||||
rows_per_program = (num_rows + n_programs - 1) // n_programs
|
||||
start_row = pid * rows_per_program
|
||||
end_row = tl.minimum(start_row + rows_per_program, num_rows)
|
||||
|
||||
for row_idx in range(start_row, end_row):
|
||||
req_state_idx = tl.load(idx_mapping_ptr + row_idx * idx_mapping_stride)
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
num_sampled = tl.load(num_sampled_ptr + row_idx)
|
||||
|
||||
if num_sampled > 0:
|
||||
token_id = tl.load(sampled_tokens_ptr + row_idx * sampled_tokens_stride + num_sampled - 1)
|
||||
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
|
||||
tl.store(total_len_ptr + req_state_idx, total_len + num_sampled)
|
||||
|
||||
for i in range(num_sampled):
|
||||
token_id = tl.load(sampled_tokens_ptr + row_idx * sampled_tokens_stride + i)
|
||||
|
||||
token_ptr = output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
|
||||
count = tl.load(token_ptr)
|
||||
count += 1
|
||||
tl.store(token_ptr, count)
|
||||
|
||||
tl.store(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
|
||||
token_id,
|
||||
)
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + row_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + row_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
num_rejected = tl.load(num_rejected_ptr + row_idx)
|
||||
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
num_computed += query_len - num_rejected
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
|
||||
|
||||
|
||||
def post_update(
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
last_sampled_tokens: torch.Tensor,
|
||||
# [max_num_reqs, vocab_size]
|
||||
output_bin_counts: torch.Tensor,
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_tokens: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_rejected: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
# [max_num_reqs, max_model_len]
|
||||
all_token_ids: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
total_len: torch.Tensor,
|
||||
) -> None:
|
||||
num_rows = idx_mapping.shape[0]
|
||||
|
||||
core_num = get_vectorcore_num()
|
||||
|
||||
grid = (min(num_rows, core_num),)
|
||||
_post_update_kernel[grid](
|
||||
idx_mapping,
|
||||
idx_mapping.stride(0),
|
||||
num_computed_tokens,
|
||||
last_sampled_tokens,
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
sampled_tokens,
|
||||
sampled_tokens.stride(0),
|
||||
num_rows,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
query_start_loc,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
total_len,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user