Files
2026-04-02 04:55:00 +00:00

163 lines
6.6 KiB
Python

import os
from typing import TYPE_CHECKING, Optional
import torch
from vllm.logger import init_logger
# from .interface import Platform, PlatformEnum, _Backend
from vllm.platforms.interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig,ModelConfig
else:
VllmConfig = None
ModelConfig = None
logger = init_logger(__name__)
class VaccPlatform(Platform):
try:
import torch_vacc
is_vacc = True
except Exception as e:
assert False, f"error import torch_vacc: {e}"
_enum = PlatformEnum.OOT
device_name: str = "vacc"
device_type: str = "vacc"
dispatch_key: str = "PrivateUse1"
ray_device_key: str = "GPU"
device_control_env_var: str = "VACC_VISIBLE_MODULES"
simple_compile_backend: str = "eager" # Disable torch.compile()
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool, has_sink: bool, use_sparse: bool) -> str:
if use_mla:
logger.info("Using VACCMLA backend.")
if use_v1:
return "vllm_vacc.vllm.v1.attention.backends.vacc_mla.VACCMLABackend"
return "vllm_vacc.vllm.attention.backends.vacc_mla.VACCMLABackend"
if use_v1:
return "vllm_vacc.vllm.v1.attention.backends.vacc_attn.VACCAttentionBackend"
else:
logger.info("Using VACCAttention backend.")
return "vllm_vacc.vllm.attention.backends.vacc_attn.VACCAttentionBackend"
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@staticmethod
def inference_mode():
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
import vllm.envs as envs
if vllm_config.kv_transfer_config:
raise NotImplementedError("kv-transfer-config is not implemented for VACC")
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
if ((scheduler_config.chunked_prefill_enabled
or cache_config.enable_prefix_caching)
and cache_config.cache_dtype != "auto"):
raise RuntimeError("Chunked-prefill and prefix-cache on the Vacc "
"backend is not compatible with FP8 KV cache.")
# scheduling_polity = scheduler_config.policy
# model_config = vllm_config.model_config
# use_async_output_proc = model_config.use_async_output_proc
# if scheduling_polity == "priority" and use_async_output_proc: # probably a bug
# logger.warning("WARNING scheduling_polity priority is not fully supported for VACC, "
# "use fcfs instead automatically")
# vllm_config.scheduler_config.scheduling_polity = "fcfs"
# if vllm_config.speculative_config is not None:
# raise NotImplementedError(
# "Speculative decoding is not implemented for VACC")
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if vllm_config.speculative_config:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = "vllm_vacc.vllm.v1.worker.vacc_worker.VACCWorker"
else:
parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = "vllm_vacc.vllm.worker.vacc_worker.VACCWorker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm_vacc.vllm.v1.worker.vacc_worker.VACCWorker"
print('v1 VACCWorker')
else:
parallel_config.worker_cls = \
"vllm_vacc.vllm.worker.vacc_worker.VACCWorker"
# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
cache_config = vllm_config.cache_config
if cache_config and cache_config.gpu_memory_utilization:
logger.warning("WARNING gpu_memory_utilization is not supported on VACC")
# if cache_config and cache_config.enable_prefix_caching:
# raise NotImplementedError("Prefix-caching is not implemented for VACC")
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
if (parallel_config.distributed_executor_backend == 'mp'
and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD",
None) is not None:
logger.warning("On VACC, VLLM_WORKER_MULTIPROC_METHOD=fork "
"might cause application hangs on exit. Using "
"VLLM_WORKER_MULTIPROC_METHOD=fork anyway, "
"as it was explicitly requested.")
else:
logger.warning(
"On VACC, VLLM_WORKER_MULTIPROC_METHOD=fork "
"might cause application hangs on exit. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"To override that behavior, please set "
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on VACC.")
return False
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm_vacc.vllm.lora.punica_wrapper.punica_vacc.PunicaWrapperVACC"
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.vacc.reset_peak_memory_stats(device)
return torch.vacc.max_memory_allocated(device)
@classmethod
def use_all_gather(cls) -> bool:
"""
Whether to use allgather in LogitsProcessor to gather the logits.
"""
return True
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
# return False # or export VLLM_USE_V1=0 to use v0
if os.getenv("VLLM_USE_V1", 1) == '0':
return False
return True