Support single batch overlap (#10422)

This commit is contained in:
fzyzcjy
2025-10-02 18:04:36 +08:00
committed by GitHub
parent 0b9dfba787
commit 5e786cca3a
9 changed files with 268 additions and 20 deletions

View File

@@ -47,6 +47,7 @@ if TYPE_CHECKING:
CombineInput,
StandardDispatchOutput,
)
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
if is_cuda():
from sgl_kernel import scaled_fp4_quant
@@ -1468,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
masked_m: torch.Tensor,
moe_runner_config: MoeRunnerConfig,
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
) -> torch.Tensor:
assert (
moe_runner_config.activation == "silu"
@@ -1495,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=masked_m,
**(
dict(
down_sm_count=down_gemm_overlap_args.num_sms,
down_signals=down_gemm_overlap_args.signal,
down_start_event=down_gemm_overlap_args.start_event,
)
if down_gemm_overlap_args is not None
else {}
),
)
return out