[Fix] fix outlines and xgrammar (#4947)
This commit is contained in:
@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import interegular
|
||||
from interegular import InvalidSyntax
|
||||
from outlines.caching import cache as disk_cache
|
||||
from outlines.caching import cache
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
try:
|
||||
# outlines >= 0.1.0
|
||||
@@ -34,6 +37,9 @@ except ImportError:
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
|
||||
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -45,6 +51,13 @@ class JumpEdge:
|
||||
byte_next_state: int = None
|
||||
|
||||
|
||||
def disk_cache(expire: Optional[float] = None, typed=False, ignore=()):
|
||||
if not DISABLE_DISK_CACHE:
|
||||
return cache(expire, typed, ignore)
|
||||
else:
|
||||
return lambda fn: None
|
||||
|
||||
|
||||
@disk_cache()
|
||||
def init_state_to_jump_forward(regex_string):
|
||||
try:
|
||||
|
||||
141
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
Normal file
141
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Adapt from
|
||||
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import get_device_core_count
|
||||
|
||||
|
||||
@triton.jit
|
||||
def apply_token_bitmask_inplace_kernel(
|
||||
logits_ptr,
|
||||
bitmask_ptr,
|
||||
indices_ptr,
|
||||
num_rows,
|
||||
vocab_size,
|
||||
logits_strides,
|
||||
bitmask_strides,
|
||||
NUM_SMS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
|
||||
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
|
||||
the masked logits will be set to -inf.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits_ptr : tl.tensor
|
||||
Pointer to the logits tensor to apply the bitmask to.
|
||||
|
||||
bitmask_ptr : tl.tensor
|
||||
Pointer to the bitmask tensor to apply.
|
||||
|
||||
indices_ptr : Optional[tl.tensor]
|
||||
Optional pointer to indices tensor specifying which rows to apply the mask to.
|
||||
|
||||
num_rows : int
|
||||
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
|
||||
|
||||
vocab_size : int
|
||||
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
|
||||
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
|
||||
|
||||
logits_strides : int
|
||||
Stride between rows in the logits tensor.
|
||||
|
||||
bitmask_strides : int
|
||||
Stride between rows in the bitmask tensor.
|
||||
|
||||
NUM_SMS : int
|
||||
Number of streaming multiprocessors to use.
|
||||
|
||||
BLOCK_SIZE : int
|
||||
Size of processing blocks.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
|
||||
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
|
||||
row_id = work_id // num_blocks
|
||||
block_offset = (work_id % num_blocks) * BLOCK_SIZE
|
||||
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
|
||||
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
|
||||
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||
vocab_mask = offsets < vocab_size
|
||||
packed_bitmask_mask = bitmask_offsets < bitmask_strides
|
||||
packed_bitmask = tl.load(
|
||||
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
|
||||
packed_bitmask_mask,
|
||||
)
|
||||
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
||||
bitmask = bitmask.reshape(BLOCK_SIZE)
|
||||
|
||||
tl.store(
|
||||
logits_ptr + batch_id * logits_strides + offsets,
|
||||
-float("inf"),
|
||||
vocab_mask & bitmask,
|
||||
)
|
||||
|
||||
|
||||
def apply_token_bitmask_inplace_triton(
|
||||
logits: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
indices: Optional[Union[List[int], torch.Tensor]] = None,
|
||||
):
|
||||
NUM_SMS = get_device_core_count()
|
||||
BLOCK_SIZE = 4096
|
||||
BITS_PER_BLOCK = 32
|
||||
|
||||
# Check input dtype
|
||||
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
|
||||
|
||||
# Check input tensor shapes.
|
||||
logits_shape = logits.shape
|
||||
bitmask_shape = bitmask.shape
|
||||
if logits.ndim == 1:
|
||||
logits_shape = (1, logits_shape[0])
|
||||
if bitmask.ndim == 1:
|
||||
bitmask_shape = (1, bitmask_shape[0])
|
||||
|
||||
required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
|
||||
assert required_bitmask_width >= bitmask_shape[1], (
|
||||
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
|
||||
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
|
||||
)
|
||||
|
||||
vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
|
||||
|
||||
num_rows = None
|
||||
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
|
||||
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
|
||||
num_rows = indices.shape[0]
|
||||
else:
|
||||
assert (
|
||||
logits_shape[0] == bitmask_shape[0]
|
||||
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
|
||||
num_rows = logits_shape[0]
|
||||
|
||||
if NUM_SMS > 0:
|
||||
grid = (NUM_SMS,)
|
||||
else:
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
grid = (num_rows * num_blocks,)
|
||||
NUM_SMS = triton.next_power_of_2(grid[0])
|
||||
|
||||
apply_token_bitmask_inplace_kernel[grid](
|
||||
logits,
|
||||
bitmask,
|
||||
indices,
|
||||
num_rows,
|
||||
vocab_size,
|
||||
logits_shape[1],
|
||||
bitmask_shape[1],
|
||||
NUM_SMS,
|
||||
BLOCK_SIZE,
|
||||
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
|
||||
num_stages=3,
|
||||
)
|
||||
@@ -25,13 +25,16 @@ from xgrammar import (
|
||||
StructuralTagItem,
|
||||
TokenizerInfo,
|
||||
allocate_token_bitmask,
|
||||
apply_token_bitmask_inplace,
|
||||
)
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
||||
apply_token_bitmask_inplace_triton,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,6 +58,18 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
self.finished = False
|
||||
|
||||
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
|
||||
# class init site to avoid re-initializing CUDA in forked subprocess.
|
||||
from xgrammar.kernels import apply_token_bitmask_inplace_kernels
|
||||
|
||||
self.use_token_bitmask_triton = get_bool_env_var(
|
||||
"SGLANG_TOKEN_BITMASK_TRITON", "false"
|
||||
)
|
||||
self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
|
||||
"cuda", None
|
||||
)
|
||||
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
|
||||
|
||||
def accept_token(self, token: int):
|
||||
assert self.matcher.accept_token(token)
|
||||
|
||||
@@ -97,9 +112,16 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask.to(device, non_blocking=True)
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
apply_token_bitmask_inplace(logits, vocab_mask)
|
||||
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
if (
|
||||
not self.use_token_bitmask_triton
|
||||
and logits.device.type == "cuda"
|
||||
and self.apply_vocab_mask_cuda
|
||||
):
|
||||
return self.apply_vocab_mask_cuda(logits, vocab_mask)
|
||||
if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
|
||||
return self.apply_vocab_mask_cpu(logits, vocab_mask)
|
||||
apply_token_bitmask_inplace_triton(logits, vocab_mask)
|
||||
|
||||
def copy(self):
|
||||
matcher = GrammarMatcher(
|
||||
|
||||
@@ -137,11 +137,6 @@ class ModelRunner:
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
if server_args.disable_outlines_disk_cache:
|
||||
from outlines.caching import disable_cache
|
||||
|
||||
disable_cache()
|
||||
|
||||
# Global vars
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
|
||||
@@ -392,6 +392,10 @@ class ServerArgs:
|
||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||
"1" if self.enable_torch_compile else "0"
|
||||
)
|
||||
# Set env var before grammar backends init
|
||||
os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
|
||||
"1" if self.disable_outlines_disk_cache else "0"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
|
||||
Reference in New Issue
Block a user