Files
enginex-mlu590-vllm/csrc/cnmem_allocator.cpp

310 lines
11 KiB
C++
Raw Permalink Normal View History

2026-04-24 09:50:34 +08:00
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
// A MLU PluggableAllocator based on cn_api APIs.
#include <iostream>
extern "C" {
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <cn_api.h>
#define DRV_CHECK_GET_RETURN(...) \
DRV_CHECK_GET_RETURN_IMPL(__VA_ARGS__, return, )
#define DRV_CHECK_GET_RETURN_IMPL(_1, _2, ...) _2
#define CN_CHECK(return_code, ...) \
do { \
CNresult rc = (return_code); \
if (rc) { \
const char *error_str; \
cnGetErrorString(rc, &error_str); \
std::cout << "Error: " << error_str \
<< " at " << __FILE__ \
<< ":" << __LINE__ \
<< std::endl; \
DRV_CHECK_GET_RETURN(__VA_ARGS__) \
__VA_ARGS__; \
} \
} while (0)
// Global references to Python callables
static PyObject* g_python_malloc_callback = nullptr;
static PyObject* g_python_free_callback = nullptr;
// ---------------------------------------------------------------------------
// Helper functions:
void ensure_context(CNdev device) {
CNcontext pctx;
CN_CHECK(cnCtxGetCurrent(&pctx));
if (!pctx) {
// Ensure device context;
CN_CHECK(cnCtxCreate(&pctx, 0, device));
CN_CHECK(cnCtxSetCurrent(pctx));
}
}
void create_and_map(CNdev device, ssize_t size, CNaddr d_mem, CNmemGenericAllocationHandle* p_memHandle) {
ensure_context(device);
// Define memory allocation properties
CNmemAllocationProp prop = {};
// The memory allocation type requested, which must be CN_MEM_ALLOCATION_TYPE_DEFAULT currently according to cndrv developer guide.
prop.type = CN_MEM_ALLOCATION_TYPE_DEFAULT; //CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = CN_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.requestedHandleTypes = CN_MEM_HANDLE_TYPE_NONE;
prop.allocFlags.compressionType = CN_MEM_ALLOCATION_COMP_NONE;
// Allocate memory using cnMemCreate
CN_CHECK(cnMemCreate(p_memHandle, size, &prop, 0));
CN_CHECK(cnMemMap(d_mem, size, 0, *p_memHandle, 0));
CNmemAccessDesc accessDesc = {};
accessDesc.location.type = CN_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = device;
accessDesc.accessFlags = CN_MEM_ACCESS_FLAGS_PROT_READWRITE;
CN_CHECK(cnMemSetAccess(d_mem, size, &accessDesc, 1));
}
void unmap_and_release(CNdev device, ssize_t size, CNaddr d_mem, CNmemGenericAllocationHandle* p_memHandle) {
ensure_context(device);
CN_CHECK(cnMemUnmap(d_mem, size));
CN_CHECK(cnMemRelease(*p_memHandle));
}
PyObject* create_tuple_from_c_integers(unsigned long long a,
unsigned long long b,
unsigned long long c,
unsigned long long d) {
// Create a new tuple of size 4
PyObject* tuple = PyTuple_New(4);
if (!tuple) {
return NULL;
}
// Convert integers to Python objects and set them in the tuple
// Steals reference to the PyLong
PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a));
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
// Note: PyTuple_SetItem "steals" a reference to each object,
// so we do not need to Py_DECREF the PyLong objects explicitly.
return tuple;
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
__attribute__ ((visibility("default"))) void* my_malloc(ssize_t size, int device, CNqueue stream) {
ensure_context(device);
// first allocation, align the size, and reserve an address, and also allocate
// a CNmemGenericAllocationHandle
// Define memory allocation properties
CNmemAllocationProp prop = {};
// The memory allocation type requested, which must be CN_MEM_ALLOCATION_TYPE_DEFAULT currently according to cndrv developer guide.
prop.type = CN_MEM_ALLOCATION_TYPE_DEFAULT; //CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = CN_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.requestedHandleTypes = CN_MEM_HANDLE_TYPE_NONE;
prop.allocFlags.compressionType = CN_MEM_ALLOCATION_COMP_NONE;
//Check if the allocation is supported
size_t granularity;
CN_CHECK(cnMemGetAllocationGranularity(&granularity, &prop, CN_MEM_ALLOC_GRANULARITY_MINIMUM), nullptr);
size_t alignedSize = ((size+granularity-1)/granularity)*granularity;
CNaddr d_mem;
CN_CHECK(cnMemAddressReserve(&d_mem, alignedSize, 0, 0, 0), nullptr);
// allocate the CNmemGenericAllocationHandle
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)malloc(sizeof(CNmemGenericAllocationHandle));
if (!g_python_malloc_callback) {
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
return nullptr;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* arg_tuple = create_tuple_from_c_integers(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
// Call g_python_malloc_callback
PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
Py_DECREF(arg_tuple);
if (!py_result) {
PyErr_Print();
PyGILState_Release(gstate);
return nullptr;
}
PyGILState_Release(gstate);
// do the final mapping
create_and_map(device, alignedSize, d_mem, p_memHandle);
return (void*)d_mem;
}
__attribute__ ((visibility("default"))) void my_free(void* ptr, ssize_t size, int device, CNqueue stream) {
// get memory handle from the pointer
if (!g_python_free_callback) {
std::cerr << "ERROR: g_python_free_callback not set.\n";
return;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* py_ptr = PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return;
}
PyGILState_Release(gstate);
// Free memory
CNaddr d_mem = (CNaddr)recv_d_mem;
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);
//free address and the handle
CN_CHECK(cnMemAddressFree(d_mem, size));
free(p_memHandle);
}
// ---------------------------------------------------------------------------
// Python extension boilerplate:
// Python-exposed function: init_module(python_malloc, python_free)
static PyObject* py_init_module(PyObject* self, PyObject* args) {
PyObject* malloc_callback = nullptr;
PyObject* free_callback = nullptr;
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
return nullptr;
}
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
return nullptr;
}
// Save the Python callables
// This module does not handle GC of these objects, so they must be kept alive
// outside of this module.
// This module keeps a strong reference to prevent premature GC
Py_XINCREF(malloc_callback);
Py_XINCREF(free_callback);
Py_XDECREF(g_python_malloc_callback);
Py_XDECREF(g_python_free_callback);
g_python_malloc_callback = malloc_callback;
g_python_free_callback = free_callback;
Py_RETURN_NONE;
}
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CNaddr d_mem_ptr = (CNaddr)recv_d_mem;
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CNaddr d_mem_ptr = (CNaddr)recv_d_mem;
CNmemGenericAllocationHandle* p_memHandle = (CNmemGenericAllocationHandle*)recv_p_memHandle;
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyObject* python_cn_memcpy(PyObject* self, PyObject* args){
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 3) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 3");
return nullptr;
}
CNaddr dst, src;
cn_uint64_t bytes;
if (!PyArg_ParseTuple(args, "KKK", &dst, &src, &bytes)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CN_CHECK(cnMemcpy(dst, src, bytes), nullptr);
Py_RETURN_NONE;
}
static PyMethodDef module_methods[] = {
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
"Initialize module with python_malloc and python_free callables."},
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
"Create and map memory on the device."},
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
METH_VARARGS, "Unmap and release memory on the device."},
{"python_cn_memcpy", (PyCFunction)python_cn_memcpy, METH_VARARGS, "Copies data from source address to destination address."},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef cnmem_allocator_module = {
PyModuleDef_HEAD_INIT, "cnmem_allocator",
"cnapi-mem-based allocator for MLUPluggableAllocator", -1, module_methods};
PyMODINIT_FUNC PyInit_vllm_mlu_C(void) {
// Initialize the module
PyObject* module = PyModule_Create(&cnmem_allocator_module);
if (!module) {
return NULL;
}
return module;
}
} // extern "C"