Files
2026-04-24 09:58:03 +08:00

305 lines
11 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import contextlib
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, Tuple
import os
import torch
from vllm.logger import init_logger
import vllm.envs as envs
from vllm.platforms.interface import (
DeviceCapability,
Platform,
PlatformEnum,
)
import vllm_mlu._mlu_utils as mlu_envs
from vllm_mlu.logger import logger
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig, VllmConfig
from vllm.config.cache import CacheDType
from vllm.utils.argparse_utils import FlexibleArgumentParser
else:
FlexibleArgumentParser = object
envs.environment_variables.update({
"MLU_VISIBLE_DEVICES":
lambda: os.environ.get("MLU_VISIBLE_DEVICES", None)
})
logger = init_logger(__name__)
class MLUPlatform(Platform):
_enum = PlatformEnum.OOT
device_name: str = "mlu"
device_type: str = "mlu"
dispatch_key: str = "MLU"
ray_device_key: str = "GPU"
device_control_env_var: str = "MLU_VISIBLE_DEVICES"
simple_compile_backend: str = "inductor"
dist_backend: str = "cncl"
supported_quantization: list[str] = ["weightonly", "smoothquant",
"awq_mlu", "gptq_mlu", "fp8"]
additional_env_vars: list[str] = ["VLLM_LATENCY_DEBUG",
"VLLM_LATENCY_DEBUG_NO_DEVICE",
"MLU_GRAPH_CAPTURE_LIST",
"VLLM_LOGITS_USE_ALL_GATHER",
"VLLM_V1_USE_FULL_GRAPH",
"VLLM_MTP_FIXED_ACCEPTANCE_RATE"]
@classmethod
def import_kernels(cls) -> None:
"""Import any platform-specific C kernels."""
try:
import torch_mlu_ops
except ImportError as e:
logger.warning("Failed to import from torch_mlu_ops with %r", e)
@classmethod
def pre_register_and_update(
cls, parser: FlexibleArgumentParser | None = None
) -> None:
from vllm_mlu.model_executor.layers.quantization import (
register_real_mlu_quantization_methods
)
register_real_mlu_quantization_methods()
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
if use_mla:
logger.info(f"[MLU-V1][MLA] Select FlashMLABackend.")
return "vllm_mlu.v1.attention.backends.mla.flashmla.FlashMLABackend"
else:
logger.info(f"[MLU-V1] Select FlashAttentionBackend.")
return "vllm_mlu.v1.attention.backends.flash_attn.MLUFlashAttentionBackend"
@classmethod
@lru_cache(maxsize=8)
def get_device_capability(
cls,
device_id: int = 0,
) -> DeviceCapability | None:
try:
major, minor = torch.mlu.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
except RuntimeError:
return None
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.mlu.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.mlu.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def set_device(cls, device: torch.device):
torch.mlu.set_device(device)
@classmethod
def empty_cache(cls):
torch.mlu.empty_cache()
@classmethod
def synchronize(cls):
torch.mlu.synchronize()
@classmethod
def mem_get_info(cls) -> Tuple[int, int]:
return torch.mlu.mem_get_info()
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
model_config = vllm_config.model_config
speculative_config = vllm_config.speculative_config
kv_transfer_config = vllm_config.kv_transfer_config
mlu_config = vllm_config.mlu_config
# Decode use full mlugraph: V1 mode + VLLM_V1_USE_FULL_GRAPH=true
use_full_mlugraph = mlu_envs.VLLM_V1_USE_FULL_GRAPH
# Check compilation config
from vllm.config import CompilationMode, CUDAGraphMode
logger.info(
"[MLU] Force select CompilationMode.None, CUDAGraphMode.FULL_DECODE_ONLY."
)
compilation_config.level = None
compilation_config.mode = CompilationMode.NONE
compilation_config.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
# Dispatch worker
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm_mlu.v1.worker.gpu_worker.MLUWorker"
cls.simple_compile_backend = "inductor"
# Activate custom ops for v1.
compilation_config.custom_ops = ["all"]
if compilation_config.splitting_ops is None:
compilation_config.splitting_ops = []
compilation_config.splitting_ops.extend(["vllm.rope_forward"])
# FIXME: support cascade attention in VLLM-1710
model_config = vllm_config.model_config
if model_config:
model_config.disable_cascade_attn = True
# Select v1 scheduler type
if scheduler_config:
if not scheduler_config.async_scheduling:
if (mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED
and not scheduler_config.enable_chunked_prefill):
vllm_config.scheduler_config.scheduler_cls = \
"vllm_mlu.v1.core.sched.scheduler.MLUUnchunkScheduler"
logger.info(f"[MLU-V1] Select UnchunkScheduler.")
else:
vllm_config.scheduler_config.scheduler_cls = \
"vllm_mlu.v1.core.sched.scheduler.SchedulerWithProfiler"
logger.info(f"[MLU-V1] Select ChunkScheduler.")
else:
if (mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED
and not scheduler_config.enable_chunked_prefill):
vllm_config.scheduler_config.scheduler_cls = \
"vllm_mlu.v1.core.sched.async_scheduler.MLUUnchunkAsyncScheduler"
logger.info(f"[MLU-V1] Select UnchunkAsyncScheduler.")
# Check cache config
if cache_config:
logger.info(
f"[MLU] Select kv_cache_dtype={cache_config.cache_dtype}."
)
if cache_config.block_size is None:
cache_config.block_size = 16
# Check mla config
if model_config and model_config.use_mla:
if (mlu_config.is_dpsk_mcc_enabled or not use_full_mlugraph):
scheduler_config.enable_chunked_prefill = False
scheduler_config.chunked_prefill_enabled = False
logger.warning(
"[MLA] Chunked prefill is disabled when deepseek mcc is enabled, "
"or not use full mlugraph.")
if mlu_config.is_dpsk_mcc_enabled:
cache_config.enable_prefix_caching = False
logger.warning("[MLA] Prefix Caching is disabled when deepseek mcc is enabled.")
# For mlu benchmark, we allow max_num_batched_tokens < max_model_len
# in certain scenarios.
if (
model_config
and scheduler_config.max_num_batched_tokens < model_config.max_model_len
and not scheduler_config.chunked_prefill_enabled
):
msg = f"max_num_batched_tokens ({scheduler_config.max_num_batched_tokens}) is " + \
f"smaller than max_model_len ({model_config.max_model_len}). " + \
"This effectively limits the maximum sequence length to " + \
"max_num_batched_tokens and makes vLLM reject longer " + \
"sequences. Please increase max_num_batched_tokens or " + \
"decrease max_model_len."
if not mlu_envs.VLLM_V1_BENCHMARK:
raise ValueError(msg)
else:
logger.warning(msg)
if (mlu_config.dispatch_shared_expert_parallel
and parallel_config.data_parallel_size <= 1
and not mlu_config.prefill_use_sequence_parallel):
mlu_config.dispatch_shared_expert_parallel = False
logger.info(
"Disabling `mlu_config.dispatch_shared_expert_parallel` when "
"data_parallel_size == 1 or not using sequence parallel."
)
# Check kv_transfer config
if kv_transfer_config:
# Register mlu kv_connectors
import vllm_mlu.distributed.kv_transfer.kv_connector.factory
@classmethod
def get_current_memory_usage(
cls, device: Optional[torch.types.Device] = None
) -> float:
torch.mlu.reset_peak_memory_stats(device)
return torch.mlu.max_memory_allocated(device)
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm_mlu.lora.punica_wrapper.punica_mlu.PunicaWrapperMLU"
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm_mlu.distributed.device_communicators.mlu_communicator.MLUCommunicator"
@classmethod
def use_all_gather(cls) -> bool:
return True
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm_mlu.compilation.mlu_graph.MLUGraphWrapper"
@classmethod
def can_update_inplace(cls) -> bool:
"""
Checks if the platform allows inplace memory updates
"""
return True
def is_sleep_mode_available(self) -> bool:
return True
@classmethod
def import_kernels(cls) -> None:
# Do not import vllm._C
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def support_static_graph_mode(cls) -> bool:
"""
Returns if the graph mode is supported by the current platform.
"""
return True