305 lines
9.1 KiB
Python
305 lines
9.1 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, List, Optional
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.moe.moe_runner.base import (
|
|
MoeQuantInfo,
|
|
MoeRunnerConfig,
|
|
MoeRunnerCore,
|
|
RunnerInput,
|
|
RunnerOutput,
|
|
register_post_permute,
|
|
register_pre_permute,
|
|
)
|
|
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
|
from sglang.srt.utils import dispose_tensor
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
|
StandardCombineInput,
|
|
StandardDispatchOutput,
|
|
)
|
|
|
|
|
|
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
|
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
|
@torch.compile
|
|
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
|
temp = x.to(torch.float32).view(torch.int32)
|
|
exp = torch.bitwise_right_shift(temp, 23)
|
|
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
|
is_ru = torch.logical_and(
|
|
torch.logical_and((mant > 0), (exp != 0xFE)),
|
|
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
|
)
|
|
exp = torch.where(is_ru, exp + 1, exp)
|
|
new_x = exp.to(torch.uint8).view(torch.int)
|
|
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
|
|
|
|
|
@dataclass
|
|
class DeepGemmRunnerInput(RunnerInput):
|
|
hidden_states: torch.Tensor
|
|
hidden_states_scale: torch.Tensor
|
|
masked_m: torch.Tensor
|
|
expected_m: int
|
|
use_masked_gemm: bool
|
|
|
|
@property
|
|
def runner_backend(self) -> MoeRunnerBackend:
|
|
return MoeRunnerBackend.DEEP_GEMM
|
|
|
|
|
|
@dataclass
|
|
class DeepGemmRunnerOutput(RunnerOutput):
|
|
hidden_states: torch.Tensor
|
|
|
|
@property
|
|
def runner_backend(self) -> MoeRunnerBackend:
|
|
return MoeRunnerBackend.DEEP_GEMM
|
|
|
|
|
|
@dataclass
|
|
class DeepGemmMoeQuantInfo(MoeQuantInfo):
|
|
w13_weight: torch.Tensor
|
|
w2_weight: torch.Tensor
|
|
use_fp8: bool
|
|
w13_scale: Optional[torch.Tensor] = None
|
|
w2_scale: Optional[torch.Tensor] = None
|
|
block_shape: Optional[List[int]] = None
|
|
|
|
|
|
class DeepGemmRunnerCore(MoeRunnerCore):
|
|
def __init__(self, config: MoeRunnerConfig):
|
|
super().__init__(config)
|
|
assert self.config.activation == "silu"
|
|
|
|
def run(
|
|
self,
|
|
runner_input: DeepGemmRunnerInput,
|
|
quant_info: DeepGemmMoeQuantInfo,
|
|
running_state: dict,
|
|
) -> DeepGemmRunnerOutput:
|
|
|
|
if runner_input.use_masked_gemm:
|
|
hidden_states = self._run_masked_gemm(
|
|
runner_input,
|
|
quant_info,
|
|
running_state,
|
|
)
|
|
else:
|
|
hidden_states = self._run_contiguous_gemm(
|
|
runner_input,
|
|
quant_info,
|
|
running_state,
|
|
)
|
|
return DeepGemmRunnerOutput(hidden_states=hidden_states)
|
|
|
|
def _run_masked_gemm(
|
|
self,
|
|
runner_input: DeepGemmRunnerInput,
|
|
quant_info: DeepGemmMoeQuantInfo,
|
|
running_state: dict,
|
|
) -> torch.Tensor:
|
|
|
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
silu_and_mul_masked_post_quant_fwd,
|
|
)
|
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|
|
|
hidden_states = runner_input.hidden_states
|
|
hidden_states_scale = runner_input.hidden_states_scale
|
|
masked_m = runner_input.masked_m
|
|
expected_m = runner_input.expected_m
|
|
|
|
w13_weight = quant_info.w13_weight
|
|
w2_weight = quant_info.w2_weight
|
|
w13_scale = quant_info.w13_scale
|
|
w2_scale = quant_info.w2_scale
|
|
|
|
hidden_states_device = running_state["hidden_states_device"]
|
|
|
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
b, s_mn, s_k = hidden_states_scale.shape
|
|
assert (
|
|
s_mn % 4 == 0 and s_k % 4 == 0
|
|
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
|
|
|
# GroupGemm-0
|
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
|
|
else:
|
|
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
|
hidden_states_scale
|
|
)
|
|
|
|
num_groups, m, k = hidden_states.shape
|
|
n = w13_weight.size(1)
|
|
gateup_output = torch.empty(
|
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
)
|
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
(hidden_states, hidden_states_scale),
|
|
(w13_weight, w13_scale),
|
|
gateup_output,
|
|
masked_m,
|
|
expected_m,
|
|
)
|
|
dispose_tensor(hidden_states)
|
|
|
|
# Act
|
|
down_input = torch.empty(
|
|
(
|
|
gateup_output.shape[0],
|
|
gateup_output.shape[1],
|
|
gateup_output.shape[2] // 2,
|
|
),
|
|
device=hidden_states_device,
|
|
dtype=torch.float8_e4m3fn,
|
|
)
|
|
scale_block_size = 128
|
|
down_input_scale = torch.empty(
|
|
(
|
|
gateup_output.shape[0],
|
|
gateup_output.shape[1],
|
|
gateup_output.shape[2] // 2 // scale_block_size,
|
|
),
|
|
device=hidden_states_device,
|
|
dtype=torch.float32,
|
|
)
|
|
silu_and_mul_masked_post_quant_fwd(
|
|
gateup_output,
|
|
down_input,
|
|
down_input_scale,
|
|
scale_block_size,
|
|
masked_m,
|
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
)
|
|
del gateup_output
|
|
|
|
# GroupGemm-1
|
|
n = w2_weight.shape[1]
|
|
|
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
|
down_input_scale
|
|
)
|
|
|
|
down_output = torch.empty(
|
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
)
|
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
(down_input, down_input_scale),
|
|
(w2_weight, w2_scale),
|
|
down_output,
|
|
masked_m,
|
|
expected_m,
|
|
)
|
|
del down_input
|
|
|
|
return down_output
|
|
|
|
def _run_contiguous_gemm(
|
|
self,
|
|
runner_input: DeepGemmRunnerInput,
|
|
quant_info: DeepGemmMoeQuantInfo,
|
|
running_state: dict,
|
|
) -> torch.Tensor:
|
|
pass
|
|
|
|
@property
|
|
def runner_backend(self) -> MoeRunnerBackend:
|
|
return MoeRunnerBackend.DEEP_GEMM
|
|
|
|
|
|
@register_pre_permute("standard", "deep_gemm")
|
|
def pre_permute_standard_to_deep_gemm(
|
|
dispatch_output: StandardDispatchOutput,
|
|
quant_info: DeepGemmMoeQuantInfo,
|
|
runner_config: MoeRunnerConfig,
|
|
running_state: dict,
|
|
) -> DeepGemmRunnerInput:
|
|
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
|
|
|
|
hidden_states, topk_output = dispatch_output
|
|
topk_weights, topk_ids, _ = topk_output
|
|
|
|
hidden_states_shape = hidden_states.shape
|
|
hidden_states_dtype = hidden_states.dtype
|
|
hidden_states_device = hidden_states.device
|
|
hidden_states_ref = hidden_states
|
|
|
|
topk_weights, topk_ids = topk_weights, topk_ids
|
|
|
|
# PreReorder
|
|
masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
|
|
moe_ep_deepgemm_preprocess(
|
|
topk_ids,
|
|
runner_config.num_local_experts,
|
|
hidden_states,
|
|
runner_config.top_k,
|
|
quant_info.block_shape,
|
|
)
|
|
)
|
|
|
|
dispose_tensor(hidden_states_ref)
|
|
|
|
running_state["topk_ids"] = topk_ids
|
|
running_state["topk_weights"] = topk_weights
|
|
running_state["hidden_states_shape"] = hidden_states_shape
|
|
running_state["hidden_states_dtype"] = hidden_states_dtype
|
|
running_state["hidden_states_device"] = hidden_states_device
|
|
running_state["src2dst"] = src2dst
|
|
|
|
return DeepGemmRunnerInput(
|
|
hidden_states=hidden_states,
|
|
hidden_states_scale=hidden_states_scale,
|
|
masked_m=masked_m,
|
|
expected_m=expected_m,
|
|
use_masked_gemm=True,
|
|
)
|
|
|
|
|
|
@register_post_permute("deep_gemm", "standard")
|
|
def post_permute_deep_gemm_to_standard(
|
|
runner_output: DeepGemmRunnerOutput,
|
|
quant_info: DeepGemmMoeQuantInfo,
|
|
runner_config: MoeRunnerConfig,
|
|
running_state: dict,
|
|
) -> StandardCombineInput:
|
|
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
|
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
|
|
|
hidden_states_shape = running_state["hidden_states_shape"]
|
|
hidden_states_dtype = running_state["hidden_states_dtype"]
|
|
hidden_states_device = running_state["hidden_states_device"]
|
|
src2dst = running_state["src2dst"]
|
|
topk_ids = running_state["topk_ids"]
|
|
topk_weights = running_state["topk_weights"]
|
|
|
|
output = torch.empty(
|
|
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
|
)
|
|
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
|
runner_output.hidden_states,
|
|
output,
|
|
src2dst,
|
|
topk_ids,
|
|
topk_weights,
|
|
runner_config.top_k,
|
|
hidden_states_shape[1],
|
|
BLOCK_SIZE=512,
|
|
)
|
|
|
|
dispose_tensor(runner_output.hidden_states)
|
|
|
|
if runner_config.routed_scaling_factor is not None:
|
|
output *= runner_config.routed_scaling_factor
|
|
|
|
return StandardCombineInput(
|
|
hidden_states=output,
|
|
)
|