Sync from v0.13
This commit is contained in:
71
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
Normal file
71
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rejection_sample_kernel(
|
||||
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
|
||||
sampled_stride,
|
||||
num_sampled_ptr, # [num_reqs]
|
||||
target_sampled_ptr, # [num_draft_tokens + num_reqs]
|
||||
input_ids_ptr, # [num_draft_tokens + num_reqs]
|
||||
cu_num_logits_ptr, # [num_reqs + 1]
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
num_sampled = 0
|
||||
rejected = False
|
||||
for i in range(num_tokens - 1):
|
||||
if not rejected:
|
||||
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
|
||||
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
|
||||
num_sampled += 1
|
||||
if target_sampled != draft_sampled:
|
||||
rejected = True
|
||||
if not rejected:
|
||||
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
|
||||
tl.store(
|
||||
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
|
||||
)
|
||||
num_sampled += 1
|
||||
tl.store(num_sampled_ptr + req_idx, num_sampled)
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
# [num_draft_tokens + num_reqs]
|
||||
target_sampled: torch.Tensor,
|
||||
# [num_draft_tokens + num_reqs]
|
||||
input_ids: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
sampled = torch.empty(
|
||||
num_reqs,
|
||||
num_speculative_steps + 1,
|
||||
dtype=target_sampled.dtype,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
num_sampled = torch.empty(
|
||||
num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
_rejection_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
num_sampled,
|
||||
target_sampled,
|
||||
input_ids,
|
||||
cu_num_logits,
|
||||
num_warps=1,
|
||||
)
|
||||
return sampled, num_sampled
|
||||
Reference in New Issue
Block a user