81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from typing import Type, Union
|
|
|
|
import torch.nn as nn
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.models.registry import (
|
|
_LazyRegisteredModel, _RegisteredModel, _ModelRegistry)
|
|
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def vllm__model_executor__models__registry___ModelRegistry__register_model(
|
|
self,
|
|
model_arch: str,
|
|
model_cls: Union[type[nn.Module], str],
|
|
) -> None:
|
|
"""
|
|
Register an external model to be used in vLLM.
|
|
|
|
`model_cls` can be either:
|
|
|
|
- A [`torch.nn.Module`][] class directly referencing the model.
|
|
- A string in the format `<module>:<class>` which can be used to
|
|
lazily import the model. This is useful to avoid initializing CUDA
|
|
when importing the model and thus the related error
|
|
`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
|
"""
|
|
if not isinstance(model_arch, str):
|
|
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
|
|
raise TypeError(msg)
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: change mlu models register log level
|
|
'''
|
|
if model_arch in self.models:
|
|
if isinstance(model_cls, str) and "MLU" in model_cls:
|
|
logger.debug(
|
|
"Model architecture %s is already registered, and will be "
|
|
"overwritten by the new model class %s.", model_arch,
|
|
model_cls)
|
|
else:
|
|
logger.warning(
|
|
"Model architecture %s is already registered, and will be "
|
|
"overwritten by the new model class %s.", model_arch,
|
|
model_cls)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
if isinstance(model_cls, str):
|
|
split_str = model_cls.split(":")
|
|
if len(split_str) != 2:
|
|
msg = "Expected a string in the format `<module>:<class>`"
|
|
raise ValueError(msg)
|
|
|
|
model = _LazyRegisteredModel(*split_str)
|
|
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
|
|
model = _RegisteredModel.from_model_cls(model_cls)
|
|
else:
|
|
msg = ("`model_cls` should be a string or PyTorch model class, "
|
|
f"not a {type(model_arch)}")
|
|
raise TypeError(msg)
|
|
|
|
self.models[model_arch] = model
|
|
|
|
|
|
MluHijackObject.apply_hijack(
|
|
_ModelRegistry,
|
|
_ModelRegistry.register_model,
|
|
vllm__model_executor__models__registry___ModelRegistry__register_model
|
|
) |