[MODELRUNNERV2]fix penality ops (#7013)
### What this PR does / why we need it?
fix penality ops for new version, and achieved a 10% performance
improvement
### How was this patch tested?
pytest
tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
@@ -20,122 +20,141 @@
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _penalties_and_temperature_kernel(
|
||||
def _penalties_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
repetition_penalty_ptr,
|
||||
frequency_penalty_ptr,
|
||||
presence_penalty_ptr,
|
||||
temperature_ptr,
|
||||
idx_mapping_ptr,
|
||||
token_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
penalties_ptr,
|
||||
penalties_stride,
|
||||
prompt_bin_mask_ptr,
|
||||
prompt_bin_mask_stride,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
INNER_BLOCK_SIZE: tl.constexpr,
|
||||
MAX_SPEC_LEN: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
|
||||
temperature = tl.load(temperature_ptr + batch_idx)
|
||||
temperature = tl.where(temperature == 0.0, 1.0, temperature)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
|
||||
|
||||
# first load penalties once
|
||||
rep_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 0)
|
||||
freq_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 1)
|
||||
pres_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 2)
|
||||
|
||||
use_rep_penalty = rep_penalty != 1.0
|
||||
use_freq_penalty = freq_penalty != 0.0
|
||||
use_pres_penalty = pres_penalty != 0.0
|
||||
# NOTE(Ronald1995): vllm original grammar `use_rep_penalty or
|
||||
# use_freq_penalty or use_pres_penalty`,
|
||||
# change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`,
|
||||
# because triton-ascend's compiler doesn't support chained boolean operator.
|
||||
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
|
||||
use_temperature = temperature != 1.0
|
||||
if not (use_penalty or use_temperature):
|
||||
|
||||
# NPU doesn't support chained 'or' operations like 'A or B or C'
|
||||
use_penalty = use_rep_penalty or use_freq_penalty
|
||||
use_penalty = use_penalty or use_pres_penalty
|
||||
|
||||
if not use_penalty:
|
||||
# Early return to avoid loading logits.
|
||||
return
|
||||
|
||||
bit_masks = tl.full((INNER_BLOCK_SIZE // 32, 32), 1, dtype=tl.int32) << tl.arange(0, 32)
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
block_start = block_idx * BLOCK_SIZE
|
||||
|
||||
if use_penalty:
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
output_bin_counts = tl.load(
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
|
||||
mask=mask,
|
||||
pos = tl.load(expanded_local_pos_ptr + token_idx)
|
||||
start_idx = token_idx - pos
|
||||
|
||||
inv_rep = 1.0 / rep_penalty
|
||||
|
||||
for inner_offset in tl.static_range(0, BLOCK_SIZE, INNER_BLOCK_SIZE):
|
||||
inner_block_start = block_start + inner_offset
|
||||
inner_block = inner_block_start + tl.arange(0, INNER_BLOCK_SIZE)
|
||||
inner_mask = inner_block < vocab_size
|
||||
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + inner_block, mask=inner_mask, other=0.0)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
base_output_counts = tl.load(
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + inner_block,
|
||||
mask=inner_mask,
|
||||
other=0,
|
||||
)
|
||||
# to use vector core, if use > 0 will use scalar to slow down performance
|
||||
output_bin_mask = output_bin_counts != 0
|
||||
|
||||
# Compute cumulative draft_counts from previous positions in this request
|
||||
total_counts = base_output_counts.to(tl.int32)
|
||||
for prev_pos in tl.static_range(MAX_SPEC_LEN):
|
||||
if prev_pos < pos:
|
||||
load_idx = start_idx + prev_pos + 1
|
||||
prev_token = tl.load(token_ids_ptr + load_idx)
|
||||
total_counts += inner_block == prev_token
|
||||
|
||||
is_present = total_counts != 0
|
||||
|
||||
# Apply repetition penalties.
|
||||
if use_rep_penalty:
|
||||
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||
packed_mask = tl.load(
|
||||
packed_inner_block_start = inner_block_start // 32
|
||||
packed_block = packed_inner_block_start + tl.arange(0, INNER_BLOCK_SIZE // 32)
|
||||
valid_packed_mask = packed_block < tl.cdiv(vocab_size, 32)
|
||||
|
||||
packed_mask_val = tl.load(
|
||||
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
|
||||
mask=packed_block < tl.cdiv(vocab_size, 32),
|
||||
mask=valid_packed_mask,
|
||||
other=0,
|
||||
)
|
||||
# the compiler itself does not optimize right-shift operations, so we change the same func
|
||||
bit_masks = 1 << tl.arange(0, 32)
|
||||
bit_masks_expanded = bit_masks[None, :]
|
||||
prompt_mask = ((packed_mask_val[:, None] & bit_masks) != 0).reshape(INNER_BLOCK_SIZE)
|
||||
|
||||
packed_expanded = packed_mask[:, None]
|
||||
bits_matrix = (packed_expanded & bit_masks_expanded) != 0
|
||||
needs_scaling = prompt_mask | is_present
|
||||
|
||||
prompt_bin_mask = bits_matrix.reshape(BLOCK_SIZE)
|
||||
base_factor = tl.where(logits > 0, inv_rep, rep_penalty)
|
||||
logits = tl.where(needs_scaling, logits * base_factor, logits)
|
||||
|
||||
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
|
||||
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
|
||||
freq_term = freq_penalty * total_counts.to(tl.float32)
|
||||
pres_term = pres_penalty * is_present.to(tl.float32)
|
||||
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
logits *= tl.where(logits > 0, 1.0 / scale, scale)
|
||||
|
||||
# Apply frequency penalties.
|
||||
logits -= freq_penalty * output_bin_counts
|
||||
# Apply presence penalties.
|
||||
logits -= pres_penalty * output_bin_mask
|
||||
|
||||
# Apply temperature.
|
||||
logits = logits / temperature
|
||||
|
||||
# Store back to logits.
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
logits = logits - freq_term - pres_term
|
||||
# Store back to logits.
|
||||
tl.store(logits_ptr + token_idx * logits_stride + inner_block, logits, mask=inner_mask)
|
||||
|
||||
|
||||
def apply_penalties_and_temperature(
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
idx_mapping: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
repetition_penalty: torch.Tensor,
|
||||
frequency_penalty: torch.Tensor,
|
||||
presence_penalty: torch.Tensor,
|
||||
prompt_bin_mask: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor,
|
||||
num_speculative_tokens: int,
|
||||
) -> None:
|
||||
"""Override the function because there are some bugs
|
||||
when _penalties_and_temperature_kernel runs on npu, we need to make some fixes.
|
||||
you could read NOTE(Ronald1995) comments to understand.
|
||||
"""
|
||||
num_reqs, vocab_size = logits.shape
|
||||
# NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow
|
||||
# in triton-ascend.
|
||||
BLOCK_SIZE = 4096
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
INNER_BLOCK_SIZE = 4096
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
|
||||
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
|
||||
|
||||
penalties = torch.stack(
|
||||
[repetition_penalty[:num_tokens], frequency_penalty[:num_tokens], presence_penalty[:num_tokens]], dim=1
|
||||
).contiguous()
|
||||
penalties_stride = penalties.stride(0)
|
||||
|
||||
_penalties_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampling_metadata.repetition_penalty,
|
||||
sampling_metadata.frequency_penalty,
|
||||
sampling_metadata.presence_penalty,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.idx_mapping,
|
||||
sampling_metadata.prompt_bin_mask,
|
||||
sampling_metadata.prompt_bin_mask.stride(0),
|
||||
sampling_metadata.output_bin_counts,
|
||||
sampling_metadata.output_bin_counts.stride(0),
|
||||
idx_mapping,
|
||||
token_ids,
|
||||
expanded_local_pos,
|
||||
penalties,
|
||||
penalties_stride,
|
||||
prompt_bin_mask,
|
||||
prompt_bin_mask.stride(0),
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
INNER_BLOCK_SIZE=INNER_BLOCK_SIZE,
|
||||
MAX_SPEC_LEN=num_speculative_tokens,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user