Move rope and bmm into sgl-kernel (#4241)
This commit is contained in:
@@ -140,6 +140,15 @@ void cublas_grouped_gemm(
|
||||
const torch::Dtype& out_dtype,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
void bmm_fp8(
|
||||
at::Tensor A,
|
||||
at::Tensor B,
|
||||
at::Tensor D,
|
||||
at::Tensor A_scale,
|
||||
at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
@@ -198,15 +207,6 @@ void build_tree_kernel(
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
void bmm_fp8(
|
||||
at::Tensor A,
|
||||
at::Tensor B,
|
||||
at::Tensor D,
|
||||
at::Tensor A_scale,
|
||||
at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
void min_p_sampling_from_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor uniform_samples,
|
||||
|
||||
Reference in New Issue
Block a user