fix custom_allreduce namespace (#6039)
This commit is contained in:
@@ -18,11 +18,11 @@ init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_dat
|
||||
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
vllm::Signal* ipc_ptrs[8];
|
||||
sglang::Signal* ipc_ptrs[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
|
||||
ipc_ptrs[i] = reinterpret_cast<sglang::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
return (fptr_t) new sglang::CustomAllreduce(
|
||||
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
@@ -98,15 +98,15 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
delete reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
}
|
||||
|
||||
int64_t meta_size() {
|
||||
return sizeof(vllm::Signal);
|
||||
return sizeof(sglang::Signal);
|
||||
}
|
||||
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||
void* ipc_ptrs[8];
|
||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||
@@ -117,7 +117,7 @@ void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
||||
return std::make_tuple(bytes, offsets);
|
||||
@@ -126,7 +126,7 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
std::vector<std::string> bytes;
|
||||
bytes.reserve(handles.size());
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
|
||||
Reference in New Issue
Block a user