diff --git a/python/pyproject.toml b/python/pyproject.toml index fa5e29689..d07254cc6 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -42,7 +42,7 @@ runtime_common = [ "transformers==4.51.1", "uvicorn", "uvloop", - "xgrammar==0.1.17", + "xgrammar==0.1.19", "blobfile==3.0.0" ] diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 400ab421f..8e715b3d8 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -18,6 +18,7 @@ import logging from typing import List, Optional, Tuple, Union import torch +import xgrammar from xgrammar import ( CompiledGrammar, GrammarCompiler, @@ -58,17 +59,11 @@ class XGrammarGrammar(BaseGrammarObject): self.override_stop_tokens = override_stop_tokens self.finished = False - # Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the - # class init site to avoid re-initializing CUDA in forked subprocess. - from xgrammar.kernels import apply_token_bitmask_inplace_kernels + from xgrammar.kernels.apply_token_bitmask_inplace_cpu import ( + apply_token_bitmask_inplace_cpu, + ) - self.use_token_bitmask_triton = get_bool_env_var( - "SGLANG_TOKEN_BITMASK_TRITON", "false" - ) - self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get( - "cuda", None - ) - self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None) + self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu def accept_token(self, token: int): assert self.matcher.accept_token(token) @@ -113,15 +108,12 @@ class XGrammarGrammar(BaseGrammarObject): return vocab_mask.to(device, non_blocking=True) def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - if ( - not self.use_token_bitmask_triton - and logits.device.type == "cuda" - and self.apply_vocab_mask_cuda - ): - return self.apply_vocab_mask_cuda(logits, vocab_mask) - if logits.device.type == "cpu" and self.apply_vocab_mask_cpu: - return self.apply_vocab_mask_cpu(logits, vocab_mask) - apply_token_bitmask_inplace_triton(logits, vocab_mask) + if logits.device.type == "cuda": + apply_token_bitmask_inplace_triton(logits, vocab_mask) + elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu: + self.apply_vocab_mask_cpu(logits, vocab_mask) + else: + raise RuntimeError(f"Unsupported device: {logits.device.type}") def copy(self): matcher = GrammarMatcher(