47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
|
|
|
|
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
|
|
if req_id in input_batch.min_p_reqs:
|
|
# Spec decode doesn't support min_p sampling.
|
|
return False
|
|
elif (req_id in input_batch.frequency_penalties_reqs
|
|
or req_id in input_batch.presence_penalties_reqs
|
|
or req_id in input_batch.repetition_penalties_reqs):
|
|
# Spec decode doesn't support penalties.
|
|
return False
|
|
elif req_id in input_batch.num_logprobs:
|
|
# Spec decode doesn't support logprobs.
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
@triton.jit
|
|
def prepare_eagle_input_kernel(
|
|
out_ptr,
|
|
cu_query_lens_ptr,
|
|
cu_num_tokens_ptr,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(0)
|
|
|
|
# [start_pos, end_pos)
|
|
start_pos = tl.load(cu_num_tokens_ptr + pid)
|
|
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
|
|
num_tokens = end_pos - start_pos
|
|
|
|
index_start = tl.load(cu_query_lens_ptr + pid)
|
|
|
|
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
|
|
for i in tl.range(num_blocks):
|
|
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
tl.store(
|
|
out_ptr + start_pos + offset,
|
|
index_start + offset,
|
|
mask=offset < num_tokens,
|
|
)
|