#include "shm_worker.h" ShmWorker::ShmWorker() { gpu_slot = -1; worker_id = -1; std::string shm_name = get_shm_name(); int shm_fd = shm_open(shm_name.c_str(), O_RDWR, 0666); if (shm_fd == -1) { spdlog::error( "Failed to open shared memory. Maybe the daemon is not started."); throw std::runtime_error("Failed to open shared memory"); } void *ptr = mmap(nullptr, SHM_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); if (ptr == MAP_FAILED) { spdlog::error("Failed to map shared memory"); throw std::runtime_error("Failed to map shared memory"); } close(shm_fd); shm_helper = static_cast(ptr); } ShmWorker::~ShmWorker() { stop_heart_beat.store(true, std::memory_order_release); heart_beat_thread.join(); munmap(shm_helper, SHM_SIZE); } bool ShmWorker::register_worker(int device_id, XPUIpcMemHandle *out_shareable_handle, uint64_t *out_vmem_size) { if (device_id < 0 || device_id >= MAX_DEVICES) { spdlog::error("Invalid device ID {}", device_id); throw std::runtime_error("Invalid device ID"); } int slot = register_worker_shm(); if (slot < 0) { spdlog::error("Failed to register as worker"); return false; } // get GPU info uint32_t gpu_pci_addr = get_device_pci_addr(device_id); for (int i = 0; i < MAX_DEVICES; ++i) { if (shm_helper->gpu_pci_addr[i] == gpu_pci_addr) { this->gpu_slot = i; break; } } if (this->gpu_slot == -1) { spdlog::error("GPU with PCI address {:x} not found in manager", gpu_pci_addr); return false; } stop_heart_beat.store(false, std::memory_order_release); heart_beat_thread = std::thread(&ShmWorker::heart_beat_loop, this, slot); memcpy(out_shareable_handle, &shm_helper->xpu_mem_handle[this->gpu_slot], sizeof(XPUIpcMemHandle)); *out_vmem_size = shm_helper->vmem_size[this->gpu_slot]; return true; } void ShmWorker::heart_beat_loop(int slot) { while (!stop_heart_beat.load(std::memory_order_acquire)) { // update heart beat int32_t shm_worker_id = shm_helper->heart_beats[slot].worker_id.load(std::memory_order_acquire); if (shm_worker_id != worker_id) { spdlog::error("Maybe bug: Heart beat slot {} worker_id mismatch (local: " "{}, shm: {})", slot, worker_id, shm_worker_id); // re-register slot = register_worker_shm(); if (slot < 0) { spdlog::error("Failed to re-register as worker"); throw std::runtime_error("Failed to re-register as worker"); } } uint64_t now = heartbeat_ts_us(); shm_helper->heart_beats[slot].timestamp.store(now, std::memory_order_release); usleep(heartbeat_us); } } bool ShmWorker::try_lock_gpu(bool &out_self_hold) { static int retry_cnt = 0; uint64_t old_flag = shm_helper->gpu_flag[gpu_slot].load(std::memory_order_acquire); if (unpack_lock_field(old_flag) == 0) { // free uint64_t new_flag = pack_locked_worker_id(worker_id); if (shm_helper->gpu_flag[gpu_slot].compare_exchange_weak( old_flag, new_flag, std::memory_order_acq_rel, std::memory_order_acquire)) { spdlog::info("Worker {} acquired GPU {} lock", worker_id, gpu_slot); int32_t prev_worker_id = unpack_worker_id_field(old_flag); out_self_hold = prev_worker_id == worker_id; retry_cnt = 0; return true; } } else { // locked if (unpack_worker_id_field(old_flag) == worker_id) { spdlog::info("Worker {} already holds the GPU {} lock", worker_id, gpu_slot); out_self_hold = true; retry_cnt = 0; return true; } } // failed if (++retry_cnt % 2000 == 0) { spdlog::info("Worker {} trying to acquire GPU {} lock, current lock holder " "is worker {}", worker_id, gpu_slot, unpack_worker_id_field(old_flag)); } out_self_hold = false; return false; } bool ShmWorker::lock_gpu(bool &out_self_hold) { while (true) { if (try_lock_gpu(out_self_hold)) { return true; } // failed usleep(1000); } } void ShmWorker::unlock_gpu() { uint64_t old_flag = shm_helper->gpu_flag[gpu_slot].load(std::memory_order_acquire); if (unpack_worker_id_field(old_flag) != worker_id) { spdlog::info("Worker {} does not hold GPU {} lock", worker_id, gpu_slot); } else { uint64_t new_flag = pack_unlocked_worker_id(worker_id); shm_helper->gpu_flag[gpu_slot].store(new_flag, std::memory_order_release); spdlog::info("Worker {} released GPU {} lock", worker_id, gpu_slot); } } uint64_t ShmWorker::make_request(uint32_t type, uint64_t parameter) { while (true) { uint64_t expected = ShmHelper::READY_STATE_NO_REQUEST; if (shm_helper->req_ready.load(std::memory_order_acquire) == ShmHelper::READY_STATE_NO_REQUEST) { // set ready to 1 if (shm_helper->req_ready.compare_exchange_weak( expected, ShmHelper::READY_STATE_PREPARING_REQUEST, std::memory_order_acq_rel, std::memory_order_acquire)) { break; } } usleep(1000); } // prepare request shm_helper->request.type = type; shm_helper->request.worker_id = worker_id; shm_helper->request.parameter = parameter; // set ready shm_helper->req_ready.store(ShmHelper::READY_STATE_REQUEST_READY, std::memory_order_release); // wait until processed while (shm_helper->req_ready.load(std::memory_order_acquire) != ShmHelper::READY_STATE_REQUEST_PROCESSED) { usleep(1000); } // get response uint64_t response = shm_helper->request.response; // set ready to 0 shm_helper->req_ready.store(ShmHelper::READY_STATE_NO_REQUEST, std::memory_order_release); return response; } int ShmWorker::register_worker_shm() { uint64_t resp = make_request(ShmHelper::REQUEST_TYPE_REGISTER_WORKER, -1); // response = slot | worker_id int32_t slot = static_cast(resp >> 32); int32_t worker_id = static_cast(resp & 0xFFFFFFFF); spdlog::info("Registered as worker {} & slot {}", worker_id, slot); this->worker_id = worker_id; return slot; }