# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch _SMALLEST_LOGIT = float("-inf") def _apply_bad_words_single_batch( logits: torch.Tensor, bad_words_token_ids: list[list[int]], past_tokens_ids: list[int], ) -> None: for bad_word_ids in bad_words_token_ids: if len(bad_word_ids) > len(past_tokens_ids) + 1: continue prefix_length = len(bad_word_ids) - 1 last_token_id = bad_word_ids[-1] actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else [] expected_prefix = bad_word_ids[:prefix_length] assert len(actual_prefix) == len(expected_prefix) if actual_prefix == expected_prefix: logits[last_token_id] = _SMALLEST_LOGIT def apply_bad_words( logits: torch.Tensor, bad_words_token_ids: dict[int, list[list[int]]], past_tokens_ids: list[list[int]], ) -> None: for i, bad_words_ids in bad_words_token_ids.items(): _apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i]) def apply_bad_words_with_drafts( logits: torch.Tensor, bad_words_token_ids: dict[int, list[list[int]]], past_tokens_ids: list[list[int]], num_draft_tokens: list[int], ) -> None: start_idx = 0 remaining = len(bad_words_token_ids) for i, n in enumerate(num_draft_tokens): if (bad_words_ids := bad_words_token_ids.get(i)) is not None: for draft_idx in range(start_idx, start_idx + n): _apply_bad_words_single_batch( logits[draft_idx], bad_words_ids, past_tokens_ids[draft_idx], ) remaining -= 1 if not remaining: break start_idx += n