Support token-level quantization for EP MoE (#6782)
This commit is contained in:
@@ -178,6 +178,7 @@ def pre_reorder_triton_kernel(
|
|||||||
topk,
|
topk,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
use_per_token_if_dynamic: tl.constexpr,
|
||||||
):
|
):
|
||||||
OutDtype = gateup_input_ptr.dtype.element_ty
|
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||||
|
|
||||||
@@ -188,11 +189,15 @@ def pre_reorder_triton_kernel(
|
|||||||
|
|
||||||
vec = tl.arange(0, BLOCK_SIZE)
|
vec = tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
if a1_scales_ptr is not None and use_per_token_if_dynamic:
|
||||||
|
scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
|
||||||
|
|
||||||
for idx in range(topk):
|
for idx in range(topk):
|
||||||
expert_id = tl.load(topk_ids_ptr + idx)
|
expert_id = tl.load(topk_ids_ptr + idx)
|
||||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||||
if a1_scales_ptr is not None:
|
if a1_scales_ptr is not None:
|
||||||
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
|
if not use_per_token_if_dynamic:
|
||||||
|
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
|
||||||
else:
|
else:
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
|
|
||||||
@@ -558,6 +563,7 @@ def grouped_gemm_triton_kernel(
|
|||||||
bs_stride_0: tl.constexpr,
|
bs_stride_0: tl.constexpr,
|
||||||
bs_stride_2: tl.constexpr,
|
bs_stride_2: tl.constexpr,
|
||||||
bs_stride_1: tl.constexpr,
|
bs_stride_1: tl.constexpr,
|
||||||
|
use_per_token_if_dynamic: tl.constexpr,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
@@ -621,7 +627,10 @@ def grouped_gemm_triton_kernel(
|
|||||||
b_ptr += BLOCK_SIZE_K
|
b_ptr += BLOCK_SIZE_K
|
||||||
|
|
||||||
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
||||||
scale_a_value = tl.load(scale_a + m_range_start + offs_am[:, None])
|
if use_per_token_if_dynamic:
|
||||||
|
scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
|
||||||
|
else:
|
||||||
|
scale_a_value = tl.load(scale_a + expert_id)
|
||||||
scale_b_value = tl.load(scale_b + expert_id)
|
scale_b_value = tl.load(scale_b + expert_id)
|
||||||
accumulator *= scale_a_value * scale_b_value
|
accumulator *= scale_a_value * scale_b_value
|
||||||
|
|
||||||
@@ -658,6 +667,7 @@ def grouped_gemm_triton(
|
|||||||
scale_b: torch.Tensor = None,
|
scale_b: torch.Tensor = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
c_dtype=None,
|
c_dtype=None,
|
||||||
|
use_per_token_if_dynamic: bool = True,
|
||||||
):
|
):
|
||||||
assert weight_column_major == True # TODO: more
|
assert weight_column_major == True # TODO: more
|
||||||
if use_fp8_w8a8 and block_shape is None:
|
if use_fp8_w8a8 and block_shape is None:
|
||||||
@@ -698,6 +708,11 @@ def grouped_gemm_triton(
|
|||||||
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
|
||||||
|
assert (
|
||||||
|
scale_a.shape[0] == a.shape[0]
|
||||||
|
), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
|
||||||
|
|
||||||
grouped_gemm_triton_kernel[grid](
|
grouped_gemm_triton_kernel[grid](
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
@@ -721,6 +736,7 @@ def grouped_gemm_triton(
|
|||||||
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
||||||
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
|
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
|
||||||
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
||||||
|
use_per_token_if_dynamic,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
|
|||||||
@@ -50,7 +50,10 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
scaled_fp8_quant,
|
||||||
|
sglang_per_token_quant_fp8,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
|
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
|
||||||
|
|
||||||
@@ -65,10 +68,16 @@ logger = logging.getLogger(__name__)
|
|||||||
class GroupedGemmRunner(torch.nn.Module):
|
class GroupedGemmRunner(torch.nn.Module):
|
||||||
flashinfer_gemm_warpper = None
|
flashinfer_gemm_warpper = None
|
||||||
|
|
||||||
def __init__(self, device, use_flashinfer: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
use_flashinfer: bool = False,
|
||||||
|
use_per_token_if_dynamic: bool = True,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.use_flashinfer = use_flashinfer
|
self.use_flashinfer = use_flashinfer
|
||||||
|
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
||||||
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
||||||
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
||||||
|
|
||||||
@@ -124,6 +133,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|||||||
scale_b,
|
scale_b,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
c_dtype=c_dtype,
|
c_dtype=c_dtype,
|
||||||
|
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -154,6 +164,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
use_per_token_if_dynamic: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -184,6 +195,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.custom_routing_function = custom_routing_function
|
self.custom_routing_function = custom_routing_function
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
||||||
@@ -227,6 +239,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||||
hidden_states.device,
|
hidden_states.device,
|
||||||
use_flashinfer=False, # TODO: use flashinfer
|
use_flashinfer=False, # TODO: use flashinfer
|
||||||
|
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
@@ -259,12 +272,16 @@ class EPMoE(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||||
max_value = (
|
if self.use_per_token_if_dynamic:
|
||||||
torch.max(hidden_states)
|
max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
|
||||||
.repeat(self.num_experts_per_partition)
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||||
.to(torch.float32)
|
else:
|
||||||
)
|
max_value = (
|
||||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
torch.max(hidden_states)
|
||||||
|
.repeat(self.num_experts_per_partition)
|
||||||
|
.to(torch.float32)
|
||||||
|
)
|
||||||
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||||
|
|
||||||
# PreReorder
|
# PreReorder
|
||||||
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
||||||
@@ -278,9 +295,27 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.top_k,
|
self.top_k,
|
||||||
hidden_states.shape[1],
|
hidden_states.shape[1],
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
|
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||||
)
|
)
|
||||||
dispose_tensor(hidden_states)
|
dispose_tensor(hidden_states)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.activation_scheme == "dynamic"
|
||||||
|
and not self.use_block_quant
|
||||||
|
and self.use_per_token_if_dynamic
|
||||||
|
):
|
||||||
|
scale = torch.empty(
|
||||||
|
hidden_states_shape[0] * self.top_k,
|
||||||
|
device=hidden_states_device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
scale[src2dst] = (
|
||||||
|
self.w13_input_scale.unsqueeze(1)
|
||||||
|
.expand(hidden_states_shape[0], self.top_k)
|
||||||
|
.reshape(-1)
|
||||||
|
)
|
||||||
|
self.w13_input_scale = scale
|
||||||
|
|
||||||
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
||||||
weight_indices_cur_rank = torch.arange(
|
weight_indices_cur_rank = torch.arange(
|
||||||
0,
|
0,
|
||||||
@@ -310,21 +345,24 @@ class EPMoE(torch.nn.Module):
|
|||||||
del gateup_input
|
del gateup_input
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
down_input = torch.empty(
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||||
gateup_output.shape[0],
|
self.w2_input_scale = None
|
||||||
gateup_output.shape[1] // 2,
|
down_input = torch.empty(
|
||||||
device=gateup_output.device,
|
gateup_output.shape[0],
|
||||||
dtype=(
|
gateup_output.shape[1] // 2,
|
||||||
self.fp8_dtype
|
device=gateup_output.device,
|
||||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
dtype=hidden_states_dtype,
|
||||||
else hidden_states_dtype
|
)
|
||||||
),
|
else:
|
||||||
)
|
down_input = torch.empty(
|
||||||
if self.w2_input_scale is None and not self.use_block_quant:
|
gateup_output.shape[0],
|
||||||
self.w2_input_scale = torch.ones(
|
gateup_output.shape[1] // 2,
|
||||||
self.num_experts_per_partition,
|
device=gateup_output.device,
|
||||||
dtype=torch.float32,
|
dtype=(
|
||||||
device=hidden_states_device,
|
self.fp8_dtype
|
||||||
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||||
|
else hidden_states_dtype
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.activation == "silu":
|
if self.activation == "silu":
|
||||||
@@ -353,6 +391,16 @@ class EPMoE(torch.nn.Module):
|
|||||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||||
del gateup_output
|
del gateup_output
|
||||||
|
|
||||||
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||||
|
if self.use_per_token_if_dynamic:
|
||||||
|
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
||||||
|
else:
|
||||||
|
self.w2_input_scale = torch.ones(
|
||||||
|
self.num_experts_per_partition,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=hidden_states_device,
|
||||||
|
)
|
||||||
|
|
||||||
# GroupGemm-1
|
# GroupGemm-1
|
||||||
down_output = torch.empty(
|
down_output = torch.empty(
|
||||||
down_input.shape[0],
|
down_input.shape[0],
|
||||||
|
|||||||
Reference in New Issue
Block a user