121 lines
4.2 KiB
Python
121 lines
4.2 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
import os
|
|
from typing import Any, Callable, Dict
|
|
|
|
import pybrml
|
|
import torch
|
|
import torch_br
|
|
|
|
# The begin-* and end* here are used by the documentation generator
|
|
# to extract the used env vars.
|
|
|
|
|
|
# begin-env-vars-definition
|
|
def check_allreduce_available():
|
|
P2P_DIRECT_LINK_TYPE = 2
|
|
pybrml.brmlInit()
|
|
device_count = pybrml.brmlDeviceGetCount()
|
|
|
|
def is_p2p_direct_link(dev0, dev1):
|
|
return pybrml.brmlDeviceGetP2PStatus_v3(
|
|
dev0, dev1).type == P2P_DIRECT_LINK_TYPE
|
|
|
|
def get_p2p_link_info(device_count):
|
|
p2p_link_info = []
|
|
for i in range(device_count):
|
|
current_link_info = []
|
|
current_dev = pybrml.brmlDeviceGetHandleByIndex(i)
|
|
for j in range(device_count):
|
|
other_dev = pybrml.brmlDeviceGetHandleByIndex(j)
|
|
current_link_info.append(
|
|
is_p2p_direct_link(current_dev, other_dev))
|
|
p2p_link_info.append(current_link_info)
|
|
return p2p_link_info
|
|
|
|
p2p_link_info = get_p2p_link_info(device_count)
|
|
all_reduce_count = sum(p2p_link_info[0])
|
|
all_reduce = 1
|
|
if all_reduce_count == 3:
|
|
all_reduce = 4
|
|
elif all_reduce_count == 4:
|
|
all_reduce = 8
|
|
pybrml.brmlShutdown()
|
|
return all_reduce
|
|
|
|
|
|
_VLLM_BR_USE_FUSED_ALLREDUCE_CACHE = check_allreduce_available()
|
|
|
|
env_variables: Dict[str, Callable[[], Any]] = {
|
|
"VLLM_VERSION":
|
|
lambda: os.getenv("VLLM_VERSION", None),
|
|
"VLLM_BR_USE_PAGED_ATTN":
|
|
lambda: os.getenv("VLLM_BR_USE_PAGED_ATTN", False),
|
|
"VLLM_BR_WEIGHT_TYPE":
|
|
lambda: os.getenv("VLLM_BR_WEIGHT_TYPE", "NUMA"),
|
|
"VLLM_BR_QUANT_METHOD":
|
|
lambda: os.getenv("VLLM_BR_QUANT_METHOD", "INT8"),
|
|
"VLLM_BR_USE_FUSED_ALLREDUCE":
|
|
lambda: int(
|
|
os.getenv("VLLM_BR_USE_FUSED_ALLREDUCE",
|
|
_VLLM_BR_USE_FUSED_ALLREDUCE_CACHE)),
|
|
"VLLM_BR_EMBEDDING_S0B":
|
|
lambda: bool(int(os.getenv("VLLM_BR_EMBEDDING_S0B", False))),
|
|
# MoE (DeepSeek)
|
|
"VLLM_BR_STATIC_MOE_DECODER_MAX_LEN":
|
|
lambda: int(os.getenv("VLLM_BR_STATIC_MOE_DECODER_MAX_LEN", "256")),
|
|
# NOTE: following are device properties
|
|
"VLLM_BR_DEVICE_SPC_NUM":
|
|
lambda: int(
|
|
os.getenv(
|
|
"VLLM_BR_DEVICE_SPC_NUM",
|
|
torch_br.supa.get_device_properties(torch.device("supa")).
|
|
max_compute_units)),
|
|
"VLLM_BR_DEVICE_WARP_SIZE":
|
|
lambda: int(os.getenv("VLLM_BR_DEVICE_WARP_SIZE", 32)),
|
|
"VLLM_BR_USE_CPU_ALL_REDUCE":
|
|
lambda: int(os.getenv("VLLM_BR_USE_CPU_ALL_REDUCE", 0)),
|
|
"VLLM_SCCL_SO_PATH":
|
|
lambda: os.getenv(
|
|
"VLLM_SCCL_SO_PATH",
|
|
"/usr/local/birensupa/base/latest/succl/lib/x86_64-linux-gnu/libsuccl.so"
|
|
),
|
|
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
|
|
lambda: bool(int(os.getenv("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", False))),
|
|
"VLLM_PP_CPU_SEND_RECV":
|
|
lambda: bool(int(os.getenv("VLLM_PP_CPU_SEND_RECV", False))),
|
|
"VLLM_BR_USE_FP32_ALL_REDUCE":
|
|
lambda: int(os.getenv("VLLM_BR_USE_FP32_ALL_REDUCE", 0)),
|
|
"VLLM_BR_USE_MROPE_0_9_2":
|
|
lambda: bool(os.getenv("VLLM_BR_USE_MROPE_0_9_2", False)),
|
|
"VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE":
|
|
lambda: bool(int(os.getenv("VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE", "0"))),
|
|
}
|
|
|
|
# end-env-vars-definition
|
|
|
|
|
|
def __getattr__(name: str):
|
|
# lazy evaluation of environment variables
|
|
if name in env_variables:
|
|
return env_variables[name]()
|
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
|
|
|
|
def __dir__():
|
|
return list(env_variables.keys())
|