[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)
Co-authored-by: yuhyao <827623970@qq.com>
This commit is contained in:
@@ -27,12 +27,18 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
|
||||
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
||||
w_q = w_q.contiguous()
|
||||
|
||||
alignment = 4 if k % 512 == 0 else 1
|
||||
scale_interleaved = ref_scale.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
|
||||
ref_scale.shape[0],
|
||||
ref_scale.shape[1],
|
||||
(ref_scale.shape[2] // alignment),
|
||||
alignment,
|
||||
) # [E, N, K/4, 4]
|
||||
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
||||
scale_interleaved = scale_interleaved.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
|
||||
ref_scale.shape[0],
|
||||
ref_scale.shape[2] // alignment,
|
||||
ref_scale.shape[1] * alignment,
|
||||
) # [E, K/4, N*4]
|
||||
w_scale = scale_interleaved.contiguous()
|
||||
|
||||
@@ -137,8 +143,8 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("k", [512, 1024])
|
||||
@pytest.mark.parametrize("n", [1024, 2048])
|
||||
@pytest.mark.parametrize("k", [256, 512, 1024])
|
||||
@pytest.mark.parametrize("n", [1024, 2048, 7168])
|
||||
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
||||
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user