788 lines
27 KiB
Python
788 lines
27 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from sglang.srt.distributed import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
ep_gather,
|
|
ep_scatter,
|
|
gelu_and_mul_triton_kernel,
|
|
grouped_gemm_triton,
|
|
moe_ep_deepgemm_preprocess,
|
|
post_reorder_triton_kernel,
|
|
pre_reorder_triton_kernel,
|
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
|
run_cutlass_moe_ep_preproess,
|
|
run_moe_ep_preproess,
|
|
silu_and_mul_masked_post_quant_fwd,
|
|
silu_and_mul_triton_kernel,
|
|
tma_align_input_scale,
|
|
)
|
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
from sglang.srt.layers.moe.topk import TopKOutput
|
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|
from sglang.srt.layers.quantization.base_config import (
|
|
QuantizationConfig,
|
|
QuantizeMethodBase,
|
|
)
|
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
|
is_fp8_fnuz,
|
|
sglang_per_token_group_quant_fp8,
|
|
sglang_per_token_quant_fp8,
|
|
)
|
|
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.utils import (
|
|
DeepEPMode,
|
|
ceil_div,
|
|
dispose_tensor,
|
|
get_bool_env_var,
|
|
is_hip,
|
|
is_npu,
|
|
next_power_of_2,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
|
|
DeepEPLLOutput,
|
|
DeepEPNormalOutput,
|
|
DispatchOutput,
|
|
)
|
|
|
|
_is_hip = is_hip()
|
|
_is_npu = is_npu()
|
|
_is_fp8_fnuz = is_fp8_fnuz()
|
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
use_flashinfer_trtllm_moe = (
|
|
global_server_args_dict["enable_flashinfer_trtllm_moe"]
|
|
and global_server_args_dict["enable_ep_moe"]
|
|
)
|
|
|
|
if not (_is_npu or _is_hip):
|
|
from sgl_kernel import silu_and_mul
|
|
|
|
if _use_aiter:
|
|
from aiter import ActivationType, QuantType
|
|
from aiter.fused_moe import fused_moe
|
|
from aiter.ops.shuffle import shuffle_weight
|
|
|
|
if use_flashinfer_trtllm_moe:
|
|
try:
|
|
import flashinfer.fused_moe as fi_fused_moe
|
|
except ImportError:
|
|
fi_fused_moe = None
|
|
use_flashinfer_trtllm_moe = False
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
|
# Guess tokens per expert assuming perfect expert distribution first.
|
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
|
# And pad the number to the next power of 2.
|
|
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
|
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
|
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
|
return tile_tokens_dim
|
|
|
|
|
|
class EPMoE(FusedMoE):
|
|
"""
|
|
MoE Expert Parallel Impl
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_experts: int,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
layer_id: int,
|
|
num_fused_shared_experts: int = 0,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
tp_size: Optional[int] = None,
|
|
prefix: str = "",
|
|
activation: str = "silu",
|
|
routed_scaling_factor: Optional[float] = None,
|
|
):
|
|
super().__init__(
|
|
num_experts=num_experts,
|
|
hidden_size=hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
num_fused_shared_experts=num_fused_shared_experts,
|
|
layer_id=layer_id,
|
|
top_k=top_k,
|
|
params_dtype=params_dtype,
|
|
quant_config=quant_config,
|
|
tp_size=tp_size,
|
|
prefix=prefix,
|
|
activation=activation,
|
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
enable_ep_moe=True,
|
|
)
|
|
|
|
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
|
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
|
|
|
self.intermediate_size = intermediate_size
|
|
|
|
if isinstance(quant_config, Fp8Config):
|
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
|
self.block_shape = (
|
|
self.quant_method.quant_config.weight_block_size
|
|
if self.use_block_quant
|
|
else None
|
|
)
|
|
self.use_fp8_w8a8 = True
|
|
self.fp8_dtype = torch.float8_e4m3fn
|
|
self.activation_scheme = quant_config.activation_scheme
|
|
else:
|
|
self.use_fp8_w8a8 = False
|
|
self.use_block_quant = False
|
|
self.block_shape = None
|
|
self.activation_scheme = None
|
|
|
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
|
return self.forward_deepgemm(hidden_states, topk_output)
|
|
else:
|
|
return super().forward(hidden_states, topk_output)
|
|
|
|
def forward_deepgemm(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_output: TopKOutput,
|
|
):
|
|
|
|
self.w13_weight_fp8 = (
|
|
self.w13_weight,
|
|
(
|
|
self.w13_weight_scale_inv
|
|
if self.use_block_quant
|
|
else self.w13_weight_scale
|
|
),
|
|
)
|
|
self.w2_weight_fp8 = (
|
|
self.w2_weight,
|
|
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
|
)
|
|
|
|
assert self.quant_method is not None
|
|
assert self.activation == "silu"
|
|
hidden_states_shape = hidden_states.shape
|
|
hidden_states_dtype = hidden_states.dtype
|
|
hidden_states_device = hidden_states.device
|
|
|
|
topk_weights, topk_ids, _ = topk_output
|
|
|
|
if not self.use_block_quant:
|
|
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
|
scale_block_size = 128
|
|
w13_weight_scale_n = 2 * (
|
|
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
|
)
|
|
w13_weight_scale_k = (
|
|
hidden_states_shape[-1] + scale_block_size - 1
|
|
) // scale_block_size
|
|
w13_weight_scale = (
|
|
self.w13_weight_scale.unsqueeze(1)
|
|
.repeat_interleave(w13_weight_scale_n, dim=1)
|
|
.unsqueeze(2)
|
|
.repeat_interleave(w13_weight_scale_k, dim=2)
|
|
)
|
|
self.w13_weight_fp8 = (
|
|
self.w13_weight,
|
|
w13_weight_scale,
|
|
)
|
|
w2_weight_scale_n = (
|
|
hidden_states_shape[-1] + scale_block_size - 1
|
|
) // scale_block_size
|
|
w2_weight_scale_k = (
|
|
self.intermediate_size + scale_block_size - 1
|
|
) // scale_block_size
|
|
w2_weight_scale = (
|
|
self.w2_weight_scale.unsqueeze(1)
|
|
.repeat_interleave(w2_weight_scale_n, dim=1)
|
|
.unsqueeze(2)
|
|
.repeat_interleave(w2_weight_scale_k, dim=2)
|
|
)
|
|
self.w2_weight_fp8 = (
|
|
self.w2_weight,
|
|
w2_weight_scale,
|
|
)
|
|
|
|
# PreReorder
|
|
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
|
moe_ep_deepgemm_preprocess(
|
|
topk_ids,
|
|
self.num_experts,
|
|
hidden_states,
|
|
self.top_k,
|
|
self.start_expert_id,
|
|
self.end_expert_id,
|
|
self.block_shape,
|
|
)
|
|
)
|
|
|
|
dispose_tensor(hidden_states)
|
|
|
|
# GroupGemm-0
|
|
gateup_input_fp8 = (
|
|
gateup_input,
|
|
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
|
|
)
|
|
num_groups, m, k = gateup_input_fp8[0].size()
|
|
n = self.w13_weight.size(1)
|
|
gateup_output = torch.empty(
|
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
)
|
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
|
)
|
|
del gateup_input
|
|
del gateup_input_fp8
|
|
|
|
# Act
|
|
down_input = torch.empty(
|
|
(
|
|
gateup_output.shape[0],
|
|
gateup_output.shape[1],
|
|
gateup_output.shape[2] // 2,
|
|
),
|
|
device=hidden_states_device,
|
|
dtype=self.fp8_dtype,
|
|
)
|
|
scale_block_size = 128
|
|
down_input_scale = torch.empty(
|
|
(
|
|
gateup_output.shape[0],
|
|
gateup_output.shape[1],
|
|
gateup_output.shape[2] // 2 // scale_block_size,
|
|
),
|
|
device=hidden_states_device,
|
|
dtype=torch.float32,
|
|
)
|
|
silu_and_mul_masked_post_quant_fwd(
|
|
gateup_output,
|
|
down_input,
|
|
down_input_scale,
|
|
scale_block_size,
|
|
masked_m,
|
|
)
|
|
del gateup_output
|
|
|
|
# GroupGemm-1
|
|
n = self.w2_weight.size(1)
|
|
down_input_fp8 = (
|
|
down_input,
|
|
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
|
)
|
|
down_output = torch.empty(
|
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
)
|
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
|
)
|
|
del down_input
|
|
del down_input_fp8
|
|
|
|
# PostReorder
|
|
output = torch.empty(
|
|
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
|
)
|
|
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
|
down_output,
|
|
output,
|
|
src2dst,
|
|
topk_ids,
|
|
topk_weights,
|
|
self.start_expert_id,
|
|
self.end_expert_id,
|
|
self.top_k,
|
|
hidden_states_shape[1],
|
|
m_max * self.start_expert_id,
|
|
BLOCK_SIZE=512,
|
|
)
|
|
return output
|
|
|
|
|
|
class DeepEPMoE(EPMoE):
|
|
"""
|
|
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
|
"""
|
|
|
|
_has_printed = False
|
|
|
|
def __init__(
|
|
self,
|
|
num_experts: int,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
layer_id: int,
|
|
num_fused_shared_experts: int = 0,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
tp_size: Optional[int] = None,
|
|
prefix: str = "",
|
|
activation: str = "silu",
|
|
routed_scaling_factor: Optional[float] = None,
|
|
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
|
):
|
|
super().__init__(
|
|
num_experts=num_experts,
|
|
top_k=top_k,
|
|
hidden_size=hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
layer_id=layer_id,
|
|
num_fused_shared_experts=num_fused_shared_experts,
|
|
params_dtype=params_dtype,
|
|
quant_config=quant_config,
|
|
tp_size=tp_size,
|
|
prefix=prefix,
|
|
activation=activation,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
)
|
|
self.deepep_mode = deepep_mode
|
|
|
|
# TODO: move to the beginning of the file
|
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
|
|
|
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
|
group=get_tp_group().device_group,
|
|
router_topk=self.top_k,
|
|
permute_fusion=True,
|
|
num_experts=self.num_experts,
|
|
num_local_experts=self.num_local_experts,
|
|
hidden_size=hidden_size,
|
|
params_dtype=params_dtype,
|
|
deepep_mode=deepep_mode,
|
|
async_finish=True, # TODO
|
|
return_recv_hook=True,
|
|
)
|
|
|
|
if self.deepep_mode.enable_low_latency():
|
|
assert (
|
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
|
if _use_aiter:
|
|
# expert_mask is of size (self.num_local_experts + 1),
|
|
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
|
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
|
# self.expert_mask = [1, 1, 1, 1, 0]
|
|
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
|
self.expert_mask = torch.zeros(
|
|
(self.num_local_experts + 1),
|
|
device=torch.cuda.current_device(),
|
|
dtype=torch.int,
|
|
)
|
|
# the last one is invalid rank_id
|
|
self.expert_mask[:-1] = 1
|
|
else:
|
|
self.w13_weight_fp8 = (
|
|
self.w13_weight,
|
|
(
|
|
self.w13_weight_scale_inv
|
|
if self.use_block_quant
|
|
else self.w13_weight_scale
|
|
),
|
|
)
|
|
self.w2_weight_fp8 = (
|
|
self.w2_weight,
|
|
(
|
|
self.w2_weight_scale_inv
|
|
if self.use_block_quant
|
|
else self.w2_weight_scale
|
|
),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
):
|
|
dispatch_output = self.dispatch(
|
|
hidden_states, topk_idx, topk_weights, forward_batch
|
|
)
|
|
hidden_states = self.moe_impl(dispatch_output)
|
|
hidden_states = self.combine(
|
|
hidden_states,
|
|
dispatch_output.topk_idx,
|
|
dispatch_output.topk_weights,
|
|
forward_batch,
|
|
)
|
|
return hidden_states
|
|
|
|
def dispatch(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
):
|
|
return self.deepep_dispatcher.dispatch(
|
|
hidden_states=hidden_states,
|
|
topk_idx=topk_idx,
|
|
topk_weights=topk_weights,
|
|
forward_batch=forward_batch,
|
|
)
|
|
|
|
def moe_impl(self, dispatch_output: DispatchOutput):
|
|
if _use_aiter:
|
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
|
return self.forward_aiter(dispatch_output)
|
|
if dispatch_output.format.is_deepep_normal():
|
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
|
elif dispatch_output.format.is_deepep_ll():
|
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
|
return self.forward_deepgemm_masked(dispatch_output)
|
|
else:
|
|
raise ValueError(
|
|
f"Dispatch output format {dispatch_output.format} is not supported"
|
|
)
|
|
|
|
def combine(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
):
|
|
return self.deepep_dispatcher.combine(
|
|
hidden_states=hidden_states,
|
|
topk_idx=topk_idx,
|
|
topk_weights=topk_weights,
|
|
forward_batch=forward_batch,
|
|
)
|
|
|
|
def forward_aiter(
|
|
self,
|
|
dispatch_output: DeepEPNormalOutput,
|
|
):
|
|
hidden_states, topk_idx, topk_weights = (
|
|
dispatch_output.hidden_states,
|
|
dispatch_output.topk_idx,
|
|
dispatch_output.topk_weights,
|
|
)
|
|
if hidden_states.shape[0] == 0:
|
|
return hidden_states
|
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
|
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
|
# (idx == num_local_experts) meaning not used in aiter fused_moe
|
|
topk_idx_copy = topk_idx.to(torch.int32)
|
|
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
|
|
|
|
return fused_moe(
|
|
hidden_states,
|
|
self.w13_weight,
|
|
self.w2_weight,
|
|
topk_weights,
|
|
topk_idx_copy,
|
|
w1_scale=self.w13_weight_scale_inv,
|
|
w2_scale=self.w2_weight_scale_inv,
|
|
quant_type=QuantType.per_128x128,
|
|
activation=(
|
|
ActivationType.Silu
|
|
if self.activation == "silu"
|
|
else ActivationType.Gelu
|
|
),
|
|
expert_mask=self.expert_mask,
|
|
)
|
|
|
|
def forward_deepgemm_contiguous(
|
|
self,
|
|
dispatch_output: DeepEPNormalOutput,
|
|
):
|
|
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
|
dispatch_output
|
|
)
|
|
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
|
assert self.quant_method is not None
|
|
assert self.activation == "silu"
|
|
if num_recv_tokens_per_expert is None:
|
|
return hidden_states_fp8.bfloat16()
|
|
all_tokens = sum(num_recv_tokens_per_expert)
|
|
if all_tokens <= 0:
|
|
return hidden_states_fp8.bfloat16()
|
|
M, K = hidden_states_fp8.size()
|
|
N = self.w13_weight.size(1)
|
|
scale_block_size = 128
|
|
|
|
hidden_states_fp8_shape = hidden_states_fp8.shape
|
|
hidden_states_fp8_device = hidden_states_fp8.device
|
|
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
|
|
|
input_tensor = [
|
|
torch.empty(
|
|
(all_tokens, K),
|
|
device=hidden_states_fp8.device,
|
|
dtype=hidden_states_fp8.dtype,
|
|
),
|
|
(
|
|
# TODO check whether need `zeros`
|
|
torch.zeros(
|
|
(ceil_div(K // 128, 4), all_tokens),
|
|
device=hidden_states_fp8.device,
|
|
dtype=torch.int,
|
|
).transpose(0, 1)
|
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
|
else torch.empty(
|
|
(all_tokens, K // 128),
|
|
device=hidden_states_fp8.device,
|
|
dtype=torch.float32,
|
|
)
|
|
),
|
|
]
|
|
m_indices = torch.empty(
|
|
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
|
|
)
|
|
output_index = torch.empty_like(topk_idx)
|
|
|
|
num_recv_tokens_per_expert_gpu = torch.tensor(
|
|
num_recv_tokens_per_expert,
|
|
dtype=torch.int32,
|
|
pin_memory=True,
|
|
device="cpu",
|
|
).cuda(non_blocking=True)
|
|
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
|
|
|
ep_scatter(
|
|
hidden_states_fp8,
|
|
hidden_states_scale,
|
|
topk_idx,
|
|
num_recv_tokens_per_expert_gpu,
|
|
expert_start_loc,
|
|
input_tensor[0],
|
|
input_tensor[1],
|
|
m_indices,
|
|
output_index,
|
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
)
|
|
dispose_tensor(hidden_states_fp8)
|
|
|
|
gateup_output = torch.empty(
|
|
(all_tokens, N),
|
|
device=hidden_states_fp8_device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
|
)
|
|
del input_tensor
|
|
down_input = torch.empty(
|
|
(
|
|
all_tokens,
|
|
N // 2,
|
|
),
|
|
device=gateup_output.device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
silu_and_mul(gateup_output.view(-1, N), down_input)
|
|
del gateup_output
|
|
down_output = torch.empty(
|
|
(all_tokens, K),
|
|
device=hidden_states_fp8_device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
|
down_input,
|
|
scale_block_size,
|
|
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
)
|
|
del down_input
|
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
|
(down_input_fp8, down_input_scale),
|
|
self.w2_weight_fp8,
|
|
down_output,
|
|
m_indices,
|
|
)
|
|
del down_input_fp8, down_input_scale
|
|
|
|
gather_out = torch.empty(
|
|
hidden_states_fp8_shape,
|
|
device=hidden_states_fp8_device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
|
|
|
return gather_out
|
|
|
|
def forward_deepgemm_masked(
|
|
self,
|
|
dispatch_output: DeepEPLLOutput,
|
|
):
|
|
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
|
assert self.quant_method is not None
|
|
assert self.activation == "silu"
|
|
|
|
# GroupGemm-0
|
|
num_groups, m, k = hidden_states_fp8[0].size()
|
|
n = self.w13_weight.size(1)
|
|
expected_m = min(expected_m, m)
|
|
gateup_output = torch.empty(
|
|
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
|
)
|
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
hidden_states_fp8,
|
|
self.w13_weight_fp8,
|
|
gateup_output,
|
|
masked_m,
|
|
expected_m,
|
|
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
|
)
|
|
dispose_tensor(hidden_states_fp8[0])
|
|
|
|
# Act
|
|
down_input = torch.empty(
|
|
(
|
|
gateup_output.shape[0],
|
|
gateup_output.shape[1],
|
|
gateup_output.shape[2] // 2,
|
|
),
|
|
device=gateup_output.device,
|
|
dtype=self.fp8_dtype,
|
|
)
|
|
scale_block_size = 128
|
|
down_input_scale = torch.empty(
|
|
(
|
|
gateup_output.shape[0],
|
|
gateup_output.shape[1],
|
|
gateup_output.shape[2] // 2 // scale_block_size,
|
|
),
|
|
device=gateup_output.device,
|
|
dtype=torch.float32,
|
|
)
|
|
silu_and_mul_masked_post_quant_fwd(
|
|
gateup_output,
|
|
down_input,
|
|
down_input_scale,
|
|
scale_block_size,
|
|
masked_m,
|
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
)
|
|
del gateup_output
|
|
|
|
# GroupGemm-1
|
|
n = self.w2_weight.size(1)
|
|
down_input_fp8 = (
|
|
down_input,
|
|
(
|
|
down_input_scale
|
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
|
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
|
down_input_scale
|
|
)
|
|
),
|
|
)
|
|
down_output = torch.empty(
|
|
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
|
)
|
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
down_input_fp8,
|
|
self.w2_weight_fp8,
|
|
down_output,
|
|
masked_m,
|
|
expected_m,
|
|
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
|
)
|
|
|
|
return down_output
|
|
|
|
|
|
class FlashInferEPMoE(EPMoE):
|
|
def __init__(self, *args, **kwargs):
|
|
renormalize = kwargs.pop("renormalize", True)
|
|
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
|
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
|
num_expert_group = kwargs.pop("num_expert_group", None)
|
|
topk_group = kwargs.pop("topk_group", None)
|
|
correction_bias = kwargs.pop("correction_bias", None)
|
|
super().__init__(*args, **kwargs)
|
|
self.renormalize = renormalize
|
|
self.num_fused_shared_experts = num_fused_shared_experts
|
|
self.use_grouped_topk = use_grouped_topk
|
|
if self.use_grouped_topk:
|
|
assert num_expert_group is not None and topk_group is not None
|
|
self.num_expert_group = num_expert_group
|
|
self.topk_group = topk_group
|
|
self.correction_bias = correction_bias
|
|
self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
|
|
|
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
|
assert use_flashinfer_trtllm_moe
|
|
assert (
|
|
self.activation == "silu"
|
|
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
|
assert (
|
|
self.renormalize
|
|
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
|
assert (
|
|
self.num_fused_shared_experts == 0
|
|
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
|
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
|
# NOTE: scales of hidden states have to be transposed!
|
|
a_sf_t = a_sf.t().contiguous()
|
|
assert fi_fused_moe is not None
|
|
return fi_fused_moe.trtllm_fp8_block_scale_moe(
|
|
routing_logits=router_logits.to(torch.float32),
|
|
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
|
hidden_states=a_q,
|
|
hidden_states_scale=a_sf_t,
|
|
gemm1_weights=self.w13_weight,
|
|
gemm1_weights_scale=self.w13_weight_scale_inv,
|
|
gemm2_weights=self.w2_weight,
|
|
gemm2_weights_scale=self.w2_weight_scale_inv,
|
|
num_experts=self.num_experts,
|
|
top_k=self.top_k,
|
|
n_group=self.num_expert_group,
|
|
topk_group=self.topk_group,
|
|
intermediate_size=self.w2_weight.shape[2],
|
|
local_expert_offset=self.start_expert_id,
|
|
local_num_experts=self.num_local_experts,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
tile_tokens_dim=_get_tile_tokens_dim(
|
|
hidden_states.shape[0], self.top_k, self.num_experts
|
|
),
|
|
routing_method_type=2, # DeepSeek-styled routing method
|
|
use_shuffled_weight=False,
|
|
)
|
|
|
|
|
|
def get_moe_impl_class():
|
|
if global_server_args_dict["enable_deepep_moe"]:
|
|
return DeepEPMoE
|
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
|
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
|
return FusedMoE
|
|
if use_flashinfer_trtllm_moe:
|
|
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
|
return FlashInferEPMoE
|
|
if global_server_args_dict["enable_ep_moe"]:
|
|
return EPMoE
|
|
return FusedMoE
|