Files
sglang/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
2025-10-07 21:51:41 -07:00

305 lines
9.1 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
import torch
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
MoeRunnerCore,
RunnerInput,
RunnerOutput,
register_post_permute,
register_pre_permute,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import dispose_tensor
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
@dataclass
class DeepGemmRunnerInput(RunnerInput):
hidden_states: torch.Tensor
hidden_states_scale: torch.Tensor
masked_m: torch.Tensor
expected_m: int
use_masked_gemm: bool
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@dataclass
class DeepGemmRunnerOutput(RunnerOutput):
hidden_states: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@dataclass
class DeepGemmMoeQuantInfo(MoeQuantInfo):
w13_weight: torch.Tensor
w2_weight: torch.Tensor
use_fp8: bool
w13_scale: Optional[torch.Tensor] = None
w2_scale: Optional[torch.Tensor] = None
block_shape: Optional[List[int]] = None
class DeepGemmRunnerCore(MoeRunnerCore):
def __init__(self, config: MoeRunnerConfig):
super().__init__(config)
assert self.config.activation == "silu"
def run(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> DeepGemmRunnerOutput:
if runner_input.use_masked_gemm:
hidden_states = self._run_masked_gemm(
runner_input,
quant_info,
running_state,
)
else:
hidden_states = self._run_contiguous_gemm(
runner_input,
quant_info,
running_state,
)
return DeepGemmRunnerOutput(hidden_states=hidden_states)
def _run_masked_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale
masked_m = runner_input.masked_m
expected_m = runner_input.expected_m
w13_weight = quant_info.w13_weight
w2_weight = quant_info.w2_weight
w13_scale = quant_info.w13_scale
w2_scale = quant_info.w2_scale
hidden_states_device = running_state["hidden_states_device"]
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = hidden_states_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
else:
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
hidden_states_scale
)
num_groups, m, k = hidden_states.shape
n = w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(hidden_states, hidden_states_scale),
(w13_weight, w13_scale),
gateup_output,
masked_m,
expected_m,
)
dispose_tensor(hidden_states)
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = w2_weight.shape[1]
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
down_input_scale
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(down_input, down_input_scale),
(w2_weight, w2_scale),
down_output,
masked_m,
expected_m,
)
del down_input
return down_output
def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
pass
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@register_pre_permute("standard", "deep_gemm")
def pre_permute_standard_to_deep_gemm(
dispatch_output: StandardDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
hidden_states, topk_output = dispatch_output
topk_weights, topk_ids, _ = topk_output
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
hidden_states_ref = hidden_states
topk_weights, topk_ids = topk_weights, topk_ids
# PreReorder
masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
runner_config.num_local_experts,
hidden_states,
runner_config.top_k,
quant_info.block_shape,
)
)
dispose_tensor(hidden_states_ref)
running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights
running_state["hidden_states_shape"] = hidden_states_shape
running_state["hidden_states_dtype"] = hidden_states_dtype
running_state["hidden_states_device"] = hidden_states_device
running_state["src2dst"] = src2dst
return DeepGemmRunnerInput(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
masked_m=masked_m,
expected_m=expected_m,
use_masked_gemm=True,
)
@register_post_permute("deep_gemm", "standard")
def post_permute_deep_gemm_to_standard(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> StandardCombineInput:
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
hidden_states_shape = running_state["hidden_states_shape"]
hidden_states_dtype = running_state["hidden_states_dtype"]
hidden_states_device = running_state["hidden_states_device"]
src2dst = running_state["src2dst"]
topk_ids = running_state["topk_ids"]
topk_weights = running_state["topk_weights"]
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
runner_output.hidden_states,
output,
src2dst,
topk_ids,
topk_weights,
runner_config.top_k,
hidden_states_shape[1],
BLOCK_SIZE=512,
)
dispose_tensor(runner_output.hidden_states)
if runner_config.routed_scaling_factor is not None:
output *= runner_config.routed_scaling_factor
return StandardCombineInput(
hidden_states=output,
)