253 lines
9.4 KiB
Python
253 lines
9.4 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms.interface import Platform, PlatformEnum, _Backend
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
|
|
logger = init_logger(__name__)
|
|
if not envs.VLLM_USE_V1:
|
|
logger.warning(
|
|
"SUPAPlatform is only supported limited set of tasks in vLLM V0. "
|
|
"Please check if `VLLM_USE_V1=1` is set before launch.")
|
|
|
|
|
|
class SUPAPlatform(Platform):
|
|
|
|
_enum = PlatformEnum.OOT
|
|
device_name: str = "supa"
|
|
device_type: str = "supa"
|
|
dispatch_key: str = "PrivateUse1"
|
|
|
|
ray_device_key: str = "GPU"
|
|
dist_backend: str = "sccl"
|
|
device_control_env_var: str = "SUPA_VISIBLE_DEVICES"
|
|
# Environment variable to control weight type for SUPA,
|
|
# which will be copy from driver to workers in Ray distributed executor.
|
|
# NOTE: Related code: vllm/executor/ray_distributed_executor.py::_init_workers_ray
|
|
additional_env_vars: list[str] = [
|
|
"SUPA_WEIGHT_TYPE", "VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE"
|
|
]
|
|
|
|
supported_quantization: list[str] = []
|
|
|
|
_global_graph_pool: Optional[Any] = None
|
|
|
|
def is_supa(self) -> bool:
|
|
return True
|
|
|
|
def is_sleep_mode_available(self) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def supported_dtypes(self) -> list[torch.dtype]:
|
|
return [torch.bfloat16, torch.float32]
|
|
|
|
@classmethod
|
|
def set_device(cls, device: torch.device) -> None:
|
|
torch.supa.set_device(device)
|
|
|
|
@classmethod
|
|
def get_device_capability(cls, device_id: int = 0) -> None:
|
|
return None
|
|
|
|
@classmethod
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
return torch.supa.get_device_name(device_id)
|
|
|
|
@classmethod
|
|
def get_memory_stats(cls, device: torch.device, info_key: str) -> int:
|
|
return torch.cuda.memory_stats(device)[info_key]
|
|
|
|
@classmethod
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
if enforce_eager:
|
|
logger.warning(
|
|
"To see benefits of async output processing, enable SUPA "
|
|
"graph. Since, enforce-eager is enabled, async output "
|
|
"processor cannot be used")
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def inference_mode(cls) -> torch.inference_mode:
|
|
return torch.inference_mode()
|
|
|
|
@classmethod
|
|
def pre_register_and_update(cls,
|
|
parser: Optional[FlexibleArgumentParser] = None
|
|
) -> None:
|
|
if parser is not None:
|
|
for action in parser._actions:
|
|
opts = action.option_strings
|
|
if opts:
|
|
if opts[0] == "--block-size":
|
|
action.choices = [128]
|
|
elif opts[0] == "--device":
|
|
action.choices = ["auto", "supa"]
|
|
|
|
@classmethod
|
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
device_props = torch.supa.get_device_properties(device_id)
|
|
return device_props.total_memory
|
|
|
|
@classmethod
|
|
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
|
import vllm_br.envs as biren_envs
|
|
|
|
vllm_config.model_config.weight_type = biren_envs.VLLM_BR_WEIGHT_TYPE
|
|
logger.warning("update model with weight type %s for supa Matmul ops",
|
|
vllm_config.model_config.weight_type)
|
|
vllm_config.model_config.use_ds_mla = False
|
|
vllm_config.model_config.use_ds_mla_sparse = False
|
|
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
if parallel_config and parallel_config.worker_cls == "auto":
|
|
# TODO: remove this once support Multi-step scheduling on SUPA
|
|
# if scheduler_config.is_multi_step:
|
|
# raise NotImplementedError(
|
|
# "Multi-step scheduling is not supported (and not "
|
|
# "needed) on vLLM V1. Please launch without "
|
|
# "--num-scheduler-steps.")
|
|
if vllm_config.speculative_config:
|
|
if not envs.VLLM_USE_V1:
|
|
raise NotImplementedError(
|
|
"Speculative decoding is not supported on vLLM V0.")
|
|
parallel_config.worker_cls = "vllm_br.v1.worker.worker.SUPAWorker"
|
|
else:
|
|
if envs.VLLM_USE_V1:
|
|
parallel_config.worker_cls = \
|
|
"vllm_br.v1.worker.worker.SUPAWorker"
|
|
else:
|
|
parallel_config.worker_cls = \
|
|
"vllm_br.v0.worker.worker.SUPAWorker"
|
|
|
|
cache_config = vllm_config.cache_config
|
|
if cache_config:
|
|
if cache_config.block_size is None:
|
|
cache_config.block_size = 128
|
|
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
|
|
cache_config.block_size = 128
|
|
logger.warning(
|
|
"If prefix caching is enabled, block size must be set to 128, and the block size has been set to 128"
|
|
) # noqa: E501
|
|
|
|
compilation_config = vllm_config.compilation_config
|
|
# NOTE: always disable inductor for SUPA
|
|
if compilation_config.use_inductor:
|
|
logger.warning("Inductor is not supported for SUPA. Disabling it.")
|
|
compilation_config.use_inductor = False
|
|
logger.warning(
|
|
"Use %s compilation backend.",
|
|
'graph' if compilation_config.use_cudagraph else 'eager')
|
|
if compilation_config.use_cudagraph:
|
|
compilation_config.level = 0 # CompilationLevel.NO_COMPILATION
|
|
|
|
if vllm_config.model_config.quantization is not None:
|
|
UNSUPPORTED_QUANTIZATION = ['gptq_marlin']
|
|
if vllm_config.model_config.quantization in UNSUPPORTED_QUANTIZATION:
|
|
raise NotImplementedError(
|
|
f"Unsupported quantization {vllm_config.model_config.quantization} for SUPAPlatform."
|
|
)
|
|
|
|
@classmethod
|
|
def get_current_memory_usage(cls,
|
|
device: Optional[torch.types.Device] = None
|
|
) -> float:
|
|
torch.supa.reset_peak_memory_stats(device)
|
|
return torch.supa.max_memory_allocated(device)
|
|
|
|
# FIXME: this matters when serving vit models, need adaptation
|
|
@classmethod
|
|
def get_vit_attn_backend(cls, head_size: int,
|
|
dtype: torch.dtype) -> _Backend:
|
|
return _Backend.TORCH_SDPA
|
|
|
|
@classmethod
|
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
|
kv_cache_dtype, block_size, use_v1, use_mla,
|
|
has_sink, use_sparse):
|
|
if use_v1:
|
|
if use_mla:
|
|
if use_sparse:
|
|
logger.info_once("Using Sparse MLA backend on V1 engine.")
|
|
return (
|
|
"vllm_br.v1.attention.backends.mla.flashmla_sparse."
|
|
"SupaFlashMLASparseBackend")
|
|
return "vllm_br.v1.attention.backends.mla.flashmla.SupaFlashMLABackend" # noqa: E501
|
|
return "vllm_br.v1.attention.backends.attention_v1.SUPAFlashAttentionBackend" # noqa: E501
|
|
else:
|
|
return "vllm_br.v0.attention.backends.attention_v0.SUPAFlashAttentionBackend" # noqa: E501
|
|
|
|
@classmethod
|
|
def get_punica_wrapper(cls) -> str:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def get_device_communicator_cls(cls) -> str:
|
|
return "vllm_br.distributed.communicator.SUPACommunicator"
|
|
|
|
@classmethod
|
|
def supports_fp8(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def use_custom_allreduce(cls) -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def opaque_attention_op(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def get_static_graph_wrapper_cls(cls) -> str:
|
|
return "vllm_br.compilation.supa_graph.SUPAGraphWrapper"
|
|
|
|
@classmethod
|
|
def get_piecewise_backend_cls(cls) -> str:
|
|
return "vllm_br.compilation.supa_piecewise_backend.SUPAPiecewiseBackend" # noqa
|
|
|
|
def get_global_graph_pool(self) -> Any:
|
|
"""
|
|
Return the global graph pool for this platform.
|
|
"""
|
|
cls = self.__class__
|
|
if cls._global_graph_pool is None:
|
|
# TODO(liming): Check this handle is thread-safe.
|
|
cls._global_graph_pool = self.graph_pool_handle()
|
|
return cls._global_graph_pool
|
|
|
|
@classmethod
|
|
def opaque_attention_op(cls) -> bool:
|
|
"""
|
|
Returns True if we register attention as one giant opaque custom op
|
|
on the current platform
|
|
"""
|
|
return False
|
|
|
|
@classmethod
|
|
def support_static_graph_mode(cls) -> bool:
|
|
return True
|