Add dsv3 router gemm kernel (#7627)
This commit is contained in:
@@ -200,6 +200,7 @@ void bmm_fp8(
|
||||
at::Tensor workspace_buffer,
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b);
|
||||
|
||||
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user