Add DeepEP to CI PR Test (#5655)

Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>
This commit is contained in:
Jinyan Chen
2025-05-07 08:36:03 +08:00
committed by GitHub
parent aff584fa54
commit 8a828666a3
10 changed files with 1607 additions and 3 deletions

View File

@@ -96,6 +96,9 @@ suites = {
TestFile("test_verl_engine.py", 64),
],
"per-commit-8-gpu": [
TestFile("test_deepep_intranode.py", 50),
TestFile("test_deepep_low_latency.py", 50),
TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
TestFile("test_local_attn.py", 250),
TestFile("test_full_deepseek_v3.py", 250),
TestFile("test_fa3.py", 30),

View File

@@ -0,0 +1,445 @@
# Copy from deepseek-ai/DeepEP/tests/test_internode.py
import os
import time
# noinspection PyUnresolvedReferences
import deep_ep
# Test compatibility with low latency functions
import test_deepep_low_latency
import torch
import torch.distributed as dist
from sglang.test.test_deepep_utils import (
bench,
calc_diff,
create_grouped_scores,
init_dist,
inplace_unique,
per_token_cast_back,
per_token_cast_to_fp8,
)
def test_main(
num_sms: int,
local_rank: int,
num_local_ranks: int,
num_ranks: int,
num_nodes: int,
rank: int,
buffer: deep_ep.Buffer,
group: dist.ProcessGroup,
):
# Settings
num_tokens, hidden, num_topk_groups, num_topk, num_experts = (
4096,
7168,
min(num_nodes, 4),
8,
(256 // num_ranks) * num_ranks,
)
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}",
flush=True,
)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
x_e4m3 = per_token_cast_to_fp8(x)
scores = (
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
+ 1
)
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(
group_scores, k=num_topk_groups, dim=-1, sorted=False
).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[
1
]
topk_weights = (
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
)
topk_weights_pure_rand = torch.randn(
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
)
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
rdma_rank_idx = rank_idx // num_local_ranks
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
inplace_unique(rdma_rank_idx, num_nodes)
# RDMA dispatch counts
rdma_idx = topk_idx // (num_experts // num_nodes)
rdma_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rdma_idx, num_nodes)
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
token_idx_in_rank = torch.full(
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
)
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(
count, dtype=torch.long, device="cuda"
)
for i in range(num_nodes):
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
(
ref_num_tokens_per_rank,
ref_num_tokens_per_rdma_rank,
ref_num_tokens_per_expert,
ref_is_token_in_rank,
_,
) = buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
print("", flush=True)
group.barrier()
time.sleep(1)
# Config
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, recv_gbl_rank_prefix_sum):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = recv_gbl_rank_prefix_sum[i].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
flush=True,
end="",
)
dispatch_args = {
"x": current_x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": config,
"async_finish": async_mode,
}
if with_topk:
dispatch_args.update(
{
"topk_idx": topk_idx,
"topk_weights": (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
),
}
)
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
event,
) = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
# Checks
recv_gbl_rank_prefix_sum = handle[-4]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
0
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
assert (
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
== recv_num_tokens_per_expert_list
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
if with_topk:
# Check `topk_idx`
assert (
recv_topk_idx.eq(-1)
| (
(recv_topk_idx >= 0)
& (recv_topk_idx < (num_experts // num_ranks))
)
).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = (
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
recv_topk_weights
)[recv_topk_idx.eq(-1)]
)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {
"x": current_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
combine_args = {
"x": recv_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if with_topk:
combine_args.update({"topk_weights": recv_topk_weights})
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(
**combine_args
)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(
dim=1
).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = (
combined_topk_weights
if (current_x is x_pure_rand)
else (
combined_topk_weights
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
)
)
ref_topk_weights = (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
)
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
if local_rank == 0:
print(" passed", flush=True)
if local_rank == 0:
print("", flush=True)
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
rdma_send_bytes = (
(dispatch_bf16_rdma_send_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_rdma_send_bytes
)
nvl_recv_bytes = (
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_nvl_recv_bytes
)
for nvl_chunk_size in range(4, 33, 4):
for rdma_chunk_size in range(4, 33, 4):
config = deep_ep.Config(
num_sms,
nvl_chunk_size,
nvl_buffer_size,
rdma_chunk_size,
rdma_buffer_size,
)
tune_args = {"x": current_x, "handle": handle, "config": config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (
num_sms,
nvl_chunk_size,
rdma_chunk_size,
)
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if local_rank == 0:
print(
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
flush=True,
)
print("", flush=True)
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor(
[best_results[0], best_results[1], best_results[2]],
dtype=torch.int32,
device="cuda",
)
all_best_fp8_results_list = [
torch.zeros_like(best_dispatch_results)
for _ in range(torch.distributed.get_world_size())
]
dist.all_gather(
all_best_fp8_results_list, best_dispatch_results, group=group
)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(
best_dispatch_results[0],
best_dispatch_results[1],
nvl_buffer_size,
best_dispatch_results[2],
rdma_buffer_size,
)
dispatch_args = {
"x": x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": dispatch_config if dispatch_config is not None else config,
}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
for rdma_chunk_size in range(8, 33, 4):
config = deep_ep.Config(
num_sms,
nvl_chunk_size,
nvl_buffer_size,
rdma_chunk_size,
rdma_buffer_size,
)
tune_args = {"x": recv_x, "handle": handle, "config": config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if t < best_time:
best_time, best_results = t, (
num_sms,
nvl_chunk_size,
rdma_chunk_size,
)
if local_rank == 0:
print(
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
flush=True,
)
print("", flush=True)
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
num_nodes = int(os.getenv("WORLD_SIZE", 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = False
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
buffer = deep_ep.Buffer(
group,
int(1e9),
int(1e9),
low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),
)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
for i in (24,):
test_main(
i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group
)
if local_rank == 0:
print("", flush=True)
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_deepep_low_latency.test_main(
ll_num_tokens,
ll_hidden,
ll_num_experts,
ll_num_topk,
rank,
num_ranks,
group,
buffer,
seed=1,
)
if __name__ == "__main__":
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)

View File

@@ -0,0 +1,379 @@
# Copy from deepseek-ai/DeepEP/tests/test_intranode.py
import os
import time
# noinspection PyUnresolvedReferences
import deep_ep
# Test compatibility with low latency functions
import test_deepep_low_latency
import torch
import torch.distributed as dist
from sglang.test.test_deepep_utils import (
bench,
calc_diff,
init_dist,
inplace_unique,
per_token_cast_back,
per_token_cast_to_fp8,
)
def test_main(
num_sms: int,
local_rank: int,
num_ranks: int,
rank: int,
buffer: deep_ep.Buffer,
group: dist.ProcessGroup,
):
# Settings
num_tokens, hidden, num_topk, num_experts = (
4096,
7168,
8,
(256 // num_ranks) * num_ranks,
)
assert num_experts % num_ranks == 0
if local_rank == 0:
print(
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}",
flush=True,
)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
x_e4m3 = per_token_cast_to_fp8(x)
scores = (
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
+ 1
)
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = (
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
)
topk_weights_pure_rand = torch.randn(
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
)
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
token_idx_in_rank = torch.full(
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
)
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(
count, dtype=torch.long, device="cuda"
)
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = (
buffer.get_dispatch_layout(topk_idx, num_experts)
)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
print("", flush=True)
group.barrier()
time.sleep(1)
# Config
nvl_buffer_size = 256
config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, rank_prefix_matrix):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = rank_prefix_matrix[i][rank].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
flush=True,
end="",
)
dispatch_args = {
"x": current_x,
"num_tokens_per_rank": num_tokens_per_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": config,
"async_finish": async_mode,
}
if with_topk:
dispatch_args.update(
{
"topk_idx": topk_idx,
"topk_weights": (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
),
}
)
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
event,
) = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
# Checks
rank_prefix_matrix = handle[0]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
0
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
assert (
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
== recv_num_tokens_per_expert_list
)
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
if with_topk:
# Check `topk_idx`
assert (
recv_topk_idx.eq(-1)
| (
(recv_topk_idx >= 0)
& (recv_topk_idx < (num_experts // num_ranks))
)
).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = (
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
recv_topk_weights
)[recv_topk_idx.eq(-1)]
)
check_data(recv_topk_weights, rank_prefix_matrix)
# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {
"x": current_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
# Test combine
combine_args = {
"x": recv_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if with_topk:
combine_args.update({"topk_weights": recv_topk_weights})
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(
**combine_args
)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(
dim=1
).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = (
combined_topk_weights
if (current_x is x_pure_rand)
else (
combined_topk_weights
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
)
)
ref_topk_weights = (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
)
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
if local_rank == 0:
print(" passed", flush=True)
if local_rank == 0:
print("", flush=True)
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
nvl_recv_bytes = (
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_nvl_recv_bytes
)
for nvl_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {"x": current_x, "handle": handle, "config": config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if local_rank == 0:
print(
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
flush=True,
)
print("", flush=True)
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor(
[best_results[0], best_results[1]], dtype=torch.int32, device="cuda"
)
all_best_fp8_results_list = [
torch.zeros_like(best_dispatch_results)
for _ in range(torch.distributed.get_world_size())
]
dist.all_gather(
all_best_fp8_results_list, best_dispatch_results, group=group
)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(
best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size
)
dispatch_args = {
"x": x,
"num_tokens_per_rank": num_tokens_per_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": dispatch_config if dispatch_config is not None else config,
}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 7, 1):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {"x": recv_x, "handle": handle, "config": config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
flush=True,
)
print("", flush=True)
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility, num_rdma_bytes = False, 0
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
ll_num_tokens, ll_hidden, num_ranks, ll_num_experts
)
buffer = deep_ep.Buffer(
group,
int(1e9),
num_rdma_bytes,
low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),
)
torch.manual_seed(rank)
for i in (24,):
test_main(i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print("", flush=True)
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_deepep_low_latency.test_main(
ll_num_tokens,
ll_hidden,
ll_num_experts,
ll_num_topk,
rank,
num_ranks,
group,
buffer,
seed=1,
)
if __name__ == "__main__":
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)

View File

@@ -0,0 +1,325 @@
# Copy from deepseek-ai/DeepEP/tests/test_low_latency.py
import random
from functools import partial
import deep_ep
import torch
import torch.distributed as dist
from sglang.test.test_deepep_utils import (
bench,
bench_kineto,
calc_diff,
hash_tensor,
init_dist,
per_token_cast_back,
)
def test_main(
num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
seed: int = 0,
):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
rank_offset = 128
assert (
num_ranks - rank_offset < 257
), "Too many ranks (exceeding test precision limit)"
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (
rank - rank_offset
)
x[:, -128:] = torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
scores = (
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
+ 1
)
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn(
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
).abs()
# Randomly mask some positions
for i in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = (
-1
)
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True):
num_times += 1
for i in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = (
buffer.low_latency_dispatch(
x,
topk_idx,
num_tokens,
num_experts,
use_fp8=dispatch_use_fp8,
async_finish=not return_recv_hook,
return_recv_hook=return_recv_hook,
)
)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (
(packed_recv_x[0], packed_recv_x[1].contiguous())
if dispatch_use_fp8
else packed_recv_x
)
simulated_gemm_x = (
per_token_cast_back(
packed_recv_x[0].view(-1, hidden),
packed_recv_x[1].view(-1, hidden // 128),
).view(packed_recv_x[0].shape)
if dispatch_use_fp8
else packed_recv_x.clone()
)
all_topk_idx = torch.empty(
(num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda"
)
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = (
per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i])
if dispatch_use_fp8
else packed_recv_x[i]
)
recv_count, recv_src_info, recv_layout_range = (
packed_recv_count[i],
handle[0][i],
handle[1][i],
)
# Check expert indices
int_mask = (2**32) - 1
num_valid_tokens = recv_count.item()
assert (
num_valid_tokens == (recv_layout_range & int_mask).sum().item()
), f"{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()"
assert (
num_valid_tokens == (all_topk_idx == expert_id).sum().item()
), f"{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}"
# Check received data
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
assert (
recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens
).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (
recv_layout_range[j] & int_mask
).item()
assert (recv_x_amin == j - rank_offset).sum().item() == (
all_topk_idx[j] == expert_id
).sum().item()
assert (
recv_x[begin_idx : begin_idx + count][:-128] - j
).sum().item() == 0
if dispatch_use_fp8:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, True):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[
:, :, :
] = simulated_gemm_x
out = torch.empty(
(num_tokens, hidden), dtype=torch.bfloat16, device="cuda"
)
combined_x, event, hook = buffer.low_latency_combine(
simulated_gemm_x,
topk_idx,
topk_weights,
handle,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out,
)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(
x
* topk_weights.masked_fill(topk_idx == -1, 0)
.sum(dim=1)
.view(-1, 1),
combined_x,
)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < 1e-5, f"Error: {diff=}, {zero_copy=}"
hash_value ^= hash_tensor(combined_x)
def create_test_cast_with_outliers(num_outliers):
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
assert tmp.abs().amax().item() <= 1
# Create some amax outliers
for i in range(num_outliers):
tmp[random.randint(0, num_tokens - 1)] *= 1e3
return tmp
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(zero_copy: bool, return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = buffer.low_latency_dispatch(
x,
topk_idx,
num_tokens,
num_experts,
async_finish=False,
return_recv_hook=return_recv_hook,
)
large_gemm_with_hook(hook) if return_recv_hook else None
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[
:, :, :
] = simulated_gemm_x
combined_x, event, hook = buffer.low_latency_combine(
simulated_gemm_x,
topk_idx,
topk_weights,
handle,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(
partial(test_func, zero_copy=False, return_recv_hook=False)
)
print(
f"[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, "
f"avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us",
flush=True,
)
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(
partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
kernel_names=("dispatch", "combine"),
barrier_comm_profiling=True,
suppress_kineto_output=True,
)
if not return_recv_hook:
print(
f"[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | "
f"Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us",
flush=True,
)
else:
print(
f"[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | "
f"Combine send/recv time: {combine_t * 2 * 1e6:.2f} us",
flush=True,
)
return hash_value
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_tokens, hidden, num_ranks, num_experts
)
if local_rank == 0:
print(f"Allocating buffer size: {num_rdma_bytes / 1e6} MB ...", flush=True)
buffer = deep_ep.Buffer(
group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
)
test_main(
num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
seed=1,
)
do_pressure_test = False
for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0:
print(f"Testing with seed {seed} ...", flush=True)
ref_hash = test_main(
num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
seed=seed,
)
for i in range(20):
assert (
test_main(
num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
seed=seed,
)
== ref_hash
), f"Error: seed={seed}"
if __name__ == "__main__":
# TODO: you may modify NUMA binding for less CPU overhead
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)

View File

@@ -0,0 +1,74 @@
"""
Usage:
python -m unittest test_moe_deepep_eval_accuracy_large.TestMoEDeepEPEvalAccuracyLarge.test_mmlu
"""
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestMoEDeepEPEvalAccuracyLarge(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"8",
"--enable-deepep-moe",
"--cuda-graph-max-bs",
"128",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=200,
parallel=64,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")
self.assertGreater(metrics["accuracy"], 0.93)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
print(f"Eval accuracy of MMLU: {metrics=}")
self.assertGreater(metrics["score"], 0.87)
if __name__ == "__main__":
unittest.main()