Files
2026-03-10 13:31:25 +08:00

121 lines
4.2 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
from typing import Any, Callable, Dict
import pybrml
import torch
import torch_br
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
def check_allreduce_available():
P2P_DIRECT_LINK_TYPE = 2
pybrml.brmlInit()
device_count = pybrml.brmlDeviceGetCount()
def is_p2p_direct_link(dev0, dev1):
return pybrml.brmlDeviceGetP2PStatus_v3(
dev0, dev1).type == P2P_DIRECT_LINK_TYPE
def get_p2p_link_info(device_count):
p2p_link_info = []
for i in range(device_count):
current_link_info = []
current_dev = pybrml.brmlDeviceGetHandleByIndex(i)
for j in range(device_count):
other_dev = pybrml.brmlDeviceGetHandleByIndex(j)
current_link_info.append(
is_p2p_direct_link(current_dev, other_dev))
p2p_link_info.append(current_link_info)
return p2p_link_info
p2p_link_info = get_p2p_link_info(device_count)
all_reduce_count = sum(p2p_link_info[0])
all_reduce = 1
if all_reduce_count == 3:
all_reduce = 4
elif all_reduce_count == 4:
all_reduce = 8
pybrml.brmlShutdown()
return all_reduce
_VLLM_BR_USE_FUSED_ALLREDUCE_CACHE = check_allreduce_available()
env_variables: Dict[str, Callable[[], Any]] = {
"VLLM_VERSION":
lambda: os.getenv("VLLM_VERSION", None),
"VLLM_BR_USE_PAGED_ATTN":
lambda: os.getenv("VLLM_BR_USE_PAGED_ATTN", False),
"VLLM_BR_WEIGHT_TYPE":
lambda: os.getenv("VLLM_BR_WEIGHT_TYPE", "NUMA"),
"VLLM_BR_QUANT_METHOD":
lambda: os.getenv("VLLM_BR_QUANT_METHOD", "INT8"),
"VLLM_BR_USE_FUSED_ALLREDUCE":
lambda: int(
os.getenv("VLLM_BR_USE_FUSED_ALLREDUCE",
_VLLM_BR_USE_FUSED_ALLREDUCE_CACHE)),
"VLLM_BR_EMBEDDING_S0B":
lambda: bool(int(os.getenv("VLLM_BR_EMBEDDING_S0B", False))),
# MoE (DeepSeek)
"VLLM_BR_STATIC_MOE_DECODER_MAX_LEN":
lambda: int(os.getenv("VLLM_BR_STATIC_MOE_DECODER_MAX_LEN", "256")),
# NOTE: following are device properties
"VLLM_BR_DEVICE_SPC_NUM":
lambda: int(
os.getenv(
"VLLM_BR_DEVICE_SPC_NUM",
torch_br.supa.get_device_properties(torch.device("supa")).
max_compute_units)),
"VLLM_BR_DEVICE_WARP_SIZE":
lambda: int(os.getenv("VLLM_BR_DEVICE_WARP_SIZE", 32)),
"VLLM_BR_USE_CPU_ALL_REDUCE":
lambda: int(os.getenv("VLLM_BR_USE_CPU_ALL_REDUCE", 0)),
"VLLM_SCCL_SO_PATH":
lambda: os.getenv(
"VLLM_SCCL_SO_PATH",
"/usr/local/birensupa/base/latest/succl/lib/x86_64-linux-gnu/libsuccl.so"
),
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
lambda: bool(int(os.getenv("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", False))),
"VLLM_PP_CPU_SEND_RECV":
lambda: bool(int(os.getenv("VLLM_PP_CPU_SEND_RECV", False))),
"VLLM_BR_USE_FP32_ALL_REDUCE":
lambda: int(os.getenv("VLLM_BR_USE_FP32_ALL_REDUCE", 0)),
"VLLM_BR_USE_MROPE_0_9_2":
lambda: bool(os.getenv("VLLM_BR_USE_MROPE_0_9_2", False)),
"VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE":
lambda: bool(int(os.getenv("VLLM_BR_ENABLE_TP_GROUPS_IN_SUPERNODE", "0"))),
}
# end-env-vars-definition
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in env_variables:
return env_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(env_variables.keys())