diff --git a/python/sglang/srt/constrained/outlines_jump_forward.py b/python/sglang/srt/constrained/outlines_jump_forward.py index 5e32c40d7..cfc65f75f 100644 --- a/python/sglang/srt/constrained/outlines_jump_forward.py +++ b/python/sglang/srt/constrained/outlines_jump_forward.py @@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ import dataclasses import logging from collections import defaultdict +from typing import Optional import interegular from interegular import InvalidSyntax -from outlines.caching import cache as disk_cache +from outlines.caching import cache + +from sglang.srt.utils import get_bool_env_var try: # outlines >= 0.1.0 @@ -34,6 +37,9 @@ except ImportError: IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" +# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__ +DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true") + logger = logging.getLogger(__name__) @@ -45,6 +51,13 @@ class JumpEdge: byte_next_state: int = None +def disk_cache(expire: Optional[float] = None, typed=False, ignore=()): + if not DISABLE_DISK_CACHE: + return cache(expire, typed, ignore) + else: + return lambda fn: None + + @disk_cache() def init_state_to_jump_forward(regex_string): try: diff --git a/python/sglang/srt/constrained/triton_ops/bitmask_ops.py b/python/sglang/srt/constrained/triton_ops/bitmask_ops.py new file mode 100644 index 000000000..9a195c006 --- /dev/null +++ b/python/sglang/srt/constrained/triton_ops/bitmask_ops.py @@ -0,0 +1,141 @@ +# Adapt from +# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py + +from typing import List, Optional, Union + +import torch +import triton +import triton.language as tl + +from sglang.srt.utils import get_device_core_count + + +@triton.jit +def apply_token_bitmask_inplace_kernel( + logits_ptr, + bitmask_ptr, + indices_ptr, + num_rows, + vocab_size, + logits_strides, + bitmask_strides, + NUM_SMS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor, + where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask, + the masked logits will be set to -inf. + + Parameters + ---------- + logits_ptr : tl.tensor + Pointer to the logits tensor to apply the bitmask to. + + bitmask_ptr : tl.tensor + Pointer to the bitmask tensor to apply. + + indices_ptr : Optional[tl.tensor] + Optional pointer to indices tensor specifying which rows to apply the mask to. + + num_rows : int + Number of rows to process. If indices_ptr is provided, this is the number of unique indices. + + vocab_size : int + Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the + same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary. + + logits_strides : int + Stride between rows in the logits tensor. + + bitmask_strides : int + Stride between rows in the bitmask tensor. + + NUM_SMS : int + Number of streaming multiprocessors to use. + + BLOCK_SIZE : int + Size of processing blocks. + """ + + pid = tl.program_id(0) + num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE) + for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS): + row_id = work_id // num_blocks + block_offset = (work_id % num_blocks) * BLOCK_SIZE + batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id) + offsets = block_offset + tl.arange(0, BLOCK_SIZE) + bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32) + vocab_mask = offsets < vocab_size + packed_bitmask_mask = bitmask_offsets < bitmask_strides + packed_bitmask = tl.load( + bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, + packed_bitmask_mask, + ) + bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0 + bitmask = bitmask.reshape(BLOCK_SIZE) + + tl.store( + logits_ptr + batch_id * logits_strides + offsets, + -float("inf"), + vocab_mask & bitmask, + ) + + +def apply_token_bitmask_inplace_triton( + logits: torch.Tensor, + bitmask: torch.Tensor, + indices: Optional[Union[List[int], torch.Tensor]] = None, +): + NUM_SMS = get_device_core_count() + BLOCK_SIZE = 4096 + BITS_PER_BLOCK = 32 + + # Check input dtype + assert bitmask.dtype == torch.int32, "bitmask must be of type int32" + + # Check input tensor shapes. + logits_shape = logits.shape + bitmask_shape = bitmask.shape + if logits.ndim == 1: + logits_shape = (1, logits_shape[0]) + if bitmask.ndim == 1: + bitmask_shape = (1, bitmask_shape[0]) + + required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK + assert required_bitmask_width >= bitmask_shape[1], ( + f"Bitmask width too large: allow at most {required_bitmask_width} int32s for " + f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}" + ) + + vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK) + + num_rows = None + if isinstance(indices, list) or isinstance(indices, torch.Tensor): + indices = torch.tensor(indices, dtype=torch.int32, device=logits.device) + num_rows = indices.shape[0] + else: + assert ( + logits_shape[0] == bitmask_shape[0] + ), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}" + num_rows = logits_shape[0] + + if NUM_SMS > 0: + grid = (NUM_SMS,) + else: + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + grid = (num_rows * num_blocks,) + NUM_SMS = triton.next_power_of_2(grid[0]) + + apply_token_bitmask_inplace_kernel[grid]( + logits, + bitmask, + indices, + num_rows, + vocab_size, + logits_shape[1], + bitmask_shape[1], + NUM_SMS, + BLOCK_SIZE, + num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()), + num_stages=3, + ) diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 5aef05f9b..5e15bc744 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -25,13 +25,16 @@ from xgrammar import ( StructuralTagItem, TokenizerInfo, allocate_token_bitmask, - apply_token_bitmask_inplace, ) from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarBackend, BaseGrammarObject, ) +from sglang.srt.constrained.triton_ops.bitmask_ops import ( + apply_token_bitmask_inplace_triton, +) +from sglang.srt.utils import get_bool_env_var logger = logging.getLogger(__name__) @@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject): self.override_stop_tokens = override_stop_tokens self.finished = False + # Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the + # class init site to avoid re-initializing CUDA in forked subprocess. + from xgrammar.kernels import apply_token_bitmask_inplace_kernels + + self.use_token_bitmask_triton = get_bool_env_var( + "SGLANG_TOKEN_BITMASK_TRITON", "false" + ) + self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get( + "cuda", None + ) + self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None) + def accept_token(self, token: int): assert self.matcher.accept_token(token) @@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject): def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: return vocab_mask.to(device, non_blocking=True) - @staticmethod - def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - apply_token_bitmask_inplace(logits, vocab_mask) + def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + if ( + not self.use_token_bitmask_triton + and logits.device.type == "cuda" + and self.apply_vocab_mask_cuda + ): + return self.apply_vocab_mask_cuda(logits, vocab_mask) + if logits.device.type == "cpu" and self.apply_vocab_mask_cpu: + return self.apply_vocab_mask_cpu(logits, vocab_mask) + apply_token_bitmask_inplace_triton(logits, vocab_mask) def copy(self): matcher = GrammarMatcher( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 97f5888ae..6a48a60b7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -137,11 +137,6 @@ class ModelRunner: if server_args.show_time_cost: enable_show_time_cost() - if server_args.disable_outlines_disk_cache: - from outlines.caching import disable_cache - - disable_cache() - # Global vars global_server_args_dict.update( { diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index bb8887168..73d8db5a1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -392,6 +392,10 @@ class ServerArgs: os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = ( "1" if self.enable_torch_compile else "0" ) + # Set env var before grammar backends init + os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = ( + "1" if self.disable_outlines_disk_cache else "0" + ) @staticmethod def add_cli_args(parser: argparse.ArgumentParser):