291 lines
9.3 KiB
Python
291 lines
9.3 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
import multiprocessing
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass
|
|
from multiprocessing import shared_memory
|
|
from typing import Callable, Optional
|
|
|
|
import psutil
|
|
import torch
|
|
from packaging.version import InvalidVersion, Tuple, Version
|
|
from torch.library import Library
|
|
|
|
import vllm_br.envs as envs
|
|
from vllm import utils
|
|
from vllm_br.platform import SUPAPlatform
|
|
|
|
|
|
def adapt_patch():
|
|
pass
|
|
|
|
|
|
def vllm_version_is(target_vllm_version: str):
|
|
if envs.VLLM_VERSION is not None:
|
|
vllm_version = envs.VLLM_VERSION
|
|
else:
|
|
import vllm
|
|
|
|
vllm_version = vllm.__version__
|
|
try:
|
|
return Version(vllm_version) == Version(target_vllm_version)
|
|
except InvalidVersion as e:
|
|
raise ValueError(
|
|
f"Invalid vllm version {vllm_version} found. A dev version of vllm "
|
|
"is installed probably. Set the environment variable VLLM_VERSION "
|
|
"to control it by hand. And please make sure the value follows the "
|
|
"format of x.y.z.") from e
|
|
|
|
|
|
def direct_register_custom_op(
|
|
op_name: str,
|
|
op_func: Callable,
|
|
mutates_args: Optional[list[str]] = None,
|
|
fake_impl: Optional[Callable] = None,
|
|
target_lib: Optional[Library] = None,
|
|
dispatch_key: Optional[str] = None,
|
|
tags: Tuple[torch.Tag, ...] = (),
|
|
):
|
|
pass
|
|
|
|
|
|
def _apply_bnb_4bit_fake(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
out: torch.Tensor,
|
|
) -> None:
|
|
raise NotImplementedError("bnb 4bit is not supported for supa")
|
|
|
|
|
|
def _fused_mul_mat_gguf_fake(
|
|
x: torch.Tensor,
|
|
qweight: torch.Tensor,
|
|
qweight_type: int,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError("fused_mul_mat_gguf is not supported for supa")
|
|
|
|
|
|
def _fused_moe_gguf(
|
|
x: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
qweight_type: int,
|
|
qweight_type2: int,
|
|
activation: str,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError("fused_moe_gguf is not supported for supa")
|
|
|
|
|
|
def _apply_gguf_embedding_fake(
|
|
x: torch.Tensor,
|
|
qweight: torch.Tensor,
|
|
qweight_type: int,
|
|
hidden_size: int,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError("gguf_embedding is not supported for supa")
|
|
|
|
|
|
def thread_safe_singleton(cls):
|
|
instances = {}
|
|
lock = threading.Lock()
|
|
|
|
def wrapper(*args, **kwargs):
|
|
with lock:
|
|
if cls not in instances:
|
|
instances[cls] = cls(*args, **kwargs)
|
|
return instances[cls]
|
|
|
|
return wrapper
|
|
|
|
|
|
@thread_safe_singleton
|
|
class CPUSharedMemory:
|
|
|
|
def __init__(self, suffix_name: int):
|
|
self.suffix_name = suffix_name
|
|
self.shm_for_reduce = None
|
|
self.shm_for_copy_flag = None
|
|
self.shm_for_cpu_rd = None
|
|
self.shm_for_rd_flag = None
|
|
|
|
self._cleanup_shm(suffix_name)
|
|
try:
|
|
self.shm_for_reduce = shared_memory.SharedMemory(
|
|
create=True,
|
|
size=256 * 1024 * 1024,
|
|
name=f'/vllm_rd_shm_{suffix_name}')
|
|
self.shm_for_copy_flag = shared_memory.SharedMemory(
|
|
create=True, size=128, name=f'/vllm_rd_cp_flag_{suffix_name}')
|
|
self.shm_for_cpu_rd = shared_memory.SharedMemory(
|
|
create=True,
|
|
size=256 * 1024 * 1024,
|
|
name=f'/vllm_rd_shm_rd_{suffix_name}')
|
|
self.shm_for_rd_flag = shared_memory.SharedMemory(
|
|
create=True, size=128, name=f'/vllm_rd_rd_flag_{suffix_name}')
|
|
except FileExistsError as e:
|
|
print(f"cpu shm create failed, file exists {str(e)}",
|
|
file=sys.stderr)
|
|
self._cleanup()
|
|
raise
|
|
except Exception as e:
|
|
print(f"cpu shm create failed: {str(e)}", file=sys.stderr)
|
|
self._cleanup()
|
|
raise
|
|
|
|
def _cleanup_shm(self, suffix_name):
|
|
shm_names = [
|
|
f'/vllm_rd_shm_{suffix_name}', f'/vllm_rd_cp_flag_{suffix_name}',
|
|
f'/vllm_rd_shm_rd_{suffix_name}', f'/vllm_rd_rd_flag_{suffix_name}'
|
|
]
|
|
for name in shm_names:
|
|
try:
|
|
existing_shm = shared_memory.SharedMemory(name=name)
|
|
existing_shm.close()
|
|
if multiprocessing.current_process().name == 'MainProcess':
|
|
existing_shm.unlink()
|
|
print(f"free shm: {name}", file=sys.stderr)
|
|
except FileNotFoundError:
|
|
continue
|
|
except Exception as e:
|
|
print(f"free shm {name} failed : {str(e)}", file=sys.stderr)
|
|
|
|
def _cleanup(self):
|
|
"""free resources"""
|
|
shm_objects = [
|
|
self.shm_for_reduce, self.shm_for_copy_flag, self.shm_for_cpu_rd,
|
|
self.shm_for_rd_flag
|
|
]
|
|
|
|
for shm in shm_objects:
|
|
if shm is None:
|
|
continue
|
|
try:
|
|
shm.close()
|
|
if multiprocessing.current_process().name == 'MainProcess':
|
|
shm.unlink()
|
|
print(f"free shm: {shm.name}", file=sys.stderr)
|
|
except Exception as e:
|
|
print(f"free shm {shm.name} failed: {str(e)}", file=sys.stderr)
|
|
# clean
|
|
self.shm_for_reduce = None
|
|
self.shm_for_copy_flag = None
|
|
self.shm_for_cpu_rd = None
|
|
self.shm_for_rd_flag = None
|
|
|
|
def __del__(self):
|
|
self._cleanup()
|
|
|
|
|
|
cpu_shared_memory: Optional[CPUSharedMemory] = None
|
|
|
|
|
|
def create_cpu_all_reduce_shared_mem():
|
|
global cpu_shared_memory
|
|
cpu_shared_memory = CPUSharedMemory(os.getpid())
|
|
|
|
|
|
def get_cpu_all_reduce_shared_mem() -> Optional[CPUSharedMemory]:
|
|
global cpu_shared_memory
|
|
return cpu_shared_memory
|
|
|
|
|
|
def get_grandparent_pid():
|
|
try:
|
|
current_process = psutil.Process()
|
|
parent_process = current_process.parent()
|
|
if parent_process is None:
|
|
return None
|
|
grandparent_process = parent_process.parent()
|
|
return grandparent_process.pid if grandparent_process else None
|
|
except psutil.NoSuchProcess:
|
|
print("Parent or grandparent process does not exist")
|
|
return None
|
|
except Exception as e:
|
|
print(f"Failed to get grandparent PID: {e}")
|
|
return None
|
|
|
|
|
|
GiB_bytes = 1 << 30
|
|
|
|
|
|
@dataclass
|
|
class SUPAMemorySnapshot:
|
|
torch_peak: int = 0
|
|
free_memory: int = 0
|
|
total_memory: int = 0
|
|
cuda_memory: int = 0
|
|
torch_memory: int = 0
|
|
non_torch_memory: int = 0
|
|
timestamp: float = 0.0
|
|
auto_measure: bool = False
|
|
|
|
def __post_init__(self):
|
|
if self.auto_measure:
|
|
self.measure()
|
|
|
|
def measure(self):
|
|
self.torch_peak = SUPAPlatform.get_memory_stats(
|
|
"supa", "allocated_bytes.all.peak")
|
|
|
|
self.free_memory, self.total_memory = torch.supa.mem_get_info()
|
|
self.cuda_memory = self.total_memory - self.free_memory
|
|
|
|
self.torch_memory = torch.supa.memory_reserved()
|
|
|
|
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
|
self.timestamp = time.time()
|
|
|
|
def __sub__(self, other: "SUPAMemorySnapshot") -> "SUPAMemorySnapshot":
|
|
return SUPAMemorySnapshot(
|
|
torch_peak=self.torch_peak - other.torch_peak,
|
|
free_memory=self.free_memory - other.free_memory,
|
|
total_memory=self.total_memory - other.total_memory,
|
|
cuda_memory=self.cuda_memory - other.cuda_memory,
|
|
torch_memory=self.torch_memory - other.torch_memory,
|
|
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
|
|
timestamp=self.timestamp - other.timestamp,
|
|
auto_measure=False,
|
|
)
|
|
|
|
|
|
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
|
|
float_dtype: torch.dtype) -> torch.Tensor:
|
|
raise NotImplementedError("_dequant_mxfp4 is not supported for supa")
|
|
|
|
|
|
def _quant_dequant_mxfp4(x: torch.Tensor,
|
|
scale_calculation_mode: str = "even") -> torch.Tensor:
|
|
raise NotImplementedError("_quant_dequant_mxfp4 is not supported for supa")
|
|
|
|
|
|
utils.direct_register_custom_op = direct_register_custom_op
|
|
torch.ops.vllm.apply_bnb_4bit = _apply_bnb_4bit_fake
|
|
torch.ops.vllm._fused_mul_mat_gguf = _fused_mul_mat_gguf_fake
|
|
torch.ops.vllm._fused_moe_gguf = _fused_moe_gguf
|
|
torch.ops.vllm._apply_gguf_embedding = _apply_gguf_embedding_fake
|
|
torch.ops.vllm.dequant_mxfp4 = _dequant_mxfp4
|
|
torch.ops.vllm.quant_dequant_mxfp4 = _quant_dequant_mxfp4
|
|
|
|
#import vllm.model_executor.layers.quantization.bitsandbytes
|
|
#vllm.model_executor.layers.quantization.bitsandbytes.direct_register_custom_op = direct_register_custom_op
|