################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from typing import TYPE_CHECKING, Any, Optional import torch import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms.interface import Platform, PlatformEnum, _Backend from vllm.utils import FlexibleArgumentParser if TYPE_CHECKING: from vllm.config import VllmConfig logger = init_logger(__name__) if not envs.VLLM_USE_V1: logger.warning( "SUPAPlatform is only supported limited set of tasks in vLLM V0. " "Please check if `VLLM_USE_V1=1` is set before launch.") class SUPAPlatform(Platform): _enum = PlatformEnum.OOT device_name: str = "supa" device_type: str = "supa" dispatch_key: str = "PrivateUse1" ray_device_key: str = "GPU" dist_backend: str = "sccl" device_control_env_var: str = "SUPA_VISIBLE_DEVICES" # Environment variable to control weight type for SUPA, # which will be copy from driver to workers in Ray distributed executor. # NOTE: Related code: vllm/executor/ray_distributed_executor.py::_init_workers_ray additional_env_vars: list[str] = [ "SUPA_WEIGHT_TYPE", "VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE" ] supported_quantization: list[str] = [] _global_graph_pool: Optional[Any] = None def is_supa(self) -> bool: return True def is_sleep_mode_available(self) -> bool: return False @property def supported_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16, torch.float32] @classmethod def set_device(cls, device: torch.device) -> None: torch.supa.set_device(device) @classmethod def get_device_capability(cls, device_id: int = 0) -> None: return None @classmethod def get_device_name(cls, device_id: int = 0) -> str: return torch.supa.get_device_name(device_id) @classmethod def get_memory_stats(cls, device: torch.device, info_key: str) -> int: return torch.cuda.memory_stats(device)[info_key] @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: if enforce_eager: logger.warning( "To see benefits of async output processing, enable SUPA " "graph. Since, enforce-eager is enabled, async output " "processor cannot be used") return False return True @classmethod def inference_mode(cls) -> torch.inference_mode: return torch.inference_mode() @classmethod def pre_register_and_update(cls, parser: Optional[FlexibleArgumentParser] = None ) -> None: if parser is not None: for action in parser._actions: opts = action.option_strings if opts: if opts[0] == "--block-size": action.choices = [128] elif opts[0] == "--device": action.choices = ["auto", "supa"] @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.supa.get_device_properties(device_id) return device_props.total_memory @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: import vllm_br.envs as biren_envs vllm_config.model_config.weight_type = biren_envs.VLLM_BR_WEIGHT_TYPE logger.warning("update model with weight type %s for supa Matmul ops", vllm_config.model_config.weight_type) vllm_config.model_config.use_ds_mla = False vllm_config.model_config.use_ds_mla_sparse = False parallel_config = vllm_config.parallel_config if parallel_config and parallel_config.worker_cls == "auto": # TODO: remove this once support Multi-step scheduling on SUPA # if scheduler_config.is_multi_step: # raise NotImplementedError( # "Multi-step scheduling is not supported (and not " # "needed) on vLLM V1. Please launch without " # "--num-scheduler-steps.") if vllm_config.speculative_config: if not envs.VLLM_USE_V1: raise NotImplementedError( "Speculative decoding is not supported on vLLM V0.") parallel_config.worker_cls = "vllm_br.v1.worker.worker.SUPAWorker" else: if envs.VLLM_USE_V1: parallel_config.worker_cls = \ "vllm_br.v1.worker.worker.SUPAWorker" else: parallel_config.worker_cls = \ "vllm_br.v0.worker.worker.SUPAWorker" cache_config = vllm_config.cache_config if cache_config: if cache_config.block_size is None: cache_config.block_size = 128 if cache_config.enable_prefix_caching and cache_config.block_size != 128: cache_config.block_size = 128 logger.warning( "If prefix caching is enabled, block size must be set to 128, and the block size has been set to 128" ) # noqa: E501 compilation_config = vllm_config.compilation_config # NOTE: always disable inductor for SUPA if compilation_config.use_inductor: logger.warning("Inductor is not supported for SUPA. Disabling it.") compilation_config.use_inductor = False logger.warning( "Use %s compilation backend.", 'graph' if compilation_config.use_cudagraph else 'eager') if compilation_config.use_cudagraph: compilation_config.level = 0 # CompilationLevel.NO_COMPILATION if vllm_config.model_config.quantization is not None: UNSUPPORTED_QUANTIZATION = ['gptq_marlin'] if vllm_config.model_config.quantization in UNSUPPORTED_QUANTIZATION: raise NotImplementedError( f"Unsupported quantization {vllm_config.model_config.quantization} for SUPAPlatform." ) @classmethod def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None ) -> float: torch.supa.reset_peak_memory_stats(device) return torch.supa.max_memory_allocated(device) # FIXME: this matters when serving vit models, need adaptation @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: return _Backend.TORCH_SDPA @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink, use_sparse): if use_v1: if use_mla: if use_sparse: logger.info_once("Using Sparse MLA backend on V1 engine.") return ( "vllm_br.v1.attention.backends.mla.flashmla_sparse." "SupaFlashMLASparseBackend") return "vllm_br.v1.attention.backends.mla.flashmla.SupaFlashMLABackend" # noqa: E501 return "vllm_br.v1.attention.backends.attention_v1.SUPAFlashAttentionBackend" # noqa: E501 else: return "vllm_br.v0.attention.backends.attention_v0.SUPAFlashAttentionBackend" # noqa: E501 @classmethod def get_punica_wrapper(cls) -> str: raise NotImplementedError @classmethod def get_device_communicator_cls(cls) -> str: return "vllm_br.distributed.communicator.SUPACommunicator" @classmethod def supports_fp8(cls) -> bool: return False @classmethod def use_custom_allreduce(cls) -> bool: return True @classmethod def opaque_attention_op(cls) -> bool: return False @classmethod def get_static_graph_wrapper_cls(cls) -> str: return "vllm_br.compilation.supa_graph.SUPAGraphWrapper" @classmethod def get_piecewise_backend_cls(cls) -> str: return "vllm_br.compilation.supa_piecewise_backend.SUPAPiecewiseBackend" # noqa def get_global_graph_pool(self) -> Any: """ Return the global graph pool for this platform. """ cls = self.__class__ if cls._global_graph_pool is None: # TODO(liming): Check this handle is thread-safe. cls._global_graph_pool = self.graph_pool_handle() return cls._global_graph_pool @classmethod def opaque_attention_op(cls) -> bool: """ Returns True if we register attention as one giant opaque custom op on the current platform """ return False @classmethod def support_static_graph_mode(cls) -> bool: return True