update
This commit is contained in:
92
vllm/renderers/registry.py
Normal file
92
vllm/renderers/registry.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
from .base import BaseRenderer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_VLLM_RENDERERS = {
|
||||
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
|
||||
"hf": ("hf", "HfRenderer"),
|
||||
"grok2": ("grok2", "Grok2Renderer"),
|
||||
"mistral": ("mistral", "MistralRenderer"),
|
||||
"terratorch": ("terratorch", "TerratorchRenderer"),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RendererRegistry:
|
||||
# Renderer mode -> (renderer module, renderer class)
|
||||
renderers: dict[str, tuple[str, str]] = field(default_factory=dict)
|
||||
|
||||
def register(self, renderer_mode: str, module: str, class_name: str) -> None:
|
||||
if renderer_mode in self.renderers:
|
||||
logger.warning(
|
||||
"%s.%s is already registered for renderer_mode=%r. "
|
||||
"It is overwritten by the new one.",
|
||||
module,
|
||||
class_name,
|
||||
renderer_mode,
|
||||
)
|
||||
|
||||
self.renderers[renderer_mode] = (module, class_name)
|
||||
|
||||
return None
|
||||
|
||||
def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
|
||||
if renderer_mode not in self.renderers:
|
||||
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
|
||||
|
||||
module, class_name = self.renderers[renderer_mode]
|
||||
logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")
|
||||
|
||||
return resolve_obj_by_qualname(f"{module}.{class_name}")
|
||||
|
||||
def load_renderer(
|
||||
self,
|
||||
renderer_mode: str,
|
||||
config: "VllmConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> BaseRenderer:
|
||||
renderer_cls = self.load_renderer_cls(renderer_mode)
|
||||
return renderer_cls.from_config(config, tokenizer_kwargs)
|
||||
|
||||
|
||||
RENDERER_REGISTRY = RendererRegistry(
|
||||
{
|
||||
mode: (f"vllm.renderers.{mod_relname}", cls_name)
|
||||
for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
|
||||
}
|
||||
)
|
||||
"""The global `RendererRegistry` instance."""
|
||||
|
||||
|
||||
def renderer_from_config(config: "VllmConfig", **kwargs):
|
||||
model_config = config.model_config
|
||||
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
||||
model_config, **kwargs
|
||||
)
|
||||
|
||||
if (
|
||||
model_config.tokenizer_mode == "auto"
|
||||
and model_config.model_impl == "terratorch"
|
||||
):
|
||||
renderer_mode = "terratorch"
|
||||
else:
|
||||
renderer_mode = tokenizer_mode
|
||||
|
||||
return RENDERER_REGISTRY.load_renderer(
|
||||
renderer_mode,
|
||||
config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
Reference in New Issue
Block a user