Files
2026-03-10 13:31:25 +08:00

291 lines
9.3 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import multiprocessing
import os
import sys
import threading
import time
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Callable, Optional
import psutil
import torch
from packaging.version import InvalidVersion, Tuple, Version
from torch.library import Library
import vllm_br.envs as envs
from vllm import utils
from vllm_br.platform import SUPAPlatform
def adapt_patch():
pass
def vllm_version_is(target_vllm_version: str):
if envs.VLLM_VERSION is not None:
vllm_version = envs.VLLM_VERSION
else:
import vllm
vllm_version = vllm.__version__
try:
return Version(vllm_version) == Version(target_vllm_version)
except InvalidVersion as e:
raise ValueError(
f"Invalid vllm version {vllm_version} found. A dev version of vllm "
"is installed probably. Set the environment variable VLLM_VERSION "
"to control it by hand. And please make sure the value follows the "
"format of x.y.z.") from e
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: Optional[list[str]] = None,
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: Optional[str] = None,
tags: Tuple[torch.Tag, ...] = (),
):
pass
def _apply_bnb_4bit_fake(
x: torch.Tensor,
weight: torch.Tensor,
offsets: torch.Tensor,
out: torch.Tensor,
) -> None:
raise NotImplementedError("bnb 4bit is not supported for supa")
def _fused_mul_mat_gguf_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
) -> torch.Tensor:
raise NotImplementedError("fused_mul_mat_gguf is not supported for supa")
def _fused_moe_gguf(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
activation: str,
) -> torch.Tensor:
raise NotImplementedError("fused_moe_gguf is not supported for supa")
def _apply_gguf_embedding_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
raise NotImplementedError("gguf_embedding is not supported for supa")
def thread_safe_singleton(cls):
instances = {}
lock = threading.Lock()
def wrapper(*args, **kwargs):
with lock:
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return wrapper
@thread_safe_singleton
class CPUSharedMemory:
def __init__(self, suffix_name: int):
self.suffix_name = suffix_name
self.shm_for_reduce = None
self.shm_for_copy_flag = None
self.shm_for_cpu_rd = None
self.shm_for_rd_flag = None
self._cleanup_shm(suffix_name)
try:
self.shm_for_reduce = shared_memory.SharedMemory(
create=True,
size=256 * 1024 * 1024,
name=f'/vllm_rd_shm_{suffix_name}')
self.shm_for_copy_flag = shared_memory.SharedMemory(
create=True, size=128, name=f'/vllm_rd_cp_flag_{suffix_name}')
self.shm_for_cpu_rd = shared_memory.SharedMemory(
create=True,
size=256 * 1024 * 1024,
name=f'/vllm_rd_shm_rd_{suffix_name}')
self.shm_for_rd_flag = shared_memory.SharedMemory(
create=True, size=128, name=f'/vllm_rd_rd_flag_{suffix_name}')
except FileExistsError as e:
print(f"cpu shm create failed, file exists {str(e)}",
file=sys.stderr)
self._cleanup()
raise
except Exception as e:
print(f"cpu shm create failed: {str(e)}", file=sys.stderr)
self._cleanup()
raise
def _cleanup_shm(self, suffix_name):
shm_names = [
f'/vllm_rd_shm_{suffix_name}', f'/vllm_rd_cp_flag_{suffix_name}',
f'/vllm_rd_shm_rd_{suffix_name}', f'/vllm_rd_rd_flag_{suffix_name}'
]
for name in shm_names:
try:
existing_shm = shared_memory.SharedMemory(name=name)
existing_shm.close()
if multiprocessing.current_process().name == 'MainProcess':
existing_shm.unlink()
print(f"free shm: {name}", file=sys.stderr)
except FileNotFoundError:
continue
except Exception as e:
print(f"free shm {name} failed : {str(e)}", file=sys.stderr)
def _cleanup(self):
"""free resources"""
shm_objects = [
self.shm_for_reduce, self.shm_for_copy_flag, self.shm_for_cpu_rd,
self.shm_for_rd_flag
]
for shm in shm_objects:
if shm is None:
continue
try:
shm.close()
if multiprocessing.current_process().name == 'MainProcess':
shm.unlink()
print(f"free shm: {shm.name}", file=sys.stderr)
except Exception as e:
print(f"free shm {shm.name} failed: {str(e)}", file=sys.stderr)
# clean
self.shm_for_reduce = None
self.shm_for_copy_flag = None
self.shm_for_cpu_rd = None
self.shm_for_rd_flag = None
def __del__(self):
self._cleanup()
cpu_shared_memory: Optional[CPUSharedMemory] = None
def create_cpu_all_reduce_shared_mem():
global cpu_shared_memory
cpu_shared_memory = CPUSharedMemory(os.getpid())
def get_cpu_all_reduce_shared_mem() -> Optional[CPUSharedMemory]:
global cpu_shared_memory
return cpu_shared_memory
def get_grandparent_pid():
try:
current_process = psutil.Process()
parent_process = current_process.parent()
if parent_process is None:
return None
grandparent_process = parent_process.parent()
return grandparent_process.pid if grandparent_process else None
except psutil.NoSuchProcess:
print("Parent or grandparent process does not exist")
return None
except Exception as e:
print(f"Failed to get grandparent PID: {e}")
return None
GiB_bytes = 1 << 30
@dataclass
class SUPAMemorySnapshot:
torch_peak: int = 0
free_memory: int = 0
total_memory: int = 0
cuda_memory: int = 0
torch_memory: int = 0
non_torch_memory: int = 0
timestamp: float = 0.0
auto_measure: bool = False
def __post_init__(self):
if self.auto_measure:
self.measure()
def measure(self):
self.torch_peak = SUPAPlatform.get_memory_stats(
"supa", "allocated_bytes.all.peak")
self.free_memory, self.total_memory = torch.supa.mem_get_info()
self.cuda_memory = self.total_memory - self.free_memory
self.torch_memory = torch.supa.memory_reserved()
self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time()
def __sub__(self, other: "SUPAMemorySnapshot") -> "SUPAMemorySnapshot":
return SUPAMemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak,
free_memory=self.free_memory - other.free_memory,
total_memory=self.total_memory - other.total_memory,
cuda_memory=self.cuda_memory - other.cuda_memory,
torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp,
auto_measure=False,
)
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
raise NotImplementedError("_dequant_mxfp4 is not supported for supa")
def _quant_dequant_mxfp4(x: torch.Tensor,
scale_calculation_mode: str = "even") -> torch.Tensor:
raise NotImplementedError("_quant_dequant_mxfp4 is not supported for supa")
utils.direct_register_custom_op = direct_register_custom_op
torch.ops.vllm.apply_bnb_4bit = _apply_bnb_4bit_fake
torch.ops.vllm._fused_mul_mat_gguf = _fused_mul_mat_gguf_fake
torch.ops.vllm._fused_moe_gguf = _fused_moe_gguf
torch.ops.vllm._apply_gguf_embedding = _apply_gguf_embedding_fake
torch.ops.vllm.dequant_mxfp4 = _dequant_mxfp4
torch.ops.vllm.quant_dequant_mxfp4 = _quant_dequant_mxfp4
#import vllm.model_executor.layers.quantization.bitsandbytes
#vllm.model_executor.layers.quantization.bitsandbytes.direct_register_custom_op = direct_register_custom_op