42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
|
|
|
|
def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
|
|
"""True if request is incompatible with speculative decoding"""
|
|
return (sampling_params.frequency_penalty != 0.0
|
|
or sampling_params.presence_penalty != 0.0
|
|
or sampling_params.repetition_penalty != 1.0
|
|
or sampling_params.min_p > _SAMPLING_EPS
|
|
or sampling_params.logprobs is not None)
|
|
|
|
|
|
@triton.jit
|
|
def prepare_eagle_input_kernel(
|
|
out_ptr,
|
|
cu_query_lens_ptr,
|
|
cu_num_tokens_ptr,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(0)
|
|
|
|
# [start_pos, end_pos)
|
|
start_pos = tl.load(cu_num_tokens_ptr + pid)
|
|
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
|
|
num_tokens = end_pos - start_pos
|
|
|
|
index_start = tl.load(cu_query_lens_ptr + pid)
|
|
|
|
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
|
|
for i in tl.range(num_blocks):
|
|
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
tl.store(
|
|
out_ptr + start_pos + offset,
|
|
index_start + offset,
|
|
mask=offset < num_tokens,
|
|
)
|