fix custom_allreduce namespace (#6039)

This commit is contained in:
Xiaoyu Zhang
2025-05-07 10:13:06 +08:00
committed by GitHub
parent 8a828666a3
commit d25398cbc8
4 changed files with 24 additions and 24 deletions

View File

@@ -18,11 +18,11 @@ init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_dat
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
vllm::Signal* ipc_ptrs[8];
sglang::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
ipc_ptrs[i] = reinterpret_cast<sglang::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new vllm::CustomAllreduce(
return (fptr_t) new sglang::CustomAllreduce(
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
}
@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
@@ -98,15 +98,15 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_
}
void dispose(fptr_t _fa) {
delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete reinterpret_cast<sglang::CustomAllreduce*>(_fa);
}
int64_t meta_size() {
return sizeof(vllm::Signal);
return sizeof(sglang::Signal);
}
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
@@ -117,7 +117,7 @@ void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
// Use vector<int64_t> to represent byte data for python binding compatibility.
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
std::vector<int64_t> bytes(handle.begin(), handle.end());
return std::make_tuple(bytes, offsets);
@@ -126,7 +126,7 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
// Use vector<int64_t> to represent byte data for python binding compatibility.
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
std::vector<std::string> bytes;
bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {