Expert Parallelism (EP) Support for DeepSeek V3/R1 (#3602)
Co-authored-by: laixin <xielx@shanghaitech.edu.cn> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: laixin <q865809639@gmail.com>
This commit is contained in:
@@ -1,10 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
|
|
||||||
|
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||||
|
if _is_cuda:
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
sglang_per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -218,12 +225,19 @@ def grouped_gemm_triton_kernel(
|
|||||||
seg_indptr,
|
seg_indptr,
|
||||||
weight_indices,
|
weight_indices,
|
||||||
m_num_tiles_indptr,
|
m_num_tiles_indptr,
|
||||||
use_fp8_w8a8,
|
|
||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
|
use_fp8_w8a8: tl.constexpr,
|
||||||
|
group_n: tl.constexpr,
|
||||||
|
group_k: tl.constexpr,
|
||||||
a_stride_0: tl.constexpr,
|
a_stride_0: tl.constexpr,
|
||||||
b_stride_0: tl.constexpr,
|
b_stride_0: tl.constexpr,
|
||||||
b_stride_1: tl.constexpr,
|
b_stride_1: tl.constexpr,
|
||||||
|
as_stride_0: tl.constexpr,
|
||||||
|
as_stride_1: tl.constexpr,
|
||||||
|
bs_stride_0: tl.constexpr,
|
||||||
|
bs_stride_2: tl.constexpr,
|
||||||
|
bs_stride_1: tl.constexpr,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
@@ -260,6 +274,12 @@ def grouped_gemm_triton_kernel(
|
|||||||
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
||||||
+ offs_k[None, :]
|
+ offs_k[None, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if group_k > 0 and group_n > 0:
|
||||||
|
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
|
||||||
|
offs_bsn = (n_range_start + offs_bn) // group_n
|
||||||
|
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
|
||||||
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
a_tile = tl.load(
|
a_tile = tl.load(
|
||||||
@@ -268,14 +288,23 @@ def grouped_gemm_triton_kernel(
|
|||||||
b_tile = tl.load(
|
b_tile = tl.load(
|
||||||
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
||||||
)
|
)
|
||||||
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
|
||||||
|
if group_k > 0 and group_n > 0:
|
||||||
|
k_start = k * BLOCK_SIZE_K
|
||||||
|
offs_ks = k_start // group_k
|
||||||
|
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
|
||||||
|
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
|
||||||
|
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
|
||||||
|
else:
|
||||||
|
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
||||||
a_ptr += BLOCK_SIZE_K
|
a_ptr += BLOCK_SIZE_K
|
||||||
b_ptr += BLOCK_SIZE_K
|
b_ptr += BLOCK_SIZE_K
|
||||||
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
||||||
scale_a_value = tl.load(scale_a + expert_id)
|
scale_a_value = tl.load(scale_a + expert_id)
|
||||||
scale_b_value = tl.load(scale_b + expert_id)
|
scale_b_value = tl.load(scale_b + expert_id)
|
||||||
accumulator *= scale_a_value * scale_b_value
|
accumulator *= scale_a_value * scale_b_value
|
||||||
|
|
||||||
c_tile = accumulator.to(c_dtype)
|
c_tile = accumulator.to(c_dtype)
|
||||||
|
|
||||||
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
||||||
@@ -307,14 +336,29 @@ def grouped_gemm_triton(
|
|||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
scale_a: torch.Tensor = None,
|
scale_a: torch.Tensor = None,
|
||||||
scale_b: torch.Tensor = None,
|
scale_b: torch.Tensor = None,
|
||||||
|
block_shape: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
assert weight_column_major == True # TODO: more
|
assert weight_column_major == True # TODO: more
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8 and block_shape is None:
|
||||||
assert scale_a is not None and scale_b is not None
|
assert scale_a is not None and scale_b is not None
|
||||||
|
|
||||||
|
if block_shape is not None:
|
||||||
|
assert len(block_shape) == 2
|
||||||
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
|
if _is_cuda:
|
||||||
|
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
|
||||||
|
else:
|
||||||
|
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
||||||
|
|
||||||
|
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
||||||
|
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
||||||
|
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
|
||||||
|
|
||||||
|
# TODO: adjust config or tune kernel
|
||||||
|
# Reduce block size to prevent L40 shared memory overflow.
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 32,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,12 +382,19 @@ def grouped_gemm_triton(
|
|||||||
seg_indptr,
|
seg_indptr,
|
||||||
weight_indices,
|
weight_indices,
|
||||||
m_num_tiles_indptr,
|
m_num_tiles_indptr,
|
||||||
use_fp8_w8a8,
|
|
||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
0 if block_shape is None else block_shape[0],
|
||||||
|
0 if block_shape is None else block_shape[1],
|
||||||
a.stride(0),
|
a.stride(0),
|
||||||
b.stride(0),
|
b.stride(0),
|
||||||
b.stride(1),
|
b.stride(1),
|
||||||
|
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
|
||||||
|
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
|
||||||
|
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
||||||
|
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
|
||||||
|
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
run_moe_ep_preproess,
|
run_moe_ep_preproess,
|
||||||
silu_and_mul_triton_kernel,
|
silu_and_mul_triton_kernel,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -61,6 +62,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
scale_a: torch.Tensor = None,
|
scale_a: torch.Tensor = None,
|
||||||
scale_b: torch.Tensor = None,
|
scale_b: torch.Tensor = None,
|
||||||
|
block_shape: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
if self.use_flashinfer:
|
if self.use_flashinfer:
|
||||||
# TODO: flashinfer
|
# TODO: flashinfer
|
||||||
@@ -87,6 +89,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -147,12 +150,20 @@ class EPMoE(torch.nn.Module):
|
|||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
||||||
self.use_fp8_w8a8 = False
|
self.use_fp8_w8a8 = False
|
||||||
|
self.use_block_quant = False
|
||||||
|
self.block_shape = None
|
||||||
self.activation_scheme = None
|
self.activation_scheme = None
|
||||||
else:
|
else:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
||||||
quant_config
|
quant_config
|
||||||
)
|
)
|
||||||
self.use_fp8_w8a8 = True
|
self.use_fp8_w8a8 = True
|
||||||
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||||
|
self.block_shape = (
|
||||||
|
self.quant_method.quant_config.weight_block_size
|
||||||
|
if self.use_block_quant
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
self.activation_scheme = quant_config.activation_scheme
|
self.activation_scheme = quant_config.activation_scheme
|
||||||
|
|
||||||
@@ -173,7 +184,8 @@ class EPMoE(torch.nn.Module):
|
|||||||
|
|
||||||
if self.grouped_gemm_runner is None:
|
if self.grouped_gemm_runner is None:
|
||||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||||
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
hidden_states.device,
|
||||||
|
use_flashinfer=False, # TODO: use flashinfer
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
@@ -195,9 +207,13 @@ class EPMoE(torch.nn.Module):
|
|||||||
gateup_input = torch.empty(
|
gateup_input = torch.empty(
|
||||||
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
|
dtype=(
|
||||||
|
self.fp8_dtype
|
||||||
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||||
|
else hidden_states.dtype
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if self.activation_scheme == "dynamic":
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||||
max_value = (
|
max_value = (
|
||||||
torch.max(hidden_states)
|
torch.max(hidden_states)
|
||||||
.repeat(self.num_experts_per_partition)
|
.repeat(self.num_experts_per_partition)
|
||||||
@@ -243,7 +259,12 @@ class EPMoE(torch.nn.Module):
|
|||||||
weight_indices=weight_indices_cur_rank,
|
weight_indices=weight_indices_cur_rank,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
scale_a=self.w13_input_scale,
|
scale_a=self.w13_input_scale,
|
||||||
scale_b=self.w13_weight_scale,
|
scale_b=(
|
||||||
|
self.w13_weight_scale_inv
|
||||||
|
if self.use_block_quant
|
||||||
|
else self.w13_weight_scale
|
||||||
|
),
|
||||||
|
block_shape=self.block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@@ -251,9 +272,13 @@ class EPMoE(torch.nn.Module):
|
|||||||
gateup_output.shape[0],
|
gateup_output.shape[0],
|
||||||
gateup_output.shape[1] // 2,
|
gateup_output.shape[1] // 2,
|
||||||
device=gateup_output.device,
|
device=gateup_output.device,
|
||||||
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
|
dtype=(
|
||||||
|
self.fp8_dtype
|
||||||
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||||
|
else hidden_states.dtype
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if self.w2_input_scale is None:
|
if self.w2_input_scale is None and not self.use_block_quant:
|
||||||
self.w2_input_scale = torch.ones(
|
self.w2_input_scale = torch.ones(
|
||||||
self.num_experts_per_partition,
|
self.num_experts_per_partition,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@@ -291,7 +316,12 @@ class EPMoE(torch.nn.Module):
|
|||||||
weight_indices=weight_indices_cur_rank,
|
weight_indices=weight_indices_cur_rank,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
scale_a=self.w2_input_scale,
|
scale_a=self.w2_input_scale,
|
||||||
scale_b=self.w2_weight_scale,
|
scale_b=(
|
||||||
|
self.w2_weight_scale_inv
|
||||||
|
if self.use_block_quant
|
||||||
|
else self.w2_weight_scale
|
||||||
|
),
|
||||||
|
block_shape=self.block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
# PostReorder
|
# PostReorder
|
||||||
@@ -358,7 +388,11 @@ class EPMoE(torch.nn.Module):
|
|||||||
# Special case for fp8 scales.
|
# Special case for fp8 scales.
|
||||||
if "scale" in weight_name:
|
if "scale" in weight_name:
|
||||||
self._load_fp8_scale(
|
self._load_fp8_scale(
|
||||||
param.data, loaded_weight, weight_name, shard_id, expert_id
|
param.data,
|
||||||
|
loaded_weight,
|
||||||
|
weight_name,
|
||||||
|
shard_id,
|
||||||
|
expert_id,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -395,18 +429,33 @@ class EPMoE(torch.nn.Module):
|
|||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
# Weight scales
|
# Weight scales
|
||||||
elif "weight_scale" in weight_name:
|
elif "weight_scale" in weight_name:
|
||||||
|
if self.use_block_quant:
|
||||||
|
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
||||||
|
if shard_id == "w1":
|
||||||
|
param_data[expert_id][
|
||||||
|
: (self.intermediate_size + block_n - 1) // block_n, :
|
||||||
|
] = loaded_weight
|
||||||
|
elif shard_id == "w3":
|
||||||
|
param_data[expert_id][
|
||||||
|
(self.intermediate_size + block_n - 1) // block_n :, :
|
||||||
|
] = loaded_weight
|
||||||
|
else: # w2
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
# If we are in merged column case (gate_up_proj)
|
# If we are in merged column case (gate_up_proj)
|
||||||
if shard_id in ("w1", "w3"):
|
|
||||||
# We have to keep the weight scales of w1 and w3 because
|
|
||||||
# we need to re-quantize w1/w3 weights after weight loading.
|
|
||||||
idx = 0 if shard_id == "w1" else 1
|
|
||||||
param_data[expert_id][idx] = loaded_weight
|
|
||||||
# If we are in the row parallel case (down_proj)
|
|
||||||
else:
|
else:
|
||||||
param_data[expert_id] = loaded_weight
|
if shard_id in ("w1", "w3"):
|
||||||
|
# We have to keep the weight scales of w1 and w3 because
|
||||||
|
# we need to re-quantize w1/w3 weights after weight loading.
|
||||||
|
idx = 0 if shard_id == "w1" else 1
|
||||||
|
param_data[expert_id][idx] = loaded_weight
|
||||||
|
|
||||||
|
# If we are in the row parallel case (down_proj)
|
||||||
|
else:
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -498,6 +547,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -512,6 +562,29 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
if self.block_quant:
|
||||||
|
block_n, block_k = (
|
||||||
|
self.quant_config.weight_block_size[0],
|
||||||
|
self.quant_config.weight_block_size[1],
|
||||||
|
)
|
||||||
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||||
|
# Required by collum parallel or enabling merged weights
|
||||||
|
if intermediate_size % block_n != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The output_size of gate's and up's weight = "
|
||||||
|
f"{intermediate_size} is not divisible by "
|
||||||
|
f"weight quantization block_n = {block_n}."
|
||||||
|
)
|
||||||
|
if tp_size > 1:
|
||||||
|
# Required by row parallel
|
||||||
|
if intermediate_size % block_k != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input_size of down's weight = "
|
||||||
|
f"{intermediate_size} is not divisible by "
|
||||||
|
f"weight quantization block_k = {block_k}."
|
||||||
|
)
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@@ -538,21 +611,49 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
# WEIGHT_SCALES
|
# WEIGHT_SCALES
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
if self.block_quant:
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
torch.ones(
|
||||||
requires_grad=False,
|
num_experts_per_partition,
|
||||||
)
|
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
(hidden_size + block_k - 1) // block_k,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
num_experts_per_partition,
|
||||||
|
(hidden_size + block_n - 1) // block_n,
|
||||||
|
(intermediate_size + block_k - 1) // block_k,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||||
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
|
else:
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
# Add the quantization method used (per tensor/grouped/channel)
|
# Add the quantization method used (per tensor/grouped/channel)
|
||||||
# to ensure the weight scales are loaded in properly
|
# to ensure the weight scales are loaded in properly
|
||||||
extra_weight_attrs.update({"quant_method": "tensor"})
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||||
|
if self.block_quant
|
||||||
|
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||||
|
)
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
# process_weights_after_loading()
|
# process_weights_after_loading()
|
||||||
|
|||||||
361
python/sglang/test/test_block_fp8_ep.py
Normal file
361
python/sglang/test/test_block_fp8_ep.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
import itertools
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
|
grouped_gemm_triton,
|
||||||
|
post_reorder_triton_kernel,
|
||||||
|
pre_reorder_triton_kernel,
|
||||||
|
run_moe_ep_preproess,
|
||||||
|
silu_and_mul_triton_kernel,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
|
|
||||||
|
# For test
|
||||||
|
def ep_moe(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
# ep config
|
||||||
|
num_experts: int = 256,
|
||||||
|
fp8_dtype: torch.types = torch.float8_e4m3fn,
|
||||||
|
num_experts_per_partition: int = 128,
|
||||||
|
start_expert_id: int = 0,
|
||||||
|
end_expert_id: int = 127,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
w1_scale_inv: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale_inv: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
use_blockwise_fp8 = block_shape is not None
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=top_k,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
# correction_bias=correction_bias, #skip this in test
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
|
||||||
|
|
||||||
|
gateup_input = torch.empty(
|
||||||
|
(int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=(
|
||||||
|
fp8_dtype
|
||||||
|
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
||||||
|
else hidden_states.dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_fp8_w8a8 and not use_blockwise_fp8:
|
||||||
|
max_value = (
|
||||||
|
torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
|
||||||
|
)
|
||||||
|
w1_input_scale = max_value / torch.finfo(fp8_dtype).max
|
||||||
|
else:
|
||||||
|
w1_input_scale = None
|
||||||
|
|
||||||
|
# PreReorder
|
||||||
|
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
||||||
|
hidden_states,
|
||||||
|
gateup_input,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
w1_input_scale,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
top_k,
|
||||||
|
hidden_states.shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
|
||||||
|
weight_indices_cur_rank = torch.arange(
|
||||||
|
0,
|
||||||
|
num_experts_per_partition,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
|
||||||
|
# GroupGemm-0
|
||||||
|
gateup_output = torch.empty(
|
||||||
|
gateup_input.shape[0],
|
||||||
|
w1.shape[1],
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
gateup_output = grouped_gemm_triton(
|
||||||
|
a=gateup_input,
|
||||||
|
b=w1,
|
||||||
|
c=gateup_output,
|
||||||
|
batch_size=num_experts_per_partition,
|
||||||
|
weight_column_major=True,
|
||||||
|
seg_indptr=seg_indptr_cur_rank,
|
||||||
|
weight_indices=weight_indices_cur_rank,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
scale_a=w1_input_scale,
|
||||||
|
scale_b=w1_scale_inv,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
down_input = torch.empty(
|
||||||
|
gateup_output.shape[0],
|
||||||
|
gateup_output.shape[1] // 2,
|
||||||
|
device=gateup_output.device,
|
||||||
|
dtype=(
|
||||||
|
fp8_dtype
|
||||||
|
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
||||||
|
else hidden_states.dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if use_fp8_w8a8 and not use_blockwise_fp8:
|
||||||
|
w2_input_scale = torch.ones(
|
||||||
|
num_experts_per_partition,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w2_input_scale = None
|
||||||
|
|
||||||
|
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
gateup_output.shape[1],
|
||||||
|
reorder_topk_ids,
|
||||||
|
w2_input_scale,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# GroupGemm-1
|
||||||
|
down_output = torch.empty(
|
||||||
|
down_input.shape[0],
|
||||||
|
w2.shape[1],
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_output = grouped_gemm_triton(
|
||||||
|
a=down_input,
|
||||||
|
b=w2,
|
||||||
|
c=down_output,
|
||||||
|
batch_size=num_experts_per_partition,
|
||||||
|
weight_column_major=True,
|
||||||
|
seg_indptr=seg_indptr_cur_rank,
|
||||||
|
weight_indices=weight_indices_cur_rank,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
scale_a=w2_input_scale,
|
||||||
|
scale_b=w2_scale_inv,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# PostReorder
|
||||||
|
output = torch.empty_like(hidden_states)
|
||||||
|
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
||||||
|
down_output,
|
||||||
|
output,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
top_k,
|
||||||
|
hidden_states.size(1),
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# test util
|
||||||
|
def block_dequant(
|
||||||
|
x_q_block: torch.Tensor,
|
||||||
|
x_s: torch.Tensor,
|
||||||
|
block_size: List[int],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""This function converts block-wise quantization to tensor-wise quantization.
|
||||||
|
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
||||||
|
and the block size.
|
||||||
|
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
|
||||||
|
Note only float8 is supported for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# process 3D tensor
|
||||||
|
if x_q_block.dim() == 3:
|
||||||
|
batch_size = x_q_block.size(0)
|
||||||
|
return torch.stack(
|
||||||
|
[block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
|
||||||
|
)
|
||||||
|
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
n, k = x_q_block.shape
|
||||||
|
n_tiles = (n + block_n - 1) // block_n
|
||||||
|
k_tiles = (k + block_k - 1) // block_k
|
||||||
|
assert n_tiles == x_s.shape[0]
|
||||||
|
assert k_tiles == x_s.shape[1]
|
||||||
|
|
||||||
|
x_dq_block = x_q_block.to(torch.float32)
|
||||||
|
|
||||||
|
x_dq_block_tiles = [
|
||||||
|
[
|
||||||
|
x_dq_block[
|
||||||
|
j * block_n : min((j + 1) * block_n, n),
|
||||||
|
i * block_k : min((i + 1) * block_k, k),
|
||||||
|
]
|
||||||
|
for i in range(k_tiles)
|
||||||
|
]
|
||||||
|
for j in range(n_tiles)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(k_tiles):
|
||||||
|
for j in range(n_tiles):
|
||||||
|
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
||||||
|
|
||||||
|
return x_dq_block
|
||||||
|
|
||||||
|
|
||||||
|
class TestW8A8BlockFP8EPMoE(unittest.TestCase):
|
||||||
|
DTYPES = [torch.half, torch.bfloat16]
|
||||||
|
M = [1, 222, 1024, 2048]
|
||||||
|
N = [128, 1024, 2048]
|
||||||
|
K = [256, 4096, 5120]
|
||||||
|
E = [8, 16]
|
||||||
|
ep_size = [2, 4]
|
||||||
|
TOP_KS = [2, 4]
|
||||||
|
BLOCK_SIZE = [[128, 128]]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise unittest.SkipTest("CUDA is not available")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
def _w8a8_block_fp8_ep_moe(
|
||||||
|
self, M, N, K, E, ep_size, topk, block_size, dtype, seed
|
||||||
|
):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
|
|
||||||
|
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
|
||||||
|
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
|
||||||
|
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||||
|
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||||
|
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||||
|
|
||||||
|
w1_s = (
|
||||||
|
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||||
|
* factor_for_scale
|
||||||
|
)
|
||||||
|
w2_s = (
|
||||||
|
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||||
|
* factor_for_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
|
||||||
|
w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
|
||||||
|
|
||||||
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
num_experts_per_partition = E // ep_size
|
||||||
|
cur_rank = random.randint(0, ep_size - 1)
|
||||||
|
start_id = cur_rank * num_experts_per_partition
|
||||||
|
end_id = start_id + num_experts_per_partition - 1
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
out = ep_moe(
|
||||||
|
hidden_states=a,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
router_logits=score,
|
||||||
|
top_k=topk,
|
||||||
|
renormalize=False,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale_inv=w1_s,
|
||||||
|
w2_scale_inv=w2_s,
|
||||||
|
block_shape=block_size,
|
||||||
|
num_experts=E,
|
||||||
|
num_experts_per_partition=num_experts_per_partition,
|
||||||
|
start_expert_id=start_id,
|
||||||
|
end_expert_id=end_id,
|
||||||
|
)
|
||||||
|
ref_out = ep_moe(
|
||||||
|
hidden_states=a,
|
||||||
|
w1=w1_ref,
|
||||||
|
w2=w2_ref,
|
||||||
|
router_logits=score,
|
||||||
|
top_k=topk,
|
||||||
|
renormalize=False,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale_inv=None,
|
||||||
|
w2_scale_inv=None,
|
||||||
|
block_shape=None,
|
||||||
|
num_experts=E,
|
||||||
|
num_experts_per_partition=num_experts_per_partition,
|
||||||
|
start_expert_id=start_id,
|
||||||
|
end_expert_id=end_id,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||||
|
/ (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
|
||||||
|
< 0.06
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_w8a8_block_fp8_ep_moe(self):
|
||||||
|
for params in itertools.product(
|
||||||
|
self.M,
|
||||||
|
self.N,
|
||||||
|
self.K,
|
||||||
|
self.E,
|
||||||
|
self.ep_size,
|
||||||
|
self.TOP_KS,
|
||||||
|
self.BLOCK_SIZE,
|
||||||
|
self.DTYPES,
|
||||||
|
self.SEEDS,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
M=params[0],
|
||||||
|
N=params[1],
|
||||||
|
K=params[2],
|
||||||
|
E=params[3],
|
||||||
|
ep_size=params[4],
|
||||||
|
topk=params[5],
|
||||||
|
block_size=params[6],
|
||||||
|
dtype=params[7],
|
||||||
|
seed=params[8],
|
||||||
|
):
|
||||||
|
self._w8a8_block_fp8_ep_moe(*params)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
Reference in New Issue
Block a user