Expert Parallelism (EP) Support for DeepSeek V3/R1 (#3602)
Co-authored-by: laixin <xielx@shanghaitech.edu.cn> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: laixin <q865809639@gmail.com>
This commit is contained in:
361
python/sglang/test/test_block_fp8_ep.py
Normal file
361
python/sglang/test/test_block_fp8_ep.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
grouped_gemm_triton,
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_triton_kernel,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
|
||||
# For test
|
||||
def ep_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
# ep config
|
||||
num_experts: int = 256,
|
||||
fp8_dtype: torch.types = torch.float8_e4m3fn,
|
||||
num_experts_per_partition: int = 128,
|
||||
start_expert_id: int = 0,
|
||||
end_expert_id: int = 127,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
w1_scale_inv: Optional[torch.Tensor] = None,
|
||||
w2_scale_inv: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
use_blockwise_fp8 = block_shape is not None
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
# correction_bias=correction_bias, #skip this in test
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
|
||||
|
||||
gateup_input = torch.empty(
|
||||
(int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=(
|
||||
fp8_dtype
|
||||
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
|
||||
if use_fp8_w8a8 and not use_blockwise_fp8:
|
||||
max_value = (
|
||||
torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
|
||||
)
|
||||
w1_input_scale = max_value / torch.finfo(fp8_dtype).max
|
||||
else:
|
||||
w1_input_scale = None
|
||||
|
||||
# PreReorder
|
||||
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
w1_input_scale,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
top_k,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
|
||||
weight_indices_cur_rank = torch.arange(
|
||||
0,
|
||||
num_experts_per_partition,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_output = torch.empty(
|
||||
gateup_input.shape[0],
|
||||
w1.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
gateup_output = grouped_gemm_triton(
|
||||
a=gateup_input,
|
||||
b=w1,
|
||||
c=gateup_output,
|
||||
batch_size=num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
scale_a=w1_input_scale,
|
||||
scale_b=w1_scale_inv,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1] // 2,
|
||||
device=gateup_output.device,
|
||||
dtype=(
|
||||
fp8_dtype
|
||||
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
if use_fp8_w8a8 and not use_blockwise_fp8:
|
||||
w2_input_scale = torch.ones(
|
||||
num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
w2_input_scale = None
|
||||
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
w2_input_scale,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
w2.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
down_output = grouped_gemm_triton(
|
||||
a=down_input,
|
||||
b=w2,
|
||||
c=down_output,
|
||||
batch_size=num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
scale_a=w2_input_scale,
|
||||
scale_b=w2_scale_inv,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# PostReorder
|
||||
output = torch.empty_like(hidden_states)
|
||||
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
top_k,
|
||||
hidden_states.size(1),
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
# test util
|
||||
def block_dequant(
|
||||
x_q_block: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
block_size: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function converts block-wise quantization to tensor-wise quantization.
|
||||
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
||||
and the block size.
|
||||
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
|
||||
Note only float8 is supported for now.
|
||||
"""
|
||||
|
||||
# process 3D tensor
|
||||
if x_q_block.dim() == 3:
|
||||
batch_size = x_q_block.size(0)
|
||||
return torch.stack(
|
||||
[block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
|
||||
)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n, k = x_q_block.shape
|
||||
n_tiles = (n + block_n - 1) // block_n
|
||||
k_tiles = (k + block_k - 1) // block_k
|
||||
assert n_tiles == x_s.shape[0]
|
||||
assert k_tiles == x_s.shape[1]
|
||||
|
||||
x_dq_block = x_q_block.to(torch.float32)
|
||||
|
||||
x_dq_block_tiles = [
|
||||
[
|
||||
x_dq_block[
|
||||
j * block_n : min((j + 1) * block_n, n),
|
||||
i * block_k : min((i + 1) * block_k, k),
|
||||
]
|
||||
for i in range(k_tiles)
|
||||
]
|
||||
for j in range(n_tiles)
|
||||
]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
||||
|
||||
return x_dq_block
|
||||
|
||||
|
||||
class TestW8A8BlockFP8EPMoE(unittest.TestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 222, 1024, 2048]
|
||||
N = [128, 1024, 2048]
|
||||
K = [256, 4096, 5120]
|
||||
E = [8, 16]
|
||||
ep_size = [2, 4]
|
||||
TOP_KS = [2, 4]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _w8a8_block_fp8_ep_moe(
|
||||
self, M, N, K, E, ep_size, topk, block_size, dtype, seed
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
|
||||
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
|
||||
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1_s = (
|
||||
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
w2_s = (
|
||||
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
|
||||
w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
num_experts_per_partition = E // ep_size
|
||||
cur_rank = random.randint(0, ep_size - 1)
|
||||
start_id = cur_rank * num_experts_per_partition
|
||||
end_id = start_id + num_experts_per_partition - 1
|
||||
|
||||
with torch.inference_mode():
|
||||
out = ep_moe(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale_inv=w1_s,
|
||||
w2_scale_inv=w2_s,
|
||||
block_shape=block_size,
|
||||
num_experts=E,
|
||||
num_experts_per_partition=num_experts_per_partition,
|
||||
start_expert_id=start_id,
|
||||
end_expert_id=end_id,
|
||||
)
|
||||
ref_out = ep_moe(
|
||||
hidden_states=a,
|
||||
w1=w1_ref,
|
||||
w2=w2_ref,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale_inv=None,
|
||||
w2_scale_inv=None,
|
||||
block_shape=None,
|
||||
num_experts=E,
|
||||
num_experts_per_partition=num_experts_per_partition,
|
||||
start_expert_id=start_id,
|
||||
end_expert_id=end_id,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||
/ (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
|
||||
< 0.06
|
||||
)
|
||||
|
||||
def test_w8a8_block_fp8_ep_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.E,
|
||||
self.ep_size,
|
||||
self.TOP_KS,
|
||||
self.BLOCK_SIZE,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
ep_size=params[4],
|
||||
topk=params[5],
|
||||
block_size=params[6],
|
||||
dtype=params[7],
|
||||
seed=params[8],
|
||||
):
|
||||
self._w8a8_block_fp8_ep_moe(*params)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user