[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)

Co-authored-by: yuhyao <827623970@qq.com>
This commit is contained in:
chenxj
2025-09-02 13:17:26 +08:00
committed by GitHub
parent 21e1bc475c
commit d4a938417d
11 changed files with 291 additions and 120 deletions

View File

@@ -27,12 +27,18 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
w_q = w_q.contiguous()
alignment = 4 if k % 512 == 0 else 1
scale_interleaved = ref_scale.reshape(
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
ref_scale.shape[0],
ref_scale.shape[1],
(ref_scale.shape[2] // alignment),
alignment,
) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
ref_scale.shape[0],
ref_scale.shape[2] // alignment,
ref_scale.shape[1] * alignment,
) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous()
@@ -137,8 +143,8 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
reason="cutlass_w4a8_moe_mm is only supported on sm90",
)
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
@pytest.mark.parametrize("k", [512, 1024])
@pytest.mark.parametrize("n", [1024, 2048])
@pytest.mark.parametrize("k", [256, 512, 1024])
@pytest.mark.parametrize("n", [1024, 2048, 7168])
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
torch.manual_seed(0)