Sync from v0.13
This commit is contained in:
18
tests/plugins/vllm_add_dummy_platform/setup.py
Normal file
18
tests/plugins/vllm_add_dummy_platform/setup.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="vllm_add_dummy_platform",
|
||||
version="0.1",
|
||||
packages=["vllm_add_dummy_platform"],
|
||||
entry_points={
|
||||
"vllm.platform_plugins": [
|
||||
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
|
||||
],
|
||||
"vllm.general_plugins": [
|
||||
"dummy_custom_ops = vllm_add_dummy_platform:register_ops"
|
||||
],
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
def dummy_platform_plugin() -> str | None:
|
||||
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
|
||||
|
||||
|
||||
def register_ops():
|
||||
import vllm_add_dummy_platform.dummy_custom_ops # noqa
|
||||
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.attention.backends.placeholder_attn import PlaceholderAttentionBackend
|
||||
|
||||
|
||||
class DummyAttentionBackend(PlaceholderAttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Dummy_Backend"
|
||||
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
# Register CustomRotaryEmbedding to CustomOP.
|
||||
@RotaryEmbedding.register_oot
|
||||
class DummyRotaryEmbedding(RotaryEmbedding):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.addition_config = True
|
||||
|
||||
def forward_oot(self, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return super().forward_oot(*args, **kwargs)
|
||||
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.platforms.interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
class DummyPlatform(Platform):
|
||||
_enum = PlatformEnum.OOT
|
||||
device_name = "DummyDevice"
|
||||
device_type: str = "privateuseone"
|
||||
dispatch_key: str = "PrivateUse1"
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
vllm_config.compilation_config.custom_ops = ["all"]
|
||||
|
||||
def get_attn_backend_cls(
|
||||
self,
|
||||
backend_name,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
Reference in New Issue
Block a user