forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
495
vllm-v0.6.2/vllm/model_executor/models/registry.py
Normal file
495
vllm-v0.6.2/vllm/model_executor/models/registry.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
Whenever you add an architecture to this page, please also update
|
||||
`tests/models/registry.py` with example HuggingFace models for it.
|
||||
"""
|
||||
import importlib
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_multimodal, supports_pp)
|
||||
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
_TEXT_GENERATION_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||
# baichuan-7b, upper case 'C' in the class name
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
|
||||
# baichuan-13b, lower case 'c' in the class name
|
||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
|
||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||
# ChatGLMModel supports multimodal
|
||||
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
||||
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
||||
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
|
||||
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
|
||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
||||
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
|
||||
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
|
||||
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
||||
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
|
||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
|
||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||
# QWenLMHeadModel supports multimodal
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||
# [Encoder-decoder]
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||
"HunYuanForCausalLM": ("hunyuan", "HunYuanForCausalLM"),
|
||||
}
|
||||
|
||||
_EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
|
||||
**{
|
||||
# Multiple models share the same architecture, so we include them all
|
||||
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
|
||||
if arch == "LlamaForCausalLM"
|
||||
},
|
||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
# [Decoder-only]
|
||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
|
||||
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
||||
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
||||
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
||||
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
# [Encoder-decoder]
|
||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||
}
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
_VLLM_MODELS = {
|
||||
**_TEXT_GENERATION_MODELS,
|
||||
**_EMBEDDING_MODELS,
|
||||
**_MULTIMODAL_MODELS,
|
||||
**_SPECULATIVE_DECODING_MODELS,
|
||||
}
|
||||
|
||||
# Models not supported by ROCm.
|
||||
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
||||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
|
||||
"Triton flash attention. For half-precision SWA support, "
|
||||
"please use CK flash attention by setting "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
"Qwen2ForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"MistralForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"MixtralForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"PaliGemmaForConditionalGeneration":
|
||||
("ROCm flash attention does not yet "
|
||||
"fully support 32-bit precision on PaliGemma"),
|
||||
"Phi3VForCausalLM":
|
||||
("ROCm Triton flash attention may run into compilation errors due to "
|
||||
"excessive use of shared memory. If this happens, disable Triton FA "
|
||||
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ModelInfo:
|
||||
is_text_generation_model: bool
|
||||
is_embedding_model: bool
|
||||
supports_multimodal: bool
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
is_attention_free: bool
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
return _ModelInfo(
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_embedding_model=is_embedding_model(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
has_inner_state=has_inner_state(model),
|
||||
is_attention_free=is_attention_free(model),
|
||||
)
|
||||
|
||||
|
||||
class _BaseRegisteredModel(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def inspect_model_cls(self) -> _ModelInfo:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_model_cls(self) -> Type[nn.Module]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RegisteredModel(_BaseRegisteredModel):
|
||||
"""
|
||||
Represents a model that has already been imported in the main process.
|
||||
"""
|
||||
|
||||
interfaces: _ModelInfo
|
||||
model_cls: Type[nn.Module]
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model_cls: Type[nn.Module]):
|
||||
return _RegisteredModel(
|
||||
interfaces=_ModelInfo.from_model_cls(model_cls),
|
||||
model_cls=model_cls,
|
||||
)
|
||||
|
||||
def inspect_model_cls(self) -> _ModelInfo:
|
||||
return self.interfaces
|
||||
|
||||
def load_model_cls(self) -> Type[nn.Module]:
|
||||
return self.model_cls
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||
"""
|
||||
Represents a model that has not been imported in the main process.
|
||||
"""
|
||||
module_name: str
|
||||
class_name: str
|
||||
|
||||
# Performed in another process to avoid initializing CUDA
|
||||
def inspect_model_cls(self) -> _ModelInfo:
|
||||
return _run_in_subprocess(
|
||||
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
|
||||
|
||||
def load_model_cls(self) -> Type[nn.Module]:
|
||||
mod = importlib.import_module(self.module_name)
|
||||
return getattr(mod, self.class_name)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _try_load_model_cls(
|
||||
model_arch: str,
|
||||
model: _BaseRegisteredModel,
|
||||
) -> Optional[Type[nn.Module]]:
|
||||
if current_platform.is_rocm():
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(f"Model architecture '{model_arch}' is not "
|
||||
"supported by ROCm for now.")
|
||||
|
||||
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
|
||||
logger.warning(
|
||||
"Model architecture '%s' is partially "
|
||||
"supported by ROCm: %s", model_arch, msg)
|
||||
|
||||
try:
|
||||
return model.load_model_cls()
|
||||
except Exception:
|
||||
logger.exception("Error in loading model architecture '%s'",
|
||||
model_arch)
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _try_inspect_model_cls(
|
||||
model_arch: str,
|
||||
model: _BaseRegisteredModel,
|
||||
) -> Optional[_ModelInfo]:
|
||||
try:
|
||||
return model.inspect_model_cls()
|
||||
except Exception:
|
||||
logger.exception("Error in inspecting model architecture '%s'",
|
||||
model_arch)
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ModelRegistry:
|
||||
# Keyed by model_arch
|
||||
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
|
||||
|
||||
def get_supported_archs(self) -> AbstractSet[str]:
|
||||
return self.models.keys()
|
||||
|
||||
def register_model(
|
||||
self,
|
||||
model_arch: str,
|
||||
model_cls: Union[Type[nn.Module], str],
|
||||
) -> None:
|
||||
"""
|
||||
Register an external model to be used in vLLM.
|
||||
|
||||
:code:`model_cls` can be either:
|
||||
|
||||
- A :class:`torch.nn.Module` class directly referencing the model.
|
||||
- A string in the format :code:`<module>:<class>` which can be used to
|
||||
lazily import the model. This is useful to avoid initializing CUDA
|
||||
when importing the model and thus the related error
|
||||
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
||||
"""
|
||||
if model_arch in self.models:
|
||||
logger.warning(
|
||||
"Model architecture %s is already registered, and will be "
|
||||
"overwritten by the new model class %s.", model_arch,
|
||||
model_cls)
|
||||
|
||||
if isinstance(model_cls, str):
|
||||
split_str = model_cls.split(":")
|
||||
if len(split_str) != 2:
|
||||
msg = "Expected a string in the format `<module>:<class>`"
|
||||
raise ValueError(msg)
|
||||
|
||||
model = _LazyRegisteredModel(*split_str)
|
||||
else:
|
||||
model = _RegisteredModel.from_model_cls(model_cls)
|
||||
|
||||
self.models[model_arch] = model
|
||||
|
||||
def _raise_for_unsupported(self, architectures: List[str]):
|
||||
all_supported_archs = self.get_supported_archs()
|
||||
|
||||
if any(arch in all_supported_archs for arch in architectures):
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} failed "
|
||||
"to be inspected. Please check the logs for more details.")
|
||||
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {all_supported_archs}")
|
||||
|
||||
def _try_load_model_cls(self,
|
||||
model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch not in self.models:
|
||||
return None
|
||||
|
||||
return _try_load_model_cls(model_arch, self.models[model_arch])
|
||||
|
||||
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
|
||||
if model_arch not in self.models:
|
||||
return None
|
||||
|
||||
return _try_inspect_model_cls(model_arch, self.models[model_arch])
|
||||
|
||||
def _normalize_archs(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> List[str]:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
return architectures
|
||||
|
||||
def inspect_model_cls(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> _ModelInfo:
|
||||
architectures = self._normalize_archs(architectures)
|
||||
|
||||
for arch in architectures:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
return model_info
|
||||
|
||||
return self._raise_for_unsupported(architectures)
|
||||
|
||||
def resolve_model_cls(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> Tuple[Type[nn.Module], str]:
|
||||
architectures = self._normalize_archs(architectures)
|
||||
|
||||
for arch in architectures:
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
|
||||
return self._raise_for_unsupported(architectures)
|
||||
|
||||
def is_text_generation_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_text_generation_model
|
||||
|
||||
def is_embedding_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_embedding_model
|
||||
|
||||
def is_multimodal_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_multimodal
|
||||
|
||||
def is_pp_supported_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_pp
|
||||
|
||||
def model_has_inner_state(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).has_inner_state
|
||||
|
||||
def is_attention_free_model(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_attention_free
|
||||
|
||||
|
||||
ModelRegistry = _ModelRegistry({
|
||||
model_arch: _LazyRegisteredModel(
|
||||
module_name=f"vllm.model_executor.models.{mod_relname}",
|
||||
class_name=cls_name,
|
||||
)
|
||||
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
|
||||
})
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
|
||||
# NOTE: We use a temporary directory instead of a temporary file to avoid
|
||||
# issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
output_filepath = os.path.join(tempdir, "registry_output.tmp")
|
||||
|
||||
# `cloudpickle` allows pickling lambda functions directly
|
||||
input_bytes = cloudpickle.dumps((fn, output_filepath))
|
||||
|
||||
# cannot use `sys.executable __file__` here because the script
|
||||
# contains relative imports
|
||||
returned = subprocess.run(
|
||||
[sys.executable, "-m", "vllm.model_executor.models.registry"],
|
||||
input=input_bytes,
|
||||
capture_output=True)
|
||||
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(f"Error raised in subprocess:\n"
|
||||
f"{returned.stderr.decode()}") from e
|
||||
|
||||
with open(output_filepath, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def _run() -> None:
|
||||
# Setup plugins
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
fn, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||
|
||||
result = fn()
|
||||
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(pickle.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_run()
|
||||
Reference in New Issue
Block a user