Files
enginex-mthreads-vllm/vllm/v1/worker/gpu/spec_decode/rejection_sample.py
2026-01-19 10:38:50 +08:00

72 lines
2.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
input_ids: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
input_ids,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled