Files
enginex-biren-vllm/vllm_br/platform.py
2026-03-10 13:31:25 +08:00

253 lines
9.4 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import TYPE_CHECKING, Any, Optional
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms.interface import Platform, PlatformEnum, _Backend
from vllm.utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
if not envs.VLLM_USE_V1:
logger.warning(
"SUPAPlatform is only supported limited set of tasks in vLLM V0. "
"Please check if `VLLM_USE_V1=1` is set before launch.")
class SUPAPlatform(Platform):
_enum = PlatformEnum.OOT
device_name: str = "supa"
device_type: str = "supa"
dispatch_key: str = "PrivateUse1"
ray_device_key: str = "GPU"
dist_backend: str = "sccl"
device_control_env_var: str = "SUPA_VISIBLE_DEVICES"
# Environment variable to control weight type for SUPA,
# which will be copy from driver to workers in Ray distributed executor.
# NOTE: Related code: vllm/executor/ray_distributed_executor.py::_init_workers_ray
additional_env_vars: list[str] = [
"SUPA_WEIGHT_TYPE", "VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE"
]
supported_quantization: list[str] = []
_global_graph_pool: Optional[Any] = None
def is_supa(self) -> bool:
return True
def is_sleep_mode_available(self) -> bool:
return False
@property
def supported_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16, torch.float32]
@classmethod
def set_device(cls, device: torch.device) -> None:
torch.supa.set_device(device)
@classmethod
def get_device_capability(cls, device_id: int = 0) -> None:
return None
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.supa.get_device_name(device_id)
@classmethod
def get_memory_stats(cls, device: torch.device, info_key: str) -> int:
return torch.cuda.memory_stats(device)[info_key]
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
if enforce_eager:
logger.warning(
"To see benefits of async output processing, enable SUPA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
return False
return True
@classmethod
def inference_mode(cls) -> torch.inference_mode:
return torch.inference_mode()
@classmethod
def pre_register_and_update(cls,
parser: Optional[FlexibleArgumentParser] = None
) -> None:
if parser is not None:
for action in parser._actions:
opts = action.option_strings
if opts:
if opts[0] == "--block-size":
action.choices = [128]
elif opts[0] == "--device":
action.choices = ["auto", "supa"]
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.supa.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
import vllm_br.envs as biren_envs
vllm_config.model_config.weight_type = biren_envs.VLLM_BR_WEIGHT_TYPE
logger.warning("update model with weight type %s for supa Matmul ops",
vllm_config.model_config.weight_type)
vllm_config.model_config.use_ds_mla = False
vllm_config.model_config.use_ds_mla_sparse = False
parallel_config = vllm_config.parallel_config
if parallel_config and parallel_config.worker_cls == "auto":
# TODO: remove this once support Multi-step scheduling on SUPA
# if scheduler_config.is_multi_step:
# raise NotImplementedError(
# "Multi-step scheduling is not supported (and not "
# "needed) on vLLM V1. Please launch without "
# "--num-scheduler-steps.")
if vllm_config.speculative_config:
if not envs.VLLM_USE_V1:
raise NotImplementedError(
"Speculative decoding is not supported on vLLM V0.")
parallel_config.worker_cls = "vllm_br.v1.worker.worker.SUPAWorker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm_br.v1.worker.worker.SUPAWorker"
else:
parallel_config.worker_cls = \
"vllm_br.v0.worker.worker.SUPAWorker"
cache_config = vllm_config.cache_config
if cache_config:
if cache_config.block_size is None:
cache_config.block_size = 128
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
cache_config.block_size = 128
logger.warning(
"If prefix caching is enabled, block size must be set to 128, and the block size has been set to 128"
) # noqa: E501
compilation_config = vllm_config.compilation_config
# NOTE: always disable inductor for SUPA
if compilation_config.use_inductor:
logger.warning("Inductor is not supported for SUPA. Disabling it.")
compilation_config.use_inductor = False
logger.warning(
"Use %s compilation backend.",
'graph' if compilation_config.use_cudagraph else 'eager')
if compilation_config.use_cudagraph:
compilation_config.level = 0 # CompilationLevel.NO_COMPILATION
if vllm_config.model_config.quantization is not None:
UNSUPPORTED_QUANTIZATION = ['gptq_marlin']
if vllm_config.model_config.quantization in UNSUPPORTED_QUANTIZATION:
raise NotImplementedError(
f"Unsupported quantization {vllm_config.model_config.quantization} for SUPAPlatform."
)
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.supa.reset_peak_memory_stats(device)
return torch.supa.max_memory_allocated(device)
# FIXME: this matters when serving vit models, need adaptation
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend:
return _Backend.TORCH_SDPA
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse):
if use_v1:
if use_mla:
if use_sparse:
logger.info_once("Using Sparse MLA backend on V1 engine.")
return (
"vllm_br.v1.attention.backends.mla.flashmla_sparse."
"SupaFlashMLASparseBackend")
return "vllm_br.v1.attention.backends.mla.flashmla.SupaFlashMLABackend" # noqa: E501
return "vllm_br.v1.attention.backends.attention_v1.SUPAFlashAttentionBackend" # noqa: E501
else:
return "vllm_br.v0.attention.backends.attention_v0.SUPAFlashAttentionBackend" # noqa: E501
@classmethod
def get_punica_wrapper(cls) -> str:
raise NotImplementedError
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm_br.distributed.communicator.SUPACommunicator"
@classmethod
def supports_fp8(cls) -> bool:
return False
@classmethod
def use_custom_allreduce(cls) -> bool:
return True
@classmethod
def opaque_attention_op(cls) -> bool:
return False
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm_br.compilation.supa_graph.SUPAGraphWrapper"
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm_br.compilation.supa_piecewise_backend.SUPAPiecewiseBackend" # noqa
def get_global_graph_pool(self) -> Any:
"""
Return the global graph pool for this platform.
"""
cls = self.__class__
if cls._global_graph_pool is None:
# TODO(liming): Check this handle is thread-safe.
cls._global_graph_pool = self.graph_pool_handle()
return cls._global_graph_pool
@classmethod
def opaque_attention_op(cls) -> bool:
"""
Returns True if we register attention as one giant opaque custom op
on the current platform
"""
return False
@classmethod
def support_static_graph_mode(cls) -> bool:
return True