Support token-level quantization for EP MoE (#6782)

This commit is contained in:
Cheng Wan
2025-05-30 17:26:30 -07:00
committed by GitHub
parent f18b068f15
commit ced3c07afe
2 changed files with 89 additions and 25 deletions

View File

@@ -50,7 +50,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant,
sglang_per_token_quant_fp8,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
@@ -65,10 +68,16 @@ logger = logging.getLogger(__name__)
class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None
def __init__(self, device, use_flashinfer: bool = False):
def __init__(
self,
device,
use_flashinfer: bool = False,
use_per_token_if_dynamic: bool = True,
):
super().__init__()
self.device = device
self.use_flashinfer = use_flashinfer
self.use_per_token_if_dynamic = use_per_token_if_dynamic
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
GroupedGemmRunner._init_flashinfer_wrapper(device)
@@ -124,6 +133,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_b,
block_shape=block_shape,
c_dtype=c_dtype,
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
return c
@@ -154,6 +164,7 @@ class EPMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
use_per_token_if_dynamic: bool = True,
):
super().__init__()
@@ -184,6 +195,7 @@ class EPMoE(torch.nn.Module):
self.custom_routing_function = custom_routing_function
self.activation = activation
self.routed_scaling_factor = routed_scaling_factor
self.use_per_token_if_dynamic = use_per_token_if_dynamic
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
@@ -227,6 +239,7 @@ class EPMoE(torch.nn.Module):
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device,
use_flashinfer=False, # TODO: use flashinfer
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
topk_weights, topk_ids = select_experts(
@@ -259,12 +272,16 @@ class EPMoE(torch.nn.Module):
),
)
if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
if self.use_per_token_if_dynamic:
max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
else:
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
# PreReorder
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
@@ -278,9 +295,27 @@ class EPMoE(torch.nn.Module):
self.top_k,
hidden_states.shape[1],
BLOCK_SIZE=512,
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
dispose_tensor(hidden_states)
if (
self.activation_scheme == "dynamic"
and not self.use_block_quant
and self.use_per_token_if_dynamic
):
scale = torch.empty(
hidden_states_shape[0] * self.top_k,
device=hidden_states_device,
dtype=torch.float32,
)
scale[src2dst] = (
self.w13_input_scale.unsqueeze(1)
.expand(hidden_states_shape[0], self.top_k)
.reshape(-1)
)
self.w13_input_scale = scale
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
@@ -310,21 +345,24 @@ class EPMoE(torch.nn.Module):
del gateup_input
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states_dtype
),
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states_device,
if self.activation_scheme == "dynamic" and not self.use_block_quant:
self.w2_input_scale = None
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states_dtype,
)
else:
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states_dtype
),
)
if self.activation == "silu":
@@ -353,6 +391,16 @@ class EPMoE(torch.nn.Module):
raise ValueError(f"Unsupported activation: {self.activation=}")
del gateup_output
if self.activation_scheme == "dynamic" and not self.use_block_quant:
if self.use_per_token_if_dynamic:
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
else:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states_device,
)
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],