94 lines
2.8 KiB
Python
94 lines
2.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.tokenizers.registry import tokenizer_args_from_config
|
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
|
|
|
from .base import BaseRenderer
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
_VLLM_RENDERERS = {
|
|
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
|
|
"hf": ("hf", "HfRenderer"),
|
|
"grok2": ("grok2", "Grok2Renderer"),
|
|
"mistral": ("mistral", "MistralRenderer"),
|
|
"qwen_vl": ("qwen_vl", "QwenVLRenderer"),
|
|
"terratorch": ("terratorch", "TerratorchRenderer"),
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class RendererRegistry:
|
|
# Renderer mode -> (renderer module, renderer class)
|
|
renderers: dict[str, tuple[str, str]] = field(default_factory=dict)
|
|
|
|
def register(self, renderer_mode: str, module: str, class_name: str) -> None:
|
|
if renderer_mode in self.renderers:
|
|
logger.warning(
|
|
"%s.%s is already registered for renderer_mode=%r. "
|
|
"It is overwritten by the new one.",
|
|
module,
|
|
class_name,
|
|
renderer_mode,
|
|
)
|
|
|
|
self.renderers[renderer_mode] = (module, class_name)
|
|
|
|
return None
|
|
|
|
def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
|
|
if renderer_mode not in self.renderers:
|
|
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
|
|
|
|
module, class_name = self.renderers[renderer_mode]
|
|
logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")
|
|
|
|
return resolve_obj_by_qualname(f"{module}.{class_name}")
|
|
|
|
def load_renderer(
|
|
self,
|
|
renderer_mode: str,
|
|
config: "VllmConfig",
|
|
tokenizer_kwargs: dict[str, Any],
|
|
) -> BaseRenderer:
|
|
renderer_cls = self.load_renderer_cls(renderer_mode)
|
|
return renderer_cls.from_config(config, tokenizer_kwargs)
|
|
|
|
|
|
RENDERER_REGISTRY = RendererRegistry(
|
|
{
|
|
mode: (f"vllm.renderers.{mod_relname}", cls_name)
|
|
for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
|
|
}
|
|
)
|
|
"""The global `RendererRegistry` instance."""
|
|
|
|
|
|
def renderer_from_config(config: "VllmConfig", **kwargs):
|
|
model_config = config.model_config
|
|
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
|
model_config, **kwargs
|
|
)
|
|
|
|
if (
|
|
model_config.tokenizer_mode == "auto"
|
|
and model_config.model_impl == "terratorch"
|
|
):
|
|
renderer_mode = "terratorch"
|
|
else:
|
|
renderer_mode = tokenizer_mode
|
|
|
|
return RENDERER_REGISTRY.load_renderer(
|
|
renderer_mode,
|
|
config,
|
|
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
|
)
|