# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from typing import Type, Union import torch.nn as nn from vllm.logger import init_logger from vllm.model_executor.models.registry import ( _LazyRegisteredModel, _RegisteredModel, _ModelRegistry) from vllm_mlu.mlu_hijack_utils import MluHijackObject logger = init_logger(__name__) def vllm__model_executor__models__registry___ModelRegistry__register_model( self, model_arch: str, model_cls: Union[type[nn.Module], str], ) -> None: """ Register an external model to be used in vLLM. `model_cls` can be either: - A [`torch.nn.Module`][] class directly referencing the model. - A string in the format `:` which can be used to lazily import the model. This is useful to avoid initializing CUDA when importing the model and thus the related error `RuntimeError: Cannot re-initialize CUDA in forked subprocess`. """ if not isinstance(model_arch, str): msg = f"`model_arch` should be a string, not a {type(model_arch)}" raise TypeError(msg) ''' ============================= Modify by vllm_mlu ============================= @brief: change mlu models register log level ''' if model_arch in self.models: if isinstance(model_cls, str) and "MLU" in model_cls: logger.debug( "Model architecture %s is already registered, and will be " "overwritten by the new model class %s.", model_arch, model_cls) else: logger.warning( "Model architecture %s is already registered, and will be " "overwritten by the new model class %s.", model_arch, model_cls) ''' ================== End of MLU Hijack ================== ''' if isinstance(model_cls, str): split_str = model_cls.split(":") if len(split_str) != 2: msg = "Expected a string in the format `:`" raise ValueError(msg) model = _LazyRegisteredModel(*split_str) elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module): model = _RegisteredModel.from_model_cls(model_cls) else: msg = ("`model_cls` should be a string or PyTorch model class, " f"not a {type(model_arch)}") raise TypeError(msg) self.models[model_arch] = model MluHijackObject.apply_hijack( _ModelRegistry, _ModelRegistry.register_model, vllm__model_executor__models__registry___ModelRegistry__register_model )