Sync from v0.13
This commit is contained in:
52
vllm/v1/sample/ops/bad_words.py
Normal file
52
vllm/v1/sample/ops/bad_words.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
|
||||
|
||||
def _apply_bad_words_single_batch(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: list[list[int]],
|
||||
past_tokens_ids: list[int],
|
||||
) -> None:
|
||||
for bad_word_ids in bad_words_token_ids:
|
||||
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||
continue
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else []
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
if actual_prefix == expected_prefix:
|
||||
logits[last_token_id] = _SMALLEST_LOGIT
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
) -> None:
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
_apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i])
|
||||
|
||||
|
||||
def apply_bad_words_with_drafts(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> None:
|
||||
start_idx = 0
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
for draft_idx in range(num_draft_tokens[i]):
|
||||
_apply_bad_words_single_batch(
|
||||
logits[start_idx + draft_idx],
|
||||
bad_words_ids,
|
||||
past_tokens_ids[start_idx + draft_idx],
|
||||
)
|
||||
start_idx += num_draft_tokens[i]
|
||||
Reference in New Issue
Block a user