From 86d10d220f665092f93a3f6e8a31a65a36a4f376 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 23 Aug 2025 05:40:18 -0700 Subject: [PATCH] Update grok.py and tiktoken tokenizer (#9532) --- .../srt/constrained/xgrammar_backend.py | 16 +- python/sglang/srt/hf_transformers_utils.py | 5 + .../srt/layers/attention/triton_backend.py | 18 +- .../attention/triton_ops/decode_attention.py | 31 ++ .../attention/triton_ops/extend_attention.py | 18 + python/sglang/srt/layers/elementwise.py | 94 ++++ python/sglang/srt/layers/moe/router.py | 24 +- python/sglang/srt/layers/radix_attention.py | 6 + python/sglang/srt/models/grok.py | 423 ++++++++++++++++-- .../srt/tokenizer/tiktoken_tokenizer.py | 161 +++++++ 10 files changed, 732 insertions(+), 64 deletions(-) create mode 100644 python/sglang/srt/tokenizer/tiktoken_tokenizer.py diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 6118aa22b..7b101df4f 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ): super().__init__() - # Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens - # This ensures consistency between what the model considers EOS and what XGrammar uses - tokenizer_info = TokenizerInfo.from_huggingface( - tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids - ) - override_stop_tokens = None + if hasattr(tokenizer, "init_xgrammar"): + # For special tokenizer + tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar() + else: + # Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens + # This ensures consistency between what the model considers EOS and what XGrammar uses + tokenizer_info = TokenizerInfo.from_huggingface( + tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids + ) + override_stop_tokens = None self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) self.vocab_size = vocab_size diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 292c7a7bd..4503a4598 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -263,6 +263,11 @@ def get_tokenizer( **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Gets a tokenizer for the given model name via Huggingface.""" + if tokenizer_name.endswith(".json"): + from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer + + return TiktokenTokenizer(tokenizer_name) + if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 2d9b42c8b..26241d849 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -20,6 +20,14 @@ if TYPE_CHECKING: from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +def logit_capping_mod(logit_capping_method, logit_cap): + # positive logit_cap -> tanh cap + if logit_capping_method == "tanh": + return logit_cap + else: + raise ValueError() + + @dataclass class ForwardMetadata: attn_logits: torch.Tensor @@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend): layer, forward_batch.out_cache_loc, k, v ) + logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap) + causal = True if layer.attn_type == AttentionType.ENCODER_ONLY: causal = False @@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend): self.forward_metadata.mask_indptr, self.forward_metadata.max_extend_len, layer.scaling, - layer.logit_cap, + logit_cap=logits_soft_cap, sliding_window_size=sliding_window_size, sinks=sinks, window_kv_offsets=window_kv_offsets, + xai_temperature_len=layer.xai_temperature_len, ) return o @@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend): else: o = torch.empty_like(q) + logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap) + if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v @@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend): self.forward_metadata.num_kv_splits, self.max_kv_splits, layer.scaling, - layer.logit_cap, + logit_cap=logits_soft_cap, sinks=sinks, + xai_temperature_len=layer.xai_temperature_len, ) return o diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index d8259be20..1ba5d463d 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -69,6 +69,7 @@ def _fwd_kernel_stage1( logit_cap: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, + xai_temperature_len: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -85,6 +86,12 @@ def _fwd_kernel_stage1( cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx kv_splits = tl.load(num_kv_splits + cur_batch) + if xai_temperature_len > 0: + offs_qidx = cur_batch_seq_len - 1 + xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len)) + _qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale + xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0) + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d kv_len_per_split = ( @@ -122,6 +129,9 @@ def _fwd_kernel_stage1( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) + if xai_temperature_len > 0: + qk *= xai_temperature_reg + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) offs_buf_v = ( @@ -181,6 +191,7 @@ def _decode_att_m_fwd( max_kv_splits, sm_scale, logit_cap, + xai_temperature_len=-1, ): BLOCK = 64 # [TODO] work around SGPR limit on MI3xx @@ -230,6 +241,7 @@ def _decode_att_m_fwd( BLOCK_N=BLOCK, MIN_BLOCK_KV=_MIN_BLOCK_KV, logit_cap=logit_cap, + xai_temperature_len=xai_temperature_len, num_warps=num_warps, num_stages=2, Lk=Lk, @@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1( BLOCK_H: tl.constexpr, MIN_BLOCK_KV: tl.constexpr, logit_cap: tl.constexpr, + xai_temperature_len: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, ): @@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1( cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx kv_splits = tl.load(num_kv_splits + cur_batch) + if xai_temperature_len > 0: + offs_qidx = cur_batch_seq_len - 1 + xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len)) + _qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale + xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0) + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] if BLOCK_DPE > 0: @@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) + if xai_temperature_len > 0: + qk *= xai_temperature_reg[:, None] + qk = tl.where( mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") ) @@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd( max_kv_splits, sm_scale, logit_cap, + xai_temperature_len=-1, ): BLOCK = 32 Lk = k_buffer.shape[-1] @@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd( BLOCK_H=BLOCK_H, MIN_BLOCK_KV=_MIN_BLOCK_KV, logit_cap=logit_cap, + xai_temperature_len=xai_temperature_len, num_warps=4, num_stages=num_stages, Lk=Lk, @@ -620,6 +644,7 @@ def decode_attention_fwd_normal( sm_scale, logit_cap=0.0, sinks=None, + xai_temperature_len=-1, ): _decode_att_m_fwd( q, @@ -633,6 +658,7 @@ def decode_attention_fwd_normal( max_kv_splits, sm_scale, logit_cap, + xai_temperature_len, ) _decode_softmax_reducev_fwd( attn_logits, @@ -661,6 +687,7 @@ def decode_attention_fwd_grouped( sm_scale, logit_cap=0.0, sinks=None, + xai_temperature_len=-1, ): _decode_grouped_att_m_fwd( q, @@ -674,6 +701,7 @@ def decode_attention_fwd_grouped( max_kv_splits, sm_scale, logit_cap, + xai_temperature_len, ) _decode_softmax_reducev_fwd( attn_logits, @@ -702,6 +730,7 @@ def decode_attention_fwd( sm_scale, logit_cap=0.0, sinks=None, + xai_temperature_len=-1, ): assert max_kv_splits == attn_logits.shape[2] assert q.shape[0] <= kv_indptr.shape[0] - 1 @@ -725,6 +754,7 @@ def decode_attention_fwd( sm_scale, logit_cap=logit_cap, sinks=sinks, + xai_temperature_len=xai_temperature_len, ) else: # GQA/MQA/MLA @@ -742,4 +772,5 @@ def decode_attention_fwd( sm_scale, logit_cap=logit_cap, sinks=sinks, + xai_temperature_len=xai_temperature_len, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index b39f1a305..e91467743 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -69,6 +69,7 @@ def _fwd_kernel( stride_buf_vh, SLIDING_WINDOW_SIZE: tl.constexpr, logit_cap: tl.constexpr, + xai_temperature_len: tl.constexpr, Lq: tl.constexpr, Lv: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -109,6 +110,15 @@ def _fwd_kernel( mask_d = offs_d < Lq mask_dv = offs_dv < Lv + if xai_temperature_len > 0: + offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m + xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len)) + xai_temperature_reg = tl.where( + offs_qidx > xai_temperature_len, + tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale, + 1.0, + ) + offs_q = ( (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs @@ -203,6 +213,9 @@ def _fwd_kernel( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) + if xai_temperature_len > 0: + qk *= xai_temperature_reg[:, None] + qk = tl.where(final_mask, qk, float("-inf")) row_max = tl.max(qk, 1) @@ -306,6 +319,9 @@ def _fwd_kernel( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) + if xai_temperature_len > 0: + qk *= xai_temperature_reg[:, None] + qk = tl.where(final_mask, qk, float("-inf")) row_max = tl.max(qk, 1) @@ -373,6 +389,7 @@ def extend_attention_fwd( sliding_window_size=-1, sinks=None, window_kv_offsets=None, + xai_temperature_len=-1, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors @@ -477,6 +494,7 @@ def extend_attention_fwd( v_buffer.stride(1), SLIDING_WINDOW_SIZE=sliding_window_size, logit_cap=logit_cap, + xai_temperature_len=xai_temperature_len, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, diff --git a/python/sglang/srt/layers/elementwise.py b/python/sglang/srt/layers/elementwise.py index 3134e2bc1..e05d88b32 100644 --- a/python/sglang/srt/layers/elementwise.py +++ b/python/sglang/srt/layers/elementwise.py @@ -486,3 +486,97 @@ def gelu_and_mul_triton( return out_hidden_states, out_scales else: return out_hidden_states, None + + +# silu on first half of vector +@triton.jit +def silu_and_mul_kernel( + out_hidden_states_ptr, # (bs, hidden_dim) + out_scales_ptr, # (bs,) + hidden_states_ptr, # (bs, hidden_dim * 2) + quant_max: tl.constexpr, + static_scale: tl.constexpr, + hidden_dim: tl.constexpr, # the output hidden_dim + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + input_start = pid * hidden_dim * 2 + output_start = pid * hidden_dim + + input1_offs = tl.arange(0, BLOCK_SIZE) + mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output + input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE) + output_offs = tl.arange(0, BLOCK_SIZE) + + x1 = tl.load( + hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0 + ).to(tl.float32) + x3 = tl.load( + hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0 + ).to(tl.float32) + + # silu + # cast down before mul to better match training? + silu_x1 = x1 * tl.sigmoid(x1) + out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty) + + if quant_max is not None: + raise NotImplementedError() + + tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask) + + +def silu_and_mul_triton( + hidden_states, + scales=None, + quantize=None, # dtype to quantize to + out=None, +): + bs, in_hidden_dim = hidden_states.shape + hidden_dim = in_hidden_dim // 2 + + if out is None: + out_hidden_states = torch.empty( + (bs, hidden_dim), + dtype=quantize or hidden_states.dtype, + device=hidden_states.device, + ) + else: + assert out.shape == (bs, hidden_dim) + assert out.dtype == (quantize or hidden_states.dtype) + out_hidden_states = out + out_scales = None + static_scale = False + if quantize is not None: + if scales is None: + out_scales = torch.empty( + (bs,), dtype=torch.float32, device=hidden_states.device + ) + else: + out_scales = scales + static_scale = True + + max_warps = 16 if _is_hip else 32 + config = { + # 8 ele per thread (not tuned) + "num_warps": max( + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4 + ), + } + + silu_and_mul_kernel[(bs,)]( + out_hidden_states, + out_scales, + hidden_states, + quant_max=torch.finfo(quantize).max if quantize is not None else None, + static_scale=static_scale, + hidden_dim=hidden_dim, + BLOCK_SIZE=triton.next_power_of_2(hidden_dim), + **config, + ) + + if quantize is not None: + return out_hidden_states, out_scales + else: + return out_hidden_states, None diff --git a/python/sglang/srt/layers/moe/router.py b/python/sglang/srt/layers/moe/router.py index d78437f7b..0138dcdad 100644 --- a/python/sglang/srt/layers/moe/router.py +++ b/python/sglang/srt/layers/moe/router.py @@ -45,11 +45,14 @@ def fused_moe_router_kernel( logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1) # logit softcap - logits_scaled = logits / moe_softcapping - exped = tl.exp(2 * logits_scaled) - top = exped - 1 - bottom = exped + 1 - logits_softcapped = top / bottom * moe_softcapping + if moe_softcapping == 0: + logits_softcapped = logits + else: + logits_scaled = logits / moe_softcapping + exped = tl.exp(2 * logits_scaled) + top = exped - 1 + bottom = exped + 1 + logits_softcapped = top / bottom * moe_softcapping # Add bias after softcapping if is_correction_bias: @@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel( b_ptrs += BLOCK_SIZE_K # 4. logit softcap - logits_scaled = acc / moe_softcapping - exped = tl.exp(2 * logits_scaled) - logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping + if moe_softcapping == 0: + logits_softcapped = acc + else: + logits_scaled = acc / moe_softcapping + exped = tl.exp(2 * logits_scaled) + logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping # 5. top1 arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :] @@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel( # 7. handle topk == 2 if topk == 2: - cond_top2 = (arange_block_size_n < num_experts) and ( + cond_top2 = (arange_block_size_n < num_experts) & ( arange_block_size_n != top1[:, None] ) top2 = tl.argmax( diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 8004fc7c9..0719cdd29 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -52,6 +52,8 @@ class RadixAttention(nn.Module): v_head_dim: int = -1, sliding_window_size: int = -1, is_cross_attention: bool = False, + pos_encoding_mode: str = "NONE", + logit_capping_method: str = "tanh", quant_config: Optional[QuantizationConfig] = None, attn_type: AttentionType = AttentionType.DECODER, use_irope: bool = False, @@ -81,6 +83,10 @@ class RadixAttention(nn.Module): self.quant_method.create_weights(self) self.attn_type = attn_type + self.pos_encoding_mode = pos_encoding_mode + self.logit_capping_method = logit_capping_method + self.xai_temperature_len = -1 + def forward( self, q, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 254d46d7b..8b2554fa3 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -16,7 +16,6 @@ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" import functools -import json import logging import math import os @@ -35,9 +34,16 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm +from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.elementwise import ( + experts_combine_triton, + fused_dual_residual_rmsnorm, + fused_rmsnorm, + gelu_and_mul_triton, +) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, @@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.rotary_embedding import ( + RotaryEmbedding, + _yarn_find_correction_range, + _yarn_get_mscale, + get_rope, +) from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import dump_to_file +from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file logger = logging.getLogger(__name__) +# Dump tensors for debugging debug_tensor_dump_output_folder = None +debug_tensor_dump_prefill_only = False +# Skip all the other tensor dumps, only dump the target logits +debug_tensor_dump_only_target_logprobs = False debug_tensor_dump_inject = False +debug_tensor_dump_layers = None +debug_tensor_dump_test = False + + +class Grok1MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + reduce_results=True, + use_presharded_weights: bool = False, + split_gate_up: bool = False, + ) -> None: + super().__init__() + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + use_presharded_weights=use_presharded_weights, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + reduce_results=reduce_results, + use_presharded_weights=use_presharded_weights, + ) + self.act_fn = GeluAndMul(approximate="tanh") + self.layer_id = layer_id + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x, _ = gelu_and_mul_triton(gate_up) + x, _ = self.down_proj(x) + return x class Grok1MoE(nn.Module): @@ -87,10 +145,11 @@ class Grok1MoE(nn.Module): params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, - reduce_results=True, + reduce_results: bool = True, use_presharded_weights: bool = False, inplace: bool = True, no_combine: bool = False, + prefix: str = "", ): super().__init__() self.hidden_size = hidden_size @@ -145,6 +204,135 @@ class Grok1MoE(nn.Module): return self.experts(hidden_states, topk_output) +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype +) -> torch.Tensor: + if low == high: + low -= 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def get_rope_scaling(config): + rope_type = getattr(config, "rope_type", None) + if rope_type: + original_max_position_embeddings = getattr( + config, "original_max_position_embeddings", None + ) + scaling_factor = getattr(config, "scaling_factor", None) + extrapolation_factor = getattr(config, "extrapolation_factor", 1.0) + attn_factor = getattr(config, "attn_factor", 1.0) + beta_fast = getattr(config, "beta_fast", 32) + beta_slow = getattr(config, "beta_slow", 1) + rope_scaling = { + "extra_method": rope_type, + "max_position_embeddings": original_max_position_embeddings, + "scaling_factor": scaling_factor, + "extrapolation_factor": extrapolation_factor, + "attn_factor": attn_factor, + "beta_fast": beta_fast, + "beta_slow": beta_slow, + "dtype": torch.float, + } + return rope_scaling + else: + return None + + +class ScalingRotaryEmbedding(RotaryEmbedding): + """Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extra_method: str = "yarn_log", + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extra_method = extra_method + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + if self.extra_method in ["original"]: + inv_freq = inv_freq_extrapolation + elif self.extra_method in ["yarn", "yarn_linear"]: + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + elif self.extra_method == "yarn_log": + inv_freq = torch.exp( + torch.log(inv_freq_extrapolation) * inv_freq_mask + + torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask) + ) + elif self.extra_method == "theta_scale": + exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float) + theta_scale_exponent = self.base ** ( + math.log( + self.max_position_embeddings * self.scaling_factor / (2 * math.pi) + ) + / math.log(self.max_position_embeddings / (2 * math.pi)) + ) + inv_freq = torch.tensor( + 1.0 / (theta_scale_exponent ** (exponents / self.rotary_dim)), + dtype=torch.float32, + ) + else: + raise ValueError(f"Unknown extrapolation method: {self.extra_method}") + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, dtype=torch.float32 + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + # cos = freqs.cos() * self.mscale + # sin = freqs.sin() * self.mscale + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + class Grok1Attention(nn.Module): def __init__( self, @@ -157,7 +345,9 @@ class Grok1Attention(nn.Module): rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + alt_stream: Optional[torch.cuda.Stream] = None, load_presharded_attn: bool = False, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -183,7 +373,9 @@ class Grok1Attention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + rope_scaling = get_rope_scaling(config) self.load_presharded_attn = load_presharded_attn + self.alt_stream = alt_stream or torch.cuda.Stream() self.qkv_proj = QKVParallelLinear( hidden_size, @@ -195,6 +387,7 @@ class Grok1Attention(nn.Module): tp_rank=attn_tp_rank, tp_size=attn_tp_size, load_presharded_attn=self.load_presharded_attn, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -205,6 +398,7 @@ class Grok1Attention(nn.Module): tp_rank=attn_tp_rank, tp_size=attn_tp_size, use_presharded_weights=self.load_presharded_attn, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, @@ -214,7 +408,37 @@ class Grok1Attention(nn.Module): is_neox_style=True, ) + self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", False) + + if rope_scaling is not None: + self.rotary_emb = ScalingRotaryEmbedding( + self.head_dim, + rotary_dim=( + self.head_dim + if not self.rope_rotate_half_dims + else self.head_dim // 2 + ), + base=int(self.rope_theta), + is_neox_style=True, + **rope_scaling, + ) + pos_encoding_mode = "NONE" + else: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=( + self.head_dim + if not self.rope_rotate_half_dims + else self.head_dim // 2 + ), + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + pos_encoding_mode = "NONE" + logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0) + logit_capping_method = getattr(config, "attn_logit_softcapping_method", "tanh") self.attn = RadixAttention( self.num_heads, @@ -224,7 +448,11 @@ class Grok1Attention(nn.Module): layer_id=layer_id, logit_cap=logit_cap, quant_config=quant_config, + pos_encoding_mode=pos_encoding_mode, + logit_capping_method=logit_capping_method, + prefix=add_prefix("attn", prefix), ) + self.attn.xai_temperature_len = getattr(self.config, "attn_temperature_len", -1) def forward( self, @@ -256,6 +484,8 @@ class Grok1Attention(nn.Module): ) qkv, _ = self.qkv_proj(hidden_states) + dispose_tensor(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) @@ -288,6 +518,7 @@ class Grok1Attention(nn.Module): ) attn_output = self.attn(q, k, v, forward_batch) + del q, k, v, qkv if debug_tensor_dump_output_folder: dump_to_file( @@ -312,49 +543,89 @@ class Grok1DecoderLayer(nn.Module): load_presharded_moe: bool = False, load_presharded_attn: bool = False, load_presharded_mlp: bool = False, + alt_stream: Optional[torch.cuda.Stream] = None, + skip_moe: bool = False, + prefix: str = "", ) -> None: super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size + self.residual_moe = getattr(config, "residual_moe", False) self.layer_id = layer_id + self.alt_stream = alt_stream or torch.cuda.Stream() rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = Grok1Attention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, + max_position=( + config.context_len + if hasattr(config, "context_len") + else config.max_position_embeddings + ), num_kv_heads=config.num_key_value_heads, layer_id=layer_id, rope_theta=rope_theta, quant_config=quant_config, reduce_results=False, + alt_stream=self.alt_stream, load_presharded_attn=load_presharded_attn, + prefix=add_prefix("attn", prefix), ) - self.block_sparse_moe = Grok1MoE( - config=config, - layer_id=layer_id, - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=getattr( - config, - "moe_intermediate_size", - getattr(config, "intermediate_size", None), - ), - quant_config=quant_config, - reduce_results=True, - use_presharded_weights=load_presharded_moe, - inplace=True, - no_combine=False, # just a suggestion to not combine topk - ) + + split_gate_up = not getattr(config, "merge_gate_up", True) + if self.num_experts > 0: + self.block_sparse_moe = Grok1MoE( + config=config, + layer_id=layer_id, + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=getattr( + config, + "moe_intermediate_size", + getattr(config, "intermediate_size", None), + ), + quant_config=quant_config, + reduce_results=not self.residual_moe, + use_presharded_weights=load_presharded_moe, + inplace=False, # not self.residual_moe, + no_combine=False, # self.residual_moe, # just a suggestion to not combine topk + prefix=add_prefix("block_sparse_moe", prefix), + ) + if self.residual_moe: + self.mlp = Grok1MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + reduce_results=False, + use_presharded_weights=load_presharded_mlp, + layer_id=layer_id, + split_gate_up=split_gate_up, + ) + else: + raise NotImplementedError() self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.ffn = self.block_sparse_moe + if self.num_experts > 0: + if self.residual_moe: + # NOTE: self.block_sparse_moe modifies the input in-place, + # so we have to call it later. Be aware of any possible related errors. + if get_tensor_model_parallel_world_size() > 1: + self.ffn = lambda x: tensor_model_parallel_all_reduce( + self.moe_with_rmoe(x) + ) + else: + self.ffn = self.moe_with_rmoe + else: + self.ffn = self.block_sparse_moe + else: + raise NotImplementedError() def forward( self, @@ -364,6 +635,10 @@ class Grok1DecoderLayer(nn.Module): residual: Optional[torch.Tensor] = None, deferred_norm: Optional[RMSNorm] = None, ) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]: + + hidden_states_original = hidden_states + residual_original = residual + # Self Attention if deferred_norm is not None: assert residual is not None @@ -386,6 +661,14 @@ class Grok1DecoderLayer(nn.Module): hidden_states, ) + if residual_original is not None: + dispose_tensor(residual_original) + + dispose_flag = False + if residual is not hidden_states_original: + dispose_flag = True + dispose_tensor(hidden_states_original) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -403,10 +686,23 @@ class Grok1DecoderLayer(nn.Module): self.post_attn_norm.variance_epsilon, ) + if not dispose_flag: + dispose_tensor(hidden_states_original) + # Fully Connected hidden_states = self.ffn(hidden_states) return hidden_states, residual, self.post_moe_norm # defer layernorm + def moe_with_rmoe(self, x): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + mlp_result = self.mlp(x) + with torch.cuda.stream(self.alt_stream): + # moe should not be inplace because of stream race condition + moe_result = self.block_sparse_moe(x) + current_stream.wait_stream(self.alt_stream) + return (mlp_result + moe_result) / 1.4142135623730951 + class Grok1Model(nn.Module): def __init__( @@ -417,6 +713,8 @@ class Grok1Model(nn.Module): load_presharded_embedding: bool = False, load_presharded_attn: bool = False, load_presharded_mlp: bool = False, + replicate_embedding: bool = False, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -427,7 +725,11 @@ class Grok1Model(nn.Module): config.vocab_size, config.hidden_size, use_presharded_weights=load_presharded_embedding, + enable_tp=not replicate_embedding, + prefix=add_prefix("embed_tokens", prefix), ) + + self.alt_stream = torch.cuda.Stream() self.layers = nn.ModuleList( [ Grok1DecoderLayer( @@ -437,6 +739,7 @@ class Grok1Model(nn.Module): load_presharded_moe=load_presharded_moe, load_presharded_attn=load_presharded_attn, load_presharded_mlp=load_presharded_mlp, + alt_stream=self.alt_stream, ) for i in range(config.num_hidden_layers) ] @@ -506,6 +809,7 @@ class Grok1ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -514,7 +818,8 @@ class Grok1ForCausalLM(nn.Module): # Get presharded weights. self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False) self.load_presharded_moe = ( - self.config.num_local_experts > 0 + getattr(config, "load_presharded_moe", True) + and self.config.num_local_experts > 0 and get_tensor_model_parallel_world_size() > 1 ) self.load_presharded_attn = getattr(config, "load_presharded_attn", False) @@ -529,6 +834,11 @@ class Grok1ForCausalLM(nn.Module): or self.load_presharded_embedding ) + default_replicate_lm_head = False + self.replicate_lm_head = getattr( + config, "replicate_lm_head", default_replicate_lm_head + ) + if self.is_weights_presharded: setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) @@ -536,6 +846,7 @@ class Grok1ForCausalLM(nn.Module): self.replicate_lm_head = getattr( config, "replicate_lm_head", default_replicate_lm_head ) + self.replicate_embedding = getattr(config, "replicate_embedding", False) self.model = Grok1Model( config, @@ -544,6 +855,8 @@ class Grok1ForCausalLM(nn.Module): load_presharded_embedding=self.load_presharded_embedding, load_presharded_attn=self.load_presharded_attn, load_presharded_mlp=self.load_presharded_mlp, + replicate_embedding=self.replicate_embedding, + prefix=add_prefix("model", prefix), ) lm_head_params_dtype = None @@ -553,6 +866,7 @@ class Grok1ForCausalLM(nn.Module): config.vocab_size, bias=False, params_dtype=lm_head_params_dtype, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config, skip_all_gather=True) else: @@ -561,6 +875,7 @@ class Grok1ForCausalLM(nn.Module): config.hidden_size, use_presharded_weights=self.load_presharded_embedding, params_dtype=lm_head_params_dtype, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) @@ -577,6 +892,7 @@ class Grok1ForCausalLM(nn.Module): f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, " f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B" ) + self.loaded_param_names = set() def forward( self, @@ -596,11 +912,13 @@ class Grok1ForCausalLM(nn.Module): def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], - num_experts: Optional[int] = None, ignore_parent_name: bool = False, + check_hit_names: bool = True, + model_config: PretrainedConfig | None = None, ) -> dict[str, torch.Tensor]: - if num_experts is None: - num_experts = self.config.num_local_experts + if model_config is None: + model_config = self.config + stacked_params_mapping = [] stacked_params_mapping += [ # (param_name, shard_name, shard_id) @@ -616,6 +934,7 @@ class Grok1ForCausalLM(nn.Module): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) + num_experts = model_config.num_local_experts expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", @@ -630,23 +949,26 @@ class Grok1ForCausalLM(nn.Module): def load_weight_wrapper( name: str, loaded_weight: torch.Tensor, *args, **kwargs ): - if ignore_parent_name: - name = name.split(".")[-1] - - if name not in params_dict: - return - # Fuse constant multipliers into the weights if "lm_head" in name: loaded_weight = ( loaded_weight.to(torch.float32) - * self.config.output_multiplier_scale + * model_config.output_multiplier_scale ) + original_name = name + if ignore_parent_name: + name = name.split(".")[-1] + + if name not in params_dict: + logger.info(f"Skipping {name=} in load_weights_wrapper") + return + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight, *args, **kwargs) hit_names.add(name) + self.loaded_param_names.add(original_name) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -685,19 +1007,22 @@ class Grok1ForCausalLM(nn.Module): load_weight_wrapper(name=name, loaded_weight=loaded_weight) - if len(hit_names) > 5: - missing = all_names - hit_names - missing_exclude_scales = {x for x in missing if "scale" not in x} - logger.info( - f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}", - ) - if len(missing_exclude_scales) > 0: - raise ValueError( - f"load_weights failed because some weights are missing: {missing_exclude_scales=}." + if check_hit_names: + if len(hit_names) > 5: + missing = all_names - hit_names + missing_exclude_scales = {x for x in missing if "scale" not in x} + logger.info( + f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}", ) + if len(missing_exclude_scales) > 0: + raise ValueError( + f"load_weights failed because some weights are missing: {missing_exclude_scales=}." + ) - elif len(hit_names) == 0: - raise ValueError("load_weights failed because it did not hit any names.") + elif len(hit_names) == 0: + raise ValueError( + f"load_weights failed because it did not hit any names. {all_names=} {hit_names=}" + ) return hit_names @@ -708,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module): "moe_intermediate_size", getattr(cfg, "intermediate_size", None), ) - num_experts = cfg.num_local_experts + residual_moe = getattr(cfg, "residual_moe", False) + if cfg.num_local_experts > 0: + num_experts = cfg.num_local_experts + (1 if residual_moe else 0) + else: + num_experts = 1 wq = ( cfg.num_hidden_layers diff --git a/python/sglang/srt/tokenizer/tiktoken_tokenizer.py b/python/sglang/srt/tokenizer/tiktoken_tokenizer.py new file mode 100644 index 000000000..8c4c91263 --- /dev/null +++ b/python/sglang/srt/tokenizer/tiktoken_tokenizer.py @@ -0,0 +1,161 @@ +import functools +import json +from typing import AbstractSet, Collection, List, Literal, Union + + +class TiktokenProcessor: + def __init__(self, name: str): + self.tokenizer = TiktokenTokenizer(name) + + def image_processor(self, image): + return {"pixel_values": [image]} + + +RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)] +CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)] + + +PAD = "<|pad|>" +EOS = "<|eos|>" +SEP = "<|separator|>" + +DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS] +DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP} + +# default + separate each single digit +PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + +class TiktokenTokenizer: + def __init__(self, tokenizer_path): + import tiktoken + from jinja2 import Template + + # Read the JSON + with open(tokenizer_path, "rb") as fin: + xtok_dict = json.load(fin) + + # Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict + mergeable_ranks = { + bytes(item["bytes"]): item["token"] for item in xtok_dict["regular_tokens"] + } + special_tokens = { + bytes(item["bytes"]).decode(): item["token"] + for item in xtok_dict["special_tokens"] + } + if xtok_dict["word_split"] == "V1": + pad_str = PAT_STR_B + else: + assert False, f"Unknown word_split: {xtok_dict['word_split']}" + pad_str = xtok_dict.get("pat_str", pad_str) + + kwargs = { + "name": tokenizer_path, + "pat_str": pad_str, + "mergeable_ranks": mergeable_ranks, + "special_tokens": special_tokens, + } + if "default_allowed_special" in xtok_dict: + default_allowed_special = set( + [ + bytes(bytes_list).decode() + for bytes_list in xtok_dict["default_allowed_special"] + ] + ) + if "vocab_size" in xtok_dict: + kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"] + + # Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__ + default_allowed_special = None + control_tokens = DEFAULT_CONTROL_TOKENS + tokenizer = tiktoken.Encoding(**kwargs) + tokenizer._default_allowed_special = default_allowed_special or set() + tokenizer._control_tokens = control_tokens + + def encode_patched( + self, + text: str, + *, + allowed_special: Union[ + Literal["all"], AbstractSet[str] + ] = set(), # noqa: B006 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + ) -> List[int]: + if isinstance(allowed_special, set): + allowed_special |= self._default_allowed_special + return tiktoken.Encoding.encode( + self, + text, + allowed_special=allowed_special, + disallowed_special=(), + ) + + tokenizer.encode = functools.partial(encode_patched, tokenizer) + + # Allow more tokens to prevent crash + tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values()) + tokenizer._default_allowed_special |= set( + CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS + ) + + # Convert to HF interface + self.tokenizer = tokenizer + self.bos_token_id = None + self.eos_token_id = tokenizer._special_tokens[EOS] + self.vocab_size = tokenizer.n_vocab + self.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" + self.chat_template_jinja = Template(self.chat_template) + self.additional_stop_token_ids = None + + def encode(self, x, add_special_tokens=False): + return self.tokenizer.encode(x) + + def decode(self, x, *args, **kwargs): + return self.tokenizer.decode(x) + + def batch_decode( + self, batch, skip_special_tokens=True, spaces_between_special_tokens=False + ): + if len(batch) > 0 and isinstance(batch[0], int): + batch = [[x] for x in batch] + return self.tokenizer.decode_batch(batch) + + def apply_chat_template( + self, messages, tokenize, add_generation_prompt, tools=None + ): + ret = self.chat_template_jinja.render( + messages=messages, add_generation_prompt=add_generation_prompt + ) + return self.encode(ret) if tokenize else ret + + def __call__(self, text, **kwargs): + return { + "input_ids": self.encode(text), + } + + def init_xgrammar(self): + from xgrammar import TokenizerInfo + + XGRAMMAR_SPECIAL_TOKEN_TEMPLATE = "<|xg_special_token_{}|>" + + enc = self.tokenizer + encoded_vocab = {**enc._mergeable_ranks, **enc._special_tokens} + encoded_vocab = [ + token for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1]) + ] + override_stop_tokens = [2] # eos + # These are treated as special tokens in xgrammar; we want to avoid them + # For now, xgrammar treats anything starting with b'\x00' as a special token + xgrammar_special_token_ids = [] + for i, token in enumerate(encoded_vocab): + if isinstance(token, bytes) and token.startswith(b"\x00"): + xgrammar_special_token_ids.append(i) + + for i, id in enumerate(xgrammar_special_token_ids): + encoded_vocab[id] = XGRAMMAR_SPECIAL_TOKEN_TEMPLATE.format(i) + tokenizer_info = TokenizerInfo( + encoded_vocab, stop_token_ids=override_stop_tokens + ) + assert len(tokenizer_info.special_token_ids) == 0 + + return tokenizer_info, override_stop_tokens