[releases/v0.18.0][Triton][Sampler] Add penalty-related Triton kernel for better performance of penalties (#7794)
### What this PR does / why we need it? Implement get_token_bin_counts_and_mask and apply_penalties with Triton-Ascend kernels. This significantly reduces latency of the sampling process when repetition/frequency/presence penalties are enabled. Cherry-pick from main PR #7569 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -0,0 +1,110 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Compare vllm_ascend.sample.penalties.apply_all_penalties (Triton-Ascend) with
|
||||||
|
# vllm.v1.sample.ops.penalties.apply_all_penalties (PyTorch via model_executor).
|
||||||
|
# Requires NPU and Triton-Ascend.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.v1.sample.ops.penalties import apply_all_penalties as v1_apply_all_penalties
|
||||||
|
from vllm_ascend.sample.penalties import apply_all_penalties as ascend_apply_all_penalties
|
||||||
|
|
||||||
|
# Same scenario grid as test_apply_penalties_model_executor (equivalence + boundaries).
|
||||||
|
APPLY_PENALTY_CASES = [
|
||||||
|
pytest.param(0, 0, "mixed", id="empty-both"),
|
||||||
|
pytest.param(0, 16, "mixed", id="empty-prompt"),
|
||||||
|
pytest.param(32, 0, "mixed", id="empty-output"),
|
||||||
|
pytest.param(1, 1, "mixed", id="single-token-each"),
|
||||||
|
pytest.param(32, 16, "mixed", id="typical-small"),
|
||||||
|
pytest.param(128, 64, "mixed", id="typical-large"),
|
||||||
|
pytest.param(128, 64, "all_padding", id="all-padding"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tokens(
|
||||||
|
num_seqs: int,
|
||||||
|
seq_len: int,
|
||||||
|
vocab_size: int,
|
||||||
|
mode: str,
|
||||||
|
device: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if mode == "all_padding":
|
||||||
|
return torch.full(
|
||||||
|
(num_seqs, seq_len), vocab_size, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
if seq_len == 0:
|
||||||
|
return torch.empty((num_seqs, 0), device=device, dtype=torch.int64)
|
||||||
|
tokens = torch.randint(
|
||||||
|
0, vocab_size, (num_seqs, seq_len), device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
pad_mask = torch.rand(num_seqs, seq_len, device=device) > 0.7
|
||||||
|
tokens[pad_mask] = vocab_size
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_seqs", [1, 8, 32, 128])
|
||||||
|
@pytest.mark.parametrize("vocab_size", [5120, 151936])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"max_prompt_len,max_output_len,token_mode",
|
||||||
|
APPLY_PENALTY_CASES,
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_apply_all_penalties_v1_vs_ascend(
|
||||||
|
num_seqs,
|
||||||
|
vocab_size,
|
||||||
|
max_prompt_len,
|
||||||
|
max_output_len,
|
||||||
|
token_mode,
|
||||||
|
dtype,
|
||||||
|
device="npu",
|
||||||
|
seed=42,
|
||||||
|
):
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||||
|
|
||||||
|
init_device_properties_triton()
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
logits_v1 = torch.randn(num_seqs, vocab_size, device=device, dtype=dtype)
|
||||||
|
logits_ascend = logits_v1.clone()
|
||||||
|
|
||||||
|
prompt_tokens = _make_tokens(
|
||||||
|
num_seqs, max_prompt_len, vocab_size, token_mode, device
|
||||||
|
)
|
||||||
|
output_tokens = _make_tokens(
|
||||||
|
num_seqs, max_output_len, vocab_size, token_mode, device
|
||||||
|
)
|
||||||
|
output_token_ids = [row.tolist() for row in output_tokens.cpu()]
|
||||||
|
|
||||||
|
presence_penalties = torch.rand(num_seqs, device=device, dtype=torch.float32) * 0.2
|
||||||
|
frequency_penalties = torch.rand(num_seqs, device=device, dtype=torch.float32) * 0.2
|
||||||
|
repetition_penalties = torch.rand(num_seqs, device=device, dtype=torch.float32) * 0.4 + 1.0
|
||||||
|
|
||||||
|
v1_apply_all_penalties(
|
||||||
|
logits_v1,
|
||||||
|
prompt_tokens,
|
||||||
|
presence_penalties,
|
||||||
|
frequency_penalties,
|
||||||
|
repetition_penalties,
|
||||||
|
output_token_ids,
|
||||||
|
)
|
||||||
|
ascend_apply_all_penalties(
|
||||||
|
logits_ascend,
|
||||||
|
prompt_tokens,
|
||||||
|
presence_penalties,
|
||||||
|
frequency_penalties,
|
||||||
|
repetition_penalties,
|
||||||
|
output_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
atol = 1e-2 if dtype == torch.bfloat16 else 1e-3
|
||||||
|
rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
|
||||||
|
assert torch.allclose(
|
||||||
|
logits_ascend.float(), logits_v1.float(), atol=atol, rtol=rtol
|
||||||
|
), (
|
||||||
|
f"Max diff: {(logits_ascend.float() - logits_v1.float()).abs().max().item()}"
|
||||||
|
)
|
||||||
|
gc.collect()
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
torch.npu.reset_peak_memory_stats()
|
||||||
139
vllm_ascend/ops/triton/bincount.py
Normal file
139
vllm_ascend/ops/triton/bincount.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Triton-Ascend implementation of get_token_bin_counts_and_mask.
|
||||||
|
# Migrated from model_executor/layers/utils.get_token_bin_counts_and_mask.
|
||||||
|
# Reference: https://github.com/vllm-project/vllm-ascend/pull/6979
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def token_bin_counts_and_mask_kernel(
|
||||||
|
tokens_ptr,
|
||||||
|
tokens_batch_stride,
|
||||||
|
tokens_seq_stride,
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
vocab_size,
|
||||||
|
bin_counts_ptr,
|
||||||
|
counts_batch_stride,
|
||||||
|
counts_vocab_stride,
|
||||||
|
SEQ_BLOCK: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Count token occurrences per batch row.
|
||||||
|
|
||||||
|
2D tiling:
|
||||||
|
- axis=0: core/program group dimension
|
||||||
|
- axis=1: block id dimension
|
||||||
|
|
||||||
|
We linearize (batch_idx, seq_block_id) into a single global block id and
|
||||||
|
distribute blocks across all programs to improve utilization when
|
||||||
|
batch_size is small but seq_len is large (typical prefill).
|
||||||
|
|
||||||
|
Tokens with value >= vocab_size (e.g. padding) are skipped.
|
||||||
|
"""
|
||||||
|
pid0 = tl.program_id(axis=0)
|
||||||
|
pid1 = tl.program_id(axis=1)
|
||||||
|
progs = tl.num_programs(axis=0)
|
||||||
|
|
||||||
|
n_seq_blocks = tl.cdiv(seq_len, SEQ_BLOCK)
|
||||||
|
linear_block = pid1 * progs + pid0
|
||||||
|
total_blocks = batch_size * n_seq_blocks
|
||||||
|
if linear_block >= total_blocks:
|
||||||
|
return
|
||||||
|
|
||||||
|
batch_idx = linear_block // n_seq_blocks
|
||||||
|
seq_block_id = linear_block - batch_idx * n_seq_blocks
|
||||||
|
seq_start = seq_block_id * SEQ_BLOCK
|
||||||
|
|
||||||
|
batch_tokens_start = tokens_ptr + batch_idx * tokens_batch_stride
|
||||||
|
batch_counts_start = bin_counts_ptr + batch_idx * counts_batch_stride
|
||||||
|
|
||||||
|
pos_offsets = seq_start + tl.arange(0, SEQ_BLOCK)
|
||||||
|
pos_mask = pos_offsets < seq_len
|
||||||
|
token = tl.load(
|
||||||
|
batch_tokens_start + pos_offsets * tokens_seq_stride,
|
||||||
|
mask=pos_mask,
|
||||||
|
other=vocab_size, # force invalid
|
||||||
|
)
|
||||||
|
# Only count valid token ids in [0, vocab_size). Padding must use id >= vocab_size
|
||||||
|
# (see vLLM apply_penalties contract); those positions are masked out here.
|
||||||
|
token_in_range = (token >= 0) & (token < vocab_size) & pos_mask
|
||||||
|
count_ptr = batch_counts_start + token * counts_vocab_stride
|
||||||
|
tl.atomic_add(count_ptr, 1, mask=token_in_range)
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_bin_counts_and_mask_triton(
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
vocab_size: int,
|
||||||
|
num_seqs: int | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Triton-Ascend implementation of token bin counting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: [num_seqs, seq_len] tensor of token IDs. Padding value
|
||||||
|
should be vocab_size and will be ignored.
|
||||||
|
vocab_size: Vocabulary size.
|
||||||
|
num_seqs: If provided, asserts tokens.shape[0] == num_seqs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bin_counts: [num_seqs, vocab_size] int32 counts.
|
||||||
|
mask: [num_seqs, vocab_size] bool, True where count > 0.
|
||||||
|
"""
|
||||||
|
n_rows, n_cols = tokens.shape
|
||||||
|
if num_seqs is not None and num_seqs > 0:
|
||||||
|
assert n_rows == num_seqs, f"tokens rows must match num_seqs: tokens.shape[0]={n_rows}, num_seqs={num_seqs}"
|
||||||
|
n_rows = num_seqs if num_seqs is not None else n_rows
|
||||||
|
|
||||||
|
# seq_len == 0 is valid for empty decode history; return directly.
|
||||||
|
if n_cols == 0:
|
||||||
|
bin_counts = torch.zeros((n_rows, vocab_size), dtype=torch.int32, device=tokens.device)
|
||||||
|
return bin_counts, bin_counts > 0
|
||||||
|
|
||||||
|
core_num = get_vectorcore_num()
|
||||||
|
|
||||||
|
bin_counts = torch.zeros((n_rows, vocab_size), dtype=torch.int32, device=tokens.device)
|
||||||
|
if not tokens.is_contiguous():
|
||||||
|
tokens = tokens.contiguous()
|
||||||
|
|
||||||
|
# 2D grid: (progs, blocks_per_prog_group)
|
||||||
|
# Keep axis-0 bounded by vector core count, and distribute (batch, seq_block)
|
||||||
|
# blocks across all programs to increase utilization when n_rows is small.
|
||||||
|
SEQ_BLOCK = 256
|
||||||
|
n_seq_blocks = triton.cdiv(n_cols, SEQ_BLOCK)
|
||||||
|
total_blocks = n_rows * n_seq_blocks
|
||||||
|
progs = min(core_num, total_blocks)
|
||||||
|
grid = (progs, triton.cdiv(total_blocks, progs))
|
||||||
|
|
||||||
|
token_bin_counts_and_mask_kernel[grid](
|
||||||
|
tokens,
|
||||||
|
tokens.stride(0),
|
||||||
|
tokens.stride(1),
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
vocab_size,
|
||||||
|
bin_counts,
|
||||||
|
bin_counts.stride(0),
|
||||||
|
bin_counts.stride(1),
|
||||||
|
SEQ_BLOCK=SEQ_BLOCK,
|
||||||
|
)
|
||||||
|
return bin_counts, bin_counts > 0
|
||||||
158
vllm_ascend/ops/triton/penalty.py
Normal file
158
vllm_ascend/ops/triton/penalty.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Triton-Ascend implementation of apply_penalties.
|
||||||
|
# Migrated from model_executor/layers/utils.apply_penalties.
|
||||||
|
# Reference: https://github.com/vllm-project/vllm-ascend/pull/6979
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
from vllm_ascend.ops.triton.bincount import get_token_bin_counts_and_mask_triton
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def apply_all_penalties_kernel(
|
||||||
|
logits_ptr,
|
||||||
|
prompt_mask_ptr,
|
||||||
|
output_mask_ptr,
|
||||||
|
output_bin_counts_ptr,
|
||||||
|
repetition_penalties_ptr,
|
||||||
|
frequency_penalties_ptr,
|
||||||
|
presence_penalties_ptr,
|
||||||
|
num_seqs,
|
||||||
|
vocab_size,
|
||||||
|
stride_logits_seq,
|
||||||
|
stride_logits_vocab,
|
||||||
|
stride_prompt_mask_seq,
|
||||||
|
stride_prompt_mask_vocab,
|
||||||
|
stride_output_mask_seq,
|
||||||
|
stride_output_mask_vocab,
|
||||||
|
stride_bin_counts_seq,
|
||||||
|
stride_bin_counts_vocab,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Apply repetition, frequency, and presence penalties to logits in place."""
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
num_programs = tl.num_programs(axis=0)
|
||||||
|
seqs_per_program = (num_seqs + num_programs - 1) // num_programs
|
||||||
|
|
||||||
|
start_seq = pid * seqs_per_program
|
||||||
|
end_seq = tl.minimum(start_seq + seqs_per_program, num_seqs)
|
||||||
|
|
||||||
|
for seq_idx in range(start_seq, end_seq):
|
||||||
|
repetition_penalty = tl.load(repetition_penalties_ptr + seq_idx)
|
||||||
|
frequency_penalty = tl.load(frequency_penalties_ptr + seq_idx)
|
||||||
|
presence_penalty = tl.load(presence_penalties_ptr + seq_idx)
|
||||||
|
|
||||||
|
for vocab_start in range(0, vocab_size, BLOCK_SIZE):
|
||||||
|
vocab_offsets = vocab_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = vocab_offsets < vocab_size
|
||||||
|
|
||||||
|
logits_offset = seq_idx * stride_logits_seq + vocab_offsets * stride_logits_vocab
|
||||||
|
prompt_mask_offset = seq_idx * stride_prompt_mask_seq + vocab_offsets * stride_prompt_mask_vocab
|
||||||
|
output_mask_offset = seq_idx * stride_output_mask_seq + vocab_offsets * stride_output_mask_vocab
|
||||||
|
counts_offset = seq_idx * stride_bin_counts_seq + vocab_offsets * stride_bin_counts_vocab
|
||||||
|
|
||||||
|
logits = tl.load(logits_ptr + logits_offset, mask=mask, other=0.0)
|
||||||
|
prompt_mask_val = tl.load(prompt_mask_ptr + prompt_mask_offset, mask=mask, other=False)
|
||||||
|
output_mask_val = tl.load(output_mask_ptr + output_mask_offset, mask=mask, other=False)
|
||||||
|
output_bin_counts = tl.load(
|
||||||
|
output_bin_counts_ptr + counts_offset,
|
||||||
|
mask=mask,
|
||||||
|
other=0,
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
need_repetition_penalty = (prompt_mask_val | output_mask_val).to(tl.int1)
|
||||||
|
penalty_factor = tl.where(need_repetition_penalty, repetition_penalty, 1.0)
|
||||||
|
scaling = tl.where(
|
||||||
|
(logits > 0.0).to(tl.int1),
|
||||||
|
1.0 / penalty_factor,
|
||||||
|
penalty_factor,
|
||||||
|
)
|
||||||
|
updated = logits * scaling
|
||||||
|
|
||||||
|
updated -= frequency_penalty * output_bin_counts
|
||||||
|
updated -= presence_penalty * output_mask_val.to(tl.float32)
|
||||||
|
tl.store(logits_ptr + logits_offset, updated, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_penalties_triton(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
prompt_tokens_tensor: torch.Tensor,
|
||||||
|
output_tokens_tensor: torch.Tensor,
|
||||||
|
presence_penalties: torch.Tensor,
|
||||||
|
frequency_penalties: torch.Tensor,
|
||||||
|
repetition_penalties: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply penalties to logits in place. Same interface as
|
||||||
|
model_executor.layers.utils.apply_penalties.
|
||||||
|
"""
|
||||||
|
num_seqs, vocab_size = logits.shape
|
||||||
|
_, prompt_mask = get_token_bin_counts_and_mask_triton(prompt_tokens_tensor, vocab_size, num_seqs)
|
||||||
|
output_bin_counts, output_mask = get_token_bin_counts_and_mask_triton(
|
||||||
|
output_tokens_tensor,
|
||||||
|
vocab_size,
|
||||||
|
num_seqs,
|
||||||
|
)
|
||||||
|
_apply_all_penalties_triton(
|
||||||
|
logits,
|
||||||
|
prompt_mask,
|
||||||
|
output_mask,
|
||||||
|
output_bin_counts,
|
||||||
|
repetition_penalties,
|
||||||
|
frequency_penalties,
|
||||||
|
presence_penalties,
|
||||||
|
)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_all_penalties_triton(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
prompt_mask: torch.Tensor,
|
||||||
|
output_mask: torch.Tensor,
|
||||||
|
output_bin_counts: torch.Tensor,
|
||||||
|
repetition_penalties: torch.Tensor,
|
||||||
|
frequency_penalties: torch.Tensor,
|
||||||
|
presence_penalties: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Apply all penalties given precomputed bin counts and masks."""
|
||||||
|
num_seqs, vocab_size = logits.shape
|
||||||
|
grid = (min(num_seqs, get_vectorcore_num()), 1, 1)
|
||||||
|
|
||||||
|
apply_all_penalties_kernel[grid](
|
||||||
|
logits,
|
||||||
|
prompt_mask,
|
||||||
|
output_mask,
|
||||||
|
output_bin_counts,
|
||||||
|
repetition_penalties,
|
||||||
|
frequency_penalties,
|
||||||
|
presence_penalties,
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
stride_logits_seq=logits.stride(0),
|
||||||
|
stride_logits_vocab=logits.stride(1),
|
||||||
|
stride_prompt_mask_seq=prompt_mask.stride(0),
|
||||||
|
stride_prompt_mask_vocab=prompt_mask.stride(1),
|
||||||
|
stride_output_mask_seq=output_mask.stride(0),
|
||||||
|
stride_output_mask_vocab=output_mask.stride(1),
|
||||||
|
stride_bin_counts_seq=output_bin_counts.stride(0),
|
||||||
|
stride_bin_counts_vocab=output_bin_counts.stride(1),
|
||||||
|
BLOCK_SIZE=2048,
|
||||||
|
)
|
||||||
45
vllm_ascend/sample/penalties.py
Normal file
45
vllm_ascend/sample/penalties.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# apply_all_penalties for AscendSampler - uses Triton-Ascend kernels.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
|
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||||
|
|
||||||
|
from vllm_ascend.ops.triton.penalty import apply_penalties_triton
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, device: torch.device) -> torch.Tensor:
|
||||||
|
"""Convert output_token_ids (list of lists) to padded tensor."""
|
||||||
|
output_tokens_tensor = make_tensor_with_pad(
|
||||||
|
output_token_ids,
|
||||||
|
pad=vocab_size,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int64,
|
||||||
|
pin_memory=is_pin_memory_available(),
|
||||||
|
)
|
||||||
|
return output_tokens_tensor.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_all_penalties(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
prompt_token_ids: torch.Tensor,
|
||||||
|
presence_penalties: torch.Tensor,
|
||||||
|
frequency_penalties: torch.Tensor,
|
||||||
|
repetition_penalties: torch.Tensor,
|
||||||
|
output_token_ids: list[list[int]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply penalties to logits via Triton-Ascend."""
|
||||||
|
_, vocab_size = logits.shape
|
||||||
|
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device)
|
||||||
|
output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size)
|
||||||
|
|
||||||
|
return apply_penalties_triton(
|
||||||
|
logits,
|
||||||
|
prompt_token_ids,
|
||||||
|
output_tokens_t,
|
||||||
|
presence_penalties,
|
||||||
|
frequency_penalties,
|
||||||
|
repetition_penalties,
|
||||||
|
)
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||||
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||||
from vllm.v1.sample.sampler import Sampler
|
from vllm.v1.sample.sampler import Sampler
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.sample.penalties import apply_all_penalties
|
||||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, global_stream, npu_stream_switch
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, global_stream, npu_stream_switch
|
||||||
|
|
||||||
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
|
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
|
||||||
@@ -36,6 +39,28 @@ def random_sample(
|
|||||||
|
|
||||||
|
|
||||||
class AscendSampler(Sampler):
|
class AscendSampler(Sampler):
|
||||||
|
@staticmethod
|
||||||
|
def apply_penalties(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
output_token_ids: list[list[int]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Use Triton-Ascend penalties on NPU when Triton is available; else vLLM default."""
|
||||||
|
if not HAS_TRITON:
|
||||||
|
return Sampler.apply_penalties(logits, sampling_metadata, output_token_ids)
|
||||||
|
|
||||||
|
if sampling_metadata.no_penalties:
|
||||||
|
return logits
|
||||||
|
assert sampling_metadata.prompt_token_ids is not None
|
||||||
|
return apply_all_penalties(
|
||||||
|
logits,
|
||||||
|
sampling_metadata.prompt_token_ids,
|
||||||
|
sampling_metadata.presence_penalties,
|
||||||
|
sampling_metadata.frequency_penalties,
|
||||||
|
sampling_metadata.repetition_penalties,
|
||||||
|
output_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
|
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
|
||||||
# TODO: support logprobs_mode in vllm-ascend
|
# TODO: support logprobs_mode in vllm-ascend
|
||||||
super().__init__(logprobs_mode=logprobs_mode)
|
super().__init__(logprobs_mode=logprobs_mode)
|
||||||
|
|||||||
Reference in New Issue
Block a user