[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant op for the flashinfer grouped gemm (#9200)
This commit is contained in:
@@ -27,6 +27,15 @@ void scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts,
|
||||
torch::Tensor const& mask);
|
||||
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant(
|
||||
@@ -50,3 +59,18 @@ void scaled_fp4_experts_quant(
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts,
|
||||
torch::Tensor const& mask) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts, mask);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user