Sync from v0.13
This commit is contained in:
35
tests/model_executor/model_loader/test_registry.py
Normal file
35
tests/model_executor/model_loader/test_registry.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader import get_model_loader, register_model_loader
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
|
||||
|
||||
@register_model_loader("custom_load_format")
|
||||
class CustomModelLoader(BaseModelLoader):
|
||||
def __init__(self, load_config: LoadConfig) -> None:
|
||||
super().__init__(load_config)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_register_model_loader():
|
||||
load_config = LoadConfig(load_format="custom_load_format")
|
||||
assert isinstance(get_model_loader(load_config), CustomModelLoader)
|
||||
|
||||
|
||||
def test_invalid_model_loader():
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@register_model_loader("invalid_load_format")
|
||||
class InValidModelLoader:
|
||||
pass
|
||||
Reference in New Issue
Block a user