Sampling penalties memory interface (#2870)
This commit is contained in:
@@ -222,8 +222,9 @@ configs = list(itertools.product(batch_size_range, seq_length_range))
|
|||||||
def benchmark(batch_size, seq_len, provider):
|
def benchmark(batch_size, seq_len, provider):
|
||||||
num_experts = 256
|
num_experts = 256
|
||||||
block_size = 128
|
block_size = 128
|
||||||
|
topk = 8
|
||||||
topk_ids = torch.randint(
|
topk_ids = torch.randint(
|
||||||
0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda"
|
0, num_experts, (batch_size * seq_len, topk), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
@@ -27,7 +27,7 @@ runtime_common = [
|
|||||||
]
|
]
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]", "cuda-python",
|
"sglang[runtime_common]", "cuda-python",
|
||||||
"sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
|
"sgl-kernel>=0.0.2.post12", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
|
||||||
"flashinfer==0.1.6"
|
"flashinfer==0.1.6"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,11 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||||
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
|
is_cuda = is_cuda_available()
|
||||||
|
if is_cuda:
|
||||||
|
from sgl_kernel import sampling_scaling_penalties
|
||||||
|
|
||||||
|
|
||||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||||
@@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|||||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||||
|
|
||||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.where(
|
if is_cuda:
|
||||||
logits > 0,
|
return sampling_scaling_penalties(
|
||||||
logits / self.cumulated_repetition_penalties,
|
logits, self.cumulated_repetition_penalties
|
||||||
logits * self.cumulated_repetition_penalties,
|
)
|
||||||
)
|
else:
|
||||||
|
return torch.where(
|
||||||
|
logits > 0,
|
||||||
|
logits / self.cumulated_repetition_penalties,
|
||||||
|
logits * self.cumulated_repetition_penalties,
|
||||||
|
)
|
||||||
|
|
||||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||||
|
|||||||
@@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
|
is_cuda = is_cuda_available()
|
||||||
|
if is_cuda:
|
||||||
|
from sgl_kernel import sampling_scaling_penalties
|
||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -245,11 +251,14 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
# repetition
|
# repetition
|
||||||
if self.scaling_penalties is not None:
|
if self.scaling_penalties is not None:
|
||||||
logits[:] = torch.where(
|
if is_cuda:
|
||||||
logits > 0,
|
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
|
||||||
logits / self.scaling_penalties,
|
else:
|
||||||
logits * self.scaling_penalties,
|
logits[:] = torch.where(
|
||||||
)
|
logits > 0,
|
||||||
|
logits / self.scaling_penalties,
|
||||||
|
logits * self.scaling_penalties,
|
||||||
|
)
|
||||||
|
|
||||||
# Apply regex vocab_mask
|
# Apply regex vocab_mask
|
||||||
if self.vocab_mask is not None:
|
if self.vocab_mask is not None:
|
||||||
|
|||||||
@@ -97,6 +97,10 @@ def is_flashinfer_available():
|
|||||||
return torch.cuda.is_available() and torch.version.cuda
|
return torch.cuda.is_available() and torch.version.cuda
|
||||||
|
|
||||||
|
|
||||||
|
def is_cuda_available():
|
||||||
|
return torch.cuda.is_available() and torch.version.cuda
|
||||||
|
|
||||||
|
|
||||||
def is_ipv6(address):
|
def is_ipv6(address):
|
||||||
try:
|
try:
|
||||||
ipaddress.IPv6Address(address)
|
ipaddress.IPv6Address(address)
|
||||||
|
|||||||
159
sgl-kernel/benchmark/benchmark_sampling_scaling_penalties.py
Normal file
159
sgl-kernel/benchmark/benchmark_sampling_scaling_penalties.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from sgl_kernel import sampling_scaling_penalties
|
||||||
|
|
||||||
|
|
||||||
|
def sampling_scaling_penalties_naive(logits, scaling_penalties):
|
||||||
|
return torch.where(
|
||||||
|
logits > 0, logits / scaling_penalties, logits * scaling_penalties
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sampling_scaling_penalties_kernel(logits, scaling_penalties):
|
||||||
|
return sampling_scaling_penalties(logits, scaling_penalties)
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory(func, _iter):
|
||||||
|
total_mem = []
|
||||||
|
|
||||||
|
for _ in range(_iter):
|
||||||
|
torch.cuda.memory.reset_peak_memory_stats()
|
||||||
|
func()
|
||||||
|
mem = torch.cuda.max_memory_allocated() / (2**20)
|
||||||
|
total_mem.append(mem)
|
||||||
|
|
||||||
|
return sum(total_mem) / len(total_mem)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff(batch_size, vocab_size):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype)
|
||||||
|
scaling_penalties = (
|
||||||
|
torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
output_naive = sampling_scaling_penalties_naive(
|
||||||
|
logits.clone(), scaling_penalties.clone()
|
||||||
|
)
|
||||||
|
output_kernel = sampling_scaling_penalties_kernel(
|
||||||
|
logits.clone(), scaling_penalties.clone()
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Naive output={output_naive}")
|
||||||
|
print(f"Kernel output={output_kernel}")
|
||||||
|
|
||||||
|
if torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2):
|
||||||
|
print("✅ Both implementations match")
|
||||||
|
else:
|
||||||
|
print("❌ Implementations differ")
|
||||||
|
|
||||||
|
|
||||||
|
batch_size_range = [2**i for i in range(0, 12)]
|
||||||
|
vocab_size_range = [2**i for i in range(10, 17)]
|
||||||
|
configs = list(itertools.product(batch_size_range, vocab_size_range))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "vocab_size"],
|
||||||
|
x_vals=[list(_) for _ in configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["naive", "kernel"],
|
||||||
|
line_names=["PyTorch Naive", "SGL Kernel"],
|
||||||
|
styles=[("blue", "-"), ("red", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="sampling-scaling-penalties-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, vocab_size, provider):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype)
|
||||||
|
scaling_penalties = (
|
||||||
|
torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "naive":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: sampling_scaling_penalties_naive(
|
||||||
|
logits.clone(),
|
||||||
|
scaling_penalties.clone(),
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: sampling_scaling_penalties_kernel(
|
||||||
|
logits.clone(),
|
||||||
|
scaling_penalties.clone(),
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "vocab_size"],
|
||||||
|
x_vals=[list(_) for _ in configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["naive", "kernel"],
|
||||||
|
line_names=["PyTorch Naive", "SGL Kernel"],
|
||||||
|
styles=[("blue", "-"), ("red", "-")],
|
||||||
|
ylabel="GPU memory usage (MB)",
|
||||||
|
plot_name="sampling-scaling-penalties-memory",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark_memory(batch_size, vocab_size, provider):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Running memory benchmark with batch_size={batch_size}, vocab_size={vocab_size}, provider={provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_kernel():
|
||||||
|
logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype)
|
||||||
|
scaling_penalties = (
|
||||||
|
torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider == "naive":
|
||||||
|
return sampling_scaling_penalties_naive(logits, scaling_penalties)
|
||||||
|
else:
|
||||||
|
return sampling_scaling_penalties_kernel(logits, scaling_penalties)
|
||||||
|
|
||||||
|
mem = test_memory(run_kernel, _iter=10)
|
||||||
|
return mem
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/sampling_scaling_penalties/",
|
||||||
|
help="Path to save sampling_scaling_penalties benchmark results",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Run correctness test
|
||||||
|
calculate_diff(batch_size=4, vocab_size=4096)
|
||||||
|
|
||||||
|
# Run performance benchmark
|
||||||
|
benchmark.run(print_data=True, save_path=args.save_path)
|
||||||
|
|
||||||
|
# Run memory benchmark
|
||||||
|
benchmark_memory.run(print_data=True, save_path=args.save_path)
|
||||||
@@ -3,38 +3,65 @@ from sgl_kernel import moe_align_block_size
|
|||||||
|
|
||||||
|
|
||||||
def test_moe_align_block_size():
|
def test_moe_align_block_size():
|
||||||
|
# For DeepSeek V3, we have 256 experts
|
||||||
num_experts = 256
|
num_experts = 256
|
||||||
block_size = 128
|
|
||||||
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
# Test different combinations of block_size, num_tokens and topk
|
||||||
sorted_ids = torch.empty(
|
for block_size in [32, 64, 128, 256]:
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
print(f"\nTesting block_size={block_size}")
|
||||||
)
|
for num_tokens in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
for topk in [1, 2, 4, 8, 16, 32, 64]:
|
||||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
print(
|
||||||
expert_ids = torch.empty(
|
f"Testing block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
)
|
||||||
)
|
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
|
||||||
|
|
||||||
token_cnts_buffer = torch.empty(
|
# Create random topk_ids with shape [num_tokens, topk]
|
||||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
topk_ids = torch.randint(
|
||||||
)
|
0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda"
|
||||||
cumsum_buffer = torch.empty(
|
)
|
||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
|
|
||||||
moe_align_block_size(
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (
|
||||||
topk_ids,
|
block_size - 1
|
||||||
num_experts,
|
)
|
||||||
block_size,
|
sorted_ids = torch.empty(
|
||||||
sorted_ids,
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
expert_ids,
|
)
|
||||||
num_tokens_post_pad,
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
token_cnts_buffer,
|
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||||
cumsum_buffer,
|
expert_ids = torch.empty(
|
||||||
)
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
num_tokens_post_pad = torch.empty(
|
||||||
|
(1), dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
|
token_cnts_buffer = torch.empty(
|
||||||
|
(num_experts + 1) * num_experts,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device,
|
||||||
|
)
|
||||||
|
cumsum_buffer = torch.empty(
|
||||||
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
moe_align_block_size(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
token_cnts_buffer,
|
||||||
|
cumsum_buffer,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Error occurred with block_size={block_size}, num_tokens={num_tokens}, topk={topk}"
|
||||||
|
)
|
||||||
|
print(f"Error message: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
test_moe_align_block_size()
|
if __name__ == "__main__":
|
||||||
|
test_moe_align_block_size()
|
||||||
|
|||||||
Reference in New Issue
Block a user