Move rope and bmm into sgl-kernel (#4241)

This commit is contained in:
Lianmin Zheng
2025-03-09 18:38:15 -07:00
committed by GitHub
parent 9dfafa743c
commit eb06dbcbf8
5 changed files with 183 additions and 13 deletions

View File

@@ -140,6 +140,15 @@ void cublas_grouped_gemm(
const torch::Dtype& out_dtype,
int64_t cublas_handle,
int64_t cuda_stream);
void bmm_fp8(
at::Tensor A,
at::Tensor B,
at::Tensor D,
at::Tensor A_scale,
at::Tensor B_scale,
at::Tensor workspace_buffer,
int64_t cublas_handle,
int64_t cuda_stream);
/*
* From csrc/moe
@@ -198,15 +207,6 @@ void build_tree_kernel(
/*
* From FlashInfer
*/
void bmm_fp8(
at::Tensor A,
at::Tensor B,
at::Tensor D,
at::Tensor A_scale,
at::Tensor B_scale,
at::Tensor workspace_buffer,
int64_t cublas_handle,
int64_t cuda_stream);
void min_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,