fix custom_allreduce namespace (#6039)
This commit is contained in:
@@ -18,11 +18,11 @@ init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_dat
|
|||||||
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||||
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
||||||
|
|
||||||
vllm::Signal* ipc_ptrs[8];
|
sglang::Signal* ipc_ptrs[8];
|
||||||
for (int i = 0; i < world_size; i++) {
|
for (int i = 0; i < world_size; i++) {
|
||||||
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
|
ipc_ptrs[i] = reinterpret_cast<sglang::Signal*>(fake_ipc_ptrs[i]);
|
||||||
}
|
}
|
||||||
return (fptr_t) new vllm::CustomAllreduce(
|
return (fptr_t) new sglang::CustomAllreduce(
|
||||||
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
|
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
|||||||
* copied into _reg_buffer.
|
* copied into _reg_buffer.
|
||||||
*/
|
*/
|
||||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||||
|
|
||||||
@@ -98,15 +98,15 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_
|
|||||||
}
|
}
|
||||||
|
|
||||||
void dispose(fptr_t _fa) {
|
void dispose(fptr_t _fa) {
|
||||||
delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
delete reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t meta_size() {
|
int64_t meta_size() {
|
||||||
return sizeof(vllm::Signal);
|
return sizeof(sglang::Signal);
|
||||||
}
|
}
|
||||||
|
|
||||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||||
void* ipc_ptrs[8];
|
void* ipc_ptrs[8];
|
||||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||||
@@ -117,7 +117,7 @@ void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
|||||||
|
|
||||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
|
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||||
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
||||||
return std::make_tuple(bytes, offsets);
|
return std::make_tuple(bytes, offsets);
|
||||||
@@ -126,7 +126,7 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
|
|||||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||||
void register_graph_buffers(
|
void register_graph_buffers(
|
||||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
std::vector<std::string> bytes;
|
std::vector<std::string> bytes;
|
||||||
bytes.reserve(handles.size());
|
bytes.reserve(handles.size());
|
||||||
for (int i = 0; i < handles.size(); i++) {
|
for (int i = 0; i < handles.size(); i++) {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace sglang {
|
||||||
|
|
||||||
constexpr int kMaxBlocks = 36;
|
constexpr int kMaxBlocks = 36;
|
||||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||||
@@ -483,7 +483,7 @@ class CustomAllreduce {
|
|||||||
/**
|
/**
|
||||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||||
a template instantiation:
|
a template instantiation:
|
||||||
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
* template void sglang::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
||||||
half *, int, int, int);
|
half *, int, int, int);
|
||||||
*/
|
*/
|
||||||
} // namespace vllm
|
} // namespace sglang
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
|||||||
for (int i = 0; i < world_size; i++) {
|
for (int i = 0; i < world_size; i++) {
|
||||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
|
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
|
||||||
}
|
}
|
||||||
return (fptr_t) new vllm::CustomAllreduce(
|
return (fptr_t) new sglang::CustomAllreduce(
|
||||||
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
reinterpret_cast<sglang::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
|||||||
|
|
||||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||||
hipStream_t stream) {
|
hipStream_t stream) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
TORCH_CHECK(_is_weak_contiguous(out));
|
TORCH_CHECK(_is_weak_contiguous(out));
|
||||||
switch (out.scalar_type()) {
|
switch (out.scalar_type()) {
|
||||||
case at::ScalarType::Float: {
|
case at::ScalarType::Float: {
|
||||||
@@ -110,22 +110,22 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void dispose(fptr_t _fa) {
|
void dispose(fptr_t _fa) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
delete fa;
|
delete fa;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t meta_size() { return sizeof(vllm::Signal); }
|
int64_t meta_size() { return sizeof(sglang::Signal); }
|
||||||
|
|
||||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t>& offsets) {
|
const std::vector<int64_t>& offsets) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
fptr_t _fa) {
|
fptr_t _fa) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||||
auto options =
|
auto options =
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||||
@@ -137,7 +137,7 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
|||||||
|
|
||||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||||
const std::vector<std::vector<int64_t>>& offsets) {
|
const std::vector<std::vector<int64_t>>& offsets) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||||
fa->register_graph_buffers(handles, offsets);
|
fa->register_graph_buffers(handles, offsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ typedef __hip_bfloat16 nv_bfloat16;
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
namespace vllm {
|
namespace sglang {
|
||||||
|
|
||||||
constexpr int kMaxBlocks = 64;
|
constexpr int kMaxBlocks = 64;
|
||||||
// note: we don't want to use atomics for signals because peer atomics are no
|
// note: we don't want to use atomics for signals because peer atomics are no
|
||||||
@@ -572,11 +572,11 @@ class CustomAllreduce {
|
|||||||
CUDACHECK(hipIpcCloseMemHandle(ptr));
|
CUDACHECK(hipIpcCloseMemHandle(ptr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}; // namespace vllm
|
}; // namespace sglang
|
||||||
/**
|
/**
|
||||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||||
a template instantiation:
|
a template instantiation:
|
||||||
* template void vllm::CustomAllreduce::allreduce<half>(hipStream_t, half *,
|
* template void sglang::CustomAllreduce::allreduce<half>(hipStream_t, half *,
|
||||||
half *, int, int, int);
|
half *, int, int, int);
|
||||||
*/
|
*/
|
||||||
} // namespace vllm
|
} // namespace sglang
|
||||||
|
|||||||
Reference in New Issue
Block a user