sgl-kernel transfer custom allreduce from trt kernel to vllm kernel (#5079)

This commit is contained in:
Yi Zhang
2025-04-06 05:23:20 +08:00
committed by GitHub
parent 0d99adb715
commit bcbbf519f9
10 changed files with 692 additions and 937 deletions

View File

@@ -21,6 +21,7 @@ limitations under the License.
#include <torch/library.h>
#include <torch/torch.h>
#include <tuple>
#include <vector>
#define _CONCAT(A, B) A##B
@@ -63,18 +64,14 @@ void register_graph_buffers(
torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#else
// TRTLLM custom allreduce
fptr_t init_custom_ar(
int64_t rank_id,
int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
// custom allreduce
fptr_t
init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
int64_t meta_size();
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
#endif

View File

@@ -1,109 +0,0 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp16.h>
#include <stdint.h>
#include <torch/all.h>
namespace trt_llm {
constexpr size_t WARP_SIZE = 32;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 32;
constexpr size_t MAX_RANKS_PER_NODE = 8;
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
enum class AllReduceStrategyType : int8_t {
RING = 0,
ONESHOT = 1,
TWOSHOT = 2,
AUTO = 3,
};
struct RankData {
void* ptrs[MAX_RANKS_PER_NODE];
};
struct AllReduceParams {
size_t elts_size;
size_t elts_total;
size_t elts_per_rank;
size_t elts_per_block;
size_t rank_offset;
size_t ranks_per_node, rank, local_rank;
uint32_t barrier_flag;
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE];
RankData* peer_comm_buffer_ptrs;
void* local_input_buffer_ptr;
void* local_output_buffer_ptr;
bool is_capturing;
};
inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
if (world_size <= 2) {
return 16 * 1024 * 1024;
}
return 8 * 1024 * 1024;
}
inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) {
const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size);
if (message_size > maxWorkspaceSize) {
assert(false && "Custom allreduce do not ring currently");
return AllReduceStrategyType::RING;
}
if (world_size <= 2) {
return AllReduceStrategyType::ONESHOT;
}
if (world_size <= 4) {
if (message_size < 1 * 1024 * 1024) {
return AllReduceStrategyType::ONESHOT;
}
return AllReduceStrategyType::TWOSHOT;
}
if (message_size < 512 * 1024) {
return AllReduceStrategyType::ONESHOT;
}
return AllReduceStrategyType::TWOSHOT;
}
void trtCustomAllReduce(
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream);
} // namespace trt_llm