Files
enginex-mlu590-vllm/vllm_mlu/model_executor/models/registry.py
2026-04-24 09:58:03 +08:00

81 lines
2.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Type, Union
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.models.registry import (
_LazyRegisteredModel, _RegisteredModel, _ModelRegistry)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__model_executor__models__registry___ModelRegistry__register_model(
self,
model_arch: str,
model_cls: Union[type[nn.Module], str],
) -> None:
"""
Register an external model to be used in vLLM.
`model_cls` can be either:
- A [`torch.nn.Module`][] class directly referencing the model.
- A string in the format `<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if not isinstance(model_arch, str):
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
raise TypeError(msg)
'''
=============================
Modify by vllm_mlu
=============================
@brief: change mlu models register log level
'''
if model_arch in self.models:
if isinstance(model_cls, str) and "MLU" in model_cls:
logger.debug(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls)
else:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls)
'''
==================
End of MLU Hijack
==================
'''
if isinstance(model_cls, str):
split_str = model_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
model = _LazyRegisteredModel(*split_str)
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
model = _RegisteredModel.from_model_cls(model_cls)
else:
msg = ("`model_cls` should be a string or PyTorch model class, "
f"not a {type(model_arch)}")
raise TypeError(msg)
self.models[model_arch] = model
MluHijackObject.apply_hijack(
_ModelRegistry,
_ModelRegistry.register_model,
vllm__model_executor__models__registry___ModelRegistry__register_model
)