Add DeepEP to CI PR Test (#5655)
Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>
This commit is contained in:
6
.github/workflows/pr-test.yml
vendored
6
.github/workflows/pr-test.yml
vendored
@@ -97,7 +97,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
bash scripts/ci_install_dependency.sh
|
||||
bash scripts/ci_install_dependency_8_gpu.sh
|
||||
|
||||
- name: Run test
|
||||
timeout-minutes: 40
|
||||
@@ -259,9 +259,9 @@ jobs:
|
||||
finish:
|
||||
if: always()
|
||||
needs: [
|
||||
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu,
|
||||
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, unit-test-backend-8-gpu,
|
||||
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
|
||||
accuracy-test-1-gpu, accuracy-test-2-gpu
|
||||
accuracy-test-1-gpu, accuracy-test-2-gpu,
|
||||
]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
36
.github/workflows/release-docker-deepep.yml
vendored
Normal file
36
.github/workflows/release-docker-deepep.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Build DeepEP Docker Image
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
|
||||
jobs:
|
||||
build-dev:
|
||||
if: ${{ github.repository == 'sgl-project/sglang' }}
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free disk space
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
tool-cache: false
|
||||
docker-images: false
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
swap-storage: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and Push DeepEP Image
|
||||
run: |
|
||||
docker build . -f docker/Dockerfile.deepep -t lmsysorg/sglang:deepep --no-cache
|
||||
docker push lmsysorg/sglang:deepep
|
||||
219
python/sglang/test/test_deepep_utils.py
Normal file
219
python/sglang/test/test_deepep_utils.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copy from deepseek-ai/DeepEP/tests/test_utils.py
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_dist(local_rank: int, num_local_ranks: int):
|
||||
# NOTES: you may rewrite this function with your own cluster settings
|
||||
ip = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||
port = int(os.getenv("MASTER_PORT", "8361"))
|
||||
num_nodes = int(os.getenv("WORLD_SIZE", 1))
|
||||
node_rank = int(os.getenv("RANK", 0))
|
||||
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=f"tcp://{ip}:{port}",
|
||||
world_size=num_nodes * num_local_ranks,
|
||||
rank=node_rank * num_local_ranks + local_rank,
|
||||
)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
return (
|
||||
dist.get_rank(),
|
||||
dist.get_world_size(),
|
||||
dist.new_group(list(range(num_local_ranks * num_nodes))),
|
||||
)
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double() + 1, y.double() + 1
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return (1 - sim).item()
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor):
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
||||
m, n
|
||||
), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
|
||||
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
|
||||
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
|
||||
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
|
||||
|
||||
|
||||
def inplace_unique(x: torch.Tensor, num_slots: int):
|
||||
assert x.dim() == 2
|
||||
mask = x < 0
|
||||
x_padded = x.masked_fill(mask, num_slots)
|
||||
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
|
||||
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
|
||||
bin_count = bin_count[:, :num_slots]
|
||||
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
|
||||
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
|
||||
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
|
||||
x[:, :].fill_(-1)
|
||||
valid_len = min(num_slots, x.size(1))
|
||||
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
|
||||
|
||||
|
||||
def create_grouped_scores(
|
||||
scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int
|
||||
):
|
||||
num_tokens, num_experts = scores.shape
|
||||
scores = scores.view(num_tokens, num_groups, -1)
|
||||
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
|
||||
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
|
||||
return (scores * mask).view(num_tokens, num_experts)
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
|
||||
# Flush L2 cache with 256 MB data
|
||||
torch.cuda.synchronize()
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmups):
|
||||
fn()
|
||||
|
||||
# Flush L2
|
||||
cache.zero_()
|
||||
|
||||
# Testing
|
||||
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
for i in range(num_tests):
|
||||
# Record
|
||||
start_events[i].record()
|
||||
fn()
|
||||
end_events[i].record()
|
||||
if post_fn is not None:
|
||||
post_fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = np.array(
|
||||
[s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)]
|
||||
)[1:]
|
||||
return np.average(times), np.min(times), np.max(times)
|
||||
|
||||
|
||||
class empty_suppress:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
pass
|
||||
|
||||
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, "w")
|
||||
self.errnull_file = open(os.devnull, "w")
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
def bench_kineto(
|
||||
fn,
|
||||
kernel_names,
|
||||
num_tests: int = 30,
|
||||
suppress_kineto_output: bool = False,
|
||||
trace_path: Optional[str] = None,
|
||||
barrier_comm_profiling: bool = False,
|
||||
):
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
||||
with suppress():
|
||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
|
||||
) as prof:
|
||||
for i in range(2):
|
||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
||||
if barrier_comm_profiling:
|
||||
lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
|
||||
rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
|
||||
lhs @ rhs
|
||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda"))
|
||||
for _ in range(num_tests):
|
||||
fn()
|
||||
prof.step()
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tupled = isinstance(kernel_names, tuple)
|
||||
prof_lines = (
|
||||
prof.key_averages()
|
||||
.table(sort_by="cuda_time_total", max_name_column_width=100)
|
||||
.split("\n")
|
||||
)
|
||||
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
), f"Errors of the kernel {name} in the profiling table"
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
prof.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {"ms": 1e3, "us": 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, "")) / scale)
|
||||
break
|
||||
break
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
def hash_tensor(t: torch.Tensor):
|
||||
return t.view(torch.int64).sum().item()
|
||||
@@ -66,6 +66,7 @@ DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
|
||||
)
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
||||
DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
|
||||
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-V3-0324"
|
||||
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST = (
|
||||
"hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
|
||||
)
|
||||
|
||||
122
scripts/ci_install_dependency_8_gpu.sh
Executable file
122
scripts/ci_install_dependency_8_gpu.sh
Executable file
@@ -0,0 +1,122 @@
|
||||
#!/bin/bash
|
||||
# Install the dependency in CI.
|
||||
set -euxo pipefail
|
||||
|
||||
export GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/
|
||||
export NVSHMEM_DIR=/opt/nvshmem/install
|
||||
export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
|
||||
export PATH="${NVSHMEM_DIR}/bin:$PATH"
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
bash "${SCRIPT_DIR}/killall_sglang.sh"
|
||||
|
||||
# Clean up existing installations
|
||||
pip uninstall -y flashinfer flashinfer_python sgl-kernel sglang vllm deepep || true
|
||||
pip cache purge
|
||||
rm -rf /root/.cache/flashinfer
|
||||
if [ -d "lmms-eval" ]; then
|
||||
rm -rf lmms-eval
|
||||
fi
|
||||
rm -rf /root/.cache/deepep
|
||||
rm -rf /usr/local/lib/python3.10/dist-packages/flashinfer*
|
||||
rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel*
|
||||
rm -rf /usr/local/lib/python3.10/dist-packages/deepep*
|
||||
dpkg -r gdrcopy gdrcopy-tests libgdrapi gdrdrv-dkms || true
|
||||
rm -rf /opt/gdrcopy
|
||||
rm -rf /usr/local/lib/libgdrapi*
|
||||
rm -rf /usr/local/include/gdrapi.h
|
||||
rm -rf /opt/nvshmem
|
||||
rm -rf /usr/local/lib/libnvshmem*
|
||||
rm -rf /usr/local/include/nvshmem*
|
||||
|
||||
# Update pip
|
||||
pip install --upgrade pip
|
||||
|
||||
# Install sgl-kernel
|
||||
pip install sgl-kernel==0.1.1 --no-cache-dir
|
||||
|
||||
# Install the main package
|
||||
pip install -e "python[all]"
|
||||
|
||||
# Install additional dependencies
|
||||
pip install torch_memory_saver
|
||||
pip install transformers==4.51.0 sentence_transformers accelerate peft pandas datasets timm torchaudio==2.6.0
|
||||
|
||||
# For compling xgrammar kernels
|
||||
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
||||
|
||||
# For lmms_evals evaluating MMMU
|
||||
git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
||||
pip install -e lmms-eval/
|
||||
|
||||
# Install FlashMLA for attention backend tests
|
||||
pip install git+https://github.com/deepseek-ai/FlashMLA.git
|
||||
|
||||
# Install system dependencies
|
||||
# apt-get update && apt-get install -y libibverbs-dev infiniband-diags libmlx5-1 rdma-core openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 rdma-core-dev infiniband-diags-dev libibverbs-dev libibverbs-utils librdmacm-dev librdmacm-utils ibverbs-utils rdma-core-utils
|
||||
apt install curl wget git sudo libibverbs-dev -y
|
||||
apt install -y rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py
|
||||
|
||||
wget https://github.com/Kitware/CMake/releases/download/v3.27.4/cmake-3.27.4-linux-x86_64.sh
|
||||
chmod +x cmake-3.27.4-linux-x86_64.sh
|
||||
./cmake-3.27.4-linux-x86_64.sh --skip-license --prefix=/usr/local
|
||||
rm cmake-3.27.4-linux-x86_64.sh
|
||||
|
||||
# Install GDRCopy
|
||||
mkdir -p /opt/gdrcopy
|
||||
mkdir -p /opt/nvshmem
|
||||
cd /opt/gdrcopy
|
||||
git clone https://github.com/NVIDIA/gdrcopy.git .
|
||||
git checkout v2.4.4
|
||||
apt update
|
||||
apt install -y nvidia-dkms-535
|
||||
apt install -y build-essential devscripts debhelper fakeroot pkg-config dkms
|
||||
apt install -y check libsubunit0 libsubunit-dev
|
||||
cd packages
|
||||
CUDA=/usr/local/cuda ./build-deb-packages.sh
|
||||
dpkg -i gdrdrv-dkms_*.deb
|
||||
dpkg -i libgdrapi_*.deb
|
||||
dpkg -i gdrcopy-tests_*.deb
|
||||
dpkg -i gdrcopy_*.deb
|
||||
|
||||
if [ ! -e "/usr/lib/x86_64-linux-gnu/libmlx5.so" ]; then
|
||||
ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so
|
||||
fi
|
||||
apt-get update && apt-get install -y libfabric-dev
|
||||
|
||||
# Clone DeepEP
|
||||
git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep
|
||||
|
||||
# Install NVSHMEM
|
||||
cd /opt/nvshmem
|
||||
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
|
||||
tar -xf nvshmem_src_3.2.5-1.txz
|
||||
mv nvshmem_src nvshmem
|
||||
cd nvshmem
|
||||
git apply /root/.cache/deepep/third-party/nvshmem.patch
|
||||
NVSHMEM_SHMEM_SUPPORT=0 \
|
||||
NVSHMEM_UCX_SUPPORT=0 \
|
||||
NVSHMEM_USE_NCCL=0 \
|
||||
NVSHMEM_MPI_SUPPORT=0 \
|
||||
NVSHMEM_IBGDA_SUPPORT=1 \
|
||||
NVSHMEM_PMIX_SUPPORT=0 \
|
||||
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
|
||||
NVSHMEM_USE_GDRCOPY=1 \
|
||||
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/opt/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90
|
||||
cd build
|
||||
make -j$(nproc) install
|
||||
|
||||
# Install DeepEP
|
||||
cd /root/.cache/deepep && python3 setup.py install
|
||||
|
||||
# Verify configuration
|
||||
echo "=== NCCL Configuration ==="
|
||||
nvidia-smi topo -m
|
||||
nvidia-smi nvlink -s
|
||||
echo "=== Verify GDRCOPY ==="
|
||||
gdrcopy_copybw
|
||||
echo "=== Verify NVSHMEM ==="
|
||||
nvshmem-info -a
|
||||
# /opt/nvshmem/bin/perftest/device/pt-to-pt/shmem_put_bw
|
||||
@@ -96,6 +96,9 @@ suites = {
|
||||
TestFile("test_verl_engine.py", 64),
|
||||
],
|
||||
"per-commit-8-gpu": [
|
||||
TestFile("test_deepep_intranode.py", 50),
|
||||
TestFile("test_deepep_low_latency.py", 50),
|
||||
TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
|
||||
TestFile("test_local_attn.py", 250),
|
||||
TestFile("test_full_deepseek_v3.py", 250),
|
||||
TestFile("test_fa3.py", 30),
|
||||
|
||||
445
test/srt/test_deepep_internode.py
Normal file
445
test/srt/test_deepep_internode.py
Normal file
@@ -0,0 +1,445 @@
|
||||
# Copy from deepseek-ai/DeepEP/tests/test_internode.py
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
import test_deepep_low_latency
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.test.test_deepep_utils import (
|
||||
bench,
|
||||
calc_diff,
|
||||
create_grouped_scores,
|
||||
init_dist,
|
||||
inplace_unique,
|
||||
per_token_cast_back,
|
||||
per_token_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def test_main(
|
||||
num_sms: int,
|
||||
local_rank: int,
|
||||
num_local_ranks: int,
|
||||
num_ranks: int,
|
||||
num_nodes: int,
|
||||
rank: int,
|
||||
buffer: deep_ep.Buffer,
|
||||
group: dist.ProcessGroup,
|
||||
):
|
||||
# Settings
|
||||
num_tokens, hidden, num_topk_groups, num_topk, num_experts = (
|
||||
4096,
|
||||
7168,
|
||||
min(num_nodes, 4),
|
||||
8,
|
||||
(256 // num_ranks) * num_ranks,
|
||||
)
|
||||
assert num_experts % num_ranks == 0 and num_local_ranks == 8
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Random data
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
scores = (
|
||||
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
|
||||
+ 1
|
||||
)
|
||||
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=num_topk_groups, dim=-1, sorted=False
|
||||
).indices
|
||||
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
|
||||
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[
|
||||
1
|
||||
]
|
||||
topk_weights = (
|
||||
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
|
||||
)
|
||||
topk_weights_pure_rand = torch.randn(
|
||||
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
rank_idx = topk_idx // (num_experts // num_ranks)
|
||||
rank_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rank_idx, num_ranks)
|
||||
rdma_rank_idx = rank_idx // num_local_ranks
|
||||
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
|
||||
inplace_unique(rdma_rank_idx, num_nodes)
|
||||
|
||||
# RDMA dispatch counts
|
||||
rdma_idx = topk_idx // (num_experts // num_nodes)
|
||||
rdma_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rdma_idx, num_nodes)
|
||||
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
|
||||
|
||||
# Expert meta
|
||||
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert[i] = (topk_idx == i).sum()
|
||||
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
|
||||
|
||||
# Rank layout meta
|
||||
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
|
||||
num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
|
||||
token_idx_in_rank = torch.full(
|
||||
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank[i] = (rank_idx == i).sum()
|
||||
token_sel = (rank_idx == i).max(dim=-1)[0]
|
||||
count = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
|
||||
tokens[:count] = torch.sort(tokens[:count])[0]
|
||||
token_idx_in_rank[i][tokens[:count]] = torch.arange(
|
||||
count, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_nodes):
|
||||
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
|
||||
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
|
||||
is_token_in_rank = token_idx_in_rank >= 0
|
||||
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
|
||||
|
||||
(
|
||||
ref_num_tokens_per_rank,
|
||||
ref_num_tokens_per_rdma_rank,
|
||||
ref_num_tokens_per_expert,
|
||||
ref_is_token_in_rank,
|
||||
_,
|
||||
) = buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
|
||||
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
|
||||
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
|
||||
if local_rank == 0:
|
||||
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
|
||||
print("", flush=True)
|
||||
group.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# Config
|
||||
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
|
||||
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
|
||||
|
||||
# Test dispatch
|
||||
# noinspection PyShadowingNames
|
||||
def check_data(check_x, recv_gbl_rank_prefix_sum):
|
||||
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
|
||||
check_start = 0
|
||||
for i in range(num_ranks):
|
||||
check_end = recv_gbl_rank_prefix_sum[i].item()
|
||||
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
|
||||
check_start = check_end
|
||||
|
||||
for previous_mode in (False, True):
|
||||
for async_mode in (False, True):
|
||||
for current_x in (x_pure_rand, x, x_e4m3):
|
||||
for with_topk in (False, True):
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
|
||||
flush=True,
|
||||
end="",
|
||||
)
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
dispatch_args.update(
|
||||
{
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
),
|
||||
}
|
||||
)
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
|
||||
# Checks
|
||||
recv_gbl_rank_prefix_sum = handle[-4]
|
||||
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
|
||||
0
|
||||
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
|
||||
assert (
|
||||
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
|
||||
== recv_num_tokens_per_expert_list
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (
|
||||
recv_topk_idx.eq(-1)
|
||||
| (
|
||||
(recv_topk_idx >= 0)
|
||||
& (recv_topk_idx < (num_experts // num_ranks))
|
||||
)
|
||||
).sum().item() == recv_topk_idx.numel()
|
||||
for i, count in enumerate(recv_num_tokens_per_expert_list):
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = (
|
||||
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
|
||||
recv_topk_weights
|
||||
)[recv_topk_idx.eq(-1)]
|
||||
)
|
||||
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
if not with_topk:
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test combine
|
||||
combine_args = {
|
||||
"x": recv_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
combine_args.update({"topk_weights": recv_topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(
|
||||
**combine_args
|
||||
)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(
|
||||
dim=1
|
||||
).unsqueeze(1)
|
||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||
assert calc_diff(check_x, ref_x) < 5e-6
|
||||
if with_topk:
|
||||
check_topk_weights = (
|
||||
combined_topk_weights
|
||||
if (current_x is x_pure_rand)
|
||||
else (
|
||||
combined_topk_weights
|
||||
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
)
|
||||
)
|
||||
ref_topk_weights = (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
)
|
||||
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
|
||||
|
||||
# For later tuning
|
||||
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
|
||||
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
|
||||
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
|
||||
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
|
||||
|
||||
if local_rank == 0:
|
||||
print(" passed", flush=True)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
# Tune dispatch performance
|
||||
best_dispatch_results = None
|
||||
fp8_factor = (1 + 4 / 128) / 2
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
rdma_send_bytes = (
|
||||
(dispatch_bf16_rdma_send_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_rdma_send_bytes
|
||||
)
|
||||
nvl_recv_bytes = (
|
||||
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_nvl_recv_bytes
|
||||
)
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
for rdma_chunk_size in range(4, 33, 4):
|
||||
config = deep_ep.Config(
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
nvl_buffer_size,
|
||||
rdma_chunk_size,
|
||||
rdma_buffer_size,
|
||||
)
|
||||
tune_args = {"x": current_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
rdma_chunk_size,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
|
||||
if isinstance(current_x, tuple):
|
||||
# Gather FP8 the best config from rank 0
|
||||
best_dispatch_results = torch.tensor(
|
||||
[best_results[0], best_results[1], best_results[2]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
all_best_fp8_results_list = [
|
||||
torch.zeros_like(best_dispatch_results)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
dist.all_gather(
|
||||
all_best_fp8_results_list, best_dispatch_results, group=group
|
||||
)
|
||||
best_dispatch_results = all_best_fp8_results_list[0].tolist()
|
||||
dispatch_config = deep_ep.Config(
|
||||
best_dispatch_results[0],
|
||||
best_dispatch_results[1],
|
||||
nvl_buffer_size,
|
||||
best_dispatch_results[2],
|
||||
rdma_buffer_size,
|
||||
)
|
||||
|
||||
dispatch_args = {
|
||||
"x": x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": dispatch_config if dispatch_config is not None else config,
|
||||
}
|
||||
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 5, 1):
|
||||
for rdma_chunk_size in range(8, 33, 4):
|
||||
config = deep_ep.Config(
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
nvl_buffer_size,
|
||||
rdma_chunk_size,
|
||||
rdma_buffer_size,
|
||||
)
|
||||
tune_args = {"x": recv_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
rdma_chunk_size,
|
||||
)
|
||||
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
num_nodes = int(os.getenv("WORLD_SIZE", 1))
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
test_ll_compatibility = False
|
||||
if test_ll_compatibility:
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
|
||||
buffer = deep_ep.Buffer(
|
||||
group,
|
||||
int(1e9),
|
||||
int(1e9),
|
||||
low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),
|
||||
)
|
||||
assert num_local_ranks == 8 and num_ranks > 8
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (24,):
|
||||
test_main(
|
||||
i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group
|
||||
)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
if test_ll_compatibility:
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_deepep_low_latency.test_main(
|
||||
ll_num_tokens,
|
||||
ll_hidden,
|
||||
ll_num_experts,
|
||||
ll_num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
|
||||
379
test/srt/test_deepep_intranode.py
Normal file
379
test/srt/test_deepep_intranode.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# Copy from deepseek-ai/DeepEP/tests/test_intranode.py
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
import test_deepep_low_latency
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.test.test_deepep_utils import (
|
||||
bench,
|
||||
calc_diff,
|
||||
init_dist,
|
||||
inplace_unique,
|
||||
per_token_cast_back,
|
||||
per_token_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def test_main(
|
||||
num_sms: int,
|
||||
local_rank: int,
|
||||
num_ranks: int,
|
||||
rank: int,
|
||||
buffer: deep_ep.Buffer,
|
||||
group: dist.ProcessGroup,
|
||||
):
|
||||
# Settings
|
||||
num_tokens, hidden, num_topk, num_experts = (
|
||||
4096,
|
||||
7168,
|
||||
8,
|
||||
(256 // num_ranks) * num_ranks,
|
||||
)
|
||||
assert num_experts % num_ranks == 0
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Random data
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
scores = (
|
||||
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
|
||||
+ 1
|
||||
)
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
|
||||
topk_weights = (
|
||||
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
|
||||
)
|
||||
topk_weights_pure_rand = torch.randn(
|
||||
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
rank_idx = topk_idx // (num_experts // num_ranks)
|
||||
rank_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rank_idx, num_ranks)
|
||||
|
||||
# Expert meta
|
||||
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert[i] = (topk_idx == i).sum()
|
||||
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
|
||||
|
||||
# Rank layout meta
|
||||
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
|
||||
token_idx_in_rank = torch.full(
|
||||
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank[i] = (rank_idx == i).sum()
|
||||
token_sel = (rank_idx == i).max(dim=-1)[0]
|
||||
count = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
|
||||
tokens[:count] = torch.sort(tokens[:count])[0]
|
||||
token_idx_in_rank[i][tokens[:count]] = torch.arange(
|
||||
count, dtype=torch.long, device="cuda"
|
||||
)
|
||||
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
|
||||
is_token_in_rank = token_idx_in_rank >= 0
|
||||
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
|
||||
|
||||
ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = (
|
||||
buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||
)
|
||||
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
|
||||
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
|
||||
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
|
||||
if local_rank == 0:
|
||||
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
|
||||
print("", flush=True)
|
||||
group.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# Config
|
||||
nvl_buffer_size = 256
|
||||
config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
|
||||
|
||||
# Test dispatch
|
||||
# noinspection PyShadowingNames
|
||||
def check_data(check_x, rank_prefix_matrix):
|
||||
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
|
||||
check_start = 0
|
||||
for i in range(num_ranks):
|
||||
check_end = rank_prefix_matrix[i][rank].item()
|
||||
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
|
||||
check_start = check_end
|
||||
|
||||
for previous_mode in (False, True):
|
||||
for async_mode in (False, True):
|
||||
for current_x in (x_pure_rand, x, x_e4m3):
|
||||
for with_topk in (False, True):
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
|
||||
flush=True,
|
||||
end="",
|
||||
)
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
dispatch_args.update(
|
||||
{
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
),
|
||||
}
|
||||
)
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
|
||||
# Checks
|
||||
rank_prefix_matrix = handle[0]
|
||||
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
|
||||
0
|
||||
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
|
||||
assert (
|
||||
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
|
||||
== recv_num_tokens_per_expert_list
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, rank_prefix_matrix)
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (
|
||||
recv_topk_idx.eq(-1)
|
||||
| (
|
||||
(recv_topk_idx >= 0)
|
||||
& (recv_topk_idx < (num_experts // num_ranks))
|
||||
)
|
||||
).sum().item() == recv_topk_idx.numel()
|
||||
for i, count in enumerate(recv_num_tokens_per_expert_list):
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = (
|
||||
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
|
||||
recv_topk_weights
|
||||
)[recv_topk_idx.eq(-1)]
|
||||
)
|
||||
check_data(recv_topk_weights, rank_prefix_matrix)
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
if not with_topk:
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, rank_prefix_matrix)
|
||||
|
||||
# Test combine
|
||||
combine_args = {
|
||||
"x": recv_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
combine_args.update({"topk_weights": recv_topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(
|
||||
**combine_args
|
||||
)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(
|
||||
dim=1
|
||||
).unsqueeze(1)
|
||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||
assert calc_diff(check_x, ref_x) < 5e-6
|
||||
if with_topk:
|
||||
check_topk_weights = (
|
||||
combined_topk_weights
|
||||
if (current_x is x_pure_rand)
|
||||
else (
|
||||
combined_topk_weights
|
||||
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
)
|
||||
)
|
||||
ref_topk_weights = (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
)
|
||||
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
|
||||
|
||||
# For later tuning
|
||||
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
|
||||
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
|
||||
|
||||
if local_rank == 0:
|
||||
print(" passed", flush=True)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
# Tune dispatch performance
|
||||
best_dispatch_results = None
|
||||
fp8_factor = (1 + 4 / 128) / 2
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
nvl_recv_bytes = (
|
||||
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_nvl_recv_bytes
|
||||
)
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
tune_args = {"x": current_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
|
||||
if isinstance(current_x, tuple):
|
||||
# Gather FP8 the best config from rank 0
|
||||
best_dispatch_results = torch.tensor(
|
||||
[best_results[0], best_results[1]], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
all_best_fp8_results_list = [
|
||||
torch.zeros_like(best_dispatch_results)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
dist.all_gather(
|
||||
all_best_fp8_results_list, best_dispatch_results, group=group
|
||||
)
|
||||
best_dispatch_results = all_best_fp8_results_list[0].tolist()
|
||||
dispatch_config = deep_ep.Config(
|
||||
best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size
|
||||
)
|
||||
|
||||
dispatch_args = {
|
||||
"x": x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": dispatch_config if dispatch_config is not None else config,
|
||||
}
|
||||
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 7, 1):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
tune_args = {"x": recv_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
test_ll_compatibility, num_rdma_bytes = False, 0
|
||||
if test_ll_compatibility:
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
ll_num_tokens, ll_hidden, num_ranks, ll_num_experts
|
||||
)
|
||||
|
||||
buffer = deep_ep.Buffer(
|
||||
group,
|
||||
int(1e9),
|
||||
num_rdma_bytes,
|
||||
low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),
|
||||
)
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (24,):
|
||||
test_main(i, local_rank, num_ranks, rank, buffer, group)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
if test_ll_compatibility:
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_deepep_low_latency.test_main(
|
||||
ll_num_tokens,
|
||||
ll_hidden,
|
||||
ll_num_experts,
|
||||
ll_num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
|
||||
325
test/srt/test_deepep_low_latency.py
Normal file
325
test/srt/test_deepep_low_latency.py
Normal file
@@ -0,0 +1,325 @@
|
||||
# Copy from deepseek-ai/DeepEP/tests/test_low_latency.py
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.test.test_deepep_utils import (
|
||||
bench,
|
||||
bench_kineto,
|
||||
calc_diff,
|
||||
hash_tensor,
|
||||
init_dist,
|
||||
per_token_cast_back,
|
||||
)
|
||||
|
||||
|
||||
def test_main(
|
||||
num_tokens: int,
|
||||
hidden: int,
|
||||
num_experts: int,
|
||||
num_topk: int,
|
||||
rank: int,
|
||||
num_ranks: int,
|
||||
group: dist.ProcessGroup,
|
||||
buffer: deep_ep.Buffer,
|
||||
seed: int = 0,
|
||||
):
|
||||
torch.manual_seed(seed + rank)
|
||||
random.seed(seed + rank)
|
||||
|
||||
assert num_experts % num_ranks == 0
|
||||
num_local_experts = num_experts // num_ranks
|
||||
|
||||
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
|
||||
rank_offset = 128
|
||||
assert (
|
||||
num_ranks - rank_offset < 257
|
||||
), "Too many ranks (exceeding test precision limit)"
|
||||
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (
|
||||
rank - rank_offset
|
||||
)
|
||||
x[:, -128:] = torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
|
||||
scores = (
|
||||
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
|
||||
+ 1
|
||||
)
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
|
||||
topk_weights = torch.randn(
|
||||
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
|
||||
).abs()
|
||||
|
||||
# Randomly mask some positions
|
||||
for i in range(10):
|
||||
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = (
|
||||
-1
|
||||
)
|
||||
|
||||
# Check dispatch correctness
|
||||
do_check = True
|
||||
hash_value, num_times = 0, 0
|
||||
for return_recv_hook in (False, True):
|
||||
for dispatch_use_fp8 in (False, True):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = (
|
||||
buffer.low_latency_dispatch(
|
||||
x,
|
||||
topk_idx,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
use_fp8=dispatch_use_fp8,
|
||||
async_finish=not return_recv_hook,
|
||||
return_recv_hook=return_recv_hook,
|
||||
)
|
||||
)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (
|
||||
(packed_recv_x[0], packed_recv_x[1].contiguous())
|
||||
if dispatch_use_fp8
|
||||
else packed_recv_x
|
||||
)
|
||||
simulated_gemm_x = (
|
||||
per_token_cast_back(
|
||||
packed_recv_x[0].view(-1, hidden),
|
||||
packed_recv_x[1].view(-1, hidden // 128),
|
||||
).view(packed_recv_x[0].shape)
|
||||
if dispatch_use_fp8
|
||||
else packed_recv_x.clone()
|
||||
)
|
||||
all_topk_idx = torch.empty(
|
||||
(num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda"
|
||||
)
|
||||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||||
for i in range(num_local_experts if do_check else 0):
|
||||
expert_id = rank * num_local_experts + i
|
||||
recv_x = (
|
||||
per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i])
|
||||
if dispatch_use_fp8
|
||||
else packed_recv_x[i]
|
||||
)
|
||||
recv_count, recv_src_info, recv_layout_range = (
|
||||
packed_recv_count[i],
|
||||
handle[0][i],
|
||||
handle[1][i],
|
||||
)
|
||||
|
||||
# Check expert indices
|
||||
int_mask = (2**32) - 1
|
||||
num_valid_tokens = recv_count.item()
|
||||
assert (
|
||||
num_valid_tokens == (recv_layout_range & int_mask).sum().item()
|
||||
), f"{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()"
|
||||
assert (
|
||||
num_valid_tokens == (all_topk_idx == expert_id).sum().item()
|
||||
), f"{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}"
|
||||
|
||||
# Check received data
|
||||
recv_x = recv_x[:num_valid_tokens]
|
||||
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
|
||||
recv_src_info = recv_src_info[:num_valid_tokens]
|
||||
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
|
||||
assert (
|
||||
recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens
|
||||
).sum().item() == 0
|
||||
for j in range(num_ranks):
|
||||
begin_idx, count = (recv_layout_range[j] >> 32).item(), (
|
||||
recv_layout_range[j] & int_mask
|
||||
).item()
|
||||
assert (recv_x_amin == j - rank_offset).sum().item() == (
|
||||
all_topk_idx[j] == expert_id
|
||||
).sum().item()
|
||||
assert (
|
||||
recv_x[begin_idx : begin_idx + count][:-128] - j
|
||||
).sum().item() == 0
|
||||
if dispatch_use_fp8:
|
||||
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
|
||||
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
|
||||
else:
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
|
||||
# Check combine correctness
|
||||
for zero_copy in (False, True):
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[
|
||||
:, :, :
|
||||
] = simulated_gemm_x
|
||||
out = torch.empty(
|
||||
(num_tokens, hidden), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
combined_x, event, hook = buffer.low_latency_combine(
|
||||
simulated_gemm_x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
handle,
|
||||
async_finish=not return_recv_hook,
|
||||
zero_copy=zero_copy,
|
||||
return_recv_hook=return_recv_hook,
|
||||
out=out,
|
||||
)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(
|
||||
x
|
||||
* topk_weights.masked_fill(topk_idx == -1, 0)
|
||||
.sum(dim=1)
|
||||
.view(-1, 1),
|
||||
combined_x,
|
||||
)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f"Error: {diff=}, {zero_copy=}"
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
|
||||
def create_test_cast_with_outliers(num_outliers):
|
||||
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
|
||||
assert tmp.abs().amax().item() <= 1
|
||||
|
||||
# Create some amax outliers
|
||||
for i in range(num_outliers):
|
||||
tmp[random.randint(0, num_tokens - 1)] *= 1e3
|
||||
return tmp
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def large_gemm_with_hook(hook):
|
||||
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
|
||||
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
|
||||
mat_0 @ mat_1
|
||||
hook()
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func(zero_copy: bool, return_recv_hook: bool):
|
||||
recv_x, recv_count, handle, event, hook = buffer.low_latency_dispatch(
|
||||
x,
|
||||
topk_idx,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
async_finish=False,
|
||||
return_recv_hook=return_recv_hook,
|
||||
)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[
|
||||
:, :, :
|
||||
] = simulated_gemm_x
|
||||
combined_x, event, hook = buffer.low_latency_combine(
|
||||
simulated_gemm_x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
handle,
|
||||
zero_copy=zero_copy,
|
||||
return_recv_hook=return_recv_hook,
|
||||
)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
|
||||
# Calculate bandwidth
|
||||
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
|
||||
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
|
||||
for i in range(num_tokens):
|
||||
num_selections = (topk_idx[i] != -1).sum().item()
|
||||
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
|
||||
num_combine_comm_bytes += num_bf16_bytes * num_selections
|
||||
|
||||
# Dispatch + combine testing
|
||||
avg_t, min_t, max_t = bench(
|
||||
partial(test_func, zero_copy=False, return_recv_hook=False)
|
||||
)
|
||||
print(
|
||||
f"[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, "
|
||||
f"avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Separate profiling
|
||||
for return_recv_hook in (False, True):
|
||||
group.barrier()
|
||||
dispatch_t, combine_t = bench_kineto(
|
||||
partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
|
||||
kernel_names=("dispatch", "combine"),
|
||||
barrier_comm_profiling=True,
|
||||
suppress_kineto_output=True,
|
||||
)
|
||||
if not return_recv_hook:
|
||||
print(
|
||||
f"[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | "
|
||||
f"Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | "
|
||||
f"Combine send/recv time: {combine_t * 2 * 1e6:.2f} us",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return hash_value
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288
|
||||
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_tokens, hidden, num_ranks, num_experts
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(f"Allocating buffer size: {num_rdma_bytes / 1e6} MB ...", flush=True)
|
||||
buffer = deep_ep.Buffer(
|
||||
group,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_experts // num_ranks,
|
||||
)
|
||||
test_main(
|
||||
num_tokens,
|
||||
hidden,
|
||||
num_experts,
|
||||
num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
do_pressure_test = False
|
||||
for seed in range(int(1e9) if do_pressure_test else 0):
|
||||
if local_rank == 0:
|
||||
print(f"Testing with seed {seed} ...", flush=True)
|
||||
ref_hash = test_main(
|
||||
num_tokens,
|
||||
hidden,
|
||||
num_experts,
|
||||
num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=seed,
|
||||
)
|
||||
for i in range(20):
|
||||
assert (
|
||||
test_main(
|
||||
num_tokens,
|
||||
hidden,
|
||||
num_experts,
|
||||
num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=seed,
|
||||
)
|
||||
== ref_hash
|
||||
), f"Error: seed={seed}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: you may modify NUMA binding for less CPU overhead
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
|
||||
74
test/srt/test_moe_deepep_eval_accuracy_large.py
Normal file
74
test/srt/test_moe_deepep_eval_accuracy_large.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Usage:
|
||||
python -m unittest test_moe_deepep_eval_accuracy_large.TestMoEDeepEPEvalAccuracyLarge.test_mmlu
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestMoEDeepEPEvalAccuracyLarge(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"8",
|
||||
"--enable-deepep-moe",
|
||||
"--cuda-graph-max-bs",
|
||||
"128",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=8,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
parallel=64,
|
||||
max_new_tokens=512,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(f"Eval accuracy of GSM8K: {metrics=}")
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.93)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
print(f"Eval accuracy of MMLU: {metrics=}")
|
||||
self.assertGreater(metrics["score"], 0.87)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user