#include #include #include #include #include #include #include #include #include "vxpu_offload/xpu_helper.h" #include "vxpu_offload/shm_worker.h" namespace py = pybind11; namespace { // vxpu static std::atomic g_initialized(false); static void *g_vmem = nullptr; static size_t g_size = 0; static std::atomic_uint_fast64_t g_allocated_offset(0); ShmWorker *shm_worker = nullptr; static const size_t granularity = 2 * 1024 * 1024; // 2MB // Global references to Python callables // NOTE: this is borrowed reference, so we don't need to DECREF them. // This brings the limitation that the allocator needs to be singleton. // static PyObject* g_python_malloc_callback = nullptr; // static PyObject* g_python_free_callback = nullptr; static py::function g_python_malloc_callback; static py::function g_python_free_callback; extern "C" { void* my_malloc(ssize_t size, int device, cudaStream_t stream) { size_t aligned_size = ((size + granularity - 1) / granularity) * granularity; size_t alloc_offset = g_allocated_offset.fetch_add(aligned_size); if (alloc_offset + aligned_size > g_size) { throw std::runtime_error( "my_malloc ERROR: Out of memory in the reserved pool." + std::string(" ") + __FILE__ + ":" + std::to_string(__LINE__)); } void *d_ptr = (void *)((uintptr_t)g_vmem + alloc_offset); if (!g_python_malloc_callback) { throw std::runtime_error( "my_malloc ERROR: g_python_malloc_callback is not callable." + std::string(" ") + __FILE__ + ":" + std::to_string(__LINE__)); } else { py::gil_scoped_acquire gil; unsigned long long fake_handle = 0; auto handle_tuple = std::make_tuple( (unsigned long long)device, (unsigned long long)aligned_size, (unsigned long long)d_ptr, (unsigned long long)fake_handle); g_python_malloc_callback(handle_tuple); } return d_ptr; } void my_free(void *ptr, ssize_t size, int device, cudaStream_t stream) { if (!g_python_free_callback) { throw std::runtime_error( "my_free ERROR: g_python_free_callback is not callable." + std::string(" ") + __FILE__ + ":" + std::to_string(__LINE__)); } else { py::gil_scoped_acquire gil; py::object result = g_python_free_callback((unsigned long long)ptr); // nothing to do } } } // extern "C" void init_module(py::function malloc_cb, py::function free_cb, int device_id) { g_python_malloc_callback = malloc_cb; g_python_free_callback = free_cb; // init vxpu if (g_initialized.load()) { return; } g_initialized.store(true); shm_worker = new ShmWorker(); XPUIpcMemHandle mem_handle; bool res = shm_worker->register_worker(device_id, &mem_handle, &g_size); if (!res) { throw std::runtime_error( "init_module ERROR: Failed to register shm worker." + std::string(" ") + __FILE__ + ":" + std::to_string(__LINE__)); } // open mem handle int ret = xpu_ipc_open_memhandle(&g_vmem, mem_handle, 1); if (ret != XPU_SUCCESS) { throw std::runtime_error( "init_module ERROR: xpu_ipc_open_memhandle failed." + std::string(" ") + __FILE__ + ":" + std::to_string(__LINE__)); } } void create_and_map(unsigned long long device, size_t size, uintptr_t p_mem, uint64_t handle) { return; } void unmap_and_release(unsigned long long device, uintptr_t p_mem, size_t size, uint64_t handle) { return; } void my_xpu_memcpy(uintptr_t dst, uintptr_t src, uint64_t sz, int kind) { XPUMemcpyKind memcpy_kind = static_cast(kind); int ret = xpu_memcpy((void *)dst, (const void *)src, sz, memcpy_kind); if (ret != XPU_SUCCESS) { throw std::runtime_error("my_xpu_memcpy ERROR: xpu_memcpy failed." + std::string(" ") + __FILE__ + ":" + std::to_string(__LINE__)); } } std::tuple get_mem_info() { size_t allocated_bytes = g_allocated_offset.load(); size_t free_mem = 0; if (allocated_bytes >= g_size) { free_mem = 0; } else { free_mem = g_size - allocated_bytes; } return std::make_tuple(free_mem, g_size); } std::tuple try_lock_gpu() { bool prev_is_self = false; bool success = shm_worker->try_lock_gpu(prev_is_self); return std::make_tuple(success, prev_is_self); } void unlock_gpu() { shm_worker->unlock_gpu(); } } // namespace PYBIND11_MODULE(_kunlun_vxpu, m) { m.def("init_module", &init_module, py::arg("malloc_cb"), py::arg("free_cb"), py::arg("device_id")); m.def("create_and_map", &create_and_map); m.def("unmap_and_release", &unmap_and_release); m.def("my_xpu_memcpy", &my_xpu_memcpy); m.def("get_mem_info", &get_mem_info); m.def("try_lock_gpu", &try_lock_gpu); m.def("unlock_gpu", &unlock_gpu); }