Update grok.py and tiktoken tokenizer (#9532)
This commit is contained in:
@@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
|
||||
# This ensures consistency between what the model considers EOS and what XGrammar uses
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
|
||||
)
|
||||
override_stop_tokens = None
|
||||
if hasattr(tokenizer, "init_xgrammar"):
|
||||
# For special tokenizer
|
||||
tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
|
||||
else:
|
||||
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
|
||||
# This ensures consistency between what the model considers EOS and what XGrammar uses
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
|
||||
)
|
||||
override_stop_tokens = None
|
||||
|
||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
@@ -263,6 +263,11 @@ def get_tokenizer(
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
if tokenizer_name.endswith(".json"):
|
||||
from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
|
||||
|
||||
return TiktokenTokenizer(tokenizer_name)
|
||||
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
|
||||
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
|
||||
def logit_capping_mod(logit_capping_method, logit_cap):
|
||||
# positive logit_cap -> tanh cap
|
||||
if logit_capping_method == "tanh":
|
||||
return logit_cap
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMetadata:
|
||||
attn_logits: torch.Tensor
|
||||
@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
||||
|
||||
causal = True
|
||||
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||
causal = False
|
||||
@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.forward_metadata.mask_indptr,
|
||||
self.forward_metadata.max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
logit_cap=logits_soft_cap,
|
||||
sliding_window_size=sliding_window_size,
|
||||
sinks=sinks,
|
||||
window_kv_offsets=window_kv_offsets,
|
||||
xai_temperature_len=layer.xai_temperature_len,
|
||||
)
|
||||
return o
|
||||
|
||||
@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.forward_metadata.num_kv_splits,
|
||||
self.max_kv_splits,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
logit_cap=logits_soft_cap,
|
||||
sinks=sinks,
|
||||
xai_temperature_len=layer.xai_temperature_len,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
xai_temperature_len: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
|
||||
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
||||
kv_splits = tl.load(num_kv_splits + cur_batch)
|
||||
|
||||
if xai_temperature_len > 0:
|
||||
offs_qidx = cur_batch_seq_len - 1
|
||||
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
||||
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
|
||||
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||
|
||||
kv_len_per_split = (
|
||||
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
if xai_temperature_len > 0:
|
||||
qk *= xai_temperature_reg
|
||||
|
||||
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (
|
||||
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
xai_temperature_len=-1,
|
||||
):
|
||||
BLOCK = 64
|
||||
# [TODO] work around SGPR limit on MI3xx
|
||||
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
|
||||
BLOCK_N=BLOCK,
|
||||
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||
logit_cap=logit_cap,
|
||||
xai_temperature_len=xai_temperature_len,
|
||||
num_warps=num_warps,
|
||||
num_stages=2,
|
||||
Lk=Lk,
|
||||
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
|
||||
BLOCK_H: tl.constexpr,
|
||||
MIN_BLOCK_KV: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
xai_temperature_len: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
|
||||
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
||||
kv_splits = tl.load(num_kv_splits + cur_batch)
|
||||
|
||||
if xai_temperature_len > 0:
|
||||
offs_qidx = cur_batch_seq_len - 1
|
||||
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
||||
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
|
||||
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
|
||||
|
||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
if xai_temperature_len > 0:
|
||||
qk *= xai_temperature_reg[:, None]
|
||||
|
||||
qk = tl.where(
|
||||
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
||||
)
|
||||
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
xai_temperature_len=-1,
|
||||
):
|
||||
BLOCK = 32
|
||||
Lk = k_buffer.shape[-1]
|
||||
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
|
||||
BLOCK_H=BLOCK_H,
|
||||
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||
logit_cap=logit_cap,
|
||||
xai_temperature_len=xai_temperature_len,
|
||||
num_warps=4,
|
||||
num_stages=num_stages,
|
||||
Lk=Lk,
|
||||
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sinks=None,
|
||||
xai_temperature_len=-1,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
xai_temperature_len,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sinks=None,
|
||||
xai_temperature_len=-1,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
xai_temperature_len,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
@@ -702,6 +730,7 @@ def decode_attention_fwd(
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sinks=None,
|
||||
xai_temperature_len=-1,
|
||||
):
|
||||
assert max_kv_splits == attn_logits.shape[2]
|
||||
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
||||
@@ -725,6 +754,7 @@ def decode_attention_fwd(
|
||||
sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
sinks=sinks,
|
||||
xai_temperature_len=xai_temperature_len,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
@@ -742,4 +772,5 @@ def decode_attention_fwd(
|
||||
sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
sinks=sinks,
|
||||
xai_temperature_len=xai_temperature_len,
|
||||
)
|
||||
|
||||
@@ -69,6 +69,7 @@ def _fwd_kernel(
|
||||
stride_buf_vh,
|
||||
SLIDING_WINDOW_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
xai_temperature_len: tl.constexpr,
|
||||
Lq: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
@@ -109,6 +110,15 @@ def _fwd_kernel(
|
||||
mask_d = offs_d < Lq
|
||||
mask_dv = offs_dv < Lv
|
||||
|
||||
if xai_temperature_len > 0:
|
||||
offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
|
||||
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
||||
xai_temperature_reg = tl.where(
|
||||
offs_qidx > xai_temperature_len,
|
||||
tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
|
||||
1.0,
|
||||
)
|
||||
|
||||
offs_q = (
|
||||
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_qbs
|
||||
@@ -203,6 +213,9 @@ def _fwd_kernel(
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
if xai_temperature_len > 0:
|
||||
qk *= xai_temperature_reg[:, None]
|
||||
|
||||
qk = tl.where(final_mask, qk, float("-inf"))
|
||||
|
||||
row_max = tl.max(qk, 1)
|
||||
@@ -306,6 +319,9 @@ def _fwd_kernel(
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
if xai_temperature_len > 0:
|
||||
qk *= xai_temperature_reg[:, None]
|
||||
|
||||
qk = tl.where(final_mask, qk, float("-inf"))
|
||||
|
||||
row_max = tl.max(qk, 1)
|
||||
@@ -373,6 +389,7 @@ def extend_attention_fwd(
|
||||
sliding_window_size=-1,
|
||||
sinks=None,
|
||||
window_kv_offsets=None,
|
||||
xai_temperature_len=-1,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
@@ -477,6 +494,7 @@ def extend_attention_fwd(
|
||||
v_buffer.stride(1),
|
||||
SLIDING_WINDOW_SIZE=sliding_window_size,
|
||||
logit_cap=logit_cap,
|
||||
xai_temperature_len=xai_temperature_len,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
|
||||
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
|
||||
return out_hidden_states, out_scales
|
||||
else:
|
||||
return out_hidden_states, None
|
||||
|
||||
|
||||
# silu on first half of vector
|
||||
@triton.jit
|
||||
def silu_and_mul_kernel(
|
||||
out_hidden_states_ptr, # (bs, hidden_dim)
|
||||
out_scales_ptr, # (bs,)
|
||||
hidden_states_ptr, # (bs, hidden_dim * 2)
|
||||
quant_max: tl.constexpr,
|
||||
static_scale: tl.constexpr,
|
||||
hidden_dim: tl.constexpr, # the output hidden_dim
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
input_start = pid * hidden_dim * 2
|
||||
output_start = pid * hidden_dim
|
||||
|
||||
input1_offs = tl.arange(0, BLOCK_SIZE)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
|
||||
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
|
||||
output_offs = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
x1 = tl.load(
|
||||
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
|
||||
).to(tl.float32)
|
||||
x3 = tl.load(
|
||||
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
|
||||
).to(tl.float32)
|
||||
|
||||
# silu
|
||||
# cast down before mul to better match training?
|
||||
silu_x1 = x1 * tl.sigmoid(x1)
|
||||
out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)
|
||||
|
||||
if quant_max is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
|
||||
|
||||
|
||||
def silu_and_mul_triton(
|
||||
hidden_states,
|
||||
scales=None,
|
||||
quantize=None, # dtype to quantize to
|
||||
out=None,
|
||||
):
|
||||
bs, in_hidden_dim = hidden_states.shape
|
||||
hidden_dim = in_hidden_dim // 2
|
||||
|
||||
if out is None:
|
||||
out_hidden_states = torch.empty(
|
||||
(bs, hidden_dim),
|
||||
dtype=quantize or hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
assert out.shape == (bs, hidden_dim)
|
||||
assert out.dtype == (quantize or hidden_states.dtype)
|
||||
out_hidden_states = out
|
||||
out_scales = None
|
||||
static_scale = False
|
||||
if quantize is not None:
|
||||
if scales is None:
|
||||
out_scales = torch.empty(
|
||||
(bs,), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
else:
|
||||
out_scales = scales
|
||||
static_scale = True
|
||||
|
||||
max_warps = 16 if _is_hip else 32
|
||||
config = {
|
||||
# 8 ele per thread (not tuned)
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
|
||||
),
|
||||
}
|
||||
|
||||
silu_and_mul_kernel[(bs,)](
|
||||
out_hidden_states,
|
||||
out_scales,
|
||||
hidden_states,
|
||||
quant_max=torch.finfo(quantize).max if quantize is not None else None,
|
||||
static_scale=static_scale,
|
||||
hidden_dim=hidden_dim,
|
||||
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
|
||||
**config,
|
||||
)
|
||||
|
||||
if quantize is not None:
|
||||
return out_hidden_states, out_scales
|
||||
else:
|
||||
return out_hidden_states, None
|
||||
|
||||
@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
|
||||
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
|
||||
|
||||
# logit softcap
|
||||
logits_scaled = logits / moe_softcapping
|
||||
exped = tl.exp(2 * logits_scaled)
|
||||
top = exped - 1
|
||||
bottom = exped + 1
|
||||
logits_softcapped = top / bottom * moe_softcapping
|
||||
if moe_softcapping == 0:
|
||||
logits_softcapped = logits
|
||||
else:
|
||||
logits_scaled = logits / moe_softcapping
|
||||
exped = tl.exp(2 * logits_scaled)
|
||||
top = exped - 1
|
||||
bottom = exped + 1
|
||||
logits_softcapped = top / bottom * moe_softcapping
|
||||
|
||||
# Add bias after softcapping
|
||||
if is_correction_bias:
|
||||
@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
|
||||
b_ptrs += BLOCK_SIZE_K
|
||||
|
||||
# 4. logit softcap
|
||||
logits_scaled = acc / moe_softcapping
|
||||
exped = tl.exp(2 * logits_scaled)
|
||||
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
||||
if moe_softcapping == 0:
|
||||
logits_softcapped = acc
|
||||
else:
|
||||
logits_scaled = acc / moe_softcapping
|
||||
exped = tl.exp(2 * logits_scaled)
|
||||
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
||||
|
||||
# 5. top1
|
||||
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
||||
@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
|
||||
|
||||
# 7. handle topk == 2
|
||||
if topk == 2:
|
||||
cond_top2 = (arange_block_size_n < num_experts) and (
|
||||
cond_top2 = (arange_block_size_n < num_experts) & (
|
||||
arange_block_size_n != top1[:, None]
|
||||
)
|
||||
top2 = tl.argmax(
|
||||
|
||||
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
|
||||
v_head_dim: int = -1,
|
||||
sliding_window_size: int = -1,
|
||||
is_cross_attention: bool = False,
|
||||
pos_encoding_mode: str = "NONE",
|
||||
logit_capping_method: str = "tanh",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
|
||||
self.quant_method.create_weights(self)
|
||||
self.attn_type = attn_type
|
||||
|
||||
self.pos_encoding_mode = pos_encoding_mode
|
||||
self.logit_capping_method = logit_capping_method
|
||||
self.xai_temperature_len = -1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q,
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||
"""Inference-only Grok1 model."""
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.elementwise import (
|
||||
experts_combine_triton,
|
||||
fused_dual_residual_rmsnorm,
|
||||
fused_rmsnorm,
|
||||
gelu_and_mul_triton,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.rotary_embedding import (
|
||||
RotaryEmbedding,
|
||||
_yarn_find_correction_range,
|
||||
_yarn_get_mscale,
|
||||
get_rope,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import dump_to_file
|
||||
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Dump tensors for debugging
|
||||
debug_tensor_dump_output_folder = None
|
||||
debug_tensor_dump_prefill_only = False
|
||||
# Skip all the other tensor dumps, only dump the target logits
|
||||
debug_tensor_dump_only_target_logprobs = False
|
||||
debug_tensor_dump_inject = False
|
||||
debug_tensor_dump_layers = None
|
||||
debug_tensor_dump_test = False
|
||||
|
||||
|
||||
class Grok1MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
reduce_results=True,
|
||||
use_presharded_weights: bool = False,
|
||||
split_gate_up: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
reduce_results=reduce_results,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.act_fn = GeluAndMul(approximate="tanh")
|
||||
self.layer_id = layer_id
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x, _ = gelu_and_mul_triton(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Grok1MoE(nn.Module):
|
||||
@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
reduce_results=True,
|
||||
reduce_results: bool = True,
|
||||
use_presharded_weights: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -145,6 +204,135 @@ class Grok1MoE(nn.Module):
|
||||
return self.experts(hidden_states, topk_output)
|
||||
|
||||
|
||||
def _yarn_linear_ramp_mask(
|
||||
low: float, high: float, dim: int, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
if low == high:
|
||||
low -= 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
def get_rope_scaling(config):
|
||||
rope_type = getattr(config, "rope_type", None)
|
||||
if rope_type:
|
||||
original_max_position_embeddings = getattr(
|
||||
config, "original_max_position_embeddings", None
|
||||
)
|
||||
scaling_factor = getattr(config, "scaling_factor", None)
|
||||
extrapolation_factor = getattr(config, "extrapolation_factor", 1.0)
|
||||
attn_factor = getattr(config, "attn_factor", 1.0)
|
||||
beta_fast = getattr(config, "beta_fast", 32)
|
||||
beta_slow = getattr(config, "beta_slow", 1)
|
||||
rope_scaling = {
|
||||
"extra_method": rope_type,
|
||||
"max_position_embeddings": original_max_position_embeddings,
|
||||
"scaling_factor": scaling_factor,
|
||||
"extrapolation_factor": extrapolation_factor,
|
||||
"attn_factor": attn_factor,
|
||||
"beta_fast": beta_fast,
|
||||
"beta_slow": beta_slow,
|
||||
"dtype": torch.float,
|
||||
}
|
||||
return rope_scaling
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extra_method: str = "yarn_log",
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extra_method = extra_method
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base ** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
||||
)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
self.rotary_dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (
|
||||
1
|
||||
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
||||
) * self.extrapolation_factor
|
||||
if self.extra_method in ["original"]:
|
||||
inv_freq = inv_freq_extrapolation
|
||||
elif self.extra_method in ["yarn", "yarn_linear"]:
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||
+ inv_freq_extrapolation * inv_freq_mask
|
||||
)
|
||||
elif self.extra_method == "yarn_log":
|
||||
inv_freq = torch.exp(
|
||||
torch.log(inv_freq_extrapolation) * inv_freq_mask
|
||||
+ torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask)
|
||||
)
|
||||
elif self.extra_method == "theta_scale":
|
||||
exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
||||
theta_scale_exponent = self.base ** (
|
||||
math.log(
|
||||
self.max_position_embeddings * self.scaling_factor / (2 * math.pi)
|
||||
)
|
||||
/ math.log(self.max_position_embeddings / (2 * math.pi))
|
||||
)
|
||||
inv_freq = torch.tensor(
|
||||
1.0 / (theta_scale_exponent ** (exponents / self.rotary_dim)),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown extrapolation method: {self.extra_method}")
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(
|
||||
self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
|
||||
)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
# cos = freqs.cos() * self.mscale
|
||||
# sin = freqs.sin() * self.mscale
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
class Grok1Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -157,7 +345,9 @@ class Grok1Attention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
load_presharded_attn: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -183,7 +373,9 @@ class Grok1Attention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
rope_scaling = get_rope_scaling(config)
|
||||
self.load_presharded_attn = load_presharded_attn
|
||||
self.alt_stream = alt_stream or torch.cuda.Stream()
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -195,6 +387,7 @@ class Grok1Attention(nn.Module):
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
load_presharded_attn=self.load_presharded_attn,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
@@ -205,6 +398,7 @@ class Grok1Attention(nn.Module):
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
use_presharded_weights=self.load_presharded_attn,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -214,7 +408,37 @@ class Grok1Attention(nn.Module):
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", False)
|
||||
|
||||
if rope_scaling is not None:
|
||||
self.rotary_emb = ScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
rotary_dim=(
|
||||
self.head_dim
|
||||
if not self.rope_rotate_half_dims
|
||||
else self.head_dim // 2
|
||||
),
|
||||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
**rope_scaling,
|
||||
)
|
||||
pos_encoding_mode = "NONE"
|
||||
else:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=(
|
||||
self.head_dim
|
||||
if not self.rope_rotate_half_dims
|
||||
else self.head_dim // 2
|
||||
),
|
||||
max_position=max_position,
|
||||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
pos_encoding_mode = "NONE"
|
||||
|
||||
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
|
||||
logit_capping_method = getattr(config, "attn_logit_softcapping_method", "tanh")
|
||||
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
@@ -224,7 +448,11 @@ class Grok1Attention(nn.Module):
|
||||
layer_id=layer_id,
|
||||
logit_cap=logit_cap,
|
||||
quant_config=quant_config,
|
||||
pos_encoding_mode=pos_encoding_mode,
|
||||
logit_capping_method=logit_capping_method,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
self.attn.xai_temperature_len = getattr(self.config, "attn_temperature_len", -1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -256,6 +484,8 @@ class Grok1Attention(nn.Module):
|
||||
)
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
@@ -288,6 +518,7 @@ class Grok1Attention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
del q, k, v, qkv
|
||||
|
||||
if debug_tensor_dump_output_folder:
|
||||
dump_to_file(
|
||||
@@ -312,49 +543,89 @@ class Grok1DecoderLayer(nn.Module):
|
||||
load_presharded_moe: bool = False,
|
||||
load_presharded_attn: bool = False,
|
||||
load_presharded_mlp: bool = False,
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
skip_moe: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.residual_moe = getattr(config, "residual_moe", False)
|
||||
self.layer_id = layer_id
|
||||
self.alt_stream = alt_stream or torch.cuda.Stream()
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = Grok1Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
max_position=(
|
||||
config.context_len
|
||||
if hasattr(config, "context_len")
|
||||
else config.max_position_embeddings
|
||||
),
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
alt_stream=self.alt_stream,
|
||||
load_presharded_attn=load_presharded_attn,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
self.block_sparse_moe = Grok1MoE(
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=getattr(
|
||||
config,
|
||||
"moe_intermediate_size",
|
||||
getattr(config, "intermediate_size", None),
|
||||
),
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
use_presharded_weights=load_presharded_moe,
|
||||
inplace=True,
|
||||
no_combine=False, # just a suggestion to not combine topk
|
||||
)
|
||||
|
||||
split_gate_up = not getattr(config, "merge_gate_up", True)
|
||||
if self.num_experts > 0:
|
||||
self.block_sparse_moe = Grok1MoE(
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=getattr(
|
||||
config,
|
||||
"moe_intermediate_size",
|
||||
getattr(config, "intermediate_size", None),
|
||||
),
|
||||
quant_config=quant_config,
|
||||
reduce_results=not self.residual_moe,
|
||||
use_presharded_weights=load_presharded_moe,
|
||||
inplace=False, # not self.residual_moe,
|
||||
no_combine=False, # self.residual_moe, # just a suggestion to not combine topk
|
||||
prefix=add_prefix("block_sparse_moe", prefix),
|
||||
)
|
||||
if self.residual_moe:
|
||||
self.mlp = Grok1MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
use_presharded_weights=load_presharded_mlp,
|
||||
layer_id=layer_id,
|
||||
split_gate_up=split_gate_up,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.ffn = self.block_sparse_moe
|
||||
if self.num_experts > 0:
|
||||
if self.residual_moe:
|
||||
# NOTE: self.block_sparse_moe modifies the input in-place,
|
||||
# so we have to call it later. Be aware of any possible related errors.
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
self.ffn = lambda x: tensor_model_parallel_all_reduce(
|
||||
self.moe_with_rmoe(x)
|
||||
)
|
||||
else:
|
||||
self.ffn = self.moe_with_rmoe
|
||||
else:
|
||||
self.ffn = self.block_sparse_moe
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -364,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
deferred_norm: Optional[RMSNorm] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
|
||||
|
||||
hidden_states_original = hidden_states
|
||||
residual_original = residual
|
||||
|
||||
# Self Attention
|
||||
if deferred_norm is not None:
|
||||
assert residual is not None
|
||||
@@ -386,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
if residual_original is not None:
|
||||
dispose_tensor(residual_original)
|
||||
|
||||
dispose_flag = False
|
||||
if residual is not hidden_states_original:
|
||||
dispose_flag = True
|
||||
dispose_tensor(hidden_states_original)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
@@ -403,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
|
||||
self.post_attn_norm.variance_epsilon,
|
||||
)
|
||||
|
||||
if not dispose_flag:
|
||||
dispose_tensor(hidden_states_original)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.ffn(hidden_states)
|
||||
return hidden_states, residual, self.post_moe_norm # defer layernorm
|
||||
|
||||
def moe_with_rmoe(self, x):
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
mlp_result = self.mlp(x)
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# moe should not be inplace because of stream race condition
|
||||
moe_result = self.block_sparse_moe(x)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
return (mlp_result + moe_result) / 1.4142135623730951
|
||||
|
||||
|
||||
class Grok1Model(nn.Module):
|
||||
def __init__(
|
||||
@@ -417,6 +713,8 @@ class Grok1Model(nn.Module):
|
||||
load_presharded_embedding: bool = False,
|
||||
load_presharded_attn: bool = False,
|
||||
load_presharded_mlp: bool = False,
|
||||
replicate_embedding: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -427,7 +725,11 @@ class Grok1Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
use_presharded_weights=load_presharded_embedding,
|
||||
enable_tp=not replicate_embedding,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
|
||||
self.alt_stream = torch.cuda.Stream()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Grok1DecoderLayer(
|
||||
@@ -437,6 +739,7 @@ class Grok1Model(nn.Module):
|
||||
load_presharded_moe=load_presharded_moe,
|
||||
load_presharded_attn=load_presharded_attn,
|
||||
load_presharded_mlp=load_presharded_mlp,
|
||||
alt_stream=self.alt_stream,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -506,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -514,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
|
||||
# Get presharded weights.
|
||||
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
|
||||
self.load_presharded_moe = (
|
||||
self.config.num_local_experts > 0
|
||||
getattr(config, "load_presharded_moe", True)
|
||||
and self.config.num_local_experts > 0
|
||||
and get_tensor_model_parallel_world_size() > 1
|
||||
)
|
||||
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
|
||||
@@ -529,6 +834,11 @@ class Grok1ForCausalLM(nn.Module):
|
||||
or self.load_presharded_embedding
|
||||
)
|
||||
|
||||
default_replicate_lm_head = False
|
||||
self.replicate_lm_head = getattr(
|
||||
config, "replicate_lm_head", default_replicate_lm_head
|
||||
)
|
||||
|
||||
if self.is_weights_presharded:
|
||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||
|
||||
@@ -536,6 +846,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
self.replicate_lm_head = getattr(
|
||||
config, "replicate_lm_head", default_replicate_lm_head
|
||||
)
|
||||
self.replicate_embedding = getattr(config, "replicate_embedding", False)
|
||||
|
||||
self.model = Grok1Model(
|
||||
config,
|
||||
@@ -544,6 +855,8 @@ class Grok1ForCausalLM(nn.Module):
|
||||
load_presharded_embedding=self.load_presharded_embedding,
|
||||
load_presharded_attn=self.load_presharded_attn,
|
||||
load_presharded_mlp=self.load_presharded_mlp,
|
||||
replicate_embedding=self.replicate_embedding,
|
||||
prefix=add_prefix("model", prefix),
|
||||
)
|
||||
|
||||
lm_head_params_dtype = None
|
||||
@@ -553,6 +866,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
params_dtype=lm_head_params_dtype,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
else:
|
||||
@@ -561,6 +875,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
use_presharded_weights=self.load_presharded_embedding,
|
||||
params_dtype=lm_head_params_dtype,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -577,6 +892,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
|
||||
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
|
||||
)
|
||||
self.loaded_param_names = set()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -596,11 +912,13 @@ class Grok1ForCausalLM(nn.Module):
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
num_experts: Optional[int] = None,
|
||||
ignore_parent_name: bool = False,
|
||||
check_hit_names: bool = True,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if num_experts is None:
|
||||
num_experts = self.config.num_local_experts
|
||||
if model_config is None:
|
||||
model_config = self.config
|
||||
|
||||
stacked_params_mapping = []
|
||||
stacked_params_mapping += [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@@ -616,6 +934,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
num_experts = model_config.num_local_experts
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
@@ -630,23 +949,26 @@ class Grok1ForCausalLM(nn.Module):
|
||||
def load_weight_wrapper(
|
||||
name: str, loaded_weight: torch.Tensor, *args, **kwargs
|
||||
):
|
||||
if ignore_parent_name:
|
||||
name = name.split(".")[-1]
|
||||
|
||||
if name not in params_dict:
|
||||
return
|
||||
|
||||
# Fuse constant multipliers into the weights
|
||||
if "lm_head" in name:
|
||||
loaded_weight = (
|
||||
loaded_weight.to(torch.float32)
|
||||
* self.config.output_multiplier_scale
|
||||
* model_config.output_multiplier_scale
|
||||
)
|
||||
|
||||
original_name = name
|
||||
if ignore_parent_name:
|
||||
name = name.split(".")[-1]
|
||||
|
||||
if name not in params_dict:
|
||||
logger.info(f"Skipping {name=} in load_weights_wrapper")
|
||||
return
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight, *args, **kwargs)
|
||||
hit_names.add(name)
|
||||
self.loaded_param_names.add(original_name)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
@@ -685,19 +1007,22 @@ class Grok1ForCausalLM(nn.Module):
|
||||
|
||||
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
|
||||
|
||||
if len(hit_names) > 5:
|
||||
missing = all_names - hit_names
|
||||
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
||||
logger.info(
|
||||
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
|
||||
)
|
||||
if len(missing_exclude_scales) > 0:
|
||||
raise ValueError(
|
||||
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
||||
if check_hit_names:
|
||||
if len(hit_names) > 5:
|
||||
missing = all_names - hit_names
|
||||
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
||||
logger.info(
|
||||
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
|
||||
)
|
||||
if len(missing_exclude_scales) > 0:
|
||||
raise ValueError(
|
||||
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
||||
)
|
||||
|
||||
elif len(hit_names) == 0:
|
||||
raise ValueError("load_weights failed because it did not hit any names.")
|
||||
elif len(hit_names) == 0:
|
||||
raise ValueError(
|
||||
f"load_weights failed because it did not hit any names. {all_names=} {hit_names=}"
|
||||
)
|
||||
|
||||
return hit_names
|
||||
|
||||
@@ -708,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module):
|
||||
"moe_intermediate_size",
|
||||
getattr(cfg, "intermediate_size", None),
|
||||
)
|
||||
num_experts = cfg.num_local_experts
|
||||
residual_moe = getattr(cfg, "residual_moe", False)
|
||||
if cfg.num_local_experts > 0:
|
||||
num_experts = cfg.num_local_experts + (1 if residual_moe else 0)
|
||||
else:
|
||||
num_experts = 1
|
||||
|
||||
wq = (
|
||||
cfg.num_hidden_layers
|
||||
|
||||
161
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
Normal file
161
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import functools
|
||||
import json
|
||||
from typing import AbstractSet, Collection, List, Literal, Union
|
||||
|
||||
|
||||
class TiktokenProcessor:
|
||||
def __init__(self, name: str):
|
||||
self.tokenizer = TiktokenTokenizer(name)
|
||||
|
||||
def image_processor(self, image):
|
||||
return {"pixel_values": [image]}
|
||||
|
||||
|
||||
RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
|
||||
CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
|
||||
|
||||
|
||||
PAD = "<|pad|>"
|
||||
EOS = "<|eos|>"
|
||||
SEP = "<|separator|>"
|
||||
|
||||
DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
|
||||
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
||||
|
||||
# default + separate each single digit
|
||||
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||
|
||||
|
||||
class TiktokenTokenizer:
|
||||
def __init__(self, tokenizer_path):
|
||||
import tiktoken
|
||||
from jinja2 import Template
|
||||
|
||||
# Read the JSON
|
||||
with open(tokenizer_path, "rb") as fin:
|
||||
xtok_dict = json.load(fin)
|
||||
|
||||
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
|
||||
mergeable_ranks = {
|
||||
bytes(item["bytes"]): item["token"] for item in xtok_dict["regular_tokens"]
|
||||
}
|
||||
special_tokens = {
|
||||
bytes(item["bytes"]).decode(): item["token"]
|
||||
for item in xtok_dict["special_tokens"]
|
||||
}
|
||||
if xtok_dict["word_split"] == "V1":
|
||||
pad_str = PAT_STR_B
|
||||
else:
|
||||
assert False, f"Unknown word_split: {xtok_dict['word_split']}"
|
||||
pad_str = xtok_dict.get("pat_str", pad_str)
|
||||
|
||||
kwargs = {
|
||||
"name": tokenizer_path,
|
||||
"pat_str": pad_str,
|
||||
"mergeable_ranks": mergeable_ranks,
|
||||
"special_tokens": special_tokens,
|
||||
}
|
||||
if "default_allowed_special" in xtok_dict:
|
||||
default_allowed_special = set(
|
||||
[
|
||||
bytes(bytes_list).decode()
|
||||
for bytes_list in xtok_dict["default_allowed_special"]
|
||||
]
|
||||
)
|
||||
if "vocab_size" in xtok_dict:
|
||||
kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]
|
||||
|
||||
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
|
||||
default_allowed_special = None
|
||||
control_tokens = DEFAULT_CONTROL_TOKENS
|
||||
tokenizer = tiktoken.Encoding(**kwargs)
|
||||
tokenizer._default_allowed_special = default_allowed_special or set()
|
||||
tokenizer._control_tokens = control_tokens
|
||||
|
||||
def encode_patched(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
allowed_special: Union[
|
||||
Literal["all"], AbstractSet[str]
|
||||
] = set(), # noqa: B006
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
) -> List[int]:
|
||||
if isinstance(allowed_special, set):
|
||||
allowed_special |= self._default_allowed_special
|
||||
return tiktoken.Encoding.encode(
|
||||
self,
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=(),
|
||||
)
|
||||
|
||||
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
||||
|
||||
# Allow more tokens to prevent crash
|
||||
tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
|
||||
tokenizer._default_allowed_special |= set(
|
||||
CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
|
||||
)
|
||||
|
||||
# Convert to HF interface
|
||||
self.tokenizer = tokenizer
|
||||
self.bos_token_id = None
|
||||
self.eos_token_id = tokenizer._special_tokens[EOS]
|
||||
self.vocab_size = tokenizer.n_vocab
|
||||
self.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
||||
self.chat_template_jinja = Template(self.chat_template)
|
||||
self.additional_stop_token_ids = None
|
||||
|
||||
def encode(self, x, add_special_tokens=False):
|
||||
return self.tokenizer.encode(x)
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def batch_decode(
|
||||
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
||||
):
|
||||
if len(batch) > 0 and isinstance(batch[0], int):
|
||||
batch = [[x] for x in batch]
|
||||
return self.tokenizer.decode_batch(batch)
|
||||
|
||||
def apply_chat_template(
|
||||
self, messages, tokenize, add_generation_prompt, tools=None
|
||||
):
|
||||
ret = self.chat_template_jinja.render(
|
||||
messages=messages, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return self.encode(ret) if tokenize else ret
|
||||
|
||||
def __call__(self, text, **kwargs):
|
||||
return {
|
||||
"input_ids": self.encode(text),
|
||||
}
|
||||
|
||||
def init_xgrammar(self):
|
||||
from xgrammar import TokenizerInfo
|
||||
|
||||
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE = "<|xg_special_token_{}|>"
|
||||
|
||||
enc = self.tokenizer
|
||||
encoded_vocab = {**enc._mergeable_ranks, **enc._special_tokens}
|
||||
encoded_vocab = [
|
||||
token for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
|
||||
]
|
||||
override_stop_tokens = [2] # eos
|
||||
# These are treated as special tokens in xgrammar; we want to avoid them
|
||||
# For now, xgrammar treats anything starting with b'\x00' as a special token
|
||||
xgrammar_special_token_ids = []
|
||||
for i, token in enumerate(encoded_vocab):
|
||||
if isinstance(token, bytes) and token.startswith(b"\x00"):
|
||||
xgrammar_special_token_ids.append(i)
|
||||
|
||||
for i, id in enumerate(xgrammar_special_token_ids):
|
||||
encoded_vocab[id] = XGRAMMAR_SPECIAL_TOKEN_TEMPLATE.format(i)
|
||||
tokenizer_info = TokenizerInfo(
|
||||
encoded_vocab, stop_token_ids=override_stop_tokens
|
||||
)
|
||||
assert len(tokenizer_info.special_token_ids) == 0
|
||||
|
||||
return tokenizer_info, override_stop_tokens
|
||||
Reference in New Issue
Block a user