[MODELRUNNERV2]fix penality ops (#7013)

### What this PR does / why we need it?
fix penality ops for new version, and achieved a 10% performance
improvement

### How was this patch tested?
pytest
‎tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
shiyuan680
2026-03-11 17:13:34 +08:00
committed by GitHub
parent 830f39dd70
commit 3b6b3c4214
2 changed files with 222 additions and 165 deletions

View File

@@ -1,63 +1,47 @@
import pytest
import torch
from vllm_ascend.worker.v2.sample.penalties import apply_penalties_and_temperature
from vllm_ascend.worker.v2.sample.penalties import apply_penalties
DTYPES = [torch.bfloat16, torch.float16]
NUM_REQS = [2, 4, 8]
NUM_TOKENS = [2, 4, 8]
VOCAB_SIZE = [151936]
NUM_STATUS = [1, 4, 8, 16]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
NUM_SPECULATIVE_TOKENS = [0, 1, 3]
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
class SamplingMetadata:
def __init__(self,
repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor,
temperature: torch.Tensor,
idx_mapping: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor):
self.repetition_penalty = repetition_penalty
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.temperature = temperature
self.idx_mapping = idx_mapping
self.prompt_bin_mask = prompt_bin_mask
self.output_bin_counts = output_bin_counts
def pytorch_apply_penalties_and_temperature(
def pytorch_apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
num_speculative_tokens: int,
) -> torch.Tensor:
"""
Pytorch equivalent implementation
"""
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
device = logits.device
dtype = logits.dtype
logits_float = logits.float()
repetition_penalty = sampling_metadata.repetition_penalty
frequency_penalty = sampling_metadata.frequency_penalty
presence_penalty = sampling_metadata.presence_penalty
temperature = sampling_metadata.temperature
idx_mapping = sampling_metadata.idx_mapping
prompt_bin_mask = sampling_metadata.prompt_bin_mask
output_bin_counts = sampling_metadata.output_bin_counts
temperature = torch.where(temperature == 0.0, torch.ones_like(temperature), temperature)
num_status = prompt_bin_mask.shape[0]
num_packed = prompt_bin_mask.shape[1]
prompt_masks_unpacked = torch.zeros(num_status, vocab_size, dtype=torch.bool, device=device)
prompt_masks_unpacked = torch.zeros(
num_status, vocab_size, dtype=torch.bool,
device=device
)
for state_idx in range(num_status):
for packed_idx in range(num_packed):
@@ -69,82 +53,99 @@ def pytorch_apply_penalties_and_temperature(
if (packed_val >> bit_pos) & 1:
prompt_masks_unpacked[state_idx, start_idx + bit_pos] = True
for batch_idx in range(num_reqs):
req_state_idx = idx_mapping[batch_idx].item()
rep_penalty = repetition_penalty[batch_idx].item()
freq_penalty = frequency_penalty[batch_idx].item()
pres_penalty = presence_penalty[batch_idx].item()
temp = temperature[batch_idx].item()
for token_idx in range(num_tokens):
req_state_idx = idx_mapping[token_idx].item()
rep_penalty = repetition_penalty[req_state_idx].item()
freq_penalty = frequency_penalty[req_state_idx].item()
pres_penalty = presence_penalty[req_state_idx].item()
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
use_temperature = temp != 1.0
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
if not (use_penalty or use_temperature):
if not use_penalty:
continue
current_prompt_mask = prompt_masks_unpacked[req_state_idx]
current_output_counts = output_bin_counts[req_state_idx]
output_bin_mask = current_output_counts > 0
base_output_counts = output_bin_counts[req_state_idx]
# Compute cumulative draft counts
pos = expanded_local_pos[token_idx].item()
start_idx_in_batch = token_idx - pos
draft_counts = torch.zeros(vocab_size, device=device, dtype=torch.int32)
for prev_pos in range(num_speculative_tokens):
if prev_pos < pos:
prev_token = token_ids[start_idx_in_batch + prev_pos + 1].item()
draft_counts[prev_token] += 1
# Total counts = base output counts + cumulative draft counts
total_output_counts = base_output_counts + draft_counts
output_bin_mask = total_output_counts > 0
if use_rep_penalty:
scale = torch.ones(vocab_size, device=device)
mask = current_prompt_mask | output_bin_mask
scale[mask] = rep_penalty
pos_mask = logits_float[batch_idx] > 0
pos_mask = logits_float[token_idx] > 0
scale_factor = torch.where(pos_mask, 1.0 / scale, scale)
logits_float[batch_idx] *= scale_factor
logits_float[token_idx] *= scale_factor
if use_freq_penalty:
logits_float[batch_idx] -= freq_penalty * current_output_counts.float()
logits_float[token_idx] -= freq_penalty * total_output_counts.float()
if use_pres_penalty:
logits_float[batch_idx] -= pres_penalty * output_bin_mask.float()
if use_temperature:
logits_float[batch_idx] /= temp
logits_float[token_idx] -= pres_penalty * output_bin_mask.float()
return logits_float.to(dtype)
def create_test_data(
num_reqs: int = 8,
num_tokens: int = 8,
vocab_size: int = 51200,
num_status: int = 16,
num_speculative_tokens: int = 3,
device: str = "npu",
dtype: torch.dtype = torch.bfloat16,
seed: int = 42,
):
"""Create test data for penalties and temperature"""
"""Create test data for penalties"""
torch.manual_seed(seed)
logits = torch.randn(num_reqs, vocab_size, device=device, dtype=dtype)
logits = torch.randn(num_tokens, vocab_size, device=device, dtype=dtype)
repetition_penalty = torch.ones(num_reqs, device=device, dtype=torch.float32)
for i in range(num_reqs):
repetition_penalty = torch.ones(num_status, device=device, dtype=torch.float32)
for i in range(num_status):
if torch.rand(1) > 0.3:
repetition_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6
frequency_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
for i in range(num_reqs):
frequency_penalty = torch.zeros(num_status, device=device, dtype=torch.float32)
for i in range(num_status):
if torch.rand(1) > 0.5:
frequency_penalty[i] = torch.rand(1, device=device).item() * 0.2
presence_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
for i in range(num_reqs):
presence_penalty = torch.zeros(num_status, device=device, dtype=torch.float32)
for i in range(num_status):
if torch.rand(1) > 0.5:
presence_penalty[i] = torch.rand(1, device=device).item() * 0.2
temperature = torch.ones(num_reqs, device=device, dtype=torch.float32)
for i in range(num_reqs):
if torch.rand(1) > 0.2:
presence_penalty[i] = torch.rand(1, device=device).item() * 1.8 + 0.2
idx_mapping = torch.randint(0, num_status, (num_reqs,), device=device, dtype=torch.int32)
idx_mapping = torch.randint(
0, num_status, (num_tokens,), device=device,
dtype=torch.int32
)
# Create token_ids for speculative decoding
token_ids = torch.randint(0, vocab_size, (num_tokens,), device=device, dtype=torch.int32)
# Create expanded_local_pos (position within speculative decoding window)
expanded_local_pos = torch.zeros(num_tokens, device=device, dtype=torch.int32)
for i in range(num_tokens):
expanded_local_pos[i] = torch.randint(
0, num_speculative_tokens + 1, (1,)
).item()
num_packed = (vocab_size + 31) // 32
prompt_bin_mask = torch.zeros(num_status, num_packed, device=device, dtype=torch.int32)
@@ -161,44 +162,60 @@ def create_test_data(
output_bin_counts = torch.zeros(num_status, vocab_size, device=device, dtype=torch.int32)
for state_idx in range(num_status):
num_output_tokens = max(1, vocab_size // 20)
output_tokens = torch.randint(0, vocab_size, (num_output_tokens, ))
output_tokens = torch.randint(0, vocab_size,
(num_output_tokens, ))
counts = torch.randint(1, 10, (num_output_tokens,))
for token, count in zip(output_tokens, counts):
output_bin_counts[state_idx, token] = count
sampling_metadata = SamplingMetadata(
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
idx_mapping=idx_mapping,
prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts
return (
logits,
idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
frequency_penalty,
presence_penalty,
prompt_bin_mask,
output_bin_counts,
num_speculative_tokens,
)
return logits, sampling_metadata
@pytest.mark.parametrize("num_reqs", NUM_REQS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("vocab_size", VOCAB_SIZE)
@pytest.mark.parametrize("num_status", NUM_STATUS)
@pytest.mark.parametrize("num_speculative_tokens", NUM_SPECULATIVE_TOKENS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_apply_penalties_and_temperature(
num_reqs,
def test_apply_penalties(
num_tokens,
vocab_size,
num_status,
num_speculative_tokens,
dtype,
seed,
device
):
logits_triton, sampling_metadata = create_test_data(
num_reqs=num_reqs,
(
logits_triton,
idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
frequency_penalty,
presence_penalty,
prompt_bin_mask,
output_bin_counts,
num_spec_tokens,
) = create_test_data(
num_tokens=num_tokens,
vocab_size=vocab_size,
num_status=num_status,
num_speculative_tokens=num_speculative_tokens,
device=device,
dtype=dtype,
seed=seed
@@ -206,14 +223,35 @@ def test_apply_penalties_and_temperature(
logits_pytorch = logits_triton.clone()
apply_penalties_and_temperature(logits_triton, sampling_metadata)
apply_penalties(
logits_triton,
idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
frequency_penalty,
presence_penalty,
prompt_bin_mask,
output_bin_counts,
num_spec_tokens,
)
logits_pytorch_result = pytorch_apply_penalties_and_temperature(logits_pytorch,
sampling_metadata)
logits_pytorch_result = pytorch_apply_penalties(
logits_pytorch,
idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
frequency_penalty,
presence_penalty,
prompt_bin_mask,
output_bin_counts,
num_spec_tokens,
)
atol = DEFAULT_ATOL
rtol = DEFAULT_RTOL
if dtype == torch.bfloat16:
atol = 1e-02
rtol = 1e-02
assert torch.allclose(logits_triton, logits_pytorch_result, atol=atol, rtol=rtol)
assert torch.allclose(logits_triton, logits_pytorch_result, atol=atol, rtol=rtol)