Rename files in sgl kernel to avoid nested folder structure (#4213)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
2
.github/workflows/release-pypi-kernel.yml
vendored
2
.github/workflows/release-pypi-kernel.yml
vendored
@@ -5,7 +5,7 @@ on:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- sgl-kernel/src/sgl-kernel/version.py
|
||||
- sgl-kernel/python/sgl_kernel/version.py
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
|
||||
4
.github/workflows/release-whl-kernel.yml
vendored
4
.github/workflows/release-whl-kernel.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- sgl-kernel/src/sgl-kernel/version.py
|
||||
- sgl-kernel/python/sgl_kernel/version.py
|
||||
|
||||
jobs:
|
||||
build-wheels:
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
id: set_tag_name
|
||||
run: |
|
||||
if [ -z "${{ inputs.tag_name }}" ]; then
|
||||
TAG_NAME="v$(cat sgl-kernel/src/sgl-kernel/version.py | cut -d'"' -f2)"
|
||||
TAG_NAME="v$(cat sgl-kernel/python/sgl_kernel/version.py | cut -d'"' -f2)"
|
||||
echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "tag_name=${{ inputs.tag_name }}" >> $GITHUB_OUTPUT
|
||||
|
||||
@@ -75,42 +75,42 @@ else:
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return sgl_kernel.ops.allreduce.init_custom_ar(
|
||||
return sgl_kernel.allreduce.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.ops.allreduce.all_reduce_reg(fa, inp, out)
|
||||
sgl_kernel.allreduce.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
sgl_kernel.ops.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.ops.allreduce.dispose(fa)
|
||||
sgl_kernel.allreduce.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return sgl_kernel.ops.allreduce.meta_size()
|
||||
return sgl_kernel.allreduce.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return sgl_kernel.ops.allreduce.register_buffer(fa, t, handles, offsets)
|
||||
return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return sgl_kernel.ops.allreduce.get_graph_buffer_ipc_meta(fa)
|
||||
return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.ops.allreduce.register_graph_buffers(fa, handles, offsets)
|
||||
sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return sgl_kernel.ops.allreduce.allocate_meta_buffer(size)
|
||||
return sgl_kernel.allreduce.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.ops.allreduce.get_meta_buffer_ipc_handle(inp)
|
||||
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
else:
|
||||
# TRTLLM custom allreduce
|
||||
@@ -123,7 +123,7 @@ else:
|
||||
barrier_in: List[int],
|
||||
barrier_out: List[int],
|
||||
) -> int:
|
||||
return sgl_kernel.ops.init_custom_reduce(
|
||||
return sgl_kernel.init_custom_reduce(
|
||||
rank_id,
|
||||
world_size,
|
||||
rank_data_base,
|
||||
@@ -134,15 +134,15 @@ else:
|
||||
)
|
||||
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.ops.custom_reduce(fa, inp, out)
|
||||
sgl_kernel.custom_reduce(fa, inp, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.ops.custom_dispose(fa)
|
||||
sgl_kernel.custom_dispose(fa)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
|
||||
return sgl_kernel.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
|
||||
sgl_kernel.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
@@ -38,12 +38,12 @@ test: ## Run all tests
|
||||
|
||||
format: check-deps ## Format all source files
|
||||
@echo "Formatting source files..."
|
||||
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i
|
||||
@find src tests -name '*.py' | xargs isort
|
||||
@find src tests -name '*.py' | xargs black
|
||||
@find csrc tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i
|
||||
@find python tests -name '*.py' | xargs isort
|
||||
@find python tests -name '*.py' | xargs black
|
||||
@pre-commit run --all-files
|
||||
|
||||
FILES_TO_UPDATE = src/sgl-kernel/version.py \
|
||||
FILES_TO_UPDATE = python/sgl_kernel/version.py \
|
||||
pyproject.toml
|
||||
|
||||
update: ## Update version numbers across project files. Usage: make update <new_version>
|
||||
@@ -51,7 +51,7 @@ update: ## Update version numbers across project files. Usage: make update <new_
|
||||
echo "Version required. Usage: make update <new_version>"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@OLD_VERSION=$$(grep "version" src/sgl-kernel/version.py | cut -d '"' -f2); \
|
||||
@OLD_VERSION=$$(grep "version" python/sgl_kernel/version.py | cut -d '"' -f2); \
|
||||
NEW_VERSION=$(filter-out $@,$(MAKECMDGOALS)); \
|
||||
echo "Updating version from $$OLD_VERSION to $$NEW_VERSION"; \
|
||||
for file in $(FILES_TO_UPDATE); do \
|
||||
|
||||
@@ -45,12 +45,11 @@ Third-party libraries:
|
||||
|
||||
Steps to add a new kernel:
|
||||
|
||||
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
|
||||
2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h)
|
||||
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
|
||||
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
|
||||
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
|
||||
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
|
||||
1. Implement the kernel in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc)
|
||||
2. Expose the interface in [include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_ops.h)
|
||||
3. Create torch extension in [csrc/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/torch_extension.cc)
|
||||
4. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
|
||||
5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel)
|
||||
|
||||
### Build & Install
|
||||
|
||||
@@ -72,4 +71,4 @@ The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, t
|
||||
|
||||
### Release new version
|
||||
|
||||
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py)
|
||||
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/version.py)
|
||||
|
||||
@@ -16,33 +16,9 @@ limitations under the License.
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernels_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
/*
|
||||
* From csrc/activation
|
||||
*/
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
||||
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||
|
||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||
|
||||
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||
|
||||
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
@@ -67,6 +43,30 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
*/
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
||||
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||
|
||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||
|
||||
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||
|
||||
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
@@ -93,6 +93,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
|
||||
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
|
||||
|
||||
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
|
||||
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
|
||||
|
||||
m.def(
|
||||
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
|
||||
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
|
||||
@@ -171,9 +174,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
|
||||
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_kernels)
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
@@ -16,9 +16,9 @@ limitations under the License.
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernels_ops.h"
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
@@ -36,18 +36,6 @@ limitations under the License.
|
||||
|
||||
using fptr_t = int64_t;
|
||||
|
||||
/*
|
||||
* From csrc/activation
|
||||
*/
|
||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
|
||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void gemma_fused_add_rmsnorm(
|
||||
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
@@ -88,6 +76,30 @@ void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* From csrc/attention
|
||||
*/
|
||||
void lightning_attention_decode(
|
||||
const torch::Tensor& q,
|
||||
const torch::Tensor& k,
|
||||
const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv,
|
||||
const torch::Tensor& slope,
|
||||
torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
|
||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void gemma_fused_add_rmsnorm(
|
||||
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
@@ -120,6 +132,7 @@ void sgl_per_token_group_quant_fp8(
|
||||
double fp8_min,
|
||||
double fp8_max);
|
||||
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
|
||||
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
|
||||
void cublas_grouped_gemm(
|
||||
const std::vector<torch::Tensor>& inputs,
|
||||
const std::vector<torch::Tensor>& weights,
|
||||
@@ -254,18 +267,3 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* Other
|
||||
*/
|
||||
void lightning_attention_decode(
|
||||
const torch::Tensor& q,
|
||||
const torch::Tensor& k,
|
||||
const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv,
|
||||
const torch::Tensor& slope,
|
||||
torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
|
||||
// sgl_per_token_quant_fp8
|
||||
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
|
||||
@@ -20,10 +20,6 @@ dependencies = []
|
||||
"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel"
|
||||
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"sgl_kernel" = "src/sgl-kernel"}
|
||||
packages = ["sgl_kernel", "sgl_kernel.ops", "sgl_kernel.csrc"]
|
||||
|
||||
[tool.wheel]
|
||||
exclude = [
|
||||
"dist*",
|
||||
|
||||
@@ -9,7 +9,10 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
|
||||
mode=ctypes.RTLD_GLOBAL,
|
||||
)
|
||||
|
||||
from sgl_kernel.ops.activation import (
|
||||
from sgl_kernel import common_ops
|
||||
from sgl_kernel.allreduce import *
|
||||
from sgl_kernel.attention import lightning_attention_decode
|
||||
from sgl_kernel.elementwise import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
@@ -19,9 +22,7 @@ from sgl_kernel.ops.activation import (
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.ops.allreduce import *
|
||||
from sgl_kernel.ops.attention import lightning_attention_decode
|
||||
from sgl_kernel.ops.gemm import (
|
||||
from sgl_kernel.gemm import (
|
||||
bmm_fp8,
|
||||
cublas_grouped_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
@@ -31,15 +32,15 @@ from sgl_kernel.ops.gemm import (
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
from sgl_kernel.ops.moe import moe_align_block_size
|
||||
from sgl_kernel.ops.sampling import (
|
||||
from sgl_kernel.moe import moe_align_block_size
|
||||
from sgl_kernel.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
top_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.ops.speculative import (
|
||||
from sgl_kernel.speculative import (
|
||||
build_tree_kernel,
|
||||
build_tree_kernel_efficient,
|
||||
tree_speculative_sampling_target_only,
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
|
||||
if torch.version.hip is not None:
|
||||
@@ -13,49 +12,49 @@ if torch.version.hip is not None:
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
return torch.ops.sgl_kernel.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out)
|
||||
torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
torch.ops.sgl_kernel.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernels.meta_size()
|
||||
return torch.ops.sgl_kernel.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets)
|
||||
return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernels.allocate_meta_buffer(size)
|
||||
return torch.ops.sgl_kernel.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp)
|
||||
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
else:
|
||||
# TRTLLM custom allreduce
|
||||
def init_custom_reduce(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
):
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
return torch.ops.sgl_kernel.init_custom_ar(
|
||||
rank_id,
|
||||
num_devices,
|
||||
rank_data,
|
||||
@@ -66,13 +65,13 @@ else:
|
||||
)
|
||||
|
||||
def custom_dispose(fa):
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
torch.ops.sgl_kernel.dispose(fa)
|
||||
|
||||
def custom_reduce(fa, inp, out):
|
||||
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
|
||||
torch.ops.sgl_kernel.all_reduce(fa, inp, out)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa):
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(fa, handles, offsets):
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
|
||||
@@ -1,8 +1,7 @@
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
|
||||
|
||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
torch.ops.sgl_kernels.lightning_attention_decode(
|
||||
torch.ops.sgl_kernel.lightning_attention_decode(
|
||||
q, k, v, past_kv, slope, output, new_kv
|
||||
)
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops.utils import get_cuda_stream
|
||||
from sgl_kernel.utils import get_cuda_stream
|
||||
|
||||
|
||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||
@@ -15,14 +14,14 @@ def rmsnorm(
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps)
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
@@ -33,14 +32,14 @@ def gemma_rmsnorm(
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernels.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.gemma_fused_add_rmsnorm(
|
||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm(
|
||||
input, residual, weight, eps, get_cuda_stream()
|
||||
)
|
||||
|
||||
@@ -66,7 +65,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernels.silu_and_mul(out, input, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
@@ -81,7 +80,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
@@ -96,7 +95,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernels.gelu_and_mul(out, input, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
@@ -141,7 +140,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
positions = positions.int()
|
||||
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
|
||||
q=query.view(query.shape[0], -1, head_size),
|
||||
k=key.view(key.shape[0], -1, head_size),
|
||||
q_rope=query.view(query.shape[0], -1, head_size),
|
||||
@@ -1,12 +1,11 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops.utils import _get_cache_buf, get_cuda_stream
|
||||
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernels.int8_scaled_mm(
|
||||
return torch.ops.sgl_kernel.int8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
@@ -17,7 +16,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||
return torch.ops.sgl_kernels.fp8_blockwise_scaled_mm(
|
||||
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
@@ -27,7 +26,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||
|
||||
|
||||
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernels.fp8_scaled_mm(
|
||||
return torch.ops.sgl_kernel.fp8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
@@ -46,7 +45,7 @@ def _bmm_fp8_internal(
|
||||
B_scale: torch.Tensor,
|
||||
) -> None:
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernels.bmm_fp8(
|
||||
torch.ops.sgl_kernel.bmm_fp8(
|
||||
A,
|
||||
B,
|
||||
D,
|
||||
@@ -86,7 +85,7 @@ def sgl_per_token_group_quant_fp8(
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8(
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
||||
)
|
||||
|
||||
@@ -97,7 +96,7 @@ def sgl_per_tensor_quant_fp8(
|
||||
output_s: torch.Tensor,
|
||||
is_static: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
|
||||
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
|
||||
|
||||
|
||||
def cublas_grouped_gemm(
|
||||
@@ -110,7 +109,7 @@ def cublas_grouped_gemm(
|
||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
||||
), "Inputs/weights/outputs should not be empty!"
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernels.cublas_grouped_gemm(
|
||||
torch.ops.sgl_kernel.cublas_grouped_gemm(
|
||||
inputs,
|
||||
weights,
|
||||
outputs,
|
||||
@@ -125,4 +124,4 @@ def sgl_per_token_quant_fp8(
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.sgl_per_token_quant_fp8(input, output_q, output_s)
|
||||
torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)
|
||||
@@ -1,4 +1,3 @@
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
|
||||
|
||||
@@ -12,7 +11,7 @@ def moe_align_block_size(
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
):
|
||||
torch.ops.sgl_kernels.moe_align_block_size(
|
||||
torch.ops.sgl_kernel.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops.utils import _to_tensor_scalar_tuple, get_cuda_stream
|
||||
from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream
|
||||
|
||||
|
||||
def _top_k_renorm_probs_internal(
|
||||
@@ -13,7 +12,7 @@ def _top_k_renorm_probs_internal(
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
|
||||
torch.ops.sgl_kernel.top_k_renorm_probs_wrapper(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_k_arr,
|
||||
@@ -41,7 +40,7 @@ def _top_p_renorm_probs_internal(
|
||||
probs = probs.float()
|
||||
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernels.top_p_renorm_probs(
|
||||
torch.ops.sgl_kernel.top_p_renorm_probs(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_p_arr,
|
||||
@@ -76,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
torch.ops.sgl_kernels.top_p_sampling_from_probs(
|
||||
torch.ops.sgl_kernel.top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
@@ -122,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
|
||||
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
@@ -180,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
|
||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernels.min_p_sampling_from_probs(
|
||||
torch.ops.sgl_kernel.min_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
@@ -1,6 +1,5 @@
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops.utils import get_cuda_stream
|
||||
from sgl_kernel.utils import get_cuda_stream
|
||||
|
||||
|
||||
def tree_speculative_sampling_target_only(
|
||||
@@ -16,7 +15,7 @@ def tree_speculative_sampling_target_only(
|
||||
draft_probs: torch.Tensor,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.tree_speculative_sampling_target_only(
|
||||
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
@@ -45,7 +44,7 @@ def build_tree_kernel_efficient(
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.build_tree_kernel_efficient(
|
||||
torch.ops.sgl_kernel.build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
@@ -71,7 +70,7 @@ def build_tree_kernel(
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.build_tree_kernel(
|
||||
torch.ops.sgl_kernel.build_tree_kernel(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
@@ -48,16 +48,16 @@ def _get_version():
|
||||
return line.split("=")[1].strip().strip('"')
|
||||
|
||||
|
||||
operator_namespace = "sgl_kernels"
|
||||
operator_namespace = "sgl_kernel"
|
||||
cutlass_default = root / "3rdparty" / "cutlass"
|
||||
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
|
||||
flashinfer = root / "3rdparty" / "flashinfer"
|
||||
turbomind = root / "3rdparty" / "turbomind"
|
||||
include_dirs = [
|
||||
root / "include",
|
||||
root / "csrc",
|
||||
cutlass.resolve() / "include",
|
||||
cutlass.resolve() / "tools" / "util" / "include",
|
||||
root / "src" / "sgl-kernel" / "include",
|
||||
root / "src" / "sgl-kernel" / "csrc",
|
||||
flashinfer.resolve() / "include",
|
||||
flashinfer.resolve() / "include" / "gemm",
|
||||
flashinfer.resolve() / "csrc",
|
||||
@@ -96,21 +96,21 @@ nvcc_flags_fp8 = [
|
||||
]
|
||||
|
||||
sources = [
|
||||
"src/sgl-kernel/torch_extension.cc",
|
||||
"src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu",
|
||||
"src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu",
|
||||
"src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu",
|
||||
"src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu",
|
||||
"src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu",
|
||||
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/speculative/eagle_utils.cu",
|
||||
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu",
|
||||
"csrc/allreduce/trt_reduce_internal.cu",
|
||||
"csrc/allreduce/trt_reduce_kernel.cu",
|
||||
"csrc/attention/lightning_attention_decode_kernel.cu",
|
||||
"csrc/elementwise/fused_add_rms_norm_kernel.cu",
|
||||
"csrc/gemm/cublas_grouped_gemm.cu",
|
||||
"csrc/gemm/fp8_gemm_kernel.cu",
|
||||
"csrc/gemm/fp8_blockwise_gemm_kernel.cu",
|
||||
"csrc/gemm/int8_gemm_kernel.cu",
|
||||
"csrc/gemm/per_token_group_quant_fp8.cu",
|
||||
"csrc/gemm/per_token_quant_fp8.cu",
|
||||
"csrc/gemm/per_tensor_quant_fp8.cu",
|
||||
"csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/speculative/eagle_utils.cu",
|
||||
"csrc/speculative/speculative_sampling.cu",
|
||||
"csrc/torch_extension.cc",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
||||
"3rdparty/flashinfer/csrc/norm.cu",
|
||||
@@ -158,7 +158,7 @@ extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linu
|
||||
|
||||
ext_modules = [
|
||||
CUDAExtension(
|
||||
name="sgl_kernel.ops._kernels",
|
||||
name="sgl_kernel.common_ops",
|
||||
sources=sources,
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
@@ -174,8 +174,8 @@ ext_modules = [
|
||||
setup(
|
||||
name="sgl-kernel",
|
||||
version=_get_version(),
|
||||
packages=find_packages(),
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages(where="python"),
|
||||
package_dir={"": "python"},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
|
||||
@@ -13,12 +13,9 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
@@ -35,16 +32,16 @@ def _get_version():
|
||||
return line.split("=")[1].strip().strip('"')
|
||||
|
||||
|
||||
operator_namespace = "sgl_kernels"
|
||||
operator_namespace = "sgl_kernel"
|
||||
include_dirs = [
|
||||
root / "src" / "sgl-kernel" / "include",
|
||||
root / "src" / "sgl-kernel" / "csrc",
|
||||
root / "include",
|
||||
root / "csrc",
|
||||
]
|
||||
|
||||
sources = [
|
||||
"src/sgl-kernel/torch_extension_rocm.cc",
|
||||
"src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip",
|
||||
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/allreduce/custom_all_reduce.hip",
|
||||
"csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/torch_extension_rocm.cc",
|
||||
]
|
||||
|
||||
cxx_flags = ["-O3"]
|
||||
@@ -64,26 +61,27 @@ hipcc_flags = [
|
||||
"-DENABLE_FP8",
|
||||
]
|
||||
|
||||
ext_modules = [
|
||||
CUDAExtension(
|
||||
name="sgl_kernel.common_ops",
|
||||
sources=sources,
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
"nvcc": hipcc_flags,
|
||||
"cxx": cxx_flags,
|
||||
},
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=True,
|
||||
),
|
||||
]
|
||||
|
||||
setup(
|
||||
name="sgl-kernel",
|
||||
version=_get_version(),
|
||||
packages=find_packages(),
|
||||
package_dir={"": "src"},
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name="sgl_kernel.ops._kernels",
|
||||
sources=sources,
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
"nvcc": hipcc_flags,
|
||||
"cxx": cxx_flags,
|
||||
},
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=True,
|
||||
),
|
||||
],
|
||||
package_dir={"": "python"},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
install_requires=["torch"],
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ import unittest
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import ray
|
||||
import sgl_kernel.ops.allreduce as custom_ops
|
||||
import sgl_kernel.allreduce as custom_ops
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
Reference in New Issue
Block a user