Support single batch overlap (#10422)
This commit is contained in:
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
@@ -1468,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
@@ -1495,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alpha=layer.g2_alphas,
|
||||
masked_m=masked_m,
|
||||
**(
|
||||
dict(
|
||||
down_sm_count=down_gemm_overlap_args.num_sms,
|
||||
down_signals=down_gemm_overlap_args.signal,
|
||||
down_start_event=down_gemm_overlap_args.start_event,
|
||||
)
|
||||
if down_gemm_overlap_args is not None
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user