Files

217 lines
8.2 KiB
Python
Raw Permalink Normal View History

2026-01-19 10:38:50 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import ctypes
from dataclasses import dataclass
from typing import Any
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
@dataclass
class Function:
name: str
restype: Any
argtypes: list[Any]
def find_loaded_library(lib_name) -> str | None:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found = False
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found = True
break
if not found:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = line.index("/")
path = line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), (
f"Unexpected filename: {filename} for library {lib_name}"
)
return path
class CudaRTLibrary:
exported_functions = [
# cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function(
"cudaMalloc",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
),
# cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function(
"cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function(
"cudaMemcpy",
cudaError_t,
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function(
"cudaIpcGetMemHandle",
cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function(
"cudaIpcOpenMemHandle",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
),
]
# https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Runtime_API_functions_supported_by_HIP.html # noqa
cuda_to_hip_mapping = {
"cudaSetDevice": "hipSetDevice",
"cudaDeviceSynchronize": "hipDeviceSynchronize",
"cudaDeviceReset": "hipDeviceReset",
"cudaGetErrorString": "hipGetErrorString",
"cudaMalloc": "hipMalloc",
"cudaFree": "hipFree",
"cudaMemset": "hipMemset",
"cudaMemcpy": "hipMemcpy",
"cudaIpcGetMemHandle": "hipIpcGetMemHandle",
"cudaIpcOpenMemHandle": "hipIpcOpenMemHandle",
}
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: str | None = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
if so_file is None:
# libcudart is not loaded in the current process, try hip
so_file = find_loaded_library("libamdhip64")
# should be safe to assume now that we are using ROCm
# as the following assertion should error out if the
# libhiprtc library is also not loaded
if so_file is None:
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
assert so_file is not None, (
"libcudart is not loaded in the current process, "
"try setting VLLM_CUDART_SO_PATH"
)
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(
self.lib,
CudaRTLibrary.cuda_to_hip_mapping[func.name]
if current_platform.is_rocm()
else func.name,
)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")
def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
def cudaMemcpy(
self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(
self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
)
return handle
def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(
self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
)
)
return devPtr