#include "shm_worker.h" // === ShmWorker === ShmWorker::ShmWorker() { std::string shm_name = get_shm_name(); int shm_fd = shm_open(shm_name.c_str(), O_RDWR, 0666); if (shm_fd == -1) { spdlog::error("Failed to open shared memory segment. Maybe the daemon is " "not started."); throw std::runtime_error("Failed to open shared memory segment"); } void *ptr = mmap(nullptr, SHM_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); if (ptr == MAP_FAILED) { spdlog::error("Failed to map shared memory segment"); throw std::runtime_error("Failed to map shared memory segment"); } close(shm_fd); shm_helper = static_cast(ptr); } ShmWorker::~ShmWorker() { stop_heart_beat.store(true, std::memory_order_release); heart_beat_thread.join(); munmap(shm_helper, SHM_SIZE); } bool ShmWorker::register_worker(int32_t tgid, uint64_t *out_shareable_handle, uint64_t *out_vmem_size) { this->tgid = tgid; int slot = register_worker_shm(); if (slot == -1) { return false; } *out_shareable_handle = shm_helper->shareable_handle; *out_vmem_size = shm_helper->total_vmem_size; stop_heart_beat.store(false, std::memory_order_release); heart_beat_thread = std::thread(&ShmWorker::heart_beat_loop, this, slot); return true; } void ShmWorker::heart_beat_loop(int slot) { while (!stop_heart_beat.load(std::memory_order_acquire)) { // update heart beat int32_t shm_tgid = shm_helper->heart_beats[slot].tgid.load(std::memory_order_acquire); if (shm_tgid != tgid) { spdlog::error( "Maybe bug: Heart beat slot {} TGID mismatch (local: {}, shm: {})", slot, tgid, shm_tgid); // re-register slot = register_worker_shm(); if (slot == -1) { spdlog::error("TGID {} failed to re-register as worker", tgid); throw std::runtime_error("Failed to re-register as worker"); } } uint64_t now = heartbeat_ts_us(); shm_helper->heart_beats[slot].timestamp.store(now, std::memory_order_release); usleep(heartbeat_us); } } bool ShmWorker::lock_gpu() { int retry_cnt = 0; uint64_t old_flag = shm_helper->gpu_flag.load(std::memory_order_acquire); while (true) { if (unpack_lock_field(old_flag) == 0) { uint64_t new_flag = pack_locked_tgid(tgid); if (shm_helper->gpu_flag.compare_exchange_weak(old_flag, new_flag, std::memory_order_acq_rel, std::memory_order_acquire)) { spdlog::info("TGID {} acquired GPU lock", tgid); int32_t old_tgid = unpack_tgid_field(old_flag); return old_tgid == tgid; } } else { if (unpack_tgid_field(old_flag) == tgid) { spdlog::info("TGID {} already holds the GPU lock", tgid); return true; } } // failed ++retry_cnt; if (retry_cnt % 1000 == 0) { spdlog::info( "TGID {} waiting for GPU lock, current lock holder TGID {}", tgid, unpack_tgid_field(old_flag)); } usleep(1000); old_flag = shm_helper->gpu_flag.load(std::memory_order_acquire); } } void ShmWorker::unlock_gpu() { uint64_t old_flag = shm_helper->gpu_flag.load(std::memory_order_acquire); if (unpack_tgid_field(old_flag) != tgid) { spdlog::warn("previous gpu flag {} does not match expected locked flag for " "TGID {}. This may be a bug, unless during startup.", old_flag, tgid); } else { uint64_t new_flag = pack_unlocked_tgid(tgid); shm_helper->gpu_flag.store(new_flag, std::memory_order_release); spdlog::info("TGID {} released GPU lock", tgid); } } uint64_t ShmWorker::make_request(uint32_t type, uint64_t parameter) { while (true) { uint64_t expected = ShmHelper::READY_STATE_NO_REQUEST; if (shm_helper->req_ready.load(std::memory_order_acquire) == ShmHelper::READY_STATE_NO_REQUEST) { // set ready to 1 if (shm_helper->req_ready.compare_exchange_weak( expected, ShmHelper::READY_STATE_PREPARING_REQUEST, std::memory_order_acq_rel, std::memory_order_acquire)) { break; } } usleep(1000); } // prepare request shm_helper->request.type = type; shm_helper->request.tgid = tgid; shm_helper->request.parameter = parameter; // set ready shm_helper->req_ready.store(ShmHelper::READY_STATE_REQUEST_READY, std::memory_order_release); // wait until processed while (shm_helper->req_ready.load(std::memory_order_acquire) != ShmHelper::READY_STATE_REQUEST_PROCESSED) { usleep(1000); } // get response uint64_t response = shm_helper->request.response; // set ready to 0 shm_helper->req_ready.store(ShmHelper::READY_STATE_NO_REQUEST, std::memory_order_release); return response; } int ShmWorker::register_worker_shm() { uint64_t slot = make_request(ShmHelper::REQUEST_TYPE_REGISTER_WORKER, tgid); spdlog::info("TGID {} registered as worker in slot {}", tgid, slot); if (slot == static_cast(-1) || slot >= MAX_WORKERS) { spdlog::error("TGID {} failed to register as worker", tgid); throw std::runtime_error("Failed to register as worker"); } return static_cast(slot); }