Files
xc-llm-ascend/csrc/idle_offload/shm_worker.cpp
2025-12-26 07:37:35 +00:00

159 lines
5.3 KiB
C++

#include "shm_worker.h"
// === ShmWorker ===
ShmWorker::ShmWorker() {
std::string shm_name = get_shm_name();
int shm_fd = shm_open(shm_name.c_str(), O_RDWR, 0666);
if (shm_fd == -1) {
spdlog::error("Failed to open shared memory segment. Maybe the daemon is "
"not started.");
throw std::runtime_error("Failed to open shared memory segment");
}
void *ptr =
mmap(nullptr, SHM_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0);
if (ptr == MAP_FAILED) {
spdlog::error("Failed to map shared memory segment");
throw std::runtime_error("Failed to map shared memory segment");
}
close(shm_fd);
shm_helper = static_cast<ShmHelper*>(ptr);
}
ShmWorker::~ShmWorker() {
stop_heart_beat.store(true, std::memory_order_release);
heart_beat_thread.join();
munmap(shm_helper, SHM_SIZE);
}
bool ShmWorker::register_worker(int32_t tgid, uint64_t *out_shareable_handle,
uint64_t *out_vmem_size) {
this->tgid = tgid;
int slot = register_worker_shm();
if (slot == -1) {
return false;
}
*out_shareable_handle = shm_helper->shareable_handle;
*out_vmem_size = shm_helper->total_vmem_size;
stop_heart_beat.store(false, std::memory_order_release);
heart_beat_thread = std::thread(&ShmWorker::heart_beat_loop, this, slot);
return true;
}
void ShmWorker::heart_beat_loop(int slot) {
while (!stop_heart_beat.load(std::memory_order_acquire)) {
// update heart beat
int32_t shm_tgid =
shm_helper->heart_beats[slot].tgid.load(std::memory_order_acquire);
if (shm_tgid != tgid) {
spdlog::error(
"Maybe bug: Heart beat slot {} TGID mismatch (local: {}, shm: {})",
slot, tgid, shm_tgid);
// re-register
slot = register_worker_shm();
if (slot == -1) {
spdlog::error("TGID {} failed to re-register as worker", tgid);
throw std::runtime_error("Failed to re-register as worker");
}
}
uint64_t now = heartbeat_ts_us();
shm_helper->heart_beats[slot].timestamp.store(now,
std::memory_order_release);
usleep(heartbeat_us);
}
}
bool ShmWorker::lock_gpu() {
int retry_cnt = 0;
uint64_t old_flag = shm_helper->gpu_flag.load(std::memory_order_acquire);
while (true) {
if (unpack_lock_field(old_flag) == 0) {
uint64_t new_flag = pack_locked_tgid(tgid);
if (shm_helper->gpu_flag.compare_exchange_weak(old_flag, new_flag,
std::memory_order_acq_rel,
std::memory_order_acquire)) {
spdlog::info("TGID {} acquired GPU lock", tgid);
int32_t old_tgid = unpack_tgid_field(old_flag);
return old_tgid == tgid;
}
} else {
if (unpack_tgid_field(old_flag) == tgid) {
spdlog::info("TGID {} already holds the GPU lock", tgid);
return true;
}
}
// failed
++retry_cnt;
if (retry_cnt % 1000 == 0) {
spdlog::info(
"TGID {} waiting for GPU lock, current lock holder TGID {}", tgid,
unpack_tgid_field(old_flag));
}
usleep(1000);
old_flag = shm_helper->gpu_flag.load(std::memory_order_acquire);
}
}
void ShmWorker::unlock_gpu() {
uint64_t old_flag = shm_helper->gpu_flag.load(std::memory_order_acquire);
if (unpack_tgid_field(old_flag) != tgid) {
spdlog::warn("previous gpu flag {} does not match expected locked flag for "
"TGID {}. This may be a bug, unless during startup.",
old_flag, tgid);
} else {
uint64_t new_flag = pack_unlocked_tgid(tgid);
shm_helper->gpu_flag.store(new_flag, std::memory_order_release);
spdlog::info("TGID {} released GPU lock", tgid);
}
}
uint64_t ShmWorker::make_request(uint32_t type, uint64_t parameter) {
while (true) {
uint64_t expected = ShmHelper::READY_STATE_NO_REQUEST;
if (shm_helper->req_ready.load(std::memory_order_acquire) ==
ShmHelper::READY_STATE_NO_REQUEST) {
// set ready to 1
if (shm_helper->req_ready.compare_exchange_weak(
expected, ShmHelper::READY_STATE_PREPARING_REQUEST,
std::memory_order_acq_rel, std::memory_order_acquire)) {
break;
}
}
usleep(1000);
}
// prepare request
shm_helper->request.type = type;
shm_helper->request.tgid = tgid;
shm_helper->request.parameter = parameter;
// set ready
shm_helper->req_ready.store(ShmHelper::READY_STATE_REQUEST_READY,
std::memory_order_release);
// wait until processed
while (shm_helper->req_ready.load(std::memory_order_acquire) !=
ShmHelper::READY_STATE_REQUEST_PROCESSED) {
usleep(1000);
}
// get response
uint64_t response = shm_helper->request.response;
// set ready to 0
shm_helper->req_ready.store(ShmHelper::READY_STATE_NO_REQUEST,
std::memory_order_release);
return response;
}
int ShmWorker::register_worker_shm() {
uint64_t slot = make_request(ShmHelper::REQUEST_TYPE_REGISTER_WORKER, tgid);
spdlog::info("TGID {} registered as worker in slot {}", tgid, slot);
if (slot == static_cast<uint64_t>(-1) || slot >= MAX_WORKERS) {
spdlog::error("TGID {} failed to register as worker", tgid);
throw std::runtime_error("Failed to register as worker");
}
return static_cast<int>(slot);
}