[Feature] adapt to uva buffer and main2main (#6657)

### What this PR does / why we need it?
vllm model runner v2 use uva buffer to prepare input data, but npu
doesn't support uva yet, this pr implement a uvawrapper class to mimic
gpu's uva backend. what's more, this pr make some modifications to adapt
to the newer main branch.

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM main:
13397841ab

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-02-12 10:36:31 +08:00
committed by GitHub
parent 56269eae0e
commit f1ffb5fb19
14 changed files with 407 additions and 179 deletions

View File

@@ -30,6 +30,7 @@ def _gumbel_sample_kernel(
local_max_stride,
logits_ptr,
logits_stride,
idx_mapping_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
@@ -37,24 +38,26 @@ def _gumbel_sample_kernel(
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
req_idx = tl.program_id(0)
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
logits_ptr + batch_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
seed = tl.load(seeds_ptr + req_state_idx)
# NOTE(Ronald1995): change pos's dtype to tl.int32, because triton-ascend's
# compiler doesn't support unint64 of pos arg.
pos = tl.load(pos_ptr + req_idx).to(tl.int32)
pos = tl.load(pos_ptr + batch_idx).to(tl.int32)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise.
@@ -66,7 +69,7 @@ def _gumbel_sample_kernel(
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
@@ -76,21 +79,18 @@ def _gumbel_sample_kernel(
idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
idx_mapping: torch.Tensor, # [num_reqs]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
"""Override the function because there are some bugs
when _gumbel_sample_kernel runs on npu, we need to make some fixes.
you could read NOTE(Ronald1995) comments to understand.
"""
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
@@ -114,6 +114,7 @@ def gumbel_sample(
local_max.stride(0),
logits,
logits.stride(0),
idx_mapping,
seed,
pos,
temperature,

View File

@@ -14,22 +14,25 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import numpy as np
import torch
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
from vllm_ascend.worker.v2.sample.penalties import apply_penalties_and_temperature
class AscendSampler(Sampler):
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Override sample method because we need to override triton operators
called in the method.
@@ -37,19 +40,42 @@ class AscendSampler(Sampler):
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply penalties and temperature in place.
apply_penalties_and_temperature(logits, sampling_metadata)
# Apply min_p in place.
if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(logits, sampling_metadata.top_k, sampling_metadata.top_p)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
# Apply penalties in place.
self.penalties_state.apply_penalties(
logits,
idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
self.num_speculative_tokens,
)
# Apply temperature in place.
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
# Apply min_p in place if any request has a non-zero min_p.
do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
if do_min_p:
apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)
# Apply top_k and/or top_p. This might return a new tensor.
do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
if do_top_k or do_top_p:
logits = apply_top_k_top_p(logits, top_k, top_p)
# Sample the next token.
sampled = gumbel_sample(
logits,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,
apply_temperature=False,
)
return sampled, logits