From 8a828666a3a994a9b938c0bb60b1a0a4de8fdff8 Mon Sep 17 00:00:00 2001 From: Jinyan Chen <93358689+liz-badada@users.noreply.github.com> Date: Wed, 7 May 2025 08:36:03 +0800 Subject: [PATCH] Add DeepEP to CI PR Test (#5655) Co-authored-by: Jinyan Chen --- .github/workflows/pr-test.yml | 6 +- .github/workflows/release-docker-deepep.yml | 36 ++ python/sglang/test/test_deepep_utils.py | 219 +++++++++ python/sglang/test/test_utils.py | 1 + scripts/ci_install_dependency_8_gpu.sh | 122 +++++ test/srt/run_suite.py | 3 + test/srt/test_deepep_internode.py | 445 ++++++++++++++++++ test/srt/test_deepep_intranode.py | 379 +++++++++++++++ test/srt/test_deepep_low_latency.py | 325 +++++++++++++ .../test_moe_deepep_eval_accuracy_large.py | 74 +++ 10 files changed, 1607 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/release-docker-deepep.yml create mode 100644 python/sglang/test/test_deepep_utils.py create mode 100755 scripts/ci_install_dependency_8_gpu.sh create mode 100644 test/srt/test_deepep_internode.py create mode 100644 test/srt/test_deepep_intranode.py create mode 100644 test/srt/test_deepep_low_latency.py create mode 100644 test/srt/test_moe_deepep_eval_accuracy_large.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index d22bc9a88..8e004a791 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -97,7 +97,7 @@ jobs: - name: Install dependencies run: | - bash scripts/ci_install_dependency.sh + bash scripts/ci_install_dependency_8_gpu.sh - name: Run test timeout-minutes: 40 @@ -259,9 +259,9 @@ jobs: finish: if: always() needs: [ - unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, + unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, unit-test-backend-8-gpu, performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu, - accuracy-test-1-gpu, accuracy-test-2-gpu + accuracy-test-1-gpu, accuracy-test-2-gpu, ] runs-on: ubuntu-latest steps: diff --git a/.github/workflows/release-docker-deepep.yml b/.github/workflows/release-docker-deepep.yml new file mode 100644 index 000000000..9f8607d9a --- /dev/null +++ b/.github/workflows/release-docker-deepep.yml @@ -0,0 +1,36 @@ +name: Build DeepEP Docker Image + +on: + workflow_dispatch: + schedule: + - cron: '0 0 * * *' + +jobs: + build-dev: + if: ${{ github.repository == 'sgl-project/sglang' }} + runs-on: ubuntu-22.04 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Free disk space + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + docker-images: false + android: true + dotnet: true + haskell: true + large-packages: true + swap-storage: false + + - name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and Push DeepEP Image + run: | + docker build . -f docker/Dockerfile.deepep -t lmsysorg/sglang:deepep --no-cache + docker push lmsysorg/sglang:deepep diff --git a/python/sglang/test/test_deepep_utils.py b/python/sglang/test/test_deepep_utils.py new file mode 100644 index 000000000..aa15b5a0b --- /dev/null +++ b/python/sglang/test/test_deepep_utils.py @@ -0,0 +1,219 @@ +# Copy from deepseek-ai/DeepEP/tests/test_utils.py + +import os +import sys +from typing import Optional + +import numpy as np +import torch +import torch.distributed as dist + + +def init_dist(local_rank: int, num_local_ranks: int): + # NOTES: you may rewrite this function with your own cluster settings + ip = os.getenv("MASTER_ADDR", "127.0.0.1") + port = int(os.getenv("MASTER_PORT", "8361")) + num_nodes = int(os.getenv("WORLD_SIZE", 1)) + node_rank = int(os.getenv("RANK", 0)) + assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 + + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{ip}:{port}", + world_size=num_nodes * num_local_ranks, + rank=node_rank * num_local_ranks + local_rank, + ) + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device("cuda") + torch.cuda.set_device(local_rank) + + return ( + dist.get_rank(), + dist.get_world_size(), + dist.new_group(list(range(num_local_ranks * num_nodes))), + ) + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double() + 1, y.double() + 1 + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return (1 - sim).item() + + +def per_token_cast_to_fp8(x: torch.Tensor): + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, n + ), (x_amax / 448.0).view(m, -1) + + +def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): + x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) + x_scales = x_scales.view(x_fp8.size(0), -1, 1) + return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16) + + +def inplace_unique(x: torch.Tensor, num_slots: int): + assert x.dim() == 2 + mask = x < 0 + x_padded = x.masked_fill(mask, num_slots) + bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device) + bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded)) + bin_count = bin_count[:, :num_slots] + sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True) + sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) + sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values + x[:, :].fill_(-1) + valid_len = min(num_slots, x.size(1)) + x[:, :valid_len] = sorted_bin_idx[:, :valid_len] + + +def create_grouped_scores( + scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int +): + num_tokens, num_experts = scores.shape + scores = scores.view(num_tokens, num_groups, -1) + mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device) + mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores) + return (scores * mask).view(num_tokens, num_experts) + + +def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + + # Warmup + for _ in range(num_warmups): + fn() + + # Flush L2 + cache.zero_() + + # Testing + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] + for i in range(num_tests): + # Record + start_events[i].record() + fn() + end_events[i].record() + if post_fn is not None: + post_fn() + torch.cuda.synchronize() + + times = np.array( + [s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)] + )[1:] + return np.average(times), np.min(times), np.max(times) + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, "w") + self.errnull_file = open(os.devnull, "w") + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto( + fn, + kernel_names, + num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: Optional[str] = None, + barrier_comm_profiling: bool = False, +): + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule + ) as prof: + for i in range(2): + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + if barrier_comm_profiling: + lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda") + rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda") + lhs @ rhs + dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda")) + for _ in range(num_tests): + fn() + prof.step() + + # Parse the profiling table + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tupled = isinstance(kernel_names, tuple) + prof_lines = ( + prof.key_averages() + .table(sort_by="cuda_time_total", max_name_column_width=100) + .split("\n") + ) + kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + for name in kernel_names: + assert ( + sum([name in line for line in prof_lines]) == 1 + ), f"Errors of the kernel {name} in the profiling table" + + # Save chrome traces + if trace_path is not None: + prof.export_chrome_trace(trace_path) + + # Return average kernel times + units = {"ms": 1e3, "us": 1e6} + kernel_times = [] + for name in kernel_names: + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + for unit, scale in units.items(): + if unit in time_str: + kernel_times.append(float(time_str.replace(unit, "")) / scale) + break + break + return tuple(kernel_times) if is_tupled else kernel_times[0] + + +def hash_tensor(t: torch.Tensor): + return t.view(torch.int64).sum().item() diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 79c43d5c3..6bcacb427 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -66,6 +66,7 @@ DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = ( ) DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" +DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-V3-0324" DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST = ( "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4" ) diff --git a/scripts/ci_install_dependency_8_gpu.sh b/scripts/ci_install_dependency_8_gpu.sh new file mode 100755 index 000000000..8d6ccd51b --- /dev/null +++ b/scripts/ci_install_dependency_8_gpu.sh @@ -0,0 +1,122 @@ +#!/bin/bash +# Install the dependency in CI. +set -euxo pipefail + +export GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ +export NVSHMEM_DIR=/opt/nvshmem/install +export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH" +export PATH="${NVSHMEM_DIR}/bin:$PATH" +export CUDA_HOME=/usr/local/cuda + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +bash "${SCRIPT_DIR}/killall_sglang.sh" + +# Clean up existing installations +pip uninstall -y flashinfer flashinfer_python sgl-kernel sglang vllm deepep || true +pip cache purge +rm -rf /root/.cache/flashinfer +if [ -d "lmms-eval" ]; then + rm -rf lmms-eval +fi +rm -rf /root/.cache/deepep +rm -rf /usr/local/lib/python3.10/dist-packages/flashinfer* +rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel* +rm -rf /usr/local/lib/python3.10/dist-packages/deepep* +dpkg -r gdrcopy gdrcopy-tests libgdrapi gdrdrv-dkms || true +rm -rf /opt/gdrcopy +rm -rf /usr/local/lib/libgdrapi* +rm -rf /usr/local/include/gdrapi.h +rm -rf /opt/nvshmem +rm -rf /usr/local/lib/libnvshmem* +rm -rf /usr/local/include/nvshmem* + +# Update pip +pip install --upgrade pip + +# Install sgl-kernel +pip install sgl-kernel==0.1.1 --no-cache-dir + +# Install the main package +pip install -e "python[all]" + +# Install additional dependencies +pip install torch_memory_saver +pip install transformers==4.51.0 sentence_transformers accelerate peft pandas datasets timm torchaudio==2.6.0 + +# For compling xgrammar kernels +pip install cuda-python nvidia-cuda-nvrtc-cu12 + +# For lmms_evals evaluating MMMU +git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git +pip install -e lmms-eval/ + +# Install FlashMLA for attention backend tests +pip install git+https://github.com/deepseek-ai/FlashMLA.git + +# Install system dependencies +# apt-get update && apt-get install -y libibverbs-dev infiniband-diags libmlx5-1 rdma-core openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 rdma-core-dev infiniband-diags-dev libibverbs-dev libibverbs-utils librdmacm-dev librdmacm-utils ibverbs-utils rdma-core-utils +apt install curl wget git sudo libibverbs-dev -y +apt install -y rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 +curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py + +wget https://github.com/Kitware/CMake/releases/download/v3.27.4/cmake-3.27.4-linux-x86_64.sh +chmod +x cmake-3.27.4-linux-x86_64.sh +./cmake-3.27.4-linux-x86_64.sh --skip-license --prefix=/usr/local +rm cmake-3.27.4-linux-x86_64.sh + +# Install GDRCopy +mkdir -p /opt/gdrcopy +mkdir -p /opt/nvshmem +cd /opt/gdrcopy +git clone https://github.com/NVIDIA/gdrcopy.git . +git checkout v2.4.4 +apt update +apt install -y nvidia-dkms-535 +apt install -y build-essential devscripts debhelper fakeroot pkg-config dkms +apt install -y check libsubunit0 libsubunit-dev +cd packages +CUDA=/usr/local/cuda ./build-deb-packages.sh +dpkg -i gdrdrv-dkms_*.deb +dpkg -i libgdrapi_*.deb +dpkg -i gdrcopy-tests_*.deb +dpkg -i gdrcopy_*.deb + +if [ ! -e "/usr/lib/x86_64-linux-gnu/libmlx5.so" ]; then + ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so +fi +apt-get update && apt-get install -y libfabric-dev + +# Clone DeepEP +git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep + +# Install NVSHMEM +cd /opt/nvshmem +wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz +tar -xf nvshmem_src_3.2.5-1.txz +mv nvshmem_src nvshmem +cd nvshmem +git apply /root/.cache/deepep/third-party/nvshmem.patch +NVSHMEM_SHMEM_SUPPORT=0 \ +NVSHMEM_UCX_SUPPORT=0 \ +NVSHMEM_USE_NCCL=0 \ +NVSHMEM_MPI_SUPPORT=0 \ +NVSHMEM_IBGDA_SUPPORT=1 \ +NVSHMEM_PMIX_SUPPORT=0 \ +NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ +NVSHMEM_USE_GDRCOPY=1 \ +cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/opt/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90 +cd build +make -j$(nproc) install + +# Install DeepEP +cd /root/.cache/deepep && python3 setup.py install + +# Verify configuration +echo "=== NCCL Configuration ===" +nvidia-smi topo -m +nvidia-smi nvlink -s +echo "=== Verify GDRCOPY ===" +gdrcopy_copybw +echo "=== Verify NVSHMEM ===" +nvshmem-info -a +# /opt/nvshmem/bin/perftest/device/pt-to-pt/shmem_put_bw diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 7a5c4310b..e1ce1160a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -96,6 +96,9 @@ suites = { TestFile("test_verl_engine.py", 64), ], "per-commit-8-gpu": [ + TestFile("test_deepep_intranode.py", 50), + TestFile("test_deepep_low_latency.py", 50), + TestFile("test_moe_deepep_eval_accuracy_large.py", 250), TestFile("test_local_attn.py", 250), TestFile("test_full_deepseek_v3.py", 250), TestFile("test_fa3.py", 30), diff --git a/test/srt/test_deepep_internode.py b/test/srt/test_deepep_internode.py new file mode 100644 index 000000000..1e4239606 --- /dev/null +++ b/test/srt/test_deepep_internode.py @@ -0,0 +1,445 @@ +# Copy from deepseek-ai/DeepEP/tests/test_internode.py + +import os +import time + +# noinspection PyUnresolvedReferences +import deep_ep + +# Test compatibility with low latency functions +import test_deepep_low_latency +import torch +import torch.distributed as dist + +from sglang.test.test_deepep_utils import ( + bench, + calc_diff, + create_grouped_scores, + init_dist, + inplace_unique, + per_token_cast_back, + per_token_cast_to_fp8, +) + + +def test_main( + num_sms: int, + local_rank: int, + num_local_ranks: int, + num_ranks: int, + num_nodes: int, + rank: int, + buffer: deep_ep.Buffer, + group: dist.ProcessGroup, +): + # Settings + num_tokens, hidden, num_topk_groups, num_topk, num_experts = ( + 4096, + 7168, + min(num_nodes, 4), + 8, + (256 // num_ranks) * num_ranks, + ) + assert num_experts % num_ranks == 0 and num_local_ranks == 8 + if local_rank == 0: + print( + f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}", + flush=True, + ) + + # Random data + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank + x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + x_e4m3 = per_token_cast_to_fp8(x) + scores = ( + torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + + 1 + ) + group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) + group_idx = torch.topk( + group_scores, k=num_topk_groups, dim=-1, sorted=False + ).indices + masked_scores = create_grouped_scores(scores, group_idx, num_nodes) + topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[ + 1 + ] + topk_weights = ( + torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank + ) + topk_weights_pure_rand = torch.randn( + (num_tokens, num_topk), dtype=torch.float32, device="cuda" + ) + rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rank_idx, num_ranks) + rdma_rank_idx = rank_idx // num_local_ranks + rdma_rank_idx.masked_fill_(rank_idx == -1, -1) + inplace_unique(rdma_rank_idx, num_nodes) + + # RDMA dispatch counts + rdma_idx = topk_idx // (num_experts // num_nodes) + rdma_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rdma_idx, num_nodes) + num_rdma_token_sent = rdma_idx.ne(-1).sum().item() + + # Expert meta + num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda") + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx == i).sum() + gbl_num_tokens_per_expert = num_tokens_per_expert.clone() + dist.all_reduce(gbl_num_tokens_per_expert, group=group) + + # Rank layout meta + num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda") + num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda") + token_idx_in_rank = torch.full( + (num_ranks, num_tokens), -1, dtype=torch.long, device="cuda" + ) + for i in range(num_ranks): + num_tokens_per_rank[i] = (rank_idx == i).sum() + token_sel = (rank_idx == i).max(dim=-1)[0] + count = token_sel.sum().item() + tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] + tokens[:count] = torch.sort(tokens[:count])[0] + token_idx_in_rank[i][tokens[:count]] = torch.arange( + count, dtype=torch.long, device="cuda" + ) + for i in range(num_nodes): + num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum() + token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) + is_token_in_rank = token_idx_in_rank >= 0 + gbl_num_tokens_per_rank = num_tokens_per_rank.clone() + dist.all_reduce(gbl_num_tokens_per_rank, group=group) + + ( + ref_num_tokens_per_rank, + ref_num_tokens_per_rdma_rank, + ref_num_tokens_per_expert, + ref_is_token_in_rank, + _, + ) = buffer.get_dispatch_layout(topk_idx, num_experts) + assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank) + assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank) + assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert) + assert torch.allclose(ref_is_token_in_rank, is_token_in_rank) + t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] + if local_rank == 0: + print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True) + print("", flush=True) + group.barrier() + time.sleep(1) + + # Config + rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512) + config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size) + + # Test dispatch + # noinspection PyShadowingNames + def check_data(check_x, recv_gbl_rank_prefix_sum): + assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1)) + check_start = 0 + for i in range(num_ranks): + check_end = recv_gbl_rank_prefix_sum[i].item() + assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 + check_start = check_end + + for previous_mode in (False, True): + for async_mode in (False, True): + for current_x in (x_pure_rand, x, x_e4m3): + for with_topk in (False, True): + if local_rank == 0: + print( + f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', + flush=True, + end="", + ) + dispatch_args = { + "x": current_x, + "num_tokens_per_rank": num_tokens_per_rank, + "num_tokens_per_rdma_rank": num_tokens_per_rdma_rank, + "is_token_in_rank": is_token_in_rank, + "num_tokens_per_expert": num_tokens_per_expert, + "config": config, + "async_finish": async_mode, + } + if with_topk: + dispatch_args.update( + { + "topk_idx": topk_idx, + "topk_weights": ( + topk_weights_pure_rand + if current_x is x_pure_rand + else topk_weights + ), + } + ) + if previous_mode: + dispatch_args.update({"previous_event": buffer.capture()}) + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_num_tokens_per_expert_list, + handle, + event, + ) = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_x = ( + per_token_cast_back(*recv_x) + if isinstance(recv_x, tuple) + else recv_x + ) + + # Checks + recv_gbl_rank_prefix_sum = handle[-4] + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size( + 0 + ), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}" + assert ( + gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() + == recv_num_tokens_per_expert_list + ) + if current_x is not x_pure_rand: + check_data(recv_x, recv_gbl_rank_prefix_sum) + if with_topk: + # Check `topk_idx` + assert ( + recv_topk_idx.eq(-1) + | ( + (recv_topk_idx >= 0) + & (recv_topk_idx < (num_experts // num_ranks)) + ) + ).sum().item() == recv_topk_idx.numel() + for i, count in enumerate(recv_num_tokens_per_expert_list): + assert recv_topk_idx.eq(i).sum().item() == count + + # Check `topk_weights` + if current_x is not x_pure_rand: + recv_topk_weights[recv_topk_idx.eq(-1)] = ( + recv_topk_weights.amax(dim=1, keepdim=True).expand_as( + recv_topk_weights + )[recv_topk_idx.eq(-1)] + ) + check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) + + # Test cached dispatch (must without top-k staffs) + if not with_topk: + dispatch_args = { + "x": current_x, + "handle": handle, + "config": config, + "async_finish": async_mode, + } + if previous_mode: + dispatch_args.update({"previous_event": buffer.capture()}) + recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_x = ( + per_token_cast_back(*recv_x) + if isinstance(recv_x, tuple) + else recv_x + ) + if current_x is not x_pure_rand: + check_data(recv_x, recv_gbl_rank_prefix_sum) + + # Test combine + combine_args = { + "x": recv_x, + "handle": handle, + "config": config, + "async_finish": async_mode, + } + if with_topk: + combine_args.update({"topk_weights": recv_topk_weights}) + if previous_mode: + dispatch_args.update({"previous_event": buffer.capture()}) + combined_x, combined_topk_weights, event = buffer.combine( + **combine_args + ) + event.current_stream_wait() if async_mode else () + check_x = combined_x.float() / is_token_in_rank.sum( + dim=1 + ).unsqueeze(1) + ref_x = x_pure_rand if current_x is x_pure_rand else x + assert calc_diff(check_x, ref_x) < 5e-6 + if with_topk: + check_topk_weights = ( + combined_topk_weights + if (current_x is x_pure_rand) + else ( + combined_topk_weights + / is_token_in_rank.sum(dim=1).unsqueeze(1) + ) + ) + ref_topk_weights = ( + topk_weights_pure_rand + if current_x is x_pure_rand + else topk_weights + ) + assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + + # For later tuning + dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 + dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 + combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes + combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes + + if local_rank == 0: + print(" passed", flush=True) + if local_rank == 0: + print("", flush=True) + + # Tune dispatch performance + best_dispatch_results = None + fp8_factor = (1 + 4 / 128) / 2 + for current_x in (x_e4m3, x): + best_time, best_results = 1e10, None + rdma_send_bytes = ( + (dispatch_bf16_rdma_send_bytes * fp8_factor) + if isinstance(current_x, tuple) + else dispatch_bf16_rdma_send_bytes + ) + nvl_recv_bytes = ( + (dispatch_bf16_nvl_recv_bytes * fp8_factor) + if isinstance(current_x, tuple) + else dispatch_bf16_nvl_recv_bytes + ) + for nvl_chunk_size in range(4, 33, 4): + for rdma_chunk_size in range(4, 33, 4): + config = deep_ep.Config( + num_sms, + nvl_chunk_size, + nvl_buffer_size, + rdma_chunk_size, + rdma_buffer_size, + ) + tune_args = {"x": current_x, "handle": handle, "config": config} + t = bench(lambda: buffer.dispatch(**tune_args))[0] + if t < best_time: + best_time, best_results = t, ( + num_sms, + nvl_chunk_size, + rdma_chunk_size, + ) + if local_rank == 0: + print( + f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ", + flush=True, + ) + if local_rank == 0: + print( + f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', + flush=True, + ) + print("", flush=True) + + if isinstance(current_x, tuple): + # Gather FP8 the best config from rank 0 + best_dispatch_results = torch.tensor( + [best_results[0], best_results[1], best_results[2]], + dtype=torch.int32, + device="cuda", + ) + all_best_fp8_results_list = [ + torch.zeros_like(best_dispatch_results) + for _ in range(torch.distributed.get_world_size()) + ] + dist.all_gather( + all_best_fp8_results_list, best_dispatch_results, group=group + ) + best_dispatch_results = all_best_fp8_results_list[0].tolist() + dispatch_config = deep_ep.Config( + best_dispatch_results[0], + best_dispatch_results[1], + nvl_buffer_size, + best_dispatch_results[2], + rdma_buffer_size, + ) + + dispatch_args = { + "x": x, + "num_tokens_per_rank": num_tokens_per_rank, + "num_tokens_per_rdma_rank": num_tokens_per_rdma_rank, + "is_token_in_rank": is_token_in_rank, + "num_tokens_per_expert": num_tokens_per_expert, + "config": dispatch_config if dispatch_config is not None else config, + } + recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + + # Tune combine performance + best_time, best_results = 1e10, None + for nvl_chunk_size in range(1, 5, 1): + for rdma_chunk_size in range(8, 33, 4): + config = deep_ep.Config( + num_sms, + nvl_chunk_size, + nvl_buffer_size, + rdma_chunk_size, + rdma_buffer_size, + ) + tune_args = {"x": recv_x, "handle": handle, "config": config} + t = bench(lambda: buffer.combine(**tune_args))[0] + if local_rank == 0: + print( + f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ", + flush=True, + ) + if t < best_time: + best_time, best_results = t, ( + num_sms, + nvl_chunk_size, + rdma_chunk_size, + ) + + if local_rank == 0: + print( + f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)", + flush=True, + ) + print("", flush=True) + + +# noinspection PyUnboundLocalVariable +def test_loop(local_rank: int, num_local_ranks: int): + num_nodes = int(os.getenv("WORLD_SIZE", 1)) + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + test_ll_compatibility = False + if test_ll_compatibility: + ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 + + buffer = deep_ep.Buffer( + group, + int(1e9), + int(1e9), + low_latency_mode=test_ll_compatibility, + num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), + ) + assert num_local_ranks == 8 and num_ranks > 8 + torch.manual_seed(rank) + + for i in (24,): + test_main( + i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group + ) + if local_rank == 0: + print("", flush=True) + + # Test compatibility with low latency functions + if test_ll_compatibility: + buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) + test_deepep_low_latency.test_main( + ll_num_tokens, + ll_hidden, + ll_num_experts, + ll_num_topk, + rank, + num_ranks, + group, + buffer, + seed=1, + ) + + +if __name__ == "__main__": + num_processes = 8 + torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) diff --git a/test/srt/test_deepep_intranode.py b/test/srt/test_deepep_intranode.py new file mode 100644 index 000000000..97acd000c --- /dev/null +++ b/test/srt/test_deepep_intranode.py @@ -0,0 +1,379 @@ +# Copy from deepseek-ai/DeepEP/tests/test_intranode.py + +import os +import time + +# noinspection PyUnresolvedReferences +import deep_ep + +# Test compatibility with low latency functions +import test_deepep_low_latency +import torch +import torch.distributed as dist + +from sglang.test.test_deepep_utils import ( + bench, + calc_diff, + init_dist, + inplace_unique, + per_token_cast_back, + per_token_cast_to_fp8, +) + + +def test_main( + num_sms: int, + local_rank: int, + num_ranks: int, + rank: int, + buffer: deep_ep.Buffer, + group: dist.ProcessGroup, +): + # Settings + num_tokens, hidden, num_topk, num_experts = ( + 4096, + 7168, + 8, + (256 // num_ranks) * num_ranks, + ) + assert num_experts % num_ranks == 0 + if local_rank == 0: + print( + f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}", + flush=True, + ) + + # Random data + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank + x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + x_e4m3 = per_token_cast_to_fp8(x) + scores = ( + torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + + 1 + ) + topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_weights = ( + torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank + ) + topk_weights_pure_rand = torch.randn( + (num_tokens, num_topk), dtype=torch.float32, device="cuda" + ) + rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rank_idx, num_ranks) + + # Expert meta + num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda") + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx == i).sum() + gbl_num_tokens_per_expert = num_tokens_per_expert.clone() + dist.all_reduce(gbl_num_tokens_per_expert, group=group) + + # Rank layout meta + num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda") + token_idx_in_rank = torch.full( + (num_ranks, num_tokens), -1, dtype=torch.long, device="cuda" + ) + for i in range(num_ranks): + num_tokens_per_rank[i] = (rank_idx == i).sum() + token_sel = (rank_idx == i).max(dim=-1)[0] + count = token_sel.sum().item() + tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] + tokens[:count] = torch.sort(tokens[:count])[0] + token_idx_in_rank[i][tokens[:count]] = torch.arange( + count, dtype=torch.long, device="cuda" + ) + token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) + is_token_in_rank = token_idx_in_rank >= 0 + gbl_num_tokens_per_rank = num_tokens_per_rank.clone() + dist.all_reduce(gbl_num_tokens_per_rank, group=group) + + ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = ( + buffer.get_dispatch_layout(topk_idx, num_experts) + ) + assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank) + assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert) + assert torch.allclose(ref_is_token_in_rank, is_token_in_rank) + t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] + if local_rank == 0: + print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True) + print("", flush=True) + group.barrier() + time.sleep(1) + + # Config + nvl_buffer_size = 256 + config = deep_ep.Config(num_sms, 8, nvl_buffer_size) + + # Test dispatch + # noinspection PyShadowingNames + def check_data(check_x, rank_prefix_matrix): + assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1)) + check_start = 0 + for i in range(num_ranks): + check_end = rank_prefix_matrix[i][rank].item() + assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 + check_start = check_end + + for previous_mode in (False, True): + for async_mode in (False, True): + for current_x in (x_pure_rand, x, x_e4m3): + for with_topk in (False, True): + if local_rank == 0: + print( + f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', + flush=True, + end="", + ) + dispatch_args = { + "x": current_x, + "num_tokens_per_rank": num_tokens_per_rank, + "is_token_in_rank": is_token_in_rank, + "num_tokens_per_expert": num_tokens_per_expert, + "config": config, + "async_finish": async_mode, + } + if with_topk: + dispatch_args.update( + { + "topk_idx": topk_idx, + "topk_weights": ( + topk_weights_pure_rand + if current_x is x_pure_rand + else topk_weights + ), + } + ) + if previous_mode: + dispatch_args.update({"previous_event": buffer.capture()}) + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_num_tokens_per_expert_list, + handle, + event, + ) = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_x = ( + per_token_cast_back(*recv_x) + if isinstance(recv_x, tuple) + else recv_x + ) + + # Checks + rank_prefix_matrix = handle[0] + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size( + 0 + ), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}" + assert ( + gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() + == recv_num_tokens_per_expert_list + ) + if current_x is not x_pure_rand: + check_data(recv_x, rank_prefix_matrix) + if with_topk: + # Check `topk_idx` + assert ( + recv_topk_idx.eq(-1) + | ( + (recv_topk_idx >= 0) + & (recv_topk_idx < (num_experts // num_ranks)) + ) + ).sum().item() == recv_topk_idx.numel() + for i, count in enumerate(recv_num_tokens_per_expert_list): + assert recv_topk_idx.eq(i).sum().item() == count + + # Check `topk_weights` + if current_x is not x_pure_rand: + recv_topk_weights[recv_topk_idx.eq(-1)] = ( + recv_topk_weights.amax(dim=1, keepdim=True).expand_as( + recv_topk_weights + )[recv_topk_idx.eq(-1)] + ) + check_data(recv_topk_weights, rank_prefix_matrix) + + # Test cached dispatch (must without top-k staffs) + if not with_topk: + dispatch_args = { + "x": current_x, + "handle": handle, + "config": config, + "async_finish": async_mode, + } + if previous_mode: + dispatch_args.update({"previous_event": buffer.capture()}) + recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_x = ( + per_token_cast_back(*recv_x) + if isinstance(recv_x, tuple) + else recv_x + ) + if current_x is not x_pure_rand: + check_data(recv_x, rank_prefix_matrix) + + # Test combine + combine_args = { + "x": recv_x, + "handle": handle, + "config": config, + "async_finish": async_mode, + } + if with_topk: + combine_args.update({"topk_weights": recv_topk_weights}) + if previous_mode: + dispatch_args.update({"previous_event": buffer.capture()}) + combined_x, combined_topk_weights, event = buffer.combine( + **combine_args + ) + event.current_stream_wait() if async_mode else () + check_x = combined_x.float() / is_token_in_rank.sum( + dim=1 + ).unsqueeze(1) + ref_x = x_pure_rand if current_x is x_pure_rand else x + assert calc_diff(check_x, ref_x) < 5e-6 + if with_topk: + check_topk_weights = ( + combined_topk_weights + if (current_x is x_pure_rand) + else ( + combined_topk_weights + / is_token_in_rank.sum(dim=1).unsqueeze(1) + ) + ) + ref_topk_weights = ( + topk_weights_pure_rand + if current_x is x_pure_rand + else topk_weights + ) + assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + + # For later tuning + dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 + combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes + + if local_rank == 0: + print(" passed", flush=True) + if local_rank == 0: + print("", flush=True) + + # Tune dispatch performance + best_dispatch_results = None + fp8_factor = (1 + 4 / 128) / 2 + for current_x in (x_e4m3, x): + best_time, best_results = 1e10, None + nvl_recv_bytes = ( + (dispatch_bf16_nvl_recv_bytes * fp8_factor) + if isinstance(current_x, tuple) + else dispatch_bf16_nvl_recv_bytes + ) + for nvl_chunk_size in range(4, 33, 4): + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + tune_args = {"x": current_x, "handle": handle, "config": config} + t = bench(lambda: buffer.dispatch(**tune_args))[0] + if t < best_time: + best_time, best_results = t, (num_sms, nvl_chunk_size) + if local_rank == 0: + print( + f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ", + flush=True, + ) + if local_rank == 0: + print( + f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', + flush=True, + ) + print("", flush=True) + + if isinstance(current_x, tuple): + # Gather FP8 the best config from rank 0 + best_dispatch_results = torch.tensor( + [best_results[0], best_results[1]], dtype=torch.int32, device="cuda" + ) + all_best_fp8_results_list = [ + torch.zeros_like(best_dispatch_results) + for _ in range(torch.distributed.get_world_size()) + ] + dist.all_gather( + all_best_fp8_results_list, best_dispatch_results, group=group + ) + best_dispatch_results = all_best_fp8_results_list[0].tolist() + dispatch_config = deep_ep.Config( + best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size + ) + + dispatch_args = { + "x": x, + "num_tokens_per_rank": num_tokens_per_rank, + "is_token_in_rank": is_token_in_rank, + "num_tokens_per_expert": num_tokens_per_expert, + "config": dispatch_config if dispatch_config is not None else config, + } + recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + + # Tune combine performance + best_time, best_results = 1e10, None + for nvl_chunk_size in range(1, 7, 1): + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + tune_args = {"x": recv_x, "handle": handle, "config": config} + t = bench(lambda: buffer.combine(**tune_args))[0] + if local_rank == 0: + print( + f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ", + flush=True, + ) + if t < best_time: + best_time, best_results = t, (num_sms, nvl_chunk_size) + + if local_rank == 0: + print( + f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)", + flush=True, + ) + print("", flush=True) + + +# noinspection PyUnboundLocalVariable +def test_loop(local_rank: int, num_local_ranks: int): + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + test_ll_compatibility, num_rdma_bytes = False, 0 + if test_ll_compatibility: + ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + ll_num_tokens, ll_hidden, num_ranks, ll_num_experts + ) + + buffer = deep_ep.Buffer( + group, + int(1e9), + num_rdma_bytes, + low_latency_mode=test_ll_compatibility, + num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), + ) + torch.manual_seed(rank) + + for i in (24,): + test_main(i, local_rank, num_ranks, rank, buffer, group) + if local_rank == 0: + print("", flush=True) + + # Test compatibility with low latency functions + if test_ll_compatibility: + buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) + test_deepep_low_latency.test_main( + ll_num_tokens, + ll_hidden, + ll_num_experts, + ll_num_topk, + rank, + num_ranks, + group, + buffer, + seed=1, + ) + + +if __name__ == "__main__": + num_processes = 8 + torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) diff --git a/test/srt/test_deepep_low_latency.py b/test/srt/test_deepep_low_latency.py new file mode 100644 index 000000000..1b3f34da9 --- /dev/null +++ b/test/srt/test_deepep_low_latency.py @@ -0,0 +1,325 @@ +# Copy from deepseek-ai/DeepEP/tests/test_low_latency.py + +import random +from functools import partial + +import deep_ep +import torch +import torch.distributed as dist + +from sglang.test.test_deepep_utils import ( + bench, + bench_kineto, + calc_diff, + hash_tensor, + init_dist, + per_token_cast_back, +) + + +def test_main( + num_tokens: int, + hidden: int, + num_experts: int, + num_topk: int, + rank: int, + num_ranks: int, + group: dist.ProcessGroup, + buffer: deep_ep.Buffer, + seed: int = 0, +): + torch.manual_seed(seed + rank) + random.seed(seed + rank) + + assert num_experts % num_ranks == 0 + num_local_experts = num_experts // num_ranks + + # NOTES: the integers greater than 256 exceeds the BF16 precision limit + rank_offset = 128 + assert ( + num_ranks - rank_offset < 257 + ), "Too many ranks (exceeding test precision limit)" + + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * ( + rank - rank_offset + ) + x[:, -128:] = torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1) + scores = ( + torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + + 1 + ) + topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] + topk_weights = torch.randn( + (num_tokens, num_topk), dtype=torch.float32, device="cuda" + ).abs() + + # Randomly mask some positions + for i in range(10): + topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = ( + -1 + ) + + # Check dispatch correctness + do_check = True + hash_value, num_times = 0, 0 + for return_recv_hook in (False, True): + for dispatch_use_fp8 in (False, True): + num_times += 1 + for i in range((num_times % 2) + 1): + packed_recv_x, packed_recv_count, handle, event, hook = ( + buffer.low_latency_dispatch( + x, + topk_idx, + num_tokens, + num_experts, + use_fp8=dispatch_use_fp8, + async_finish=not return_recv_hook, + return_recv_hook=return_recv_hook, + ) + ) + hook() if return_recv_hook else event.current_stream_wait() + packed_recv_x = ( + (packed_recv_x[0], packed_recv_x[1].contiguous()) + if dispatch_use_fp8 + else packed_recv_x + ) + simulated_gemm_x = ( + per_token_cast_back( + packed_recv_x[0].view(-1, hidden), + packed_recv_x[1].view(-1, hidden // 128), + ).view(packed_recv_x[0].shape) + if dispatch_use_fp8 + else packed_recv_x.clone() + ) + all_topk_idx = torch.empty( + (num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda" + ) + dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) + for i in range(num_local_experts if do_check else 0): + expert_id = rank * num_local_experts + i + recv_x = ( + per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) + if dispatch_use_fp8 + else packed_recv_x[i] + ) + recv_count, recv_src_info, recv_layout_range = ( + packed_recv_count[i], + handle[0][i], + handle[1][i], + ) + + # Check expert indices + int_mask = (2**32) - 1 + num_valid_tokens = recv_count.item() + assert ( + num_valid_tokens == (recv_layout_range & int_mask).sum().item() + ), f"{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()" + assert ( + num_valid_tokens == (all_topk_idx == expert_id).sum().item() + ), f"{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}" + + # Check received data + recv_x = recv_x[:num_valid_tokens] + recv_x_amin = recv_x[:, :-128].amin(dim=-1) + recv_src_info = recv_src_info[:num_valid_tokens] + assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) + assert ( + recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens + ).sum().item() == 0 + for j in range(num_ranks): + begin_idx, count = (recv_layout_range[j] >> 32).item(), ( + recv_layout_range[j] & int_mask + ).item() + assert (recv_x_amin == j - rank_offset).sum().item() == ( + all_topk_idx[j] == expert_id + ).sum().item() + assert ( + recv_x[begin_idx : begin_idx + count][:-128] - j + ).sum().item() == 0 + if dispatch_use_fp8: + hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) + hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) + else: + hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) + + # Check combine correctness + for zero_copy in (False, True): + if zero_copy: + buffer.get_next_low_latency_combine_buffer(handle)[ + :, :, : + ] = simulated_gemm_x + out = torch.empty( + (num_tokens, hidden), dtype=torch.bfloat16, device="cuda" + ) + combined_x, event, hook = buffer.low_latency_combine( + simulated_gemm_x, + topk_idx, + topk_weights, + handle, + async_finish=not return_recv_hook, + zero_copy=zero_copy, + return_recv_hook=return_recv_hook, + out=out, + ) + hook() if return_recv_hook else event.current_stream_wait() + if do_check: + diff = calc_diff( + x + * topk_weights.masked_fill(topk_idx == -1, 0) + .sum(dim=1) + .view(-1, 1), + combined_x, + ) + assert torch.isnan(combined_x).sum().item() == 0 + assert diff < 1e-5, f"Error: {diff=}, {zero_copy=}" + hash_value ^= hash_tensor(combined_x) + + def create_test_cast_with_outliers(num_outliers): + tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + tmp /= tmp.abs().amax(dim=1).view(-1, 1) + assert tmp.abs().amax().item() <= 1 + + # Create some amax outliers + for i in range(num_outliers): + tmp[random.randint(0, num_tokens - 1)] *= 1e3 + return tmp + + # noinspection PyShadowingNames + def large_gemm_with_hook(hook): + mat_0 = torch.randn((8192, 8192), dtype=torch.float) + mat_1 = torch.randn((8192, 8192), dtype=torch.float) + mat_0 @ mat_1 + hook() + + # noinspection PyShadowingNames + def test_func(zero_copy: bool, return_recv_hook: bool): + recv_x, recv_count, handle, event, hook = buffer.low_latency_dispatch( + x, + topk_idx, + num_tokens, + num_experts, + async_finish=False, + return_recv_hook=return_recv_hook, + ) + large_gemm_with_hook(hook) if return_recv_hook else None + if zero_copy: + buffer.get_next_low_latency_combine_buffer(handle)[ + :, :, : + ] = simulated_gemm_x + combined_x, event, hook = buffer.low_latency_combine( + simulated_gemm_x, + topk_idx, + topk_weights, + handle, + zero_copy=zero_copy, + return_recv_hook=return_recv_hook, + ) + large_gemm_with_hook(hook) if return_recv_hook else None + + # Calculate bandwidth + num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 + num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 + for i in range(num_tokens): + num_selections = (topk_idx[i] != -1).sum().item() + num_dispatch_comm_bytes += num_fp8_bytes * num_selections + num_combine_comm_bytes += num_bf16_bytes * num_selections + + # Dispatch + combine testing + avg_t, min_t, max_t = bench( + partial(test_func, zero_copy=False, return_recv_hook=False) + ) + print( + f"[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, " + f"avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us", + flush=True, + ) + + # Separate profiling + for return_recv_hook in (False, True): + group.barrier() + dispatch_t, combine_t = bench_kineto( + partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook), + kernel_names=("dispatch", "combine"), + barrier_comm_profiling=True, + suppress_kineto_output=True, + ) + if not return_recv_hook: + print( + f"[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | " + f"Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us", + flush=True, + ) + else: + print( + f"[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | " + f"Combine send/recv time: {combine_t * 2 * 1e6:.2f} us", + flush=True, + ) + + return hash_value + + +# noinspection PyUnboundLocalVariable +def test_loop(local_rank: int, num_local_ranks: int): + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288 + + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + num_tokens, hidden, num_ranks, num_experts + ) + if local_rank == 0: + print(f"Allocating buffer size: {num_rdma_bytes / 1e6} MB ...", flush=True) + buffer = deep_ep.Buffer( + group, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_experts // num_ranks, + ) + test_main( + num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + seed=1, + ) + + do_pressure_test = False + for seed in range(int(1e9) if do_pressure_test else 0): + if local_rank == 0: + print(f"Testing with seed {seed} ...", flush=True) + ref_hash = test_main( + num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + seed=seed, + ) + for i in range(20): + assert ( + test_main( + num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + seed=seed, + ) + == ref_hash + ), f"Error: seed={seed}" + + +if __name__ == "__main__": + # TODO: you may modify NUMA binding for less CPU overhead + num_processes = 8 + torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) diff --git a/test/srt/test_moe_deepep_eval_accuracy_large.py b/test/srt/test_moe_deepep_eval_accuracy_large.py new file mode 100644 index 000000000..618135628 --- /dev/null +++ b/test/srt/test_moe_deepep_eval_accuracy_large.py @@ -0,0 +1,74 @@ +""" +Usage: +python -m unittest test_moe_deepep_eval_accuracy_large.TestMoEDeepEPEvalAccuracyLarge.test_mmlu +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestMoEDeepEPEvalAccuracyLarge(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "8", + "--enable-deepep-moe", + "--cuda-graph-max-bs", + "128", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=8, + data_path=None, + num_questions=200, + parallel=64, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Eval accuracy of GSM8K: {metrics=}") + + self.assertGreater(metrics["accuracy"], 0.93) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + print(f"Eval accuracy of MMLU: {metrics=}") + self.assertGreater(metrics["score"], 0.87) + + +if __name__ == "__main__": + unittest.main()