[MODELRUNNERV2]fix penality ops (#7013)

### What this PR does / why we need it?
fix penality ops for new version, and achieved a 10% performance
improvement

### How was this patch tested?
pytest
‎tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
shiyuan680
2026-03-11 17:13:34 +08:00
committed by GitHub
parent 830f39dd70
commit 3b6b3c4214
2 changed files with 222 additions and 165 deletions

View File

@@ -1,63 +1,47 @@
import pytest import pytest
import torch import torch
from vllm_ascend.worker.v2.sample.penalties import apply_penalties_and_temperature from vllm_ascend.worker.v2.sample.penalties import apply_penalties
DTYPES = [torch.bfloat16, torch.float16] DTYPES = [torch.bfloat16, torch.float16]
NUM_REQS = [2, 4, 8] NUM_TOKENS = [2, 4, 8]
VOCAB_SIZE = [151936] VOCAB_SIZE = [151936]
NUM_STATUS = [1, 4, 8, 16] NUM_STATUS = [1, 4, 8, 16]
SEEDS = [0] SEEDS = [0]
DEVICES = [f"npu:{0}"] DEVICES = [f"npu:{0}"]
NUM_SPECULATIVE_TOKENS = [0, 1, 3]
DEFAULT_ATOL = 1e-3 DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3 DEFAULT_RTOL = 1e-3
class SamplingMetadata: def pytorch_apply_penalties(
def __init__(self, logits: torch.Tensor,
idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor, repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor, frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor, presence_penalty: torch.Tensor,
temperature: torch.Tensor,
idx_mapping: torch.Tensor,
prompt_bin_mask: torch.Tensor, prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor): output_bin_counts: torch.Tensor,
self.repetition_penalty = repetition_penalty num_speculative_tokens: int,
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.temperature = temperature
self.idx_mapping = idx_mapping
self.prompt_bin_mask = prompt_bin_mask
self.output_bin_counts = output_bin_counts
def pytorch_apply_penalties_and_temperature(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Pytorch equivalent implementation Pytorch equivalent implementation
""" """
num_reqs, vocab_size = logits.shape num_tokens, vocab_size = logits.shape
device = logits.device device = logits.device
dtype = logits.dtype dtype = logits.dtype
logits_float = logits.float() logits_float = logits.float()
repetition_penalty = sampling_metadata.repetition_penalty
frequency_penalty = sampling_metadata.frequency_penalty
presence_penalty = sampling_metadata.presence_penalty
temperature = sampling_metadata.temperature
idx_mapping = sampling_metadata.idx_mapping
prompt_bin_mask = sampling_metadata.prompt_bin_mask
output_bin_counts = sampling_metadata.output_bin_counts
temperature = torch.where(temperature == 0.0, torch.ones_like(temperature), temperature)
num_status = prompt_bin_mask.shape[0] num_status = prompt_bin_mask.shape[0]
num_packed = prompt_bin_mask.shape[1] num_packed = prompt_bin_mask.shape[1]
prompt_masks_unpacked = torch.zeros(num_status, vocab_size, dtype=torch.bool, device=device) prompt_masks_unpacked = torch.zeros(
num_status, vocab_size, dtype=torch.bool,
device=device
)
for state_idx in range(num_status): for state_idx in range(num_status):
for packed_idx in range(num_packed): for packed_idx in range(num_packed):
@@ -69,82 +53,99 @@ def pytorch_apply_penalties_and_temperature(
if (packed_val >> bit_pos) & 1: if (packed_val >> bit_pos) & 1:
prompt_masks_unpacked[state_idx, start_idx + bit_pos] = True prompt_masks_unpacked[state_idx, start_idx + bit_pos] = True
for batch_idx in range(num_reqs): for token_idx in range(num_tokens):
req_state_idx = idx_mapping[batch_idx].item() req_state_idx = idx_mapping[token_idx].item()
rep_penalty = repetition_penalty[batch_idx].item() rep_penalty = repetition_penalty[req_state_idx].item()
freq_penalty = frequency_penalty[batch_idx].item() freq_penalty = frequency_penalty[req_state_idx].item()
pres_penalty = presence_penalty[batch_idx].item() pres_penalty = presence_penalty[req_state_idx].item()
temp = temperature[batch_idx].item()
use_rep_penalty = rep_penalty != 1.0 use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0 use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0 use_pres_penalty = pres_penalty != 0.0
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
use_temperature = temp != 1.0
if not (use_penalty or use_temperature): if not use_penalty:
continue continue
current_prompt_mask = prompt_masks_unpacked[req_state_idx] current_prompt_mask = prompt_masks_unpacked[req_state_idx]
current_output_counts = output_bin_counts[req_state_idx] base_output_counts = output_bin_counts[req_state_idx]
output_bin_mask = current_output_counts > 0
# Compute cumulative draft counts
pos = expanded_local_pos[token_idx].item()
start_idx_in_batch = token_idx - pos
draft_counts = torch.zeros(vocab_size, device=device, dtype=torch.int32)
for prev_pos in range(num_speculative_tokens):
if prev_pos < pos:
prev_token = token_ids[start_idx_in_batch + prev_pos + 1].item()
draft_counts[prev_token] += 1
# Total counts = base output counts + cumulative draft counts
total_output_counts = base_output_counts + draft_counts
output_bin_mask = total_output_counts > 0
if use_rep_penalty: if use_rep_penalty:
scale = torch.ones(vocab_size, device=device) scale = torch.ones(vocab_size, device=device)
mask = current_prompt_mask | output_bin_mask mask = current_prompt_mask | output_bin_mask
scale[mask] = rep_penalty scale[mask] = rep_penalty
pos_mask = logits_float[batch_idx] > 0 pos_mask = logits_float[token_idx] > 0
scale_factor = torch.where(pos_mask, 1.0 / scale, scale) scale_factor = torch.where(pos_mask, 1.0 / scale, scale)
logits_float[batch_idx] *= scale_factor logits_float[token_idx] *= scale_factor
if use_freq_penalty: if use_freq_penalty:
logits_float[batch_idx] -= freq_penalty * current_output_counts.float() logits_float[token_idx] -= freq_penalty * total_output_counts.float()
if use_pres_penalty: if use_pres_penalty:
logits_float[batch_idx] -= pres_penalty * output_bin_mask.float() logits_float[token_idx] -= pres_penalty * output_bin_mask.float()
if use_temperature:
logits_float[batch_idx] /= temp
return logits_float.to(dtype) return logits_float.to(dtype)
def create_test_data( def create_test_data(
num_reqs: int = 8, num_tokens: int = 8,
vocab_size: int = 51200, vocab_size: int = 51200,
num_status: int = 16, num_status: int = 16,
num_speculative_tokens: int = 3,
device: str = "npu", device: str = "npu",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
seed: int = 42, seed: int = 42,
): ):
"""Create test data for penalties and temperature""" """Create test data for penalties"""
torch.manual_seed(seed) torch.manual_seed(seed)
logits = torch.randn(num_reqs, vocab_size, device=device, dtype=dtype) logits = torch.randn(num_tokens, vocab_size, device=device, dtype=dtype)
repetition_penalty = torch.ones(num_reqs, device=device, dtype=torch.float32) repetition_penalty = torch.ones(num_status, device=device, dtype=torch.float32)
for i in range(num_reqs): for i in range(num_status):
if torch.rand(1) > 0.3: if torch.rand(1) > 0.3:
repetition_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6 repetition_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6
frequency_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32) frequency_penalty = torch.zeros(num_status, device=device, dtype=torch.float32)
for i in range(num_reqs): for i in range(num_status):
if torch.rand(1) > 0.5: if torch.rand(1) > 0.5:
frequency_penalty[i] = torch.rand(1, device=device).item() * 0.2 frequency_penalty[i] = torch.rand(1, device=device).item() * 0.2
presence_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32) presence_penalty = torch.zeros(num_status, device=device, dtype=torch.float32)
for i in range(num_reqs): for i in range(num_status):
if torch.rand(1) > 0.5: if torch.rand(1) > 0.5:
presence_penalty[i] = torch.rand(1, device=device).item() * 0.2 presence_penalty[i] = torch.rand(1, device=device).item() * 0.2
temperature = torch.ones(num_reqs, device=device, dtype=torch.float32) idx_mapping = torch.randint(
for i in range(num_reqs): 0, num_status, (num_tokens,), device=device,
if torch.rand(1) > 0.2: dtype=torch.int32
presence_penalty[i] = torch.rand(1, device=device).item() * 1.8 + 0.2 )
idx_mapping = torch.randint(0, num_status, (num_reqs,), device=device, dtype=torch.int32) # Create token_ids for speculative decoding
token_ids = torch.randint(0, vocab_size, (num_tokens,), device=device, dtype=torch.int32)
# Create expanded_local_pos (position within speculative decoding window)
expanded_local_pos = torch.zeros(num_tokens, device=device, dtype=torch.int32)
for i in range(num_tokens):
expanded_local_pos[i] = torch.randint(
0, num_speculative_tokens + 1, (1,)
).item()
num_packed = (vocab_size + 31) // 32 num_packed = (vocab_size + 31) // 32
prompt_bin_mask = torch.zeros(num_status, num_packed, device=device, dtype=torch.int32) prompt_bin_mask = torch.zeros(num_status, num_packed, device=device, dtype=torch.int32)
@@ -161,44 +162,60 @@ def create_test_data(
output_bin_counts = torch.zeros(num_status, vocab_size, device=device, dtype=torch.int32) output_bin_counts = torch.zeros(num_status, vocab_size, device=device, dtype=torch.int32)
for state_idx in range(num_status): for state_idx in range(num_status):
num_output_tokens = max(1, vocab_size // 20) num_output_tokens = max(1, vocab_size // 20)
output_tokens = torch.randint(0, vocab_size, (num_output_tokens, )) output_tokens = torch.randint(0, vocab_size,
(num_output_tokens, ))
counts = torch.randint(1, 10, (num_output_tokens,)) counts = torch.randint(1, 10, (num_output_tokens,))
for token, count in zip(output_tokens, counts): for token, count in zip(output_tokens, counts):
output_bin_counts[state_idx, token] = count output_bin_counts[state_idx, token] = count
sampling_metadata = SamplingMetadata( return (
repetition_penalty=repetition_penalty, logits,
frequency_penalty=frequency_penalty, idx_mapping,
presence_penalty=presence_penalty, token_ids,
temperature=temperature, expanded_local_pos,
idx_mapping=idx_mapping, repetition_penalty,
prompt_bin_mask=prompt_bin_mask, frequency_penalty,
output_bin_counts=output_bin_counts presence_penalty,
prompt_bin_mask,
output_bin_counts,
num_speculative_tokens,
) )
return logits, sampling_metadata
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_reqs", NUM_REQS)
@pytest.mark.parametrize("vocab_size", VOCAB_SIZE) @pytest.mark.parametrize("vocab_size", VOCAB_SIZE)
@pytest.mark.parametrize("num_status", NUM_STATUS) @pytest.mark.parametrize("num_status", NUM_STATUS)
@pytest.mark.parametrize("num_speculative_tokens", NUM_SPECULATIVE_TOKENS)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_apply_penalties_and_temperature( def test_apply_penalties(
num_reqs, num_tokens,
vocab_size, vocab_size,
num_status, num_status,
num_speculative_tokens,
dtype, dtype,
seed, seed,
device device
): ):
logits_triton, sampling_metadata = create_test_data( (
num_reqs=num_reqs, logits_triton,
idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
frequency_penalty,
presence_penalty,
prompt_bin_mask,
output_bin_counts,
num_spec_tokens,
) = create_test_data(
num_tokens=num_tokens,
vocab_size=vocab_size, vocab_size=vocab_size,
num_status=num_status, num_status=num_status,
num_speculative_tokens=num_speculative_tokens,
device=device, device=device,
dtype=dtype, dtype=dtype,
seed=seed seed=seed
@@ -206,10 +223,31 @@ def test_apply_penalties_and_temperature(
logits_pytorch = logits_triton.clone() logits_pytorch = logits_triton.clone()
apply_penalties_and_temperature(logits_triton, sampling_metadata) apply_penalties(
logits_triton,
idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
frequency_penalty,
presence_penalty,
prompt_bin_mask,
output_bin_counts,
num_spec_tokens,
)
logits_pytorch_result = pytorch_apply_penalties_and_temperature(logits_pytorch, logits_pytorch_result = pytorch_apply_penalties(
sampling_metadata) logits_pytorch,
idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
frequency_penalty,
presence_penalty,
prompt_bin_mask,
output_bin_counts,
num_spec_tokens,
)
atol = DEFAULT_ATOL atol = DEFAULT_ATOL
rtol = DEFAULT_RTOL rtol = DEFAULT_RTOL

View File

@@ -20,122 +20,141 @@
import torch import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
@triton.jit @triton.jit
def _penalties_and_temperature_kernel( def _penalties_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
temperature_ptr,
idx_mapping_ptr, idx_mapping_ptr,
token_ids_ptr,
expanded_local_pos_ptr,
penalties_ptr,
penalties_stride,
prompt_bin_mask_ptr, prompt_bin_mask_ptr,
prompt_bin_mask_stride, prompt_bin_mask_stride,
output_bin_counts_ptr, output_bin_counts_ptr,
output_bin_counts_stride, output_bin_counts_stride,
vocab_size, vocab_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
INNER_BLOCK_SIZE: tl.constexpr,
MAX_SPEC_LEN: tl.constexpr,
): ):
batch_idx = tl.program_id(0) token_idx = tl.program_id(0)
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx) req_state_idx = tl.load(idx_mapping_ptr + token_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx) # first load penalties once
temperature = tl.load(temperature_ptr + batch_idx) rep_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 0)
temperature = tl.where(temperature == 0.0, 1.0, temperature) freq_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 1)
pres_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 2)
use_rep_penalty = rep_penalty != 1.0 use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0 use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0 use_pres_penalty = pres_penalty != 0.0
# NOTE(Ronald1995): vllm original grammar `use_rep_penalty or
# use_freq_penalty or use_pres_penalty`, # NPU doesn't support chained 'or' operations like 'A or B or C'
# change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`, use_penalty = use_rep_penalty or use_freq_penalty
# because triton-ascend's compiler doesn't support chained boolean operator. use_penalty = use_penalty or use_pres_penalty
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
use_temperature = temperature != 1.0 if not use_penalty:
if not (use_penalty or use_temperature):
# Early return to avoid loading logits. # Early return to avoid loading logits.
return return
bit_masks = tl.full((INNER_BLOCK_SIZE // 32, 32), 1, dtype=tl.int32) << tl.arange(0, 32)
block_idx = tl.program_id(1) block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) block_start = block_idx * BLOCK_SIZE
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) pos = tl.load(expanded_local_pos_ptr + token_idx)
start_idx = token_idx - pos
inv_rep = 1.0 / rep_penalty
for inner_offset in tl.static_range(0, BLOCK_SIZE, INNER_BLOCK_SIZE):
inner_block_start = block_start + inner_offset
inner_block = inner_block_start + tl.arange(0, INNER_BLOCK_SIZE)
inner_mask = inner_block < vocab_size
logits = tl.load(logits_ptr + token_idx * logits_stride + inner_block, mask=inner_mask, other=0.0)
logits = logits.to(tl.float32) logits = logits.to(tl.float32)
if use_penalty: base_output_counts = tl.load(
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + inner_block,
output_bin_counts = tl.load( mask=inner_mask,
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, other=0,
mask=mask,
) )
# to use vector core, if use > 0 will use scalar to slow down performance
output_bin_mask = output_bin_counts != 0 # Compute cumulative draft_counts from previous positions in this request
total_counts = base_output_counts.to(tl.int32)
for prev_pos in tl.static_range(MAX_SPEC_LEN):
if prev_pos < pos:
load_idx = start_idx + prev_pos + 1
prev_token = tl.load(token_ids_ptr + load_idx)
total_counts += inner_block == prev_token
is_present = total_counts != 0
# Apply repetition penalties. # Apply repetition penalties.
if use_rep_penalty: if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32) packed_inner_block_start = inner_block_start // 32
packed_mask = tl.load( packed_block = packed_inner_block_start + tl.arange(0, INNER_BLOCK_SIZE // 32)
valid_packed_mask = packed_block < tl.cdiv(vocab_size, 32)
packed_mask_val = tl.load(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block, prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32), mask=valid_packed_mask,
other=0,
) )
# the compiler itself does not optimize right-shift operations, so we change the same func prompt_mask = ((packed_mask_val[:, None] & bit_masks) != 0).reshape(INNER_BLOCK_SIZE)
bit_masks = 1 << tl.arange(0, 32)
bit_masks_expanded = bit_masks[None, :]
packed_expanded = packed_mask[:, None] needs_scaling = prompt_mask | is_present
bits_matrix = (packed_expanded & bit_masks_expanded) != 0
prompt_bin_mask = bits_matrix.reshape(BLOCK_SIZE) base_factor = tl.where(logits > 0, inv_rep, rep_penalty)
logits = tl.where(needs_scaling, logits * base_factor, logits)
prompt_bin_mask = prompt_bin_mask.to(tl.int1) freq_term = freq_penalty * total_counts.to(tl.float32)
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE) pres_term = pres_penalty * is_present.to(tl.float32)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale)
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * output_bin_mask
# Apply temperature.
logits = logits / temperature
logits = logits - freq_term - pres_term
# Store back to logits. # Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) tl.store(logits_ptr + token_idx * logits_stride + inner_block, logits, mask=inner_mask)
def apply_penalties_and_temperature( def apply_penalties(
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
num_speculative_tokens: int,
) -> None: ) -> None:
"""Override the function because there are some bugs num_tokens, vocab_size = logits.shape
when _penalties_and_temperature_kernel runs on npu, we need to make some fixes. BLOCK_SIZE = 8192
you could read NOTE(Ronald1995) comments to understand. INNER_BLOCK_SIZE = 4096
"""
num_reqs, vocab_size = logits.shape
# NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow
# in triton-ascend.
BLOCK_SIZE = 4096
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
_penalties_and_temperature_kernel[(num_reqs, num_blocks)]( penalties = torch.stack(
[repetition_penalty[:num_tokens], frequency_penalty[:num_tokens], presence_penalty[:num_tokens]], dim=1
).contiguous()
penalties_stride = penalties.stride(0)
_penalties_kernel[(num_tokens, num_blocks)](
logits, logits,
logits.stride(0), logits.stride(0),
sampling_metadata.repetition_penalty, idx_mapping,
sampling_metadata.frequency_penalty, token_ids,
sampling_metadata.presence_penalty, expanded_local_pos,
sampling_metadata.temperature, penalties,
sampling_metadata.idx_mapping, penalties_stride,
sampling_metadata.prompt_bin_mask, prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0), prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts, output_bin_counts,
sampling_metadata.output_bin_counts.stride(0), output_bin_counts.stride(0),
vocab_size, vocab_size,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
INNER_BLOCK_SIZE=INNER_BLOCK_SIZE,
MAX_SPEC_LEN=num_speculative_tokens,
) )