# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton _SAMPLING_EPS = 1e-5 def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: """True if request is incompatible with speculative decoding""" return (sampling_params.frequency_penalty != 0.0 or sampling_params.presence_penalty != 0.0 or sampling_params.repetition_penalty != 1.0 or sampling_params.min_p > _SAMPLING_EPS or sampling_params.logprobs is not None) @triton.jit def prepare_eagle_input_kernel( out_ptr, cu_query_lens_ptr, cu_num_tokens_ptr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) # [start_pos, end_pos) start_pos = tl.load(cu_num_tokens_ptr + pid) end_pos = tl.load(cu_num_tokens_ptr + pid + 1) num_tokens = end_pos - start_pos index_start = tl.load(cu_query_lens_ptr + pid) num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) for i in tl.range(num_blocks): offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) tl.store( out_ptr + start_pos + offset, index_start + offset, mask=offset < num_tokens, )