[model_runner_v2]optimize the performance of the post_update. (#7496)

### What this PR does / why we need it?
- This PR aims to enhance the operator performance in the `post_update`
phase of `model_runner_v2` on NPUs. By optimizing the relevant
operations, it is expected to improve the overall efficiency and speed
of the model running on NPU hardware, which is crucial for scenarios
where high-performance inference is required.
- when bs = 256, time cost reduce from 26us to 11 us; 

### Does this PR introduce _any_ user-facing change?
No, there are no changes to the API, interface, or other high-level
behaviors that would directly affect the user's code or interaction with
the system beyond the performance improvement.

### How was this patch tested?
CI passed with new added/existing tests. In addition to the regular CI
tests, specific benchmark tests were conducted on NPU hardware to
measure the performance improvement of the `post_update` operators.

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2026-03-23 20:29:55 +08:00
committed by GitHub
parent 170dcbda62
commit bdd90c0088
2 changed files with 203 additions and 0 deletions

View File

@@ -20,9 +20,11 @@ from dataclasses import asdict, dataclass
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
class AscendInputBuffers(InputBuffers):
@@ -101,3 +103,107 @@ class AscendInputBatch(InputBatch):
# we can also set attn_state to AscendAttentionState.DecodeOnly.
input_batch.attn_state = AscendAttentionState.DecodeOnly
return cls(**asdict(input_batch), seq_lens_np=seq_lens_np)
@triton.jit
def _post_update_kernel(
idx_mapping_ptr,
idx_mapping_stride,
num_computed_tokens_ptr,
last_sampled_tokens_ptr,
output_bin_counts_ptr,
output_bin_counts_stride,
sampled_tokens_ptr,
sampled_tokens_stride,
num_rows,
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
all_token_ids_ptr,
all_token_ids_stride,
total_len_ptr,
):
pid = tl.program_id(0)
n_programs = tl.num_programs(0)
rows_per_program = (num_rows + n_programs - 1) // n_programs
start_row = pid * rows_per_program
end_row = tl.minimum(start_row + rows_per_program, num_rows)
for row_idx in range(start_row, end_row):
req_state_idx = tl.load(idx_mapping_ptr + row_idx * idx_mapping_stride)
total_len = tl.load(total_len_ptr + req_state_idx)
num_sampled = tl.load(num_sampled_ptr + row_idx)
if num_sampled > 0:
token_id = tl.load(sampled_tokens_ptr + row_idx * sampled_tokens_stride + num_sampled - 1)
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
tl.store(total_len_ptr + req_state_idx, total_len + num_sampled)
for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + row_idx * sampled_tokens_stride + i)
token_ptr = output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
tl.store(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
token_id,
)
query_start = tl.load(query_start_loc_ptr + row_idx)
query_end = tl.load(query_start_loc_ptr + row_idx + 1)
query_len = query_end - query_start
num_rejected = tl.load(num_rejected_ptr + row_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
num_computed += query_len - num_rejected
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
def post_update(
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
# [max_num_reqs, max_model_len]
all_token_ids: torch.Tensor,
# [max_num_reqs]
total_len: torch.Tensor,
) -> None:
num_rows = idx_mapping.shape[0]
core_num = get_vectorcore_num()
grid = (min(num_rows, core_num),)
_post_update_kernel[grid](
idx_mapping,
idx_mapping.stride(0),
num_computed_tokens,
last_sampled_tokens,
output_bin_counts,
output_bin_counts.stride(0),
sampled_tokens,
sampled_tokens.stride(0),
num_rows,
num_sampled,
num_rejected,
query_start_loc,
all_token_ids,
all_token_ids.stride(0),
total_len,
)