Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
92
vllm/renderers/grok2.py
Normal file
92
vllm/renderers/grok2.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "Grok2Renderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=Grok2Tokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
return cls(config, tokenizer)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
Reference in New Issue
Block a user