Files
sglang/python/sglang/srt/constrained/triton_ops/bitmask_ops.py
2025-04-20 13:31:25 -07:00

142 lines
4.5 KiB
Python

# Adapt from
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
from typing import List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.srt.utils import get_device_core_count
@triton.jit
def apply_token_bitmask_inplace_kernel(
logits_ptr,
bitmask_ptr,
indices_ptr,
num_rows,
vocab_size,
logits_strides,
bitmask_strides,
NUM_SMS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
the masked logits will be set to -inf.
Parameters
----------
logits_ptr : tl.tensor
Pointer to the logits tensor to apply the bitmask to.
bitmask_ptr : tl.tensor
Pointer to the bitmask tensor to apply.
indices_ptr : Optional[tl.tensor]
Optional pointer to indices tensor specifying which rows to apply the mask to.
num_rows : int
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
vocab_size : int
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
logits_strides : int
Stride between rows in the logits tensor.
bitmask_strides : int
Stride between rows in the bitmask tensor.
NUM_SMS : int
Number of streaming multiprocessors to use.
BLOCK_SIZE : int
Size of processing blocks.
"""
pid = tl.program_id(0)
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
row_id = work_id // num_blocks
block_offset = (work_id % num_blocks) * BLOCK_SIZE
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
vocab_mask = offsets < vocab_size
packed_bitmask_mask = bitmask_offsets < bitmask_strides
packed_bitmask = tl.load(
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
packed_bitmask_mask,
)
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
tl.store(
logits_ptr + batch_id * logits_strides + offsets,
-float("inf"),
vocab_mask & bitmask,
)
def apply_token_bitmask_inplace_triton(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
):
NUM_SMS = get_device_core_count()
BLOCK_SIZE = 4096
BITS_PER_BLOCK = 32
# Check input dtype
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
# Check input tensor shapes.
logits_shape = logits.shape
bitmask_shape = bitmask.shape
if logits.ndim == 1:
logits_shape = (1, logits_shape[0])
if bitmask.ndim == 1:
bitmask_shape = (1, bitmask_shape[0])
required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
assert required_bitmask_width >= bitmask_shape[1], (
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
)
vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
num_rows = None
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
num_rows = indices.shape[0]
else:
assert (
logits_shape[0] == bitmask_shape[0]
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
num_rows = logits_shape[0]
if NUM_SMS > 0:
grid = (NUM_SMS,)
else:
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
grid = (num_rows * num_blocks,)
NUM_SMS = triton.next_power_of_2(grid[0])
apply_token_bitmask_inplace_kernel[grid](
logits,
bitmask,
indices,
num_rows,
vocab_size,
logits_shape[1],
bitmask_shape[1],
NUM_SMS,
BLOCK_SIZE,
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
num_stages=3,
)