[DPSKv3.2] Rewrite nsa tilelang act_quant kernel to triton (#11450)

This commit is contained in:
Binyao Jiang
2025-10-10 23:13:46 -07:00
committed by GitHub
parent c80a96dae9
commit 451d15c44b
3 changed files with 420 additions and 1 deletions

View File

@@ -505,8 +505,10 @@ class Indexer(CustomOp):
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
if not is_npu():
if is_hip():
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
elif not is_npu():
from sglang.srt.layers.attention.nsa.triton_kernel import act_quant
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)

View File

@@ -0,0 +1,136 @@
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
# Triton implementation
@triton.jit
def _act_quant_kernel(
X_ptr,
Y_ptr,
S_ptr,
M,
N,
group_size: tl.constexpr,
round_scale: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
Triton kernel for activation quantization.
Each block processes BLOCK_M rows and group_size columns.
"""
# Get block IDs
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# FP8 constants
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1.0 / fp8_max
# Calculate row and column offsets
row_start = pid_m * BLOCK_M
col_start = pid_n * group_size
# Create offset arrays
rows = row_start + tl.arange(0, BLOCK_M)
cols = col_start + tl.arange(0, BLOCK_N)
# Mask for valid rows and columns
row_mask = rows < M
col_mask = cols < N
mask = row_mask[:, None] & col_mask[None, :]
# Load input data
x_ptrs = X_ptr + rows[:, None] * N + cols[None, :]
x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)
# Compute absolute max along columns (group_size dimension) for each row
x_abs = tl.abs(x)
amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,)
# Clamp amax to avoid division by zero
amax = tl.maximum(amax, 1e-4)
# Compute scale
if round_scale:
# Fast round scale using bit manipulation approximation
# This is a simplified version - the exact bit manipulation is harder in Triton
# Using log2 + ceil + pow2 as approximation
log_val = tl.log2(amax * fp8_max_inv)
log_ceil = tl.ceil(log_val)
scale = tl.exp2(log_ceil)
else:
scale = amax * fp8_max_inv
# Quantize: y = clamp(x / scale, fp8_min, fp8_max)
scale_broadcast = scale[:, None]
y = x / scale_broadcast
y = tl.minimum(tl.maximum(y, fp8_min), fp8_max)
# Store quantized output
y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :]
tl.store(y_ptrs, y, mask=mask)
# Store scales
s_cols = pid_n
s_ptrs = S_ptr + rows * (N // group_size) + s_cols
s_mask = row_mask
tl.store(s_ptrs, scale, mask=s_mask)
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization with Triton.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), "Input tensor must be contiguous"
assert (
x.size(-1) % block_size == 0
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
# Flatten all dims except last
N = x.size(-1)
x_flat = x.view(-1, N)
M = x_flat.size(0)
# Allocate output tensors
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
y_flat = y.view(-1, N)
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
s_flat = s.view(-1, N // block_size)
# Launch kernel
BLOCK_M = 32
BLOCK_N = block_size
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size))
round_scale = scale_fmt is not None
_act_quant_kernel[grid](
x_flat,
y_flat,
s_flat,
M,
N,
group_size=block_size,
round_scale=round_scale,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_stages=0 if round_scale else 2,
)
return y, s