feat: support DeepSeek-R1-W4AFP8 model with ep-moe mode (#7762)
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
This commit is contained in:
281
python/sglang/test/test_cutlass_w4a8_moe.py
Normal file
281
python/sglang/test/test_cutlass_w4a8_moe.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
|
||||
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
if int4_values_interleaved.shape[-1] % 2 != 0:
|
||||
raise ValueError(
|
||||
"the last dim size of int4_values_interleaved tensor must be even."
|
||||
)
|
||||
|
||||
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
|
||||
|
||||
low_nibbles = input_tensor_int8[..., 0::2]
|
||||
high_nibbles = input_tensor_int8[..., 1::2]
|
||||
|
||||
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
|
||||
|
||||
return packed_tensor.to(torch.int8)
|
||||
|
||||
|
||||
def pack_interleave(num_experts, ref_weight, ref_scale):
|
||||
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
||||
|
||||
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
||||
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
||||
w_q = w_q.contiguous()
|
||||
|
||||
scale_interleaved = ref_scale.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
|
||||
) # [E, N, K/4, 4]
|
||||
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
||||
scale_interleaved = scale_interleaved.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
|
||||
) # [E, K/4, N*4]
|
||||
w_scale = scale_interleaved.contiguous()
|
||||
|
||||
return w_q, w_scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("N", [2048])
|
||||
@pytest.mark.parametrize("K", [7168])
|
||||
@pytest.mark.parametrize("E", [256])
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.parametrize("topk", [8])
|
||||
@pytest.mark.parametrize("group_size", [128])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
||||
local_e = E // ep_size
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
a = torch.ones((M, K), dtype=dtype, device="cuda") * 0.001
|
||||
ref_weight_1 = torch.ones((local_e, N * 2, K), dtype=torch.int8, device="cuda")
|
||||
ref_weight_2 = torch.ones((local_e, K, N), dtype=torch.int8, device="cuda")
|
||||
a1_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
||||
a2_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
||||
scale_1 = torch.ones(
|
||||
(local_e, N * 2, K // group_size), dtype=dtype, device="cuda"
|
||||
)
|
||||
scale_2 = torch.ones((local_e, K, N // group_size), dtype=dtype, device="cuda")
|
||||
else:
|
||||
a = torch.randn(M, K, dtype=dtype, device="cuda")
|
||||
ref_weight_1 = torch.randint(
|
||||
-8, 8, (local_e, N * 2, K), dtype=torch.int8, device="cuda"
|
||||
)
|
||||
ref_weight_2 = torch.randint(
|
||||
-8, 8, (local_e, K, N), dtype=torch.int8, device="cuda"
|
||||
)
|
||||
affine_coeff = 0.005
|
||||
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
||||
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
||||
scale_1 = (
|
||||
torch.randn(local_e, N * 2, K // group_size, dtype=dtype, device="cuda")
|
||||
* affine_coeff
|
||||
)
|
||||
scale_2 = (
|
||||
torch.randn(local_e, K, N // group_size, dtype=dtype, device="cuda")
|
||||
* affine_coeff
|
||||
)
|
||||
|
||||
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
|
||||
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
|
||||
|
||||
device = "cuda"
|
||||
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
||||
c_strides1 = torch.full((local_e, 3), 2 * N, device=device, dtype=torch.int64)
|
||||
a_strides2 = torch.full((local_e, 3), N, device=device, dtype=torch.int64)
|
||||
c_strides2 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
||||
b_strides1 = a_strides1
|
||||
s_strides13 = c_strides1
|
||||
b_strides2 = a_strides2
|
||||
s_strides2 = c_strides2
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype, device=device)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||
expert_map[local_e:] = E
|
||||
|
||||
output = cutlass_moe(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
a_strides1,
|
||||
b_strides1,
|
||||
c_strides1,
|
||||
a_strides2,
|
||||
b_strides2,
|
||||
c_strides2,
|
||||
s_strides13,
|
||||
s_strides2,
|
||||
0,
|
||||
local_e - 1,
|
||||
E,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
ref_output = ref(
|
||||
a,
|
||||
local_e,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ref_weight_1,
|
||||
ref_weight_2,
|
||||
scale_1,
|
||||
scale_2,
|
||||
has_pre_quant=True,
|
||||
has_alpha=True,
|
||||
pre_quant_scale_1=a1_scale,
|
||||
pre_quant_scale_2=a2_scale,
|
||||
alpha_1=a1_scale,
|
||||
alpha_2=a2_scale,
|
||||
)
|
||||
|
||||
# compare
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# compare final output
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||
print("SUCCESS: Final output tensors are close.")
|
||||
|
||||
|
||||
def cutlass_moe(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids_: torch.Tensor,
|
||||
a_strides1: torch.Tensor,
|
||||
b_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
a_strides2: torch.Tensor,
|
||||
b_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
s_strides13: torch.Tensor,
|
||||
s_strides2: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
E: int,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
):
|
||||
local_topk_ids = topk_ids_
|
||||
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
|
||||
device = a.device
|
||||
|
||||
local_num_experts = end_expert_id - start_expert_id + 1
|
||||
expert_offsets = torch.empty(
|
||||
(local_num_experts + 1), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes1 = torch.empty(
|
||||
(local_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes2 = torch.empty(
|
||||
(local_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
return cutlass_w4a8_moe(
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
E,
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids_,
|
||||
local_topk_ids,
|
||||
a_strides1,
|
||||
b_strides1,
|
||||
c_strides1,
|
||||
a_strides2,
|
||||
b_strides2,
|
||||
c_strides2,
|
||||
s_strides13,
|
||||
s_strides2,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def ref(
|
||||
x: torch.Tensor,
|
||||
num_experts: int,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ref_weight_1: torch.Tensor,
|
||||
ref_weight_2: torch.Tensor,
|
||||
ref_weight_scale_1: torch.Tensor,
|
||||
ref_weight_scale_2: torch.Tensor,
|
||||
has_pre_quant: bool = False,
|
||||
has_alpha: bool = False,
|
||||
pre_quant_scale_1: Optional[torch.Tensor] = None,
|
||||
pre_quant_scale_2: Optional[torch.Tensor] = None,
|
||||
alpha_1: Optional[torch.Tensor] = None,
|
||||
alpha_2: Optional[torch.Tensor] = None,
|
||||
):
|
||||
results = torch.zeros_like(x)
|
||||
dtype = x.dtype
|
||||
for e_idx in range(num_experts):
|
||||
mask = topk_ids == e_idx
|
||||
activated_tokens = mask.sum(1).bool()
|
||||
act = x[activated_tokens, :]
|
||||
if act.shape[0] == 0:
|
||||
continue
|
||||
final_scale = (topk_weights * mask).sum(1)[activated_tokens].unsqueeze(1)
|
||||
|
||||
act = (
|
||||
torch.clamp((act / pre_quant_scale_1.float()), -448.0, 448.0)
|
||||
.to(torch.float8_e4m3fn)
|
||||
.to(dtype)
|
||||
)
|
||||
w3_w1 = ref_weight_1[e_idx]
|
||||
ref_w_scale_repeat = (
|
||||
ref_weight_scale_1[e_idx].repeat_interleave(128, dim=1).to(float)
|
||||
)
|
||||
w3_w1 = (w3_w1.to(float) * ref_w_scale_repeat).to(dtype)
|
||||
fc1 = ((torch.matmul(act, w3_w1.T)) * alpha_1).to(torch.float16)
|
||||
|
||||
gate, fc1 = fc1.chunk(2, dim=-1)
|
||||
fc1 = fc1 * torch.nn.functional.silu(gate)
|
||||
act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
|
||||
act = act.to(dtype)
|
||||
|
||||
w2 = ref_weight_2[e_idx]
|
||||
ref_w_scale_repeat = (
|
||||
ref_weight_scale_2[e_idx].repeat_interleave(128, dim=1).to(float)
|
||||
)
|
||||
w2 = (w2.to(float) * ref_w_scale_repeat).to(dtype)
|
||||
fc2 = (torch.matmul(act, w2.T) * alpha_2).to(torch.float16)
|
||||
|
||||
results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype)
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user