[Model] Support DeepSeek-V4
This commit is contained in:
81
vllm_mlu/model_executor/models/registry.py
Normal file
81
vllm_mlu/model_executor/models/registry.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Type, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.registry import (
|
||||
_LazyRegisteredModel, _RegisteredModel, _ModelRegistry)
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__model_executor__models__registry___ModelRegistry__register_model(
|
||||
self,
|
||||
model_arch: str,
|
||||
model_cls: Union[type[nn.Module], str],
|
||||
) -> None:
|
||||
"""
|
||||
Register an external model to be used in vLLM.
|
||||
|
||||
`model_cls` can be either:
|
||||
|
||||
- A [`torch.nn.Module`][] class directly referencing the model.
|
||||
- A string in the format `<module>:<class>` which can be used to
|
||||
lazily import the model. This is useful to avoid initializing CUDA
|
||||
when importing the model and thus the related error
|
||||
`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
||||
"""
|
||||
if not isinstance(model_arch, str):
|
||||
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: change mlu models register log level
|
||||
'''
|
||||
if model_arch in self.models:
|
||||
if isinstance(model_cls, str) and "MLU" in model_cls:
|
||||
logger.debug(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls)
|
||||
else:
|
||||
logger.warning(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if isinstance(model_cls, str):
|
||||
split_str = model_cls.split(":")
|
||||
if len(split_str) != 2:
|
||||
msg = "Expected a string in the format `<module>:<class>`"
|
||||
raise ValueError(msg)
|
||||
|
||||
model = _LazyRegisteredModel(*split_str)
|
||||
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
|
||||
model = _RegisteredModel.from_model_cls(model_cls)
|
||||
else:
|
||||
msg = ("`model_cls` should be a string or PyTorch model class, "
|
||||
f"not a {type(model_arch)}")
|
||||
raise TypeError(msg)
|
||||
|
||||
self.models[model_arch] = model
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
_ModelRegistry,
|
||||
_ModelRegistry.register_model,
|
||||
vllm__model_executor__models__registry___ModelRegistry__register_model
|
||||
)
|
||||
Reference in New Issue
Block a user