Sync from v0.13
This commit is contained in:
13
tests/plugins/vllm_add_dummy_model/setup.py
Normal file
13
tests/plugins/vllm_add_dummy_model/setup.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="vllm_add_dummy_model",
|
||||
version="0.1",
|
||||
packages=["vllm_add_dummy_model"],
|
||||
entry_points={
|
||||
"vllm.general_plugins": ["register_dummy_model = vllm_add_dummy_model:register"]
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import ModelRegistry
|
||||
|
||||
|
||||
def register():
|
||||
# Test directly passing the model
|
||||
from .my_opt import MyOPTForCausalLM
|
||||
|
||||
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
|
||||
|
||||
# Test passing lazy model
|
||||
if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model(
|
||||
"MyGemma2Embedding",
|
||||
"vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding",
|
||||
)
|
||||
|
||||
if "MyLlava" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model("MyLlava", "vllm_add_dummy_model.my_llava:MyLlava")
|
||||
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class MyGemma2Embedding(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.model = Gemma2Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
if isinstance(hidden_states, IntermediateTensors):
|
||||
return hidden_states
|
||||
|
||||
# Return all-zero embeddings
|
||||
return torch.zeros_like(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
weights = (
|
||||
(name, data) for name, data in weights if not name.startswith("lm_head.")
|
||||
)
|
||||
return self.model.load_weights(weights)
|
||||
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.llava import (
|
||||
LlavaDummyInputsBuilder,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaMultiModalProcessor,
|
||||
LlavaProcessingInfo,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
LlavaMultiModalProcessor,
|
||||
info=LlavaProcessingInfo,
|
||||
dummy_inputs=LlavaDummyInputsBuilder,
|
||||
)
|
||||
class MyLlava(LlavaForConditionalGeneration):
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states)
|
||||
if logits is not None:
|
||||
logits.zero_()
|
||||
logits[:, 0] += 1.0
|
||||
return logits
|
||||
@@ -0,0 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states)
|
||||
if logits is not None:
|
||||
logits.zero_()
|
||||
logits[:, 0] += 1.0
|
||||
return logits
|
||||
Reference in New Issue
Block a user