Add fp8 fused_experts kernel for CPU in sgl-kernel and add UT (#6404)
This commit is contained in:
@@ -130,8 +130,10 @@ at::Tensor fused_experts_cpu(
|
||||
at::Tensor& topk_ids,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni);
|
||||
@@ -260,7 +262,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// moe
|
||||
m.def(
|
||||
"fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
|
||||
"inplace, bool use_int8_w8a8, Tensor? w1_scale, Tensor? w2_scale, Tensor? a1_scale, Tensor? a2_scale, bool "
|
||||
"inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? "
|
||||
"a1_scale, Tensor? a2_scale, bool "
|
||||
"is_vnni) -> Tensor");
|
||||
m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user