From d08c77c434981534267d13ef78c22a817ac08775 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Mon, 13 Jan 2025 23:09:00 +0800 Subject: [PATCH] Sampling penalties memory interface (#2870) --- ... benchmark_deepseekv3_moe_align_blocks.py} | 3 +- python/pyproject.toml | 2 +- .../penalizers/repetition_penalty.py | 20 ++- .../srt/sampling/sampling_batch_info.py | 19 ++- python/sglang/srt/utils.py | 4 + .../benchmark_sampling_scaling_penalties.py | 159 ++++++++++++++++++ sgl-kernel/tests/test_moe_align.py | 85 ++++++---- 7 files changed, 251 insertions(+), 41 deletions(-) rename benchmark/kernels/fused_moe_triton/{benchmark_moe_align_blocks.py => benchmark_deepseekv3_moe_align_blocks.py} (98%) create mode 100644 sgl-kernel/benchmark/benchmark_sampling_scaling_penalties.py diff --git a/benchmark/kernels/fused_moe_triton/benchmark_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py similarity index 98% rename from benchmark/kernels/fused_moe_triton/benchmark_moe_align_blocks.py rename to benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py index 92547ea95..0a6049a12 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_moe_align_blocks.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -222,8 +222,9 @@ configs = list(itertools.product(batch_size_range, seq_length_range)) def benchmark(batch_size, seq_len, provider): num_experts = 256 block_size = 128 + topk = 8 topk_ids = torch.randint( - 0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda" + 0, num_experts, (batch_size * seq_len, topk), dtype=torch.int32, device="cuda" ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) diff --git a/python/pyproject.toml b/python/pyproject.toml index a236469a1..4b627ae94 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,7 +27,7 @@ runtime_common = [ ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", + "sgl-kernel>=0.0.2.post12", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "flashinfer==0.1.6" ] diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py index 4c293b895..fcd5ff71c 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py @@ -3,6 +3,11 @@ from typing import List import torch from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs +from sglang.srt.utils import is_cuda_available + +is_cuda = is_cuda_available() +if is_cuda: + from sgl_kernel import sampling_scaling_penalties class BatchedRepetitionPenalizer(_BatchedPenalizer): @@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer): self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] def _apply(self, logits: torch.Tensor) -> torch.Tensor: - return torch.where( - logits > 0, - logits / self.cumulated_repetition_penalties, - logits * self.cumulated_repetition_penalties, - ) + if is_cuda: + return sampling_scaling_penalties( + logits, self.cumulated_repetition_penalties + ) + else: + return torch.where( + logits > 0, + logits / self.cumulated_repetition_penalties, + logits * self.cumulated_repetition_penalties, + ) def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 9497e53d3..6eda63c70 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional import torch +from sglang.srt.utils import is_cuda_available + +is_cuda = is_cuda_available() +if is_cuda: + from sgl_kernel import sampling_scaling_penalties + import sglang.srt.sampling.penaltylib as penaltylib logger = logging.getLogger(__name__) @@ -245,11 +251,14 @@ class SamplingBatchInfo: # repetition if self.scaling_penalties is not None: - logits[:] = torch.where( - logits > 0, - logits / self.scaling_penalties, - logits * self.scaling_penalties, - ) + if is_cuda: + logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties) + else: + logits[:] = torch.where( + logits > 0, + logits / self.scaling_penalties, + logits * self.scaling_penalties, + ) # Apply regex vocab_mask if self.vocab_mask is not None: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 51ca91a96..e70e6b425 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -97,6 +97,10 @@ def is_flashinfer_available(): return torch.cuda.is_available() and torch.version.cuda +def is_cuda_available(): + return torch.cuda.is_available() and torch.version.cuda + + def is_ipv6(address): try: ipaddress.IPv6Address(address) diff --git a/sgl-kernel/benchmark/benchmark_sampling_scaling_penalties.py b/sgl-kernel/benchmark/benchmark_sampling_scaling_penalties.py new file mode 100644 index 000000000..000dab0d8 --- /dev/null +++ b/sgl-kernel/benchmark/benchmark_sampling_scaling_penalties.py @@ -0,0 +1,159 @@ +import itertools + +import torch +import triton +from sgl_kernel import sampling_scaling_penalties + + +def sampling_scaling_penalties_naive(logits, scaling_penalties): + return torch.where( + logits > 0, logits / scaling_penalties, logits * scaling_penalties + ) + + +def sampling_scaling_penalties_kernel(logits, scaling_penalties): + return sampling_scaling_penalties(logits, scaling_penalties) + + +def test_memory(func, _iter): + total_mem = [] + + for _ in range(_iter): + torch.cuda.memory.reset_peak_memory_stats() + func() + mem = torch.cuda.max_memory_allocated() / (2**20) + total_mem.append(mem) + + return sum(total_mem) / len(total_mem) + + +def calculate_diff(batch_size, vocab_size): + dtype = torch.bfloat16 + device = torch.device("cuda") + + logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) + scaling_penalties = ( + torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 + ) + + output_naive = sampling_scaling_penalties_naive( + logits.clone(), scaling_penalties.clone() + ) + output_kernel = sampling_scaling_penalties_kernel( + logits.clone(), scaling_penalties.clone() + ) + + print(f"Naive output={output_naive}") + print(f"Kernel output={output_kernel}") + + if torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2): + print("✅ Both implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 12)] +vocab_size_range = [2**i for i in range(10, 17)] +configs = list(itertools.product(batch_size_range, vocab_size_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "vocab_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "kernel"], + line_names=["PyTorch Naive", "SGL Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name="sampling-scaling-penalties-performance", + args={}, + ) +) +def benchmark(batch_size, vocab_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) + scaling_penalties = ( + torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sampling_scaling_penalties_naive( + logits.clone(), + scaling_penalties.clone(), + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sampling_scaling_penalties_kernel( + logits.clone(), + scaling_penalties.clone(), + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "vocab_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "kernel"], + line_names=["PyTorch Naive", "SGL Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="GPU memory usage (MB)", + plot_name="sampling-scaling-penalties-memory", + args={}, + ) +) +def benchmark_memory(batch_size, vocab_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + print( + f"Running memory benchmark with batch_size={batch_size}, vocab_size={vocab_size}, provider={provider}" + ) + + def run_kernel(): + logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) + scaling_penalties = ( + torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 + ) + + if provider == "naive": + return sampling_scaling_penalties_naive(logits, scaling_penalties) + else: + return sampling_scaling_penalties_kernel(logits, scaling_penalties) + + mem = test_memory(run_kernel, _iter=10) + return mem + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/sampling_scaling_penalties/", + help="Path to save sampling_scaling_penalties benchmark results", + ) + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=4, vocab_size=4096) + + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) + + # Run memory benchmark + benchmark_memory.run(print_data=True, save_path=args.save_path) diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 92596a47e..2fca90b2f 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -3,38 +3,65 @@ from sgl_kernel import moe_align_block_size def test_moe_align_block_size(): + # For DeepSeek V3, we have 256 experts num_experts = 256 - block_size = 128 - topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda") - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty( - (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device - ) - sorted_ids.fill_(topk_ids.numel()) - max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids = torch.empty( - (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + # Test different combinations of block_size, num_tokens and topk + for block_size in [32, 64, 128, 256]: + print(f"\nTesting block_size={block_size}") + for num_tokens in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: + for topk in [1, 2, 4, 8, 16, 32, 64]: + print( + f"Testing block_size={block_size}, num_tokens={num_tokens}, topk={topk}" + ) - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device - ) - cumsum_buffer = torch.empty( - num_experts + 1, dtype=torch.int32, device=topk_ids.device - ) + # Create random topk_ids with shape [num_tokens, topk] + topk_ids = torch.randint( + 0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) - moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - token_cnts_buffer, - cumsum_buffer, - ) + max_num_tokens_padded = topk_ids.numel() + num_experts * ( + block_size - 1 + ) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) + + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + + try: + moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, + ) + except Exception as e: + print( + f"Error occurred with block_size={block_size}, num_tokens={num_tokens}, topk={topk}" + ) + print(f"Error message: {str(e)}") + raise e -test_moe_align_block_size() +if __name__ == "__main__": + test_moe_align_block_size()