Update grok.py and tiktoken tokenizer (#9532)
This commit is contained in:
@@ -162,6 +162,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if hasattr(tokenizer, "init_xgrammar"):
|
||||||
|
# For special tokenizer
|
||||||
|
tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
|
||||||
|
else:
|
||||||
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
|
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
|
||||||
# This ensures consistency between what the model considers EOS and what XGrammar uses
|
# This ensures consistency between what the model considers EOS and what XGrammar uses
|
||||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||||
|
|||||||
@@ -263,6 +263,11 @@ def get_tokenizer(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||||
|
if tokenizer_name.endswith(".json"):
|
||||||
|
from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
|
||||||
|
|
||||||
|
return TiktokenTokenizer(tokenizer_name)
|
||||||
|
|
||||||
if tokenizer_mode == "slow":
|
if tokenizer_mode == "slow":
|
||||||
if kwargs.get("use_fast", False):
|
if kwargs.get("use_fast", False):
|
||||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
|
|||||||
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
|
||||||
|
|
||||||
|
def logit_capping_mod(logit_capping_method, logit_cap):
|
||||||
|
# positive logit_cap -> tanh cap
|
||||||
|
if logit_capping_method == "tanh":
|
||||||
|
return logit_cap
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardMetadata:
|
class ForwardMetadata:
|
||||||
attn_logits: torch.Tensor
|
attn_logits: torch.Tensor
|
||||||
@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
||||||
|
|
||||||
causal = True
|
causal = True
|
||||||
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
causal = False
|
causal = False
|
||||||
@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.forward_metadata.mask_indptr,
|
self.forward_metadata.mask_indptr,
|
||||||
self.forward_metadata.max_extend_len,
|
self.forward_metadata.max_extend_len,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
logit_cap=logits_soft_cap,
|
||||||
sliding_window_size=sliding_window_size,
|
sliding_window_size=sliding_window_size,
|
||||||
sinks=sinks,
|
sinks=sinks,
|
||||||
window_kv_offsets=window_kv_offsets,
|
window_kv_offsets=window_kv_offsets,
|
||||||
|
xai_temperature_len=layer.xai_temperature_len,
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.forward_metadata.num_kv_splits,
|
self.forward_metadata.num_kv_splits,
|
||||||
self.max_kv_splits,
|
self.max_kv_splits,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
logit_cap=logits_soft_cap,
|
||||||
sinks=sinks,
|
sinks=sinks,
|
||||||
|
xai_temperature_len=layer.xai_temperature_len,
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
|
|||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
Lk: tl.constexpr,
|
Lk: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
|
xai_temperature_len: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
|
|||||||
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
||||||
kv_splits = tl.load(num_kv_splits + cur_batch)
|
kv_splits = tl.load(num_kv_splits + cur_batch)
|
||||||
|
|
||||||
|
if xai_temperature_len > 0:
|
||||||
|
offs_qidx = cur_batch_seq_len - 1
|
||||||
|
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
||||||
|
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
|
||||||
|
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
|
||||||
|
|
||||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||||
|
|
||||||
kv_len_per_split = (
|
kv_len_per_split = (
|
||||||
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
|
|||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
if xai_temperature_len > 0:
|
||||||
|
qk *= xai_temperature_reg
|
||||||
|
|
||||||
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
||||||
|
|
||||||
offs_buf_v = (
|
offs_buf_v = (
|
||||||
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
|
xai_temperature_len=-1,
|
||||||
):
|
):
|
||||||
BLOCK = 64
|
BLOCK = 64
|
||||||
# [TODO] work around SGPR limit on MI3xx
|
# [TODO] work around SGPR limit on MI3xx
|
||||||
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
|
|||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
|
xai_temperature_len=xai_temperature_len,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=2,
|
num_stages=2,
|
||||||
Lk=Lk,
|
Lk=Lk,
|
||||||
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
BLOCK_H: tl.constexpr,
|
BLOCK_H: tl.constexpr,
|
||||||
MIN_BLOCK_KV: tl.constexpr,
|
MIN_BLOCK_KV: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
|
xai_temperature_len: tl.constexpr,
|
||||||
Lk: tl.constexpr,
|
Lk: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
):
|
):
|
||||||
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
||||||
kv_splits = tl.load(num_kv_splits + cur_batch)
|
kv_splits = tl.load(num_kv_splits + cur_batch)
|
||||||
|
|
||||||
|
if xai_temperature_len > 0:
|
||||||
|
offs_qidx = cur_batch_seq_len - 1
|
||||||
|
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
||||||
|
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
|
||||||
|
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
|
||||||
|
|
||||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
||||||
|
|
||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
if xai_temperature_len > 0:
|
||||||
|
qk *= xai_temperature_reg[:, None]
|
||||||
|
|
||||||
qk = tl.where(
|
qk = tl.where(
|
||||||
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
||||||
)
|
)
|
||||||
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
|
xai_temperature_len=-1,
|
||||||
):
|
):
|
||||||
BLOCK = 32
|
BLOCK = 32
|
||||||
Lk = k_buffer.shape[-1]
|
Lk = k_buffer.shape[-1]
|
||||||
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
BLOCK_H=BLOCK_H,
|
BLOCK_H=BLOCK_H,
|
||||||
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
|
xai_temperature_len=xai_temperature_len,
|
||||||
num_warps=4,
|
num_warps=4,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
Lk=Lk,
|
Lk=Lk,
|
||||||
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
|
|||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
sinks=None,
|
sinks=None,
|
||||||
|
xai_temperature_len=-1,
|
||||||
):
|
):
|
||||||
_decode_att_m_fwd(
|
_decode_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
|
xai_temperature_len,
|
||||||
)
|
)
|
||||||
_decode_softmax_reducev_fwd(
|
_decode_softmax_reducev_fwd(
|
||||||
attn_logits,
|
attn_logits,
|
||||||
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
|
|||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
sinks=None,
|
sinks=None,
|
||||||
|
xai_temperature_len=-1,
|
||||||
):
|
):
|
||||||
_decode_grouped_att_m_fwd(
|
_decode_grouped_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
|
xai_temperature_len,
|
||||||
)
|
)
|
||||||
_decode_softmax_reducev_fwd(
|
_decode_softmax_reducev_fwd(
|
||||||
attn_logits,
|
attn_logits,
|
||||||
@@ -702,6 +730,7 @@ def decode_attention_fwd(
|
|||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
sinks=None,
|
sinks=None,
|
||||||
|
xai_temperature_len=-1,
|
||||||
):
|
):
|
||||||
assert max_kv_splits == attn_logits.shape[2]
|
assert max_kv_splits == attn_logits.shape[2]
|
||||||
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
||||||
@@ -725,6 +754,7 @@ def decode_attention_fwd(
|
|||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
sinks=sinks,
|
sinks=sinks,
|
||||||
|
xai_temperature_len=xai_temperature_len,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# GQA/MQA/MLA
|
# GQA/MQA/MLA
|
||||||
@@ -742,4 +772,5 @@ def decode_attention_fwd(
|
|||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
sinks=sinks,
|
sinks=sinks,
|
||||||
|
xai_temperature_len=xai_temperature_len,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ def _fwd_kernel(
|
|||||||
stride_buf_vh,
|
stride_buf_vh,
|
||||||
SLIDING_WINDOW_SIZE: tl.constexpr,
|
SLIDING_WINDOW_SIZE: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
|
xai_temperature_len: tl.constexpr,
|
||||||
Lq: tl.constexpr,
|
Lq: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
@@ -109,6 +110,15 @@ def _fwd_kernel(
|
|||||||
mask_d = offs_d < Lq
|
mask_d = offs_d < Lq
|
||||||
mask_dv = offs_dv < Lv
|
mask_dv = offs_dv < Lv
|
||||||
|
|
||||||
|
if xai_temperature_len > 0:
|
||||||
|
offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
|
||||||
|
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
||||||
|
xai_temperature_reg = tl.where(
|
||||||
|
offs_qidx > xai_temperature_len,
|
||||||
|
tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
offs_q = (
|
offs_q = (
|
||||||
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
* stride_qbs
|
* stride_qbs
|
||||||
@@ -203,6 +213,9 @@ def _fwd_kernel(
|
|||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
if xai_temperature_len > 0:
|
||||||
|
qk *= xai_temperature_reg[:, None]
|
||||||
|
|
||||||
qk = tl.where(final_mask, qk, float("-inf"))
|
qk = tl.where(final_mask, qk, float("-inf"))
|
||||||
|
|
||||||
row_max = tl.max(qk, 1)
|
row_max = tl.max(qk, 1)
|
||||||
@@ -306,6 +319,9 @@ def _fwd_kernel(
|
|||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
if xai_temperature_len > 0:
|
||||||
|
qk *= xai_temperature_reg[:, None]
|
||||||
|
|
||||||
qk = tl.where(final_mask, qk, float("-inf"))
|
qk = tl.where(final_mask, qk, float("-inf"))
|
||||||
|
|
||||||
row_max = tl.max(qk, 1)
|
row_max = tl.max(qk, 1)
|
||||||
@@ -373,6 +389,7 @@ def extend_attention_fwd(
|
|||||||
sliding_window_size=-1,
|
sliding_window_size=-1,
|
||||||
sinks=None,
|
sinks=None,
|
||||||
window_kv_offsets=None,
|
window_kv_offsets=None,
|
||||||
|
xai_temperature_len=-1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||||
@@ -477,6 +494,7 @@ def extend_attention_fwd(
|
|||||||
v_buffer.stride(1),
|
v_buffer.stride(1),
|
||||||
SLIDING_WINDOW_SIZE=sliding_window_size,
|
SLIDING_WINDOW_SIZE=sliding_window_size,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
|
xai_temperature_len=xai_temperature_len,
|
||||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
BLOCK_DPE=BLOCK_DPE,
|
BLOCK_DPE=BLOCK_DPE,
|
||||||
BLOCK_DV=BLOCK_DV,
|
BLOCK_DV=BLOCK_DV,
|
||||||
|
|||||||
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
|
|||||||
return out_hidden_states, out_scales
|
return out_hidden_states, out_scales
|
||||||
else:
|
else:
|
||||||
return out_hidden_states, None
|
return out_hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
|
# silu on first half of vector
|
||||||
|
@triton.jit
|
||||||
|
def silu_and_mul_kernel(
|
||||||
|
out_hidden_states_ptr, # (bs, hidden_dim)
|
||||||
|
out_scales_ptr, # (bs,)
|
||||||
|
hidden_states_ptr, # (bs, hidden_dim * 2)
|
||||||
|
quant_max: tl.constexpr,
|
||||||
|
static_scale: tl.constexpr,
|
||||||
|
hidden_dim: tl.constexpr, # the output hidden_dim
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
input_start = pid * hidden_dim * 2
|
||||||
|
output_start = pid * hidden_dim
|
||||||
|
|
||||||
|
input1_offs = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
|
||||||
|
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
|
||||||
|
output_offs = tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
x1 = tl.load(
|
||||||
|
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
|
||||||
|
).to(tl.float32)
|
||||||
|
x3 = tl.load(
|
||||||
|
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
# silu
|
||||||
|
# cast down before mul to better match training?
|
||||||
|
silu_x1 = x1 * tl.sigmoid(x1)
|
||||||
|
out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
if quant_max is not None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def silu_and_mul_triton(
|
||||||
|
hidden_states,
|
||||||
|
scales=None,
|
||||||
|
quantize=None, # dtype to quantize to
|
||||||
|
out=None,
|
||||||
|
):
|
||||||
|
bs, in_hidden_dim = hidden_states.shape
|
||||||
|
hidden_dim = in_hidden_dim // 2
|
||||||
|
|
||||||
|
if out is None:
|
||||||
|
out_hidden_states = torch.empty(
|
||||||
|
(bs, hidden_dim),
|
||||||
|
dtype=quantize or hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert out.shape == (bs, hidden_dim)
|
||||||
|
assert out.dtype == (quantize or hidden_states.dtype)
|
||||||
|
out_hidden_states = out
|
||||||
|
out_scales = None
|
||||||
|
static_scale = False
|
||||||
|
if quantize is not None:
|
||||||
|
if scales is None:
|
||||||
|
out_scales = torch.empty(
|
||||||
|
(bs,), dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_scales = scales
|
||||||
|
static_scale = True
|
||||||
|
|
||||||
|
max_warps = 16 if _is_hip else 32
|
||||||
|
config = {
|
||||||
|
# 8 ele per thread (not tuned)
|
||||||
|
"num_warps": max(
|
||||||
|
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
silu_and_mul_kernel[(bs,)](
|
||||||
|
out_hidden_states,
|
||||||
|
out_scales,
|
||||||
|
hidden_states,
|
||||||
|
quant_max=torch.finfo(quantize).max if quantize is not None else None,
|
||||||
|
static_scale=static_scale,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if quantize is not None:
|
||||||
|
return out_hidden_states, out_scales
|
||||||
|
else:
|
||||||
|
return out_hidden_states, None
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ def fused_moe_router_kernel(
|
|||||||
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
|
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
|
||||||
|
|
||||||
# logit softcap
|
# logit softcap
|
||||||
|
if moe_softcapping == 0:
|
||||||
|
logits_softcapped = logits
|
||||||
|
else:
|
||||||
logits_scaled = logits / moe_softcapping
|
logits_scaled = logits / moe_softcapping
|
||||||
exped = tl.exp(2 * logits_scaled)
|
exped = tl.exp(2 * logits_scaled)
|
||||||
top = exped - 1
|
top = exped - 1
|
||||||
@@ -207,6 +210,9 @@ def fused_moe_router_large_bs_kernel(
|
|||||||
b_ptrs += BLOCK_SIZE_K
|
b_ptrs += BLOCK_SIZE_K
|
||||||
|
|
||||||
# 4. logit softcap
|
# 4. logit softcap
|
||||||
|
if moe_softcapping == 0:
|
||||||
|
logits_softcapped = acc
|
||||||
|
else:
|
||||||
logits_scaled = acc / moe_softcapping
|
logits_scaled = acc / moe_softcapping
|
||||||
exped = tl.exp(2 * logits_scaled)
|
exped = tl.exp(2 * logits_scaled)
|
||||||
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
||||||
@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
|
|||||||
|
|
||||||
# 7. handle topk == 2
|
# 7. handle topk == 2
|
||||||
if topk == 2:
|
if topk == 2:
|
||||||
cond_top2 = (arange_block_size_n < num_experts) and (
|
cond_top2 = (arange_block_size_n < num_experts) & (
|
||||||
arange_block_size_n != top1[:, None]
|
arange_block_size_n != top1[:, None]
|
||||||
)
|
)
|
||||||
top2 = tl.argmax(
|
top2 = tl.argmax(
|
||||||
|
|||||||
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
|
|||||||
v_head_dim: int = -1,
|
v_head_dim: int = -1,
|
||||||
sliding_window_size: int = -1,
|
sliding_window_size: int = -1,
|
||||||
is_cross_attention: bool = False,
|
is_cross_attention: bool = False,
|
||||||
|
pos_encoding_mode: str = "NONE",
|
||||||
|
logit_capping_method: str = "tanh",
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
|
|||||||
self.quant_method.create_weights(self)
|
self.quant_method.create_weights(self)
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
|
|
||||||
|
self.pos_encoding_mode = pos_encoding_mode
|
||||||
|
self.logit_capping_method = logit_capping_method
|
||||||
|
self.xai_temperature_len = -1
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
q,
|
q,
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||||
"""Inference-only Grok1 model."""
|
"""Inference-only Grok1 model."""
|
||||||
import functools
|
import functools
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
|
|||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
|
from sglang.srt.layers.activation import GeluAndMul
|
||||||
|
from sglang.srt.layers.elementwise import (
|
||||||
|
experts_combine_triton,
|
||||||
|
fused_dual_residual_rmsnorm,
|
||||||
|
fused_rmsnorm,
|
||||||
|
gelu_and_mul_triton,
|
||||||
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
|
|||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import (
|
||||||
|
RotaryEmbedding,
|
||||||
|
_yarn_find_correction_range,
|
||||||
|
_yarn_get_mscale,
|
||||||
|
get_rope,
|
||||||
|
)
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import dump_to_file
|
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Dump tensors for debugging
|
||||||
debug_tensor_dump_output_folder = None
|
debug_tensor_dump_output_folder = None
|
||||||
|
debug_tensor_dump_prefill_only = False
|
||||||
|
# Skip all the other tensor dumps, only dump the target logits
|
||||||
|
debug_tensor_dump_only_target_logprobs = False
|
||||||
debug_tensor_dump_inject = False
|
debug_tensor_dump_inject = False
|
||||||
|
debug_tensor_dump_layers = None
|
||||||
|
debug_tensor_dump_test = False
|
||||||
|
|
||||||
|
|
||||||
|
class Grok1MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
layer_id: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
reduce_results=True,
|
||||||
|
use_presharded_weights: bool = False,
|
||||||
|
split_gate_up: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("gate_up_proj", prefix),
|
||||||
|
use_presharded_weights=use_presharded_weights,
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("down_proj", prefix),
|
||||||
|
reduce_results=reduce_results,
|
||||||
|
use_presharded_weights=use_presharded_weights,
|
||||||
|
)
|
||||||
|
self.act_fn = GeluAndMul(approximate="tanh")
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x, _ = gelu_and_mul_triton(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Grok1MoE(nn.Module):
|
class Grok1MoE(nn.Module):
|
||||||
@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
reduce_results=True,
|
reduce_results: bool = True,
|
||||||
use_presharded_weights: bool = False,
|
use_presharded_weights: bool = False,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -145,6 +204,135 @@ class Grok1MoE(nn.Module):
|
|||||||
return self.experts(hidden_states, topk_output)
|
return self.experts(hidden_states, topk_output)
|
||||||
|
|
||||||
|
|
||||||
|
def _yarn_linear_ramp_mask(
|
||||||
|
low: float, high: float, dim: int, dtype: torch.dtype
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if low == high:
|
||||||
|
low -= 0.001 # Prevent singularity
|
||||||
|
|
||||||
|
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
||||||
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
|
return ramp_func
|
||||||
|
|
||||||
|
|
||||||
|
def get_rope_scaling(config):
|
||||||
|
rope_type = getattr(config, "rope_type", None)
|
||||||
|
if rope_type:
|
||||||
|
original_max_position_embeddings = getattr(
|
||||||
|
config, "original_max_position_embeddings", None
|
||||||
|
)
|
||||||
|
scaling_factor = getattr(config, "scaling_factor", None)
|
||||||
|
extrapolation_factor = getattr(config, "extrapolation_factor", 1.0)
|
||||||
|
attn_factor = getattr(config, "attn_factor", 1.0)
|
||||||
|
beta_fast = getattr(config, "beta_fast", 32)
|
||||||
|
beta_slow = getattr(config, "beta_slow", 1)
|
||||||
|
rope_scaling = {
|
||||||
|
"extra_method": rope_type,
|
||||||
|
"max_position_embeddings": original_max_position_embeddings,
|
||||||
|
"scaling_factor": scaling_factor,
|
||||||
|
"extrapolation_factor": extrapolation_factor,
|
||||||
|
"attn_factor": attn_factor,
|
||||||
|
"beta_fast": beta_fast,
|
||||||
|
"beta_slow": beta_slow,
|
||||||
|
"dtype": torch.float,
|
||||||
|
}
|
||||||
|
return rope_scaling
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
scaling_factor: float,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
*,
|
||||||
|
extra_method: str = "yarn_log",
|
||||||
|
extrapolation_factor: float = 1,
|
||||||
|
attn_factor: float = 1,
|
||||||
|
beta_fast: int = 32,
|
||||||
|
beta_slow: int = 1,
|
||||||
|
) -> None:
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.extra_method = extra_method
|
||||||
|
self.extrapolation_factor = extrapolation_factor
|
||||||
|
self.attn_factor = attn_factor
|
||||||
|
self.beta_fast = beta_fast
|
||||||
|
self.beta_slow = beta_slow
|
||||||
|
# Get n-d magnitude scaling corrected for interpolation
|
||||||
|
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||||
|
super().__init__(
|
||||||
|
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||||
|
pos_freqs = self.base ** (
|
||||||
|
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
||||||
|
)
|
||||||
|
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||||
|
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||||
|
|
||||||
|
low, high = _yarn_find_correction_range(
|
||||||
|
self.beta_fast,
|
||||||
|
self.beta_slow,
|
||||||
|
self.rotary_dim,
|
||||||
|
self.base,
|
||||||
|
self.max_position_embeddings,
|
||||||
|
)
|
||||||
|
# Get n-d rotational scaling corrected for extrapolation
|
||||||
|
inv_freq_mask = (
|
||||||
|
1
|
||||||
|
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
||||||
|
) * self.extrapolation_factor
|
||||||
|
if self.extra_method in ["original"]:
|
||||||
|
inv_freq = inv_freq_extrapolation
|
||||||
|
elif self.extra_method in ["yarn", "yarn_linear"]:
|
||||||
|
inv_freq = (
|
||||||
|
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||||
|
+ inv_freq_extrapolation * inv_freq_mask
|
||||||
|
)
|
||||||
|
elif self.extra_method == "yarn_log":
|
||||||
|
inv_freq = torch.exp(
|
||||||
|
torch.log(inv_freq_extrapolation) * inv_freq_mask
|
||||||
|
+ torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask)
|
||||||
|
)
|
||||||
|
elif self.extra_method == "theta_scale":
|
||||||
|
exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
||||||
|
theta_scale_exponent = self.base ** (
|
||||||
|
math.log(
|
||||||
|
self.max_position_embeddings * self.scaling_factor / (2 * math.pi)
|
||||||
|
)
|
||||||
|
/ math.log(self.max_position_embeddings / (2 * math.pi))
|
||||||
|
)
|
||||||
|
inv_freq = torch.tensor(
|
||||||
|
1.0 / (theta_scale_exponent ** (exponents / self.rotary_dim)),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown extrapolation method: {self.extra_method}")
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||||
|
t = torch.arange(
|
||||||
|
self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
|
||||||
|
)
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
# cos = freqs.cos() * self.mscale
|
||||||
|
# sin = freqs.sin() * self.mscale
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
class Grok1Attention(nn.Module):
|
class Grok1Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -157,7 +345,9 @@ class Grok1Attention(nn.Module):
|
|||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
reduce_results: bool = True,
|
reduce_results: bool = True,
|
||||||
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
load_presharded_attn: bool = False,
|
load_presharded_attn: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -183,7 +373,9 @@ class Grok1Attention(nn.Module):
|
|||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
|
rope_scaling = get_rope_scaling(config)
|
||||||
self.load_presharded_attn = load_presharded_attn
|
self.load_presharded_attn = load_presharded_attn
|
||||||
|
self.alt_stream = alt_stream or torch.cuda.Stream()
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -195,6 +387,7 @@ class Grok1Attention(nn.Module):
|
|||||||
tp_rank=attn_tp_rank,
|
tp_rank=attn_tp_rank,
|
||||||
tp_size=attn_tp_size,
|
tp_size=attn_tp_size,
|
||||||
load_presharded_attn=self.load_presharded_attn,
|
load_presharded_attn=self.load_presharded_attn,
|
||||||
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
@@ -205,6 +398,7 @@ class Grok1Attention(nn.Module):
|
|||||||
tp_rank=attn_tp_rank,
|
tp_rank=attn_tp_rank,
|
||||||
tp_size=attn_tp_size,
|
tp_size=attn_tp_size,
|
||||||
use_presharded_weights=self.load_presharded_attn,
|
use_presharded_weights=self.load_presharded_attn,
|
||||||
|
prefix=add_prefix("o_proj", prefix),
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -214,7 +408,37 @@ class Grok1Attention(nn.Module):
|
|||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", False)
|
||||||
|
|
||||||
|
if rope_scaling is not None:
|
||||||
|
self.rotary_emb = ScalingRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=(
|
||||||
|
self.head_dim
|
||||||
|
if not self.rope_rotate_half_dims
|
||||||
|
else self.head_dim // 2
|
||||||
|
),
|
||||||
|
base=int(self.rope_theta),
|
||||||
|
is_neox_style=True,
|
||||||
|
**rope_scaling,
|
||||||
|
)
|
||||||
|
pos_encoding_mode = "NONE"
|
||||||
|
else:
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=(
|
||||||
|
self.head_dim
|
||||||
|
if not self.rope_rotate_half_dims
|
||||||
|
else self.head_dim // 2
|
||||||
|
),
|
||||||
|
max_position=max_position,
|
||||||
|
base=int(self.rope_theta),
|
||||||
|
is_neox_style=True,
|
||||||
|
)
|
||||||
|
pos_encoding_mode = "NONE"
|
||||||
|
|
||||||
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
|
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
|
||||||
|
logit_capping_method = getattr(config, "attn_logit_softcapping_method", "tanh")
|
||||||
|
|
||||||
self.attn = RadixAttention(
|
self.attn = RadixAttention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@@ -224,7 +448,11 @@ class Grok1Attention(nn.Module):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
pos_encoding_mode=pos_encoding_mode,
|
||||||
|
logit_capping_method=logit_capping_method,
|
||||||
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
self.attn.xai_temperature_len = getattr(self.config, "attn_temperature_len", -1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -256,6 +484,8 @@ class Grok1Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
dispose_tensor(hidden_states)
|
||||||
|
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
@@ -288,6 +518,7 @@ class Grok1Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v, forward_batch)
|
attn_output = self.attn(q, k, v, forward_batch)
|
||||||
|
del q, k, v, qkv
|
||||||
|
|
||||||
if debug_tensor_dump_output_folder:
|
if debug_tensor_dump_output_folder:
|
||||||
dump_to_file(
|
dump_to_file(
|
||||||
@@ -312,25 +543,39 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
load_presharded_moe: bool = False,
|
load_presharded_moe: bool = False,
|
||||||
load_presharded_attn: bool = False,
|
load_presharded_attn: bool = False,
|
||||||
load_presharded_mlp: bool = False,
|
load_presharded_mlp: bool = False,
|
||||||
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
|
skip_moe: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
self.residual_moe = getattr(config, "residual_moe", False)
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
self.alt_stream = alt_stream or torch.cuda.Stream()
|
||||||
|
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
self.self_attn = Grok1Attention(
|
self.self_attn = Grok1Attention(
|
||||||
config=config,
|
config=config,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
max_position=config.max_position_embeddings,
|
max_position=(
|
||||||
|
config.context_len
|
||||||
|
if hasattr(config, "context_len")
|
||||||
|
else config.max_position_embeddings
|
||||||
|
),
|
||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
|
alt_stream=self.alt_stream,
|
||||||
load_presharded_attn=load_presharded_attn,
|
load_presharded_attn=load_presharded_attn,
|
||||||
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
split_gate_up = not getattr(config, "merge_gate_up", True)
|
||||||
|
if self.num_experts > 0:
|
||||||
self.block_sparse_moe = Grok1MoE(
|
self.block_sparse_moe = Grok1MoE(
|
||||||
config=config,
|
config=config,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
@@ -343,18 +588,44 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
getattr(config, "intermediate_size", None),
|
getattr(config, "intermediate_size", None),
|
||||||
),
|
),
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=True,
|
reduce_results=not self.residual_moe,
|
||||||
use_presharded_weights=load_presharded_moe,
|
use_presharded_weights=load_presharded_moe,
|
||||||
inplace=True,
|
inplace=False, # not self.residual_moe,
|
||||||
no_combine=False, # just a suggestion to not combine topk
|
no_combine=False, # self.residual_moe, # just a suggestion to not combine topk
|
||||||
|
prefix=add_prefix("block_sparse_moe", prefix),
|
||||||
)
|
)
|
||||||
|
if self.residual_moe:
|
||||||
|
self.mlp = Grok1MLP(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
reduce_results=False,
|
||||||
|
use_presharded_weights=load_presharded_mlp,
|
||||||
|
layer_id=layer_id,
|
||||||
|
split_gate_up=split_gate_up,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
if self.num_experts > 0:
|
||||||
|
if self.residual_moe:
|
||||||
|
# NOTE: self.block_sparse_moe modifies the input in-place,
|
||||||
|
# so we have to call it later. Be aware of any possible related errors.
|
||||||
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
|
self.ffn = lambda x: tensor_model_parallel_all_reduce(
|
||||||
|
self.moe_with_rmoe(x)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ffn = self.moe_with_rmoe
|
||||||
|
else:
|
||||||
self.ffn = self.block_sparse_moe
|
self.ffn = self.block_sparse_moe
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -364,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
deferred_norm: Optional[RMSNorm] = None,
|
deferred_norm: Optional[RMSNorm] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
|
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
|
||||||
|
|
||||||
|
hidden_states_original = hidden_states
|
||||||
|
residual_original = residual
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if deferred_norm is not None:
|
if deferred_norm is not None:
|
||||||
assert residual is not None
|
assert residual is not None
|
||||||
@@ -386,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if residual_original is not None:
|
||||||
|
dispose_tensor(residual_original)
|
||||||
|
|
||||||
|
dispose_flag = False
|
||||||
|
if residual is not hidden_states_original:
|
||||||
|
dispose_flag = True
|
||||||
|
dispose_tensor(hidden_states_original)
|
||||||
|
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@@ -403,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
self.post_attn_norm.variance_epsilon,
|
self.post_attn_norm.variance_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not dispose_flag:
|
||||||
|
dispose_tensor(hidden_states_original)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states = self.ffn(hidden_states)
|
hidden_states = self.ffn(hidden_states)
|
||||||
return hidden_states, residual, self.post_moe_norm # defer layernorm
|
return hidden_states, residual, self.post_moe_norm # defer layernorm
|
||||||
|
|
||||||
|
def moe_with_rmoe(self, x):
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
mlp_result = self.mlp(x)
|
||||||
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
# moe should not be inplace because of stream race condition
|
||||||
|
moe_result = self.block_sparse_moe(x)
|
||||||
|
current_stream.wait_stream(self.alt_stream)
|
||||||
|
return (mlp_result + moe_result) / 1.4142135623730951
|
||||||
|
|
||||||
|
|
||||||
class Grok1Model(nn.Module):
|
class Grok1Model(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -417,6 +713,8 @@ class Grok1Model(nn.Module):
|
|||||||
load_presharded_embedding: bool = False,
|
load_presharded_embedding: bool = False,
|
||||||
load_presharded_attn: bool = False,
|
load_presharded_attn: bool = False,
|
||||||
load_presharded_mlp: bool = False,
|
load_presharded_mlp: bool = False,
|
||||||
|
replicate_embedding: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -427,7 +725,11 @@ class Grok1Model(nn.Module):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
use_presharded_weights=load_presharded_embedding,
|
use_presharded_weights=load_presharded_embedding,
|
||||||
|
enable_tp=not replicate_embedding,
|
||||||
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.alt_stream = torch.cuda.Stream()
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Grok1DecoderLayer(
|
Grok1DecoderLayer(
|
||||||
@@ -437,6 +739,7 @@ class Grok1Model(nn.Module):
|
|||||||
load_presharded_moe=load_presharded_moe,
|
load_presharded_moe=load_presharded_moe,
|
||||||
load_presharded_attn=load_presharded_attn,
|
load_presharded_attn=load_presharded_attn,
|
||||||
load_presharded_mlp=load_presharded_mlp,
|
load_presharded_mlp=load_presharded_mlp,
|
||||||
|
alt_stream=self.alt_stream,
|
||||||
)
|
)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
@@ -506,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -514,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
# Get presharded weights.
|
# Get presharded weights.
|
||||||
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
|
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
|
||||||
self.load_presharded_moe = (
|
self.load_presharded_moe = (
|
||||||
self.config.num_local_experts > 0
|
getattr(config, "load_presharded_moe", True)
|
||||||
|
and self.config.num_local_experts > 0
|
||||||
and get_tensor_model_parallel_world_size() > 1
|
and get_tensor_model_parallel_world_size() > 1
|
||||||
)
|
)
|
||||||
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
|
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
|
||||||
@@ -529,6 +834,11 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
or self.load_presharded_embedding
|
or self.load_presharded_embedding
|
||||||
)
|
)
|
||||||
|
|
||||||
|
default_replicate_lm_head = False
|
||||||
|
self.replicate_lm_head = getattr(
|
||||||
|
config, "replicate_lm_head", default_replicate_lm_head
|
||||||
|
)
|
||||||
|
|
||||||
if self.is_weights_presharded:
|
if self.is_weights_presharded:
|
||||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||||
|
|
||||||
@@ -536,6 +846,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
self.replicate_lm_head = getattr(
|
self.replicate_lm_head = getattr(
|
||||||
config, "replicate_lm_head", default_replicate_lm_head
|
config, "replicate_lm_head", default_replicate_lm_head
|
||||||
)
|
)
|
||||||
|
self.replicate_embedding = getattr(config, "replicate_embedding", False)
|
||||||
|
|
||||||
self.model = Grok1Model(
|
self.model = Grok1Model(
|
||||||
config,
|
config,
|
||||||
@@ -544,6 +855,8 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
load_presharded_embedding=self.load_presharded_embedding,
|
load_presharded_embedding=self.load_presharded_embedding,
|
||||||
load_presharded_attn=self.load_presharded_attn,
|
load_presharded_attn=self.load_presharded_attn,
|
||||||
load_presharded_mlp=self.load_presharded_mlp,
|
load_presharded_mlp=self.load_presharded_mlp,
|
||||||
|
replicate_embedding=self.replicate_embedding,
|
||||||
|
prefix=add_prefix("model", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
lm_head_params_dtype = None
|
lm_head_params_dtype = None
|
||||||
@@ -553,6 +866,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
params_dtype=lm_head_params_dtype,
|
params_dtype=lm_head_params_dtype,
|
||||||
|
prefix=add_prefix("lm_head", prefix),
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||||
else:
|
else:
|
||||||
@@ -561,6 +875,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
use_presharded_weights=self.load_presharded_embedding,
|
use_presharded_weights=self.load_presharded_embedding,
|
||||||
params_dtype=lm_head_params_dtype,
|
params_dtype=lm_head_params_dtype,
|
||||||
|
prefix=add_prefix("lm_head", prefix),
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
@@ -577,6 +892,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
|
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
|
||||||
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
|
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
|
||||||
)
|
)
|
||||||
|
self.loaded_param_names = set()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -596,11 +912,13 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
def load_weights(
|
def load_weights(
|
||||||
self,
|
self,
|
||||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||||
num_experts: Optional[int] = None,
|
|
||||||
ignore_parent_name: bool = False,
|
ignore_parent_name: bool = False,
|
||||||
|
check_hit_names: bool = True,
|
||||||
|
model_config: PretrainedConfig | None = None,
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
if num_experts is None:
|
if model_config is None:
|
||||||
num_experts = self.config.num_local_experts
|
model_config = self.config
|
||||||
|
|
||||||
stacked_params_mapping = []
|
stacked_params_mapping = []
|
||||||
stacked_params_mapping += [
|
stacked_params_mapping += [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
@@ -616,6 +934,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
num_experts = model_config.num_local_experts
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="w1",
|
ckpt_gate_proj_name="w1",
|
||||||
ckpt_down_proj_name="w2",
|
ckpt_down_proj_name="w2",
|
||||||
@@ -630,23 +949,26 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
def load_weight_wrapper(
|
def load_weight_wrapper(
|
||||||
name: str, loaded_weight: torch.Tensor, *args, **kwargs
|
name: str, loaded_weight: torch.Tensor, *args, **kwargs
|
||||||
):
|
):
|
||||||
if ignore_parent_name:
|
|
||||||
name = name.split(".")[-1]
|
|
||||||
|
|
||||||
if name not in params_dict:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Fuse constant multipliers into the weights
|
# Fuse constant multipliers into the weights
|
||||||
if "lm_head" in name:
|
if "lm_head" in name:
|
||||||
loaded_weight = (
|
loaded_weight = (
|
||||||
loaded_weight.to(torch.float32)
|
loaded_weight.to(torch.float32)
|
||||||
* self.config.output_multiplier_scale
|
* model_config.output_multiplier_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
original_name = name
|
||||||
|
if ignore_parent_name:
|
||||||
|
name = name.split(".")[-1]
|
||||||
|
|
||||||
|
if name not in params_dict:
|
||||||
|
logger.info(f"Skipping {name=} in load_weights_wrapper")
|
||||||
|
return
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight, *args, **kwargs)
|
weight_loader(param, loaded_weight, *args, **kwargs)
|
||||||
hit_names.add(name)
|
hit_names.add(name)
|
||||||
|
self.loaded_param_names.add(original_name)
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
@@ -685,6 +1007,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
|
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
|
||||||
|
|
||||||
|
if check_hit_names:
|
||||||
if len(hit_names) > 5:
|
if len(hit_names) > 5:
|
||||||
missing = all_names - hit_names
|
missing = all_names - hit_names
|
||||||
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
||||||
@@ -697,7 +1020,9 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif len(hit_names) == 0:
|
elif len(hit_names) == 0:
|
||||||
raise ValueError("load_weights failed because it did not hit any names.")
|
raise ValueError(
|
||||||
|
f"load_weights failed because it did not hit any names. {all_names=} {hit_names=}"
|
||||||
|
)
|
||||||
|
|
||||||
return hit_names
|
return hit_names
|
||||||
|
|
||||||
@@ -708,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
"moe_intermediate_size",
|
"moe_intermediate_size",
|
||||||
getattr(cfg, "intermediate_size", None),
|
getattr(cfg, "intermediate_size", None),
|
||||||
)
|
)
|
||||||
num_experts = cfg.num_local_experts
|
residual_moe = getattr(cfg, "residual_moe", False)
|
||||||
|
if cfg.num_local_experts > 0:
|
||||||
|
num_experts = cfg.num_local_experts + (1 if residual_moe else 0)
|
||||||
|
else:
|
||||||
|
num_experts = 1
|
||||||
|
|
||||||
wq = (
|
wq = (
|
||||||
cfg.num_hidden_layers
|
cfg.num_hidden_layers
|
||||||
|
|||||||
161
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
Normal file
161
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import functools
|
||||||
|
import json
|
||||||
|
from typing import AbstractSet, Collection, List, Literal, Union
|
||||||
|
|
||||||
|
|
||||||
|
class TiktokenProcessor:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.tokenizer = TiktokenTokenizer(name)
|
||||||
|
|
||||||
|
def image_processor(self, image):
|
||||||
|
return {"pixel_values": [image]}
|
||||||
|
|
||||||
|
|
||||||
|
RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
|
||||||
|
CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
|
||||||
|
|
||||||
|
|
||||||
|
PAD = "<|pad|>"
|
||||||
|
EOS = "<|eos|>"
|
||||||
|
SEP = "<|separator|>"
|
||||||
|
|
||||||
|
DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
|
||||||
|
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
||||||
|
|
||||||
|
# default + separate each single digit
|
||||||
|
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||||
|
|
||||||
|
|
||||||
|
class TiktokenTokenizer:
|
||||||
|
def __init__(self, tokenizer_path):
|
||||||
|
import tiktoken
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
|
# Read the JSON
|
||||||
|
with open(tokenizer_path, "rb") as fin:
|
||||||
|
xtok_dict = json.load(fin)
|
||||||
|
|
||||||
|
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
|
||||||
|
mergeable_ranks = {
|
||||||
|
bytes(item["bytes"]): item["token"] for item in xtok_dict["regular_tokens"]
|
||||||
|
}
|
||||||
|
special_tokens = {
|
||||||
|
bytes(item["bytes"]).decode(): item["token"]
|
||||||
|
for item in xtok_dict["special_tokens"]
|
||||||
|
}
|
||||||
|
if xtok_dict["word_split"] == "V1":
|
||||||
|
pad_str = PAT_STR_B
|
||||||
|
else:
|
||||||
|
assert False, f"Unknown word_split: {xtok_dict['word_split']}"
|
||||||
|
pad_str = xtok_dict.get("pat_str", pad_str)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"name": tokenizer_path,
|
||||||
|
"pat_str": pad_str,
|
||||||
|
"mergeable_ranks": mergeable_ranks,
|
||||||
|
"special_tokens": special_tokens,
|
||||||
|
}
|
||||||
|
if "default_allowed_special" in xtok_dict:
|
||||||
|
default_allowed_special = set(
|
||||||
|
[
|
||||||
|
bytes(bytes_list).decode()
|
||||||
|
for bytes_list in xtok_dict["default_allowed_special"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if "vocab_size" in xtok_dict:
|
||||||
|
kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]
|
||||||
|
|
||||||
|
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
|
||||||
|
default_allowed_special = None
|
||||||
|
control_tokens = DEFAULT_CONTROL_TOKENS
|
||||||
|
tokenizer = tiktoken.Encoding(**kwargs)
|
||||||
|
tokenizer._default_allowed_special = default_allowed_special or set()
|
||||||
|
tokenizer._control_tokens = control_tokens
|
||||||
|
|
||||||
|
def encode_patched(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
allowed_special: Union[
|
||||||
|
Literal["all"], AbstractSet[str]
|
||||||
|
] = set(), # noqa: B006
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
|
) -> List[int]:
|
||||||
|
if isinstance(allowed_special, set):
|
||||||
|
allowed_special |= self._default_allowed_special
|
||||||
|
return tiktoken.Encoding.encode(
|
||||||
|
self,
|
||||||
|
text,
|
||||||
|
allowed_special=allowed_special,
|
||||||
|
disallowed_special=(),
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
||||||
|
|
||||||
|
# Allow more tokens to prevent crash
|
||||||
|
tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
|
||||||
|
tokenizer._default_allowed_special |= set(
|
||||||
|
CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to HF interface
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.bos_token_id = None
|
||||||
|
self.eos_token_id = tokenizer._special_tokens[EOS]
|
||||||
|
self.vocab_size = tokenizer.n_vocab
|
||||||
|
self.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
||||||
|
self.chat_template_jinja = Template(self.chat_template)
|
||||||
|
self.additional_stop_token_ids = None
|
||||||
|
|
||||||
|
def encode(self, x, add_special_tokens=False):
|
||||||
|
return self.tokenizer.encode(x)
|
||||||
|
|
||||||
|
def decode(self, x, *args, **kwargs):
|
||||||
|
return self.tokenizer.decode(x)
|
||||||
|
|
||||||
|
def batch_decode(
|
||||||
|
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
||||||
|
):
|
||||||
|
if len(batch) > 0 and isinstance(batch[0], int):
|
||||||
|
batch = [[x] for x in batch]
|
||||||
|
return self.tokenizer.decode_batch(batch)
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self, messages, tokenize, add_generation_prompt, tools=None
|
||||||
|
):
|
||||||
|
ret = self.chat_template_jinja.render(
|
||||||
|
messages=messages, add_generation_prompt=add_generation_prompt
|
||||||
|
)
|
||||||
|
return self.encode(ret) if tokenize else ret
|
||||||
|
|
||||||
|
def __call__(self, text, **kwargs):
|
||||||
|
return {
|
||||||
|
"input_ids": self.encode(text),
|
||||||
|
}
|
||||||
|
|
||||||
|
def init_xgrammar(self):
|
||||||
|
from xgrammar import TokenizerInfo
|
||||||
|
|
||||||
|
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE = "<|xg_special_token_{}|>"
|
||||||
|
|
||||||
|
enc = self.tokenizer
|
||||||
|
encoded_vocab = {**enc._mergeable_ranks, **enc._special_tokens}
|
||||||
|
encoded_vocab = [
|
||||||
|
token for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
|
||||||
|
]
|
||||||
|
override_stop_tokens = [2] # eos
|
||||||
|
# These are treated as special tokens in xgrammar; we want to avoid them
|
||||||
|
# For now, xgrammar treats anything starting with b'\x00' as a special token
|
||||||
|
xgrammar_special_token_ids = []
|
||||||
|
for i, token in enumerate(encoded_vocab):
|
||||||
|
if isinstance(token, bytes) and token.startswith(b"\x00"):
|
||||||
|
xgrammar_special_token_ids.append(i)
|
||||||
|
|
||||||
|
for i, id in enumerate(xgrammar_special_token_ids):
|
||||||
|
encoded_vocab[id] = XGRAMMAR_SPECIAL_TOKEN_TEMPLATE.format(i)
|
||||||
|
tokenizer_info = TokenizerInfo(
|
||||||
|
encoded_vocab, stop_token_ids=override_stop_tokens
|
||||||
|
)
|
||||||
|
assert len(tokenizer_info.special_token_ids) == 0
|
||||||
|
|
||||||
|
return tokenizer_info, override_stop_tokens
|
||||||
Reference in New Issue
Block a user