fix custom_allreduce namespace (#6039)
This commit is contained in:
@@ -29,8 +29,8 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
return (fptr_t) new sglang::CustomAllreduce(
|
||||
reinterpret_cast<sglang::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
hipStream_t stream) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
@@ -110,22 +110,22 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int64_t meta_size() { return sizeof(vllm::Signal); }
|
||||
int64_t meta_size() { return sizeof(sglang::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
@@ -137,7 +137,7 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user