305 lines
11 KiB
Python
305 lines
11 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
|
||
|
|
import contextlib
|
||
|
|
from functools import lru_cache
|
||
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||
|
|
|
||
|
|
import os
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
|
||
|
|
import vllm.envs as envs
|
||
|
|
from vllm.platforms.interface import (
|
||
|
|
DeviceCapability,
|
||
|
|
Platform,
|
||
|
|
PlatformEnum,
|
||
|
|
)
|
||
|
|
|
||
|
|
import vllm_mlu._mlu_utils as mlu_envs
|
||
|
|
from vllm_mlu.logger import logger
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||
|
|
from vllm.config import ModelConfig, VllmConfig
|
||
|
|
from vllm.config.cache import CacheDType
|
||
|
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||
|
|
else:
|
||
|
|
FlexibleArgumentParser = object
|
||
|
|
|
||
|
|
|
||
|
|
envs.environment_variables.update({
|
||
|
|
"MLU_VISIBLE_DEVICES":
|
||
|
|
lambda: os.environ.get("MLU_VISIBLE_DEVICES", None)
|
||
|
|
})
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class MLUPlatform(Platform):
|
||
|
|
_enum = PlatformEnum.OOT
|
||
|
|
device_name: str = "mlu"
|
||
|
|
device_type: str = "mlu"
|
||
|
|
dispatch_key: str = "MLU"
|
||
|
|
ray_device_key: str = "GPU"
|
||
|
|
device_control_env_var: str = "MLU_VISIBLE_DEVICES"
|
||
|
|
simple_compile_backend: str = "inductor"
|
||
|
|
dist_backend: str = "cncl"
|
||
|
|
|
||
|
|
supported_quantization: list[str] = ["weightonly", "smoothquant",
|
||
|
|
"awq_mlu", "gptq_mlu", "fp8"]
|
||
|
|
additional_env_vars: list[str] = ["VLLM_LATENCY_DEBUG",
|
||
|
|
"VLLM_LATENCY_DEBUG_NO_DEVICE",
|
||
|
|
"MLU_GRAPH_CAPTURE_LIST",
|
||
|
|
"VLLM_LOGITS_USE_ALL_GATHER",
|
||
|
|
"VLLM_V1_USE_FULL_GRAPH",
|
||
|
|
"VLLM_MTP_FIXED_ACCEPTANCE_RATE"]
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def import_kernels(cls) -> None:
|
||
|
|
"""Import any platform-specific C kernels."""
|
||
|
|
try:
|
||
|
|
import torch_mlu_ops
|
||
|
|
except ImportError as e:
|
||
|
|
logger.warning("Failed to import from torch_mlu_ops with %r", e)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def pre_register_and_update(
|
||
|
|
cls, parser: FlexibleArgumentParser | None = None
|
||
|
|
) -> None:
|
||
|
|
from vllm_mlu.model_executor.layers.quantization import (
|
||
|
|
register_real_mlu_quantization_methods
|
||
|
|
)
|
||
|
|
register_real_mlu_quantization_methods()
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_attn_backend_cls(
|
||
|
|
cls,
|
||
|
|
selected_backend: "AttentionBackendEnum",
|
||
|
|
head_size: int,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
kv_cache_dtype: "CacheDType | None",
|
||
|
|
block_size: int,
|
||
|
|
use_mla: bool,
|
||
|
|
has_sink: bool,
|
||
|
|
use_sparse: bool,
|
||
|
|
attn_type: str | None = None,
|
||
|
|
) -> str:
|
||
|
|
|
||
|
|
if use_mla:
|
||
|
|
logger.info(f"[MLU-V1][MLA] Select FlashMLABackend.")
|
||
|
|
return "vllm_mlu.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||
|
|
else:
|
||
|
|
logger.info(f"[MLU-V1] Select FlashAttentionBackend.")
|
||
|
|
return "vllm_mlu.v1.attention.backends.flash_attn.MLUFlashAttentionBackend"
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
@lru_cache(maxsize=8)
|
||
|
|
def get_device_capability(
|
||
|
|
cls,
|
||
|
|
device_id: int = 0,
|
||
|
|
) -> DeviceCapability | None:
|
||
|
|
try:
|
||
|
|
major, minor = torch.mlu.get_device_capability(device_id)
|
||
|
|
return DeviceCapability(major=major, minor=minor)
|
||
|
|
except RuntimeError:
|
||
|
|
return None
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
@lru_cache(maxsize=8)
|
||
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||
|
|
return torch.mlu.get_device_name(device_id)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||
|
|
device_props = torch.mlu.get_device_properties(device_id)
|
||
|
|
return device_props.total_memory
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def set_device(cls, device: torch.device):
|
||
|
|
torch.mlu.set_device(device)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def empty_cache(cls):
|
||
|
|
torch.mlu.empty_cache()
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def synchronize(cls):
|
||
|
|
torch.mlu.synchronize()
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def mem_get_info(cls) -> Tuple[int, int]:
|
||
|
|
return torch.mlu.mem_get_info()
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||
|
|
return True
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def inference_mode(cls):
|
||
|
|
return torch.no_grad()
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def support_hybrid_kv_cache(cls) -> bool:
|
||
|
|
return True
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||
|
|
cache_config = vllm_config.cache_config
|
||
|
|
compilation_config = vllm_config.compilation_config
|
||
|
|
parallel_config = vllm_config.parallel_config
|
||
|
|
scheduler_config = vllm_config.scheduler_config
|
||
|
|
model_config = vllm_config.model_config
|
||
|
|
speculative_config = vllm_config.speculative_config
|
||
|
|
kv_transfer_config = vllm_config.kv_transfer_config
|
||
|
|
mlu_config = vllm_config.mlu_config
|
||
|
|
|
||
|
|
# Decode use full mlugraph: V1 mode + VLLM_V1_USE_FULL_GRAPH=true
|
||
|
|
use_full_mlugraph = mlu_envs.VLLM_V1_USE_FULL_GRAPH
|
||
|
|
|
||
|
|
# Check compilation config
|
||
|
|
from vllm.config import CompilationMode, CUDAGraphMode
|
||
|
|
logger.info(
|
||
|
|
"[MLU] Force select CompilationMode.None, CUDAGraphMode.FULL_DECODE_ONLY."
|
||
|
|
)
|
||
|
|
compilation_config.level = None
|
||
|
|
compilation_config.mode = CompilationMode.NONE
|
||
|
|
compilation_config.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||
|
|
|
||
|
|
# Dispatch worker
|
||
|
|
if parallel_config.worker_cls == "auto":
|
||
|
|
parallel_config.worker_cls = "vllm_mlu.v1.worker.gpu_worker.MLUWorker"
|
||
|
|
|
||
|
|
cls.simple_compile_backend = "inductor"
|
||
|
|
# Activate custom ops for v1.
|
||
|
|
compilation_config.custom_ops = ["all"]
|
||
|
|
if compilation_config.splitting_ops is None:
|
||
|
|
compilation_config.splitting_ops = []
|
||
|
|
compilation_config.splitting_ops.extend(["vllm.rope_forward"])
|
||
|
|
|
||
|
|
# FIXME: support cascade attention in VLLM-1710
|
||
|
|
model_config = vllm_config.model_config
|
||
|
|
if model_config:
|
||
|
|
model_config.disable_cascade_attn = True
|
||
|
|
|
||
|
|
# Select v1 scheduler type
|
||
|
|
if scheduler_config:
|
||
|
|
if not scheduler_config.async_scheduling:
|
||
|
|
if (mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED
|
||
|
|
and not scheduler_config.enable_chunked_prefill):
|
||
|
|
vllm_config.scheduler_config.scheduler_cls = \
|
||
|
|
"vllm_mlu.v1.core.sched.scheduler.MLUUnchunkScheduler"
|
||
|
|
logger.info(f"[MLU-V1] Select UnchunkScheduler.")
|
||
|
|
else:
|
||
|
|
vllm_config.scheduler_config.scheduler_cls = \
|
||
|
|
"vllm_mlu.v1.core.sched.scheduler.SchedulerWithProfiler"
|
||
|
|
logger.info(f"[MLU-V1] Select ChunkScheduler.")
|
||
|
|
else:
|
||
|
|
if (mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED
|
||
|
|
and not scheduler_config.enable_chunked_prefill):
|
||
|
|
vllm_config.scheduler_config.scheduler_cls = \
|
||
|
|
"vllm_mlu.v1.core.sched.async_scheduler.MLUUnchunkAsyncScheduler"
|
||
|
|
logger.info(f"[MLU-V1] Select UnchunkAsyncScheduler.")
|
||
|
|
|
||
|
|
|
||
|
|
# Check cache config
|
||
|
|
if cache_config:
|
||
|
|
logger.info(
|
||
|
|
f"[MLU] Select kv_cache_dtype={cache_config.cache_dtype}."
|
||
|
|
)
|
||
|
|
if cache_config.block_size is None:
|
||
|
|
cache_config.block_size = 16
|
||
|
|
|
||
|
|
# Check mla config
|
||
|
|
if model_config and model_config.use_mla:
|
||
|
|
if (mlu_config.is_dpsk_mcc_enabled or not use_full_mlugraph):
|
||
|
|
scheduler_config.enable_chunked_prefill = False
|
||
|
|
scheduler_config.chunked_prefill_enabled = False
|
||
|
|
logger.warning(
|
||
|
|
"[MLA] Chunked prefill is disabled when deepseek mcc is enabled, "
|
||
|
|
"or not use full mlugraph.")
|
||
|
|
|
||
|
|
if mlu_config.is_dpsk_mcc_enabled:
|
||
|
|
cache_config.enable_prefix_caching = False
|
||
|
|
logger.warning("[MLA] Prefix Caching is disabled when deepseek mcc is enabled.")
|
||
|
|
|
||
|
|
|
||
|
|
# For mlu benchmark, we allow max_num_batched_tokens < max_model_len
|
||
|
|
# in certain scenarios.
|
||
|
|
if (
|
||
|
|
model_config
|
||
|
|
and scheduler_config.max_num_batched_tokens < model_config.max_model_len
|
||
|
|
and not scheduler_config.chunked_prefill_enabled
|
||
|
|
):
|
||
|
|
msg = f"max_num_batched_tokens ({scheduler_config.max_num_batched_tokens}) is " + \
|
||
|
|
f"smaller than max_model_len ({model_config.max_model_len}). " + \
|
||
|
|
"This effectively limits the maximum sequence length to " + \
|
||
|
|
"max_num_batched_tokens and makes vLLM reject longer " + \
|
||
|
|
"sequences. Please increase max_num_batched_tokens or " + \
|
||
|
|
"decrease max_model_len."
|
||
|
|
if not mlu_envs.VLLM_V1_BENCHMARK:
|
||
|
|
raise ValueError(msg)
|
||
|
|
else:
|
||
|
|
logger.warning(msg)
|
||
|
|
|
||
|
|
if (mlu_config.dispatch_shared_expert_parallel
|
||
|
|
and parallel_config.data_parallel_size <= 1
|
||
|
|
and not mlu_config.prefill_use_sequence_parallel):
|
||
|
|
mlu_config.dispatch_shared_expert_parallel = False
|
||
|
|
logger.info(
|
||
|
|
"Disabling `mlu_config.dispatch_shared_expert_parallel` when "
|
||
|
|
"data_parallel_size == 1 or not using sequence parallel."
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check kv_transfer config
|
||
|
|
if kv_transfer_config:
|
||
|
|
# Register mlu kv_connectors
|
||
|
|
import vllm_mlu.distributed.kv_transfer.kv_connector.factory
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_current_memory_usage(
|
||
|
|
cls, device: Optional[torch.types.Device] = None
|
||
|
|
) -> float:
|
||
|
|
torch.mlu.reset_peak_memory_stats(device)
|
||
|
|
return torch.mlu.max_memory_allocated(device)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_punica_wrapper(cls) -> str:
|
||
|
|
return "vllm_mlu.lora.punica_wrapper.punica_mlu.PunicaWrapperMLU"
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_device_communicator_cls(cls) -> str:
|
||
|
|
return "vllm_mlu.distributed.device_communicators.mlu_communicator.MLUCommunicator"
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def use_all_gather(cls) -> bool:
|
||
|
|
return True
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_static_graph_wrapper_cls(cls) -> str:
|
||
|
|
return "vllm_mlu.compilation.mlu_graph.MLUGraphWrapper"
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def can_update_inplace(cls) -> bool:
|
||
|
|
"""
|
||
|
|
Checks if the platform allows inplace memory updates
|
||
|
|
"""
|
||
|
|
return True
|
||
|
|
|
||
|
|
def is_sleep_mode_available(self) -> bool:
|
||
|
|
return True
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def import_kernels(cls) -> None:
|
||
|
|
# Do not import vllm._C
|
||
|
|
with contextlib.suppress(ImportError):
|
||
|
|
import vllm._moe_C # noqa: F401
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def support_static_graph_mode(cls) -> bool:
|
||
|
|
"""
|
||
|
|
Returns if the graph mode is supported by the current platform.
|
||
|
|
"""
|
||
|
|
return True
|