134 lines
4.2 KiB
Python
134 lines
4.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Any
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.entrypoints.chat_utils import (
|
|
ChatCompletionMessageParam,
|
|
ConversationMessage,
|
|
parse_chat_messages,
|
|
parse_chat_messages_async,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.tokenizers import cached_get_tokenizer
|
|
from vllm.tokenizers.mistral import MistralTokenizer
|
|
from vllm.utils.async_utils import make_async
|
|
|
|
from .base import BaseRenderer
|
|
from .inputs import DictPrompt
|
|
from .inputs.preprocess import parse_dec_only_prompt
|
|
from .params import ChatParams
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def safe_apply_chat_template(
|
|
tokenizer: MistralTokenizer,
|
|
messages: list[ChatCompletionMessageParam],
|
|
**kwargs,
|
|
) -> str | list[int]:
|
|
from mistral_common.exceptions import MistralCommonException
|
|
|
|
try:
|
|
return tokenizer.apply_chat_template(messages, **kwargs)
|
|
# mistral-common uses assert statements to stop processing of input
|
|
# if input does not comply with the expected format.
|
|
# We convert those assertion errors to ValueErrors so they can be
|
|
# properly caught in the preprocessing_input step
|
|
except (AssertionError, MistralCommonException) as e:
|
|
raise ValueError(str(e)) from e
|
|
|
|
# External library exceptions can sometimes occur despite the framework's
|
|
# internal exception management capabilities.
|
|
except Exception as e:
|
|
# Log and report any library-related exceptions for further
|
|
# investigation.
|
|
logger.exception(
|
|
"An error occurred in `mistral_common` while applying chat template"
|
|
)
|
|
raise ValueError(str(e)) from e
|
|
|
|
|
|
class MistralRenderer(BaseRenderer[MistralTokenizer]):
|
|
@classmethod
|
|
def from_config( # type: ignore[override]
|
|
cls,
|
|
config: VllmConfig,
|
|
tokenizer_kwargs: dict[str, Any],
|
|
) -> "MistralRenderer":
|
|
model_config = config.model_config
|
|
if model_config.skip_tokenizer_init:
|
|
tokenizer = None
|
|
else:
|
|
tokenizer = cached_get_tokenizer(
|
|
tokenizer_cls=MistralTokenizer,
|
|
**tokenizer_kwargs,
|
|
)
|
|
|
|
return cls(config, tokenizer)
|
|
|
|
def __init__(
|
|
self,
|
|
config: VllmConfig,
|
|
tokenizer: MistralTokenizer | None,
|
|
) -> None:
|
|
super().__init__(config, tokenizer)
|
|
|
|
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
|
|
self._apply_chat_template_async = make_async(
|
|
safe_apply_chat_template, executor=self._apply_chat_template_executor
|
|
)
|
|
|
|
def render_messages(
|
|
self,
|
|
messages: list[ChatCompletionMessageParam],
|
|
params: ChatParams,
|
|
) -> tuple[list[ConversationMessage], DictPrompt]:
|
|
tokenizer = self.get_tokenizer()
|
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
|
messages,
|
|
self.model_config,
|
|
content_format="string",
|
|
)
|
|
|
|
prompt_raw = safe_apply_chat_template(
|
|
tokenizer,
|
|
messages,
|
|
**params.get_apply_chat_template_kwargs(),
|
|
)
|
|
|
|
prompt = parse_dec_only_prompt(prompt_raw)
|
|
if mm_data is not None:
|
|
prompt["multi_modal_data"] = mm_data
|
|
if mm_uuids is not None:
|
|
prompt["multi_modal_uuids"] = mm_uuids
|
|
|
|
return conversation, prompt
|
|
|
|
async def render_messages_async(
|
|
self,
|
|
messages: list[ChatCompletionMessageParam],
|
|
params: ChatParams,
|
|
) -> tuple[list[ConversationMessage], DictPrompt]:
|
|
tokenizer = self.get_tokenizer()
|
|
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
|
messages,
|
|
self.model_config,
|
|
content_format="string",
|
|
)
|
|
|
|
prompt_raw = await self._apply_chat_template_async(
|
|
tokenizer,
|
|
messages,
|
|
**params.get_apply_chat_template_kwargs(),
|
|
)
|
|
|
|
prompt = parse_dec_only_prompt(prompt_raw)
|
|
if mm_data is not None:
|
|
prompt["multi_modal_data"] = mm_data
|
|
if mm_uuids is not None:
|
|
prompt["multi_modal_uuids"] = mm_uuids
|
|
|
|
return conversation, prompt
|