feat: support DeepSeek-R1-W4AFP8 model with ep-moe mode (#7762)

Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
This commit is contained in:
SijiaYang
2025-07-08 05:47:21 +08:00
committed by GitHub
parent 6a6e0bb7fd
commit cb9d91ea8a
10 changed files with 1006 additions and 9 deletions

View File

@@ -0,0 +1,281 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.topk import select_experts
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
if int4_values_interleaved.shape[-1] % 2 != 0:
raise ValueError(
"the last dim size of int4_values_interleaved tensor must be even."
)
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
low_nibbles = input_tensor_int8[..., 0::2]
high_nibbles = input_tensor_int8[..., 1::2]
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
return packed_tensor.to(torch.int8)
def pack_interleave(num_experts, ref_weight, ref_scale):
n, k = ref_weight.shape[1], ref_weight.shape[2]
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
w_q = w_q.contiguous()
scale_interleaved = ref_scale.reshape(
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous()
return w_q, w_scale
@pytest.mark.parametrize("M", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("N", [2048])
@pytest.mark.parametrize("K", [7168])
@pytest.mark.parametrize("E", [256])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.parametrize("topk", [8])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
local_e = E // ep_size
debug = False
if debug:
a = torch.ones((M, K), dtype=dtype, device="cuda") * 0.001
ref_weight_1 = torch.ones((local_e, N * 2, K), dtype=torch.int8, device="cuda")
ref_weight_2 = torch.ones((local_e, K, N), dtype=torch.int8, device="cuda")
a1_scale = torch.ones(1, dtype=torch.float32, device="cuda")
a2_scale = torch.ones(1, dtype=torch.float32, device="cuda")
scale_1 = torch.ones(
(local_e, N * 2, K // group_size), dtype=dtype, device="cuda"
)
scale_2 = torch.ones((local_e, K, N // group_size), dtype=dtype, device="cuda")
else:
a = torch.randn(M, K, dtype=dtype, device="cuda")
ref_weight_1 = torch.randint(
-8, 8, (local_e, N * 2, K), dtype=torch.int8, device="cuda"
)
ref_weight_2 = torch.randint(
-8, 8, (local_e, K, N), dtype=torch.int8, device="cuda"
)
affine_coeff = 0.005
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
scale_1 = (
torch.randn(local_e, N * 2, K // group_size, dtype=dtype, device="cuda")
* affine_coeff
)
scale_2 = (
torch.randn(local_e, K, N // group_size, dtype=dtype, device="cuda")
* affine_coeff
)
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
device = "cuda"
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
c_strides1 = torch.full((local_e, 3), 2 * N, device=device, dtype=torch.int64)
a_strides2 = torch.full((local_e, 3), N, device=device, dtype=torch.int64)
c_strides2 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
b_strides1 = a_strides1
s_strides13 = c_strides1
b_strides2 = a_strides2
s_strides2 = c_strides2
score = torch.randn((M, E), dtype=dtype, device=device)
topk_weights, topk_ids = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
use_grouped_topk=False,
renormalize=False,
)
expert_map = torch.arange(E, dtype=torch.int32, device=device)
expert_map[local_e:] = E
output = cutlass_moe(
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
a_strides1,
b_strides1,
c_strides1,
a_strides2,
b_strides2,
c_strides2,
s_strides13,
s_strides2,
0,
local_e - 1,
E,
a1_scale,
a2_scale,
expert_map,
)
ref_output = ref(
a,
local_e,
topk_weights,
topk_ids,
ref_weight_1,
ref_weight_2,
scale_1,
scale_2,
has_pre_quant=True,
has_alpha=True,
pre_quant_scale_1=a1_scale,
pre_quant_scale_2=a2_scale,
alpha_1=a1_scale,
alpha_2=a2_scale,
)
# compare
torch.cuda.synchronize()
# compare final output
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
print("SUCCESS: Final output tensors are close.")
def cutlass_moe(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides2: torch.Tensor,
c_strides2: torch.Tensor,
s_strides13: torch.Tensor,
s_strides2: torch.Tensor,
start_expert_id: int,
end_expert_id: int,
E: int,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
):
local_topk_ids = topk_ids_
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
device = a.device
local_num_experts = end_expert_id - start_expert_id + 1
expert_offsets = torch.empty(
(local_num_experts + 1), dtype=torch.int32, device=device
)
problem_sizes1 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device
)
return cutlass_w4a8_moe(
start_expert_id,
end_expert_id,
E,
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids_,
local_topk_ids,
a_strides1,
b_strides1,
c_strides1,
a_strides2,
b_strides2,
c_strides2,
s_strides13,
s_strides2,
expert_offsets,
problem_sizes1,
problem_sizes2,
a1_scale,
a2_scale,
apply_router_weight_on_input,
)
def ref(
x: torch.Tensor,
num_experts: int,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ref_weight_1: torch.Tensor,
ref_weight_2: torch.Tensor,
ref_weight_scale_1: torch.Tensor,
ref_weight_scale_2: torch.Tensor,
has_pre_quant: bool = False,
has_alpha: bool = False,
pre_quant_scale_1: Optional[torch.Tensor] = None,
pre_quant_scale_2: Optional[torch.Tensor] = None,
alpha_1: Optional[torch.Tensor] = None,
alpha_2: Optional[torch.Tensor] = None,
):
results = torch.zeros_like(x)
dtype = x.dtype
for e_idx in range(num_experts):
mask = topk_ids == e_idx
activated_tokens = mask.sum(1).bool()
act = x[activated_tokens, :]
if act.shape[0] == 0:
continue
final_scale = (topk_weights * mask).sum(1)[activated_tokens].unsqueeze(1)
act = (
torch.clamp((act / pre_quant_scale_1.float()), -448.0, 448.0)
.to(torch.float8_e4m3fn)
.to(dtype)
)
w3_w1 = ref_weight_1[e_idx]
ref_w_scale_repeat = (
ref_weight_scale_1[e_idx].repeat_interleave(128, dim=1).to(float)
)
w3_w1 = (w3_w1.to(float) * ref_w_scale_repeat).to(dtype)
fc1 = ((torch.matmul(act, w3_w1.T)) * alpha_1).to(torch.float16)
gate, fc1 = fc1.chunk(2, dim=-1)
fc1 = fc1 * torch.nn.functional.silu(gate)
act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
act = act.to(dtype)
w2 = ref_weight_2[e_idx]
ref_w_scale_repeat = (
ref_weight_scale_2[e_idx].repeat_interleave(128, dim=1).to(float)
)
w2 = (w2.to(float) * ref_w_scale_repeat).to(dtype)
fc2 = (torch.matmul(act, w2.T) * alpha_2).to(torch.float16)
results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype)
return results