Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/_mlu_utils.py
2026-02-04 17:22:39 +08:00

123 lines
4.4 KiB
Python

from torch.utils import collect_env as torch_collect_env
import os
import re
def _check_env(env, default=False):
if env in os.environ:
return os.environ[env].lower() in ["true", "1"]
return default
def _check_env_value(env, default=0):
if env in os.environ:
if not os.environ[env].isdigit():
raise ValueError(f"'{env}' should be set with integer")
value = int(os.environ[env])
return value
return default
def get_device_name(device_id: int = 0) -> str:
r"""Gets the name of a device.
Args:
device_id (int): device id for which to return the device name.
Returns:
str: the name of the device. eg. MLU370.
"""
run_lambda = torch_collect_env.run
try:
out = torch_collect_env.run_and_read_all(run_lambda, "cnmon -l")
matches = re.findall(r'MLU\d+(?:-\w+)?', out)
return matches[device_id]
except Exception as e:
raise Exception(f"No device found with ID {device_id}.")
def get_device_major_capability(device_id: int = 0) -> int:
r"""Gets the cuda major capability of a device.
Args:
device_id (int): device id for which to return the device capability.
Returns:
int: the major cuda capability of the device.
"""
try:
device_name = get_device_name(device_id)
return int(device_name[3])
except Exception as e:
raise Exception(f"Fail to parse device capability with ID: {device_id}.")
# USE_PAGED: Select the vLLM running mode, default value depends on current platform.
USE_PAGED = _check_env("USE_PAGED", default=(get_device_major_capability() > 3))
# VLLM_LATENCY_DEBUG: Get more kernel info for benchmark latency.
VLLM_LATENCY_DEBUG = _check_env("VLLM_LATENCY_DEBUG", default=False)
# VLLM_LATENCY_DEBUG_NO_DEVICE: Get more kernel info(without device) for benchmark latency.
VLLM_LATENCY_DEBUG_NO_DEVICE = _check_env("VLLM_LATENCY_DEBUG_NO_DEVICE", default=False)
# VLLM_DUMP_TENSORS: Dump each layer outputs when running vLLM inference.
VLLM_DUMP_OUTPUTS = _check_env("VLLM_DUMP_OUTPUTS", default=False)
# VLLM_DUMP_CPU_INFO: Get cpu info when running vLLM inference.
VLLM_DUMP_CPU_INFO = _check_env("VLLM_DUMP_CPU_INFO", default=False)
# VLLM_DUMP_MLU_INFO: Get device info when running vLLM inference.
VLLM_DUMP_MLU_INFO = _check_env("VLLM_DUMP_MLU_INFO", default=False)
# VLLM_SCHEDULER_PROFILE: Profiling vLLM scheduler.
VLLM_SCHEDULER_PROFILE = _check_env("VLLM_SCHEDULER_PROFILE", default=False)
# VLLM_GRAPH_DEBUG: Debug the graph status when running decoder, default value is True.
# Set to False to disable warning messages.
VLLM_GRAPH_DEBUG = _check_env("VLLM_GRAPH_DEBUG", default=True)
# CHUNKED_PIPELINE_PARALLEL_EN: use chunked pipeline parallel, default value is False.
CHUNKED_PIPELINE_PARALLEL_EN = _check_env("CHUNKED_PIPELINE_PARALLEL_EN", default=False)
# CONTEXT_PARALLEL_EN: use context parallel, default value is False.
CONTEXT_PARALLEL_EN = _check_env("CONTEXT_PARALLEL_EN", default=False)
# EXPERT_PARALLEL_EN: use expert parallel, default value is False.
EXPERT_PARALLEL_EN = _check_env("EXPERT_PARALLEL_EN", default=False)
VLLM_LATENCY_DEBUG_EN = (VLLM_LATENCY_DEBUG or VLLM_LATENCY_DEBUG_NO_DEVICE)
VLLM_LATENCY_DEBUG_WITH_DEVICE_EN = (VLLM_LATENCY_DEBUG and not VLLM_LATENCY_DEBUG_NO_DEVICE)
VLLM_DUMP_CPU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_CPU_INFO)
VLLM_DUMP_MLU_INFO_EN = (VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and VLLM_DUMP_MLU_INFO)
CUSTOM_VLLM_HIJACK_EN = (CHUNKED_PIPELINE_PARALLEL_EN or CONTEXT_PARALLEL_EN or EXPERT_PARALLEL_EN)
VLLM_PRELOAD_SIZE = _check_env_value("VLLM_PRELOAD_SIZE", default=0)
# ATTN_PARALLEL_NUM & FFN_PARALLEL_NUM: use context comm cmpt parallel.
ATTN_PARALLEL_NUM = 'ATTN_PARALLEL_NUM'
FFN_PARALLEL_NUM = 'FFN_PARALLEL_NUM'
# this class is used by layers, add BlockSizeInfo to get BLOCKSIZE in model/layer
class BlockSizeInfo :
BLOCK_SIZE = -1
@classmethod
def set_block_size(cls, a : int) :
if USE_PAGED :
if a != -1 and a != 16 :
raise ValueError("BLOCKSIZE other than 16 are not supported in paged mode, please check '--block-size' value.")
cls.BLOCK_SIZE = 16
else :
cls.BLOCK_SIZE = 2048 if a == -1 else a
def check_context_comm_cmpt_parallel():
return (ATTN_PARALLEL_NUM in os.environ) or (FFN_PARALLEL_NUM in os.environ)
def set_is_prompt(flag):
global IS_PROMPT
IS_PROMPT=flag
def get_is_prompt():
return IS_PROMPT