Support mixing cutedsl and deepgemm backend (#11807)
This commit is contained in:
@@ -191,11 +191,15 @@ class DeepEPMoE(FusedMoE):
|
|||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||||
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
if (
|
||||||
|
get_moe_runner_backend().is_flashinfer_cutedsl()
|
||||||
|
and self.quant_config.get_name() == "modelopt_fp4"
|
||||||
|
):
|
||||||
return self.forward_flashinfer_cutedsl(
|
return self.forward_flashinfer_cutedsl(
|
||||||
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
||||||
)
|
)
|
||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
|
assert down_gemm_overlap_args is None
|
||||||
return self.forward_deepgemm_masked(dispatch_output)
|
return self.forward_deepgemm_masked(dispatch_output)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Reference in New Issue
Block a user