diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index b37e5ffac..798e1c0a8 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -505,8 +505,10 @@ class Indexer(CustomOp): forward_batch: ForwardBatch, layer_id: int, ) -> Optional[torch.Tensor]: - if not is_npu(): + if is_hip(): from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant + elif not is_npu(): + from sglang.srt.layers.attention.nsa.triton_kernel import act_quant if TYPE_CHECKING: assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool) diff --git a/python/sglang/srt/layers/attention/nsa/triton_kernel.py b/python/sglang/srt/layers/attention/nsa/triton_kernel.py new file mode 100644 index 000000000..9d970b83a --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/triton_kernel.py @@ -0,0 +1,136 @@ +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +# Triton implementation +@triton.jit +def _act_quant_kernel( + X_ptr, + Y_ptr, + S_ptr, + M, + N, + group_size: tl.constexpr, + round_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel for activation quantization. + + Each block processes BLOCK_M rows and group_size columns. + """ + # Get block IDs + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # FP8 constants + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1.0 / fp8_max + + # Calculate row and column offsets + row_start = pid_m * BLOCK_M + col_start = pid_n * group_size + + # Create offset arrays + rows = row_start + tl.arange(0, BLOCK_M) + cols = col_start + tl.arange(0, BLOCK_N) + + # Mask for valid rows and columns + row_mask = rows < M + col_mask = cols < N + mask = row_mask[:, None] & col_mask[None, :] + + # Load input data + x_ptrs = X_ptr + rows[:, None] * N + cols[None, :] + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Compute absolute max along columns (group_size dimension) for each row + x_abs = tl.abs(x) + amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,) + + # Clamp amax to avoid division by zero + amax = tl.maximum(amax, 1e-4) + + # Compute scale + if round_scale: + # Fast round scale using bit manipulation approximation + # This is a simplified version - the exact bit manipulation is harder in Triton + # Using log2 + ceil + pow2 as approximation + log_val = tl.log2(amax * fp8_max_inv) + log_ceil = tl.ceil(log_val) + scale = tl.exp2(log_ceil) + else: + scale = amax * fp8_max_inv + + # Quantize: y = clamp(x / scale, fp8_min, fp8_max) + scale_broadcast = scale[:, None] + y = x / scale_broadcast + y = tl.minimum(tl.maximum(y, fp8_min), fp8_max) + + # Store quantized output + y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :] + tl.store(y_ptrs, y, mask=mask) + + # Store scales + s_cols = pid_n + s_ptrs = S_ptr + rows * (N // group_size) + s_cols + s_mask = row_mask + tl.store(s_ptrs, scale, mask=s_mask) + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization with Triton. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + + # Flatten all dims except last + N = x.size(-1) + x_flat = x.view(-1, N) + M = x_flat.size(0) + + # Allocate output tensors + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y_flat = y.view(-1, N) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + s_flat = s.view(-1, N // block_size) + + # Launch kernel + BLOCK_M = 32 + BLOCK_N = block_size + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size)) + round_scale = scale_fmt is not None + + _act_quant_kernel[grid]( + x_flat, + y_flat, + s_flat, + M, + N, + group_size=block_size, + round_scale=round_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=0 if round_scale else 2, + ) + + return y, s diff --git a/test/srt/layers/attention/nsa/test_act_quant_triton.py b/test/srt/layers/attention/nsa/test_act_quant_triton.py new file mode 100644 index 000000000..a5257dff6 --- /dev/null +++ b/test/srt/layers/attention/nsa/test_act_quant_triton.py @@ -0,0 +1,281 @@ +""" +Unit tests comparing TileLang and Triton implementations of activation quantization. +Tests both accuracy and performance. +""" + +import time +from typing import Tuple + +import pytest +import torch + +from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant +from sglang.srt.layers.attention.nsa.triton_kernel import act_quant as act_quant_triton + + +def benchmark_kernel( + fn, + x: torch.Tensor, + block_size: int, + scale_fmt, + warmup: int = 10, + repeat: int = 100, + use_cuda_graph: bool = True, +) -> Tuple[float, torch.Tensor, torch.Tensor]: + """ + Benchmark a kernel function. + + Args: + fn: Function to benchmark + x: Input tensor + block_size: Block size for quantization + scale_fmt: Scale format + warmup: Number of warmup iterations + repeat: Number of repeat iterations + use_cuda_graph: Whether to use CUDA graphs for more accurate timing + + Returns: + Tuple of (avg_time_ms, quantized_output, scales) + """ + # Warmup + for _ in range(warmup): + y, s = fn(x, block_size=block_size, scale_fmt=scale_fmt) + + if not x.is_cuda or not use_cuda_graph: + # Fallback to regular timing + if x.is_cuda: + torch.cuda.synchronize() + + start = time.perf_counter() + for _ in range(repeat): + y, s = fn(x, block_size=block_size, scale_fmt=scale_fmt) + + if x.is_cuda: + torch.cuda.synchronize() + + end = time.perf_counter() + avg_time_ms = (end - start) / repeat * 1000 + + return avg_time_ms, y, s + + # Use CUDA graph for more accurate timing + torch.cuda.synchronize() + + # Allocate output buffers + N = x.size(-1) + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + + # Capture CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + y_cap, s_cap = fn(x, block_size=block_size, scale_fmt=scale_fmt) + + # Warmup with graph + for _ in range(warmup): + graph.replay() + + torch.cuda.synchronize() + + # Timing with CUDA graph + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(repeat): + graph.replay() + end_event.record() + + torch.cuda.synchronize() + + avg_time_ms = start_event.elapsed_time(end_event) / repeat + + return avg_time_ms, y_cap, s_cap + + +def check_accuracy( + y_ref: torch.Tensor, + s_ref: torch.Tensor, + y_test: torch.Tensor, + s_test: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, +) -> Tuple[bool, dict]: + """ + Check accuracy between reference and test outputs. + + Args: + y_ref: Reference quantized output + s_ref: Reference scales + y_test: Test quantized output + s_test: Test scales + rtol: Relative tolerance + atol: Absolute tolerance + + Returns: + Tuple of (passed, metrics_dict) + """ + # Convert FP8 to float for comparison + y_ref_float = y_ref.float() + y_test_float = y_test.float() + + # Compute differences + y_diff = torch.abs(y_ref_float - y_test_float) + s_diff = torch.abs(s_ref - s_test) + + # Compute metrics + y_max_diff = y_diff.max().item() + y_mean_diff = y_diff.mean().item() + s_max_diff = s_diff.max().item() + s_mean_diff = s_diff.mean().item() + + # Check relative and absolute tolerance + y_close = torch.allclose(y_ref_float, y_test_float, rtol=rtol, atol=atol) + s_close = torch.allclose(s_ref, s_test, rtol=rtol, atol=atol) + + # Compute percentage of matching elements + y_match_pct = (y_ref_float == y_test_float).float().mean().item() * 100 + + metrics = { + "y_max_diff": y_max_diff, + "y_mean_diff": y_mean_diff, + "y_match_pct": y_match_pct, + "s_max_diff": s_max_diff, + "s_mean_diff": s_mean_diff, + "y_close": y_close, + "s_close": s_close, + } + + passed = y_close and s_close + + return passed, metrics + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_quant_comprehensive_benchmark(scale_fmt=None): + """Comprehensive benchmark across multiple sizes with CUDA graphs.""" + device = torch.device("cuda") + dtype = torch.bfloat16 + block_size = 128 + + shapes = [ + (128, 512), + (256, 1024), + (512, 2048), + (1024, 4096), + (2048, 8192), + (4096, 16384), + ] + + print("\n" + "=" * 100) + print("Comprehensive Performance Benchmark with CUDA Graphs") + print("=" * 100) + print( + f"{'Shape':<20} {'TileLang (ms)':<15} {'Triton (ms)':<15} {'Speedup':<10} {'Status'}" + ) + print("-" * 100) + + for shape in shapes: + torch.manual_seed(42) + x = torch.randn(shape, dtype=dtype, device=device) + + try: + # Benchmark both with CUDA graphs + time_tilelang, y_ref, s_ref = benchmark_kernel( + act_quant, + x, + block_size, + scale_fmt, + warmup=5, + repeat=50, + use_cuda_graph=True, + ) + time_triton, y_triton, s_triton = benchmark_kernel( + act_quant_triton, + x, + block_size, + scale_fmt, + warmup=5, + repeat=50, + use_cuda_graph=True, + ) + + # Check accuracy + passed, _ = check_accuracy(y_ref, s_ref, y_triton, s_triton) + + speedup = time_tilelang / time_triton if time_triton > 0 else 0 + status = "✓ PASS" if passed else "✗ FAIL" + + print( + f"{str(shape):<20} {time_tilelang:<15.4f} {time_triton:<15.4f} " + f"{speedup:<10.2f} {status}" + ) + except Exception as e: + print(f"{str(shape):<20} ERROR: {str(e)}") + + print("=" * 100) + + # Also run without CUDA graphs for comparison + print("\n" + "=" * 100) + print("Performance Benchmark WITHOUT CUDA Graphs (for comparison)") + print("=" * 100) + print( + f"{'Shape':<20} {'TileLang (ms)':<15} {'Triton (ms)':<15} {'Speedup':<10} {'Status'}" + ) + print("-" * 100) + + for shape in shapes: + torch.manual_seed(42) + x = torch.randn(shape, dtype=dtype, device=device) + + try: + # Benchmark both without CUDA graphs + time_tilelang, y_ref, s_ref = benchmark_kernel( + act_quant, + x, + block_size, + scale_fmt, + warmup=5, + repeat=50, + use_cuda_graph=False, + ) + time_triton, y_triton, s_triton = benchmark_kernel( + act_quant_triton, + x, + block_size, + scale_fmt, + warmup=5, + repeat=50, + use_cuda_graph=False, + ) + + # Check accuracy + passed, _ = check_accuracy(y_ref, s_ref, y_triton, s_triton) + + speedup = time_tilelang / time_triton if time_triton > 0 else 0 + status = "✓ PASS" if passed else "✗ FAIL" + + print( + f"{str(shape):<20} {time_tilelang:<15.4f} {time_triton:<15.4f} " + f"{speedup:<10.2f} {status}" + ) + except Exception as e: + print(f"{str(shape):<20} ERROR: {str(e)}") + + print("=" * 100) + + +if __name__ == "__main__": + # Run comprehensive benchmark + if torch.cuda.is_available(): + print("\n" + "=" * 80) + print("Running Comprehensive Benchmark with scale_fmt=None") + print("=" * 80) + test_act_quant_comprehensive_benchmark(scale_fmt=None) + + print("\n" + "=" * 80) + print("Running Comprehensive Benchmark with scale_fmt!=None") + print("=" * 80) + test_act_quant_comprehensive_benchmark(scale_fmt="any") + else: + print("CUDA not available. Skipping tests.")