Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -1,10 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.sampling_params import RepetitionDetectionParams
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def _has_repeating_pattern(
|
||||
token_ids: Sequence[int],
|
||||
pattern_len: int,
|
||||
repetition_min_count: int,
|
||||
) -> bool:
|
||||
"""Check if the tail of token_ids contains a repeating pattern.
|
||||
|
||||
Compares the last pattern_len tokens against the preceding
|
||||
(repetition_min_count - 1) repetitions of the same length.
|
||||
"""
|
||||
for n in range(1, pattern_len + 1):
|
||||
target_token = token_ids[-n]
|
||||
for m in range(1, repetition_min_count):
|
||||
if token_ids[-(pattern_len * m + n)] != target_token:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_sequence_repetition(
|
||||
token_ids: Sequence[int],
|
||||
params: RepetitionDetectionParams,
|
||||
) -> bool:
|
||||
"""Check if a sequence of token IDs has a repetition pattern.
|
||||
Args:
|
||||
token_ids: List of token IDs
|
||||
params: Repetition detection parameters.
|
||||
Returns:
|
||||
True if a repetition pattern is found, False otherwise.
|
||||
"""
|
||||
max_pattern_size = params.max_pattern_size
|
||||
min_pattern_size = params.min_pattern_size
|
||||
min_count = params.min_count
|
||||
|
||||
if min_pattern_size <= 0:
|
||||
min_pattern_size = 1
|
||||
|
||||
if max_pattern_size <= 0 or min_count < 2 or min_pattern_size > max_pattern_size:
|
||||
return False
|
||||
|
||||
for pattern_len in range(
|
||||
min_pattern_size,
|
||||
max_pattern_size + 1,
|
||||
):
|
||||
if pattern_len * min_count > len(token_ids):
|
||||
return False
|
||||
|
||||
if _has_repeating_pattern(token_ids, pattern_len, min_count):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def remove_all(lst: list, items_to_remove: set) -> list:
|
||||
"""Remove all items from a list that are in the items_to_remove set.
|
||||
|
||||
@@ -61,4 +115,16 @@ def check_stop(request: Request, max_model_len: int) -> bool:
|
||||
):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
|
||||
repetition_detection = sampling_params.repetition_detection
|
||||
if repetition_detection is not None and (
|
||||
check_sequence_repetition(
|
||||
request.output_token_ids,
|
||||
repetition_detection,
|
||||
)
|
||||
):
|
||||
request.status = RequestStatus.FINISHED_REPETITION
|
||||
request.stop_reason = "repetition_detected"
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user