[Feature] support eager mode in model runner v2 (#5210)

### What this PR does / why we need it?
#5051 only implement a basic framework for model runner v2, but there
are still some bugs for e2e functionality, this PR aim to enable basic
functionality.
model runner v2 plans:
https://github.com/vllm-project/vllm-ascend/issues/5208

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2025-12-29 15:28:34 +08:00
committed by GitHub
parent 4da46da9bf
commit e7e1a7dc05
19 changed files with 528 additions and 44 deletions

View File

View File

@@ -0,0 +1,128 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/gumbel.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
logits_ptr,
logits_stride,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
# NOTE(Ronald1995): change pos's dtype to tl.int32, because triton-ascend's
# compiler doesn't support unint64 of pos arg.
pos = tl.load(pos_ptr + req_idx).to(tl.int32)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise.
# NOTE(Ronald1995): r is tl.float64 in vllm, change it to tl.float32,
# or triton-ascend's compiler will raise error.
r = tl.rand(gumbel_seed, block).to(tl.float32)
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
gumbel_noise = gumbel_noise.to(tl.float32)
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx,
token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
"""Override the function because there are some bugs
when _gumbel_sample_kernel runs on npu, we need to make some fixes.
you could read NOTE(Ronald1995) comments to understand.
"""
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
_gumbel_sample_kernel[(num_reqs, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
APPLY_TEMPERATURE=apply_temperature,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled

View File

@@ -0,0 +1,137 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/penalties.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
@triton.jit
def _penalties_and_temperature_kernel(
logits_ptr,
logits_stride,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
temperature_ptr,
idx_mapping_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
temperature = tl.load(temperature_ptr + batch_idx)
temperature = tl.where(temperature == 0.0, 1.0, temperature)
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
# NOTE(Ronald1995): vllm original grammar `use_rep_penalty or
# use_freq_penalty or use_pres_penalty`,
# change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`,
# because triton-ascend's compiler doesn't support chained boolean operator.
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
use_temperature = temperature != 1.0
if not (use_penalty or use_temperature):
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
if use_penalty:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride +
block,
mask=mask,
)
output_bin_mask = output_bin_counts > 0
# Apply repetition penalties.
if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(
0, BLOCK_SIZE // 32)
packed_mask = tl.load(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride +
packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32),
)
prompt_bin_mask = (packed_mask[:, None] >>
(tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale)
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * output_bin_mask
# Apply temperature.
logits = logits / temperature
# Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_penalties_and_temperature(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> None:
"""Override the function because there are some bugs
when _penalties_and_temperature_kernel runs on npu, we need to make some fixes.
you could read NOTE(Ronald1995) comments to understand.
"""
num_reqs, vocab_size = logits.shape
# NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow
# in triton-ascend.
BLOCK_SIZE = 4096
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty,
sampling_metadata.temperature,
sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts,
sampling_metadata.output_bin_counts.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -0,0 +1,58 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/sampler.py.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
from vllm_ascend.worker.v2.sample.penalties import \
apply_penalties_and_temperature
class AscendSampler(Sampler):
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Override sample method because we need to override triton operators
called in the method.
"""
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply penalties and temperature in place.
apply_penalties_and_temperature(logits, sampling_metadata)
# Apply min_p in place.
if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(logits, sampling_metadata.top_k,
sampling_metadata.top_p)
sampled = gumbel_sample(
logits,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
apply_temperature=False,
)
return sampled, logits