[MODELRUNNERV2]fix penality ops (#7013)
### What this PR does / why we need it?
fix penality ops for new version, and achieved a 10% performance
improvement
### How was this patch tested?
pytest
tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
@@ -1,63 +1,47 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm_ascend.worker.v2.sample.penalties import apply_penalties_and_temperature
|
from vllm_ascend.worker.v2.sample.penalties import apply_penalties
|
||||||
|
|
||||||
DTYPES = [torch.bfloat16, torch.float16]
|
DTYPES = [torch.bfloat16, torch.float16]
|
||||||
NUM_REQS = [2, 4, 8]
|
NUM_TOKENS = [2, 4, 8]
|
||||||
VOCAB_SIZE = [151936]
|
VOCAB_SIZE = [151936]
|
||||||
NUM_STATUS = [1, 4, 8, 16]
|
NUM_STATUS = [1, 4, 8, 16]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
DEVICES = [f"npu:{0}"]
|
DEVICES = [f"npu:{0}"]
|
||||||
|
NUM_SPECULATIVE_TOKENS = [0, 1, 3]
|
||||||
DEFAULT_ATOL = 1e-3
|
DEFAULT_ATOL = 1e-3
|
||||||
DEFAULT_RTOL = 1e-3
|
DEFAULT_RTOL = 1e-3
|
||||||
|
|
||||||
|
|
||||||
class SamplingMetadata:
|
def pytorch_apply_penalties(
|
||||||
def __init__(self,
|
|
||||||
repetition_penalty: torch.Tensor,
|
|
||||||
frequency_penalty: torch.Tensor,
|
|
||||||
presence_penalty: torch.Tensor,
|
|
||||||
temperature: torch.Tensor,
|
|
||||||
idx_mapping: torch.Tensor,
|
|
||||||
prompt_bin_mask: torch.Tensor,
|
|
||||||
output_bin_counts: torch.Tensor):
|
|
||||||
self.repetition_penalty = repetition_penalty
|
|
||||||
self.frequency_penalty = frequency_penalty
|
|
||||||
self.presence_penalty = presence_penalty
|
|
||||||
self.temperature = temperature
|
|
||||||
self.idx_mapping = idx_mapping
|
|
||||||
self.prompt_bin_mask = prompt_bin_mask
|
|
||||||
self.output_bin_counts = output_bin_counts
|
|
||||||
|
|
||||||
|
|
||||||
def pytorch_apply_penalties_and_temperature(
|
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
idx_mapping: torch.Tensor,
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
expanded_local_pos: torch.Tensor,
|
||||||
|
repetition_penalty: torch.Tensor,
|
||||||
|
frequency_penalty: torch.Tensor,
|
||||||
|
presence_penalty: torch.Tensor,
|
||||||
|
prompt_bin_mask: torch.Tensor,
|
||||||
|
output_bin_counts: torch.Tensor,
|
||||||
|
num_speculative_tokens: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pytorch equivalent implementation
|
Pytorch equivalent implementation
|
||||||
"""
|
"""
|
||||||
num_reqs, vocab_size = logits.shape
|
num_tokens, vocab_size = logits.shape
|
||||||
device = logits.device
|
device = logits.device
|
||||||
dtype = logits.dtype
|
dtype = logits.dtype
|
||||||
|
|
||||||
logits_float = logits.float()
|
logits_float = logits.float()
|
||||||
|
|
||||||
repetition_penalty = sampling_metadata.repetition_penalty
|
|
||||||
frequency_penalty = sampling_metadata.frequency_penalty
|
|
||||||
presence_penalty = sampling_metadata.presence_penalty
|
|
||||||
temperature = sampling_metadata.temperature
|
|
||||||
idx_mapping = sampling_metadata.idx_mapping
|
|
||||||
prompt_bin_mask = sampling_metadata.prompt_bin_mask
|
|
||||||
output_bin_counts = sampling_metadata.output_bin_counts
|
|
||||||
|
|
||||||
temperature = torch.where(temperature == 0.0, torch.ones_like(temperature), temperature)
|
|
||||||
|
|
||||||
num_status = prompt_bin_mask.shape[0]
|
num_status = prompt_bin_mask.shape[0]
|
||||||
num_packed = prompt_bin_mask.shape[1]
|
num_packed = prompt_bin_mask.shape[1]
|
||||||
|
|
||||||
prompt_masks_unpacked = torch.zeros(num_status, vocab_size, dtype=torch.bool, device=device)
|
prompt_masks_unpacked = torch.zeros(
|
||||||
|
num_status, vocab_size, dtype=torch.bool,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
for state_idx in range(num_status):
|
for state_idx in range(num_status):
|
||||||
for packed_idx in range(num_packed):
|
for packed_idx in range(num_packed):
|
||||||
@@ -69,82 +53,99 @@ def pytorch_apply_penalties_and_temperature(
|
|||||||
if (packed_val >> bit_pos) & 1:
|
if (packed_val >> bit_pos) & 1:
|
||||||
prompt_masks_unpacked[state_idx, start_idx + bit_pos] = True
|
prompt_masks_unpacked[state_idx, start_idx + bit_pos] = True
|
||||||
|
|
||||||
for batch_idx in range(num_reqs):
|
for token_idx in range(num_tokens):
|
||||||
req_state_idx = idx_mapping[batch_idx].item()
|
req_state_idx = idx_mapping[token_idx].item()
|
||||||
|
|
||||||
rep_penalty = repetition_penalty[batch_idx].item()
|
rep_penalty = repetition_penalty[req_state_idx].item()
|
||||||
freq_penalty = frequency_penalty[batch_idx].item()
|
freq_penalty = frequency_penalty[req_state_idx].item()
|
||||||
pres_penalty = presence_penalty[batch_idx].item()
|
pres_penalty = presence_penalty[req_state_idx].item()
|
||||||
temp = temperature[batch_idx].item()
|
|
||||||
|
|
||||||
use_rep_penalty = rep_penalty != 1.0
|
use_rep_penalty = rep_penalty != 1.0
|
||||||
use_freq_penalty = freq_penalty != 0.0
|
use_freq_penalty = freq_penalty != 0.0
|
||||||
use_pres_penalty = pres_penalty != 0.0
|
use_pres_penalty = pres_penalty != 0.0
|
||||||
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
|
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
|
||||||
use_temperature = temp != 1.0
|
|
||||||
|
|
||||||
if not (use_penalty or use_temperature):
|
if not use_penalty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
current_prompt_mask = prompt_masks_unpacked[req_state_idx]
|
current_prompt_mask = prompt_masks_unpacked[req_state_idx]
|
||||||
current_output_counts = output_bin_counts[req_state_idx]
|
base_output_counts = output_bin_counts[req_state_idx]
|
||||||
output_bin_mask = current_output_counts > 0
|
|
||||||
|
# Compute cumulative draft counts
|
||||||
|
pos = expanded_local_pos[token_idx].item()
|
||||||
|
start_idx_in_batch = token_idx - pos
|
||||||
|
draft_counts = torch.zeros(vocab_size, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
for prev_pos in range(num_speculative_tokens):
|
||||||
|
if prev_pos < pos:
|
||||||
|
prev_token = token_ids[start_idx_in_batch + prev_pos + 1].item()
|
||||||
|
draft_counts[prev_token] += 1
|
||||||
|
|
||||||
|
# Total counts = base output counts + cumulative draft counts
|
||||||
|
total_output_counts = base_output_counts + draft_counts
|
||||||
|
output_bin_mask = total_output_counts > 0
|
||||||
|
|
||||||
if use_rep_penalty:
|
if use_rep_penalty:
|
||||||
scale = torch.ones(vocab_size, device=device)
|
scale = torch.ones(vocab_size, device=device)
|
||||||
mask = current_prompt_mask | output_bin_mask
|
mask = current_prompt_mask | output_bin_mask
|
||||||
scale[mask] = rep_penalty
|
scale[mask] = rep_penalty
|
||||||
|
|
||||||
pos_mask = logits_float[batch_idx] > 0
|
pos_mask = logits_float[token_idx] > 0
|
||||||
scale_factor = torch.where(pos_mask, 1.0 / scale, scale)
|
scale_factor = torch.where(pos_mask, 1.0 / scale, scale)
|
||||||
logits_float[batch_idx] *= scale_factor
|
logits_float[token_idx] *= scale_factor
|
||||||
|
|
||||||
if use_freq_penalty:
|
if use_freq_penalty:
|
||||||
logits_float[batch_idx] -= freq_penalty * current_output_counts.float()
|
logits_float[token_idx] -= freq_penalty * total_output_counts.float()
|
||||||
|
|
||||||
if use_pres_penalty:
|
if use_pres_penalty:
|
||||||
logits_float[batch_idx] -= pres_penalty * output_bin_mask.float()
|
logits_float[token_idx] -= pres_penalty * output_bin_mask.float()
|
||||||
|
|
||||||
if use_temperature:
|
|
||||||
logits_float[batch_idx] /= temp
|
|
||||||
|
|
||||||
return logits_float.to(dtype)
|
return logits_float.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
def create_test_data(
|
def create_test_data(
|
||||||
num_reqs: int = 8,
|
num_tokens: int = 8,
|
||||||
vocab_size: int = 51200,
|
vocab_size: int = 51200,
|
||||||
num_status: int = 16,
|
num_status: int = 16,
|
||||||
|
num_speculative_tokens: int = 3,
|
||||||
device: str = "npu",
|
device: str = "npu",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
):
|
):
|
||||||
"""Create test data for penalties and temperature"""
|
"""Create test data for penalties"""
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
logits = torch.randn(num_reqs, vocab_size, device=device, dtype=dtype)
|
logits = torch.randn(num_tokens, vocab_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
repetition_penalty = torch.ones(num_reqs, device=device, dtype=torch.float32)
|
repetition_penalty = torch.ones(num_status, device=device, dtype=torch.float32)
|
||||||
for i in range(num_reqs):
|
for i in range(num_status):
|
||||||
if torch.rand(1) > 0.3:
|
if torch.rand(1) > 0.3:
|
||||||
repetition_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6
|
repetition_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6
|
||||||
|
|
||||||
frequency_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
|
frequency_penalty = torch.zeros(num_status, device=device, dtype=torch.float32)
|
||||||
for i in range(num_reqs):
|
for i in range(num_status):
|
||||||
if torch.rand(1) > 0.5:
|
if torch.rand(1) > 0.5:
|
||||||
frequency_penalty[i] = torch.rand(1, device=device).item() * 0.2
|
frequency_penalty[i] = torch.rand(1, device=device).item() * 0.2
|
||||||
|
|
||||||
presence_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32)
|
presence_penalty = torch.zeros(num_status, device=device, dtype=torch.float32)
|
||||||
for i in range(num_reqs):
|
for i in range(num_status):
|
||||||
if torch.rand(1) > 0.5:
|
if torch.rand(1) > 0.5:
|
||||||
presence_penalty[i] = torch.rand(1, device=device).item() * 0.2
|
presence_penalty[i] = torch.rand(1, device=device).item() * 0.2
|
||||||
|
|
||||||
temperature = torch.ones(num_reqs, device=device, dtype=torch.float32)
|
idx_mapping = torch.randint(
|
||||||
for i in range(num_reqs):
|
0, num_status, (num_tokens,), device=device,
|
||||||
if torch.rand(1) > 0.2:
|
dtype=torch.int32
|
||||||
presence_penalty[i] = torch.rand(1, device=device).item() * 1.8 + 0.2
|
)
|
||||||
|
|
||||||
idx_mapping = torch.randint(0, num_status, (num_reqs,), device=device, dtype=torch.int32)
|
# Create token_ids for speculative decoding
|
||||||
|
token_ids = torch.randint(0, vocab_size, (num_tokens,), device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
# Create expanded_local_pos (position within speculative decoding window)
|
||||||
|
expanded_local_pos = torch.zeros(num_tokens, device=device, dtype=torch.int32)
|
||||||
|
for i in range(num_tokens):
|
||||||
|
expanded_local_pos[i] = torch.randint(
|
||||||
|
0, num_speculative_tokens + 1, (1,)
|
||||||
|
).item()
|
||||||
|
|
||||||
num_packed = (vocab_size + 31) // 32
|
num_packed = (vocab_size + 31) // 32
|
||||||
prompt_bin_mask = torch.zeros(num_status, num_packed, device=device, dtype=torch.int32)
|
prompt_bin_mask = torch.zeros(num_status, num_packed, device=device, dtype=torch.int32)
|
||||||
@@ -161,44 +162,60 @@ def create_test_data(
|
|||||||
output_bin_counts = torch.zeros(num_status, vocab_size, device=device, dtype=torch.int32)
|
output_bin_counts = torch.zeros(num_status, vocab_size, device=device, dtype=torch.int32)
|
||||||
for state_idx in range(num_status):
|
for state_idx in range(num_status):
|
||||||
num_output_tokens = max(1, vocab_size // 20)
|
num_output_tokens = max(1, vocab_size // 20)
|
||||||
output_tokens = torch.randint(0, vocab_size, (num_output_tokens, ))
|
output_tokens = torch.randint(0, vocab_size,
|
||||||
|
(num_output_tokens, ))
|
||||||
counts = torch.randint(1, 10, (num_output_tokens,))
|
counts = torch.randint(1, 10, (num_output_tokens,))
|
||||||
|
|
||||||
for token, count in zip(output_tokens, counts):
|
for token, count in zip(output_tokens, counts):
|
||||||
output_bin_counts[state_idx, token] = count
|
output_bin_counts[state_idx, token] = count
|
||||||
|
|
||||||
sampling_metadata = SamplingMetadata(
|
return (
|
||||||
repetition_penalty=repetition_penalty,
|
logits,
|
||||||
frequency_penalty=frequency_penalty,
|
idx_mapping,
|
||||||
presence_penalty=presence_penalty,
|
token_ids,
|
||||||
temperature=temperature,
|
expanded_local_pos,
|
||||||
idx_mapping=idx_mapping,
|
repetition_penalty,
|
||||||
prompt_bin_mask=prompt_bin_mask,
|
frequency_penalty,
|
||||||
output_bin_counts=output_bin_counts
|
presence_penalty,
|
||||||
|
prompt_bin_mask,
|
||||||
|
output_bin_counts,
|
||||||
|
num_speculative_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits, sampling_metadata
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
@pytest.mark.parametrize("num_reqs", NUM_REQS)
|
|
||||||
@pytest.mark.parametrize("vocab_size", VOCAB_SIZE)
|
@pytest.mark.parametrize("vocab_size", VOCAB_SIZE)
|
||||||
@pytest.mark.parametrize("num_status", NUM_STATUS)
|
@pytest.mark.parametrize("num_status", NUM_STATUS)
|
||||||
|
@pytest.mark.parametrize("num_speculative_tokens", NUM_SPECULATIVE_TOKENS)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_apply_penalties_and_temperature(
|
def test_apply_penalties(
|
||||||
num_reqs,
|
num_tokens,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
num_status,
|
num_status,
|
||||||
|
num_speculative_tokens,
|
||||||
dtype,
|
dtype,
|
||||||
seed,
|
seed,
|
||||||
device
|
device
|
||||||
):
|
):
|
||||||
logits_triton, sampling_metadata = create_test_data(
|
(
|
||||||
num_reqs=num_reqs,
|
logits_triton,
|
||||||
|
idx_mapping,
|
||||||
|
token_ids,
|
||||||
|
expanded_local_pos,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
presence_penalty,
|
||||||
|
prompt_bin_mask,
|
||||||
|
output_bin_counts,
|
||||||
|
num_spec_tokens,
|
||||||
|
) = create_test_data(
|
||||||
|
num_tokens=num_tokens,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
num_status=num_status,
|
num_status=num_status,
|
||||||
|
num_speculative_tokens=num_speculative_tokens,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=seed
|
seed=seed
|
||||||
@@ -206,14 +223,35 @@ def test_apply_penalties_and_temperature(
|
|||||||
|
|
||||||
logits_pytorch = logits_triton.clone()
|
logits_pytorch = logits_triton.clone()
|
||||||
|
|
||||||
apply_penalties_and_temperature(logits_triton, sampling_metadata)
|
apply_penalties(
|
||||||
|
logits_triton,
|
||||||
|
idx_mapping,
|
||||||
|
token_ids,
|
||||||
|
expanded_local_pos,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
presence_penalty,
|
||||||
|
prompt_bin_mask,
|
||||||
|
output_bin_counts,
|
||||||
|
num_spec_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
logits_pytorch_result = pytorch_apply_penalties_and_temperature(logits_pytorch,
|
logits_pytorch_result = pytorch_apply_penalties(
|
||||||
sampling_metadata)
|
logits_pytorch,
|
||||||
|
idx_mapping,
|
||||||
|
token_ids,
|
||||||
|
expanded_local_pos,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
presence_penalty,
|
||||||
|
prompt_bin_mask,
|
||||||
|
output_bin_counts,
|
||||||
|
num_spec_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
atol = DEFAULT_ATOL
|
atol = DEFAULT_ATOL
|
||||||
rtol = DEFAULT_RTOL
|
rtol = DEFAULT_RTOL
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
atol = 1e-02
|
atol = 1e-02
|
||||||
rtol = 1e-02
|
rtol = 1e-02
|
||||||
assert torch.allclose(logits_triton, logits_pytorch_result, atol=atol, rtol=rtol)
|
assert torch.allclose(logits_triton, logits_pytorch_result, atol=atol, rtol=rtol)
|
||||||
@@ -20,122 +20,141 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _penalties_and_temperature_kernel(
|
def _penalties_kernel(
|
||||||
logits_ptr,
|
logits_ptr,
|
||||||
logits_stride,
|
logits_stride,
|
||||||
repetition_penalty_ptr,
|
|
||||||
frequency_penalty_ptr,
|
|
||||||
presence_penalty_ptr,
|
|
||||||
temperature_ptr,
|
|
||||||
idx_mapping_ptr,
|
idx_mapping_ptr,
|
||||||
|
token_ids_ptr,
|
||||||
|
expanded_local_pos_ptr,
|
||||||
|
penalties_ptr,
|
||||||
|
penalties_stride,
|
||||||
prompt_bin_mask_ptr,
|
prompt_bin_mask_ptr,
|
||||||
prompt_bin_mask_stride,
|
prompt_bin_mask_stride,
|
||||||
output_bin_counts_ptr,
|
output_bin_counts_ptr,
|
||||||
output_bin_counts_stride,
|
output_bin_counts_stride,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
INNER_BLOCK_SIZE: tl.constexpr,
|
||||||
|
MAX_SPEC_LEN: tl.constexpr,
|
||||||
):
|
):
|
||||||
batch_idx = tl.program_id(0)
|
token_idx = tl.program_id(0)
|
||||||
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
|
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
|
||||||
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
|
|
||||||
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
|
# first load penalties once
|
||||||
temperature = tl.load(temperature_ptr + batch_idx)
|
rep_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 0)
|
||||||
temperature = tl.where(temperature == 0.0, 1.0, temperature)
|
freq_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 1)
|
||||||
|
pres_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 2)
|
||||||
|
|
||||||
use_rep_penalty = rep_penalty != 1.0
|
use_rep_penalty = rep_penalty != 1.0
|
||||||
use_freq_penalty = freq_penalty != 0.0
|
use_freq_penalty = freq_penalty != 0.0
|
||||||
use_pres_penalty = pres_penalty != 0.0
|
use_pres_penalty = pres_penalty != 0.0
|
||||||
# NOTE(Ronald1995): vllm original grammar `use_rep_penalty or
|
|
||||||
# use_freq_penalty or use_pres_penalty`,
|
# NPU doesn't support chained 'or' operations like 'A or B or C'
|
||||||
# change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`,
|
use_penalty = use_rep_penalty or use_freq_penalty
|
||||||
# because triton-ascend's compiler doesn't support chained boolean operator.
|
use_penalty = use_penalty or use_pres_penalty
|
||||||
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
|
|
||||||
use_temperature = temperature != 1.0
|
if not use_penalty:
|
||||||
if not (use_penalty or use_temperature):
|
|
||||||
# Early return to avoid loading logits.
|
# Early return to avoid loading logits.
|
||||||
return
|
return
|
||||||
|
|
||||||
|
bit_masks = tl.full((INNER_BLOCK_SIZE // 32, 32), 1, dtype=tl.int32) << tl.arange(0, 32)
|
||||||
block_idx = tl.program_id(1)
|
block_idx = tl.program_id(1)
|
||||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
block_start = block_idx * BLOCK_SIZE
|
||||||
mask = block < vocab_size
|
|
||||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
|
||||||
logits = logits.to(tl.float32)
|
|
||||||
|
|
||||||
if use_penalty:
|
pos = tl.load(expanded_local_pos_ptr + token_idx)
|
||||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
start_idx = token_idx - pos
|
||||||
output_bin_counts = tl.load(
|
|
||||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
|
inv_rep = 1.0 / rep_penalty
|
||||||
mask=mask,
|
|
||||||
|
for inner_offset in tl.static_range(0, BLOCK_SIZE, INNER_BLOCK_SIZE):
|
||||||
|
inner_block_start = block_start + inner_offset
|
||||||
|
inner_block = inner_block_start + tl.arange(0, INNER_BLOCK_SIZE)
|
||||||
|
inner_mask = inner_block < vocab_size
|
||||||
|
|
||||||
|
logits = tl.load(logits_ptr + token_idx * logits_stride + inner_block, mask=inner_mask, other=0.0)
|
||||||
|
logits = logits.to(tl.float32)
|
||||||
|
|
||||||
|
base_output_counts = tl.load(
|
||||||
|
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + inner_block,
|
||||||
|
mask=inner_mask,
|
||||||
|
other=0,
|
||||||
)
|
)
|
||||||
# to use vector core, if use > 0 will use scalar to slow down performance
|
|
||||||
output_bin_mask = output_bin_counts != 0
|
# Compute cumulative draft_counts from previous positions in this request
|
||||||
|
total_counts = base_output_counts.to(tl.int32)
|
||||||
|
for prev_pos in tl.static_range(MAX_SPEC_LEN):
|
||||||
|
if prev_pos < pos:
|
||||||
|
load_idx = start_idx + prev_pos + 1
|
||||||
|
prev_token = tl.load(token_ids_ptr + load_idx)
|
||||||
|
total_counts += inner_block == prev_token
|
||||||
|
|
||||||
|
is_present = total_counts != 0
|
||||||
|
|
||||||
# Apply repetition penalties.
|
# Apply repetition penalties.
|
||||||
if use_rep_penalty:
|
if use_rep_penalty:
|
||||||
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
packed_inner_block_start = inner_block_start // 32
|
||||||
packed_mask = tl.load(
|
packed_block = packed_inner_block_start + tl.arange(0, INNER_BLOCK_SIZE // 32)
|
||||||
|
valid_packed_mask = packed_block < tl.cdiv(vocab_size, 32)
|
||||||
|
|
||||||
|
packed_mask_val = tl.load(
|
||||||
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
|
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
|
||||||
mask=packed_block < tl.cdiv(vocab_size, 32),
|
mask=valid_packed_mask,
|
||||||
|
other=0,
|
||||||
)
|
)
|
||||||
# the compiler itself does not optimize right-shift operations, so we change the same func
|
prompt_mask = ((packed_mask_val[:, None] & bit_masks) != 0).reshape(INNER_BLOCK_SIZE)
|
||||||
bit_masks = 1 << tl.arange(0, 32)
|
|
||||||
bit_masks_expanded = bit_masks[None, :]
|
|
||||||
|
|
||||||
packed_expanded = packed_mask[:, None]
|
needs_scaling = prompt_mask | is_present
|
||||||
bits_matrix = (packed_expanded & bit_masks_expanded) != 0
|
|
||||||
|
|
||||||
prompt_bin_mask = bits_matrix.reshape(BLOCK_SIZE)
|
base_factor = tl.where(logits > 0, inv_rep, rep_penalty)
|
||||||
|
logits = tl.where(needs_scaling, logits * base_factor, logits)
|
||||||
|
|
||||||
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
|
freq_term = freq_penalty * total_counts.to(tl.float32)
|
||||||
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
|
pres_term = pres_penalty * is_present.to(tl.float32)
|
||||||
|
|
||||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
logits = logits - freq_term - pres_term
|
||||||
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
|
# Store back to logits.
|
||||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
tl.store(logits_ptr + token_idx * logits_stride + inner_block, logits, mask=inner_mask)
|
||||||
logits *= tl.where(logits > 0, 1.0 / scale, scale)
|
|
||||||
|
|
||||||
# Apply frequency penalties.
|
|
||||||
logits -= freq_penalty * output_bin_counts
|
|
||||||
# Apply presence penalties.
|
|
||||||
logits -= pres_penalty * output_bin_mask
|
|
||||||
|
|
||||||
# Apply temperature.
|
|
||||||
logits = logits / temperature
|
|
||||||
|
|
||||||
# Store back to logits.
|
|
||||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_penalties_and_temperature(
|
def apply_penalties(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
idx_mapping: torch.Tensor,
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
expanded_local_pos: torch.Tensor,
|
||||||
|
repetition_penalty: torch.Tensor,
|
||||||
|
frequency_penalty: torch.Tensor,
|
||||||
|
presence_penalty: torch.Tensor,
|
||||||
|
prompt_bin_mask: torch.Tensor,
|
||||||
|
output_bin_counts: torch.Tensor,
|
||||||
|
num_speculative_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Override the function because there are some bugs
|
num_tokens, vocab_size = logits.shape
|
||||||
when _penalties_and_temperature_kernel runs on npu, we need to make some fixes.
|
BLOCK_SIZE = 8192
|
||||||
you could read NOTE(Ronald1995) comments to understand.
|
INNER_BLOCK_SIZE = 4096
|
||||||
"""
|
|
||||||
num_reqs, vocab_size = logits.shape
|
|
||||||
# NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow
|
|
||||||
# in triton-ascend.
|
|
||||||
BLOCK_SIZE = 4096
|
|
||||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||||
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
|
|
||||||
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
|
penalties = torch.stack(
|
||||||
|
[repetition_penalty[:num_tokens], frequency_penalty[:num_tokens], presence_penalty[:num_tokens]], dim=1
|
||||||
|
).contiguous()
|
||||||
|
penalties_stride = penalties.stride(0)
|
||||||
|
|
||||||
|
_penalties_kernel[(num_tokens, num_blocks)](
|
||||||
logits,
|
logits,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
sampling_metadata.repetition_penalty,
|
idx_mapping,
|
||||||
sampling_metadata.frequency_penalty,
|
token_ids,
|
||||||
sampling_metadata.presence_penalty,
|
expanded_local_pos,
|
||||||
sampling_metadata.temperature,
|
penalties,
|
||||||
sampling_metadata.idx_mapping,
|
penalties_stride,
|
||||||
sampling_metadata.prompt_bin_mask,
|
prompt_bin_mask,
|
||||||
sampling_metadata.prompt_bin_mask.stride(0),
|
prompt_bin_mask.stride(0),
|
||||||
sampling_metadata.output_bin_counts,
|
output_bin_counts,
|
||||||
sampling_metadata.output_bin_counts.stride(0),
|
output_bin_counts.stride(0),
|
||||||
vocab_size,
|
vocab_size,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
INNER_BLOCK_SIZE=INNER_BLOCK_SIZE,
|
||||||
|
MAX_SPEC_LEN=num_speculative_tokens,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user