Update grok.py and tiktoken tokenizer (#9532)

This commit is contained in:
Lianmin Zheng
2025-08-23 05:40:18 -07:00
committed by GitHub
parent 83871aa12d
commit 86d10d220f
10 changed files with 732 additions and 64 deletions

View File

@@ -16,7 +16,6 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
import functools
import json
import logging
import math
import os
@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.elementwise import (
experts_combine_triton,
fused_dual_residual_rmsnorm,
fused_rmsnorm,
gelu_and_mul_triton,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.rotary_embedding import (
RotaryEmbedding,
_yarn_find_correction_range,
_yarn_get_mscale,
get_rope,
)
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import dump_to_file
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
logger = logging.getLogger(__name__)
# Dump tensors for debugging
debug_tensor_dump_output_folder = None
debug_tensor_dump_prefill_only = False
# Skip all the other tensor dumps, only dump the target logits
debug_tensor_dump_only_target_logprobs = False
debug_tensor_dump_inject = False
debug_tensor_dump_layers = None
debug_tensor_dump_test = False
class Grok1MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results=True,
use_presharded_weights: bool = False,
split_gate_up: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
use_presharded_weights=use_presharded_weights,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
reduce_results=reduce_results,
use_presharded_weights=use_presharded_weights,
)
self.act_fn = GeluAndMul(approximate="tanh")
self.layer_id = layer_id
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x, _ = gelu_and_mul_triton(gate_up)
x, _ = self.down_proj(x)
return x
class Grok1MoE(nn.Module):
@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
reduce_results=True,
reduce_results: bool = True,
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
@@ -145,6 +204,135 @@ class Grok1MoE(nn.Module):
return self.experts(hidden_states, topk_output)
def _yarn_linear_ramp_mask(
low: float, high: float, dim: int, dtype: torch.dtype
) -> torch.Tensor:
if low == high:
low -= 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def get_rope_scaling(config):
rope_type = getattr(config, "rope_type", None)
if rope_type:
original_max_position_embeddings = getattr(
config, "original_max_position_embeddings", None
)
scaling_factor = getattr(config, "scaling_factor", None)
extrapolation_factor = getattr(config, "extrapolation_factor", 1.0)
attn_factor = getattr(config, "attn_factor", 1.0)
beta_fast = getattr(config, "beta_fast", 32)
beta_slow = getattr(config, "beta_slow", 1)
rope_scaling = {
"extra_method": rope_type,
"max_position_embeddings": original_max_position_embeddings,
"scaling_factor": scaling_factor,
"extrapolation_factor": extrapolation_factor,
"attn_factor": attn_factor,
"beta_fast": beta_fast,
"beta_slow": beta_slow,
"dtype": torch.float,
}
return rope_scaling
else:
return None
class ScalingRotaryEmbedding(RotaryEmbedding):
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extra_method: str = "yarn_log",
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extra_method = extra_method
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (
1
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
) * self.extrapolation_factor
if self.extra_method in ["original"]:
inv_freq = inv_freq_extrapolation
elif self.extra_method in ["yarn", "yarn_linear"]:
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
elif self.extra_method == "yarn_log":
inv_freq = torch.exp(
torch.log(inv_freq_extrapolation) * inv_freq_mask
+ torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask)
)
elif self.extra_method == "theta_scale":
exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
theta_scale_exponent = self.base ** (
math.log(
self.max_position_embeddings * self.scaling_factor / (2 * math.pi)
)
/ math.log(self.max_position_embeddings / (2 * math.pi))
)
inv_freq = torch.tensor(
1.0 / (theta_scale_exponent ** (exponents / self.rotary_dim)),
dtype=torch.float32,
)
else:
raise ValueError(f"Unknown extrapolation method: {self.extra_method}")
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
# cos = freqs.cos() * self.mscale
# sin = freqs.sin() * self.mscale
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
class Grok1Attention(nn.Module):
def __init__(
self,
@@ -157,7 +345,9 @@ class Grok1Attention(nn.Module):
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
alt_stream: Optional[torch.cuda.Stream] = None,
load_presharded_attn: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@@ -183,7 +373,9 @@ class Grok1Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
rope_scaling = get_rope_scaling(config)
self.load_presharded_attn = load_presharded_attn
self.alt_stream = alt_stream or torch.cuda.Stream()
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -195,6 +387,7 @@ class Grok1Attention(nn.Module):
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
load_presharded_attn=self.load_presharded_attn,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
@@ -205,6 +398,7 @@ class Grok1Attention(nn.Module):
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
use_presharded_weights=self.load_presharded_attn,
prefix=add_prefix("o_proj", prefix),
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -214,7 +408,37 @@ class Grok1Attention(nn.Module):
is_neox_style=True,
)
self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", False)
if rope_scaling is not None:
self.rotary_emb = ScalingRotaryEmbedding(
self.head_dim,
rotary_dim=(
self.head_dim
if not self.rope_rotate_half_dims
else self.head_dim // 2
),
base=int(self.rope_theta),
is_neox_style=True,
**rope_scaling,
)
pos_encoding_mode = "NONE"
else:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=(
self.head_dim
if not self.rope_rotate_half_dims
else self.head_dim // 2
),
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
pos_encoding_mode = "NONE"
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
logit_capping_method = getattr(config, "attn_logit_softcapping_method", "tanh")
self.attn = RadixAttention(
self.num_heads,
@@ -224,7 +448,11 @@ class Grok1Attention(nn.Module):
layer_id=layer_id,
logit_cap=logit_cap,
quant_config=quant_config,
pos_encoding_mode=pos_encoding_mode,
logit_capping_method=logit_capping_method,
prefix=add_prefix("attn", prefix),
)
self.attn.xai_temperature_len = getattr(self.config, "attn_temperature_len", -1)
def forward(
self,
@@ -256,6 +484,8 @@ class Grok1Attention(nn.Module):
)
qkv, _ = self.qkv_proj(hidden_states)
dispose_tensor(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
@@ -288,6 +518,7 @@ class Grok1Attention(nn.Module):
)
attn_output = self.attn(q, k, v, forward_batch)
del q, k, v, qkv
if debug_tensor_dump_output_folder:
dump_to_file(
@@ -312,49 +543,89 @@ class Grok1DecoderLayer(nn.Module):
load_presharded_moe: bool = False,
load_presharded_attn: bool = False,
load_presharded_mlp: bool = False,
alt_stream: Optional[torch.cuda.Stream] = None,
skip_moe: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.residual_moe = getattr(config, "residual_moe", False)
self.layer_id = layer_id
self.alt_stream = alt_stream or torch.cuda.Stream()
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = Grok1Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
max_position=(
config.context_len
if hasattr(config, "context_len")
else config.max_position_embeddings
),
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
quant_config=quant_config,
reduce_results=False,
alt_stream=self.alt_stream,
load_presharded_attn=load_presharded_attn,
prefix=add_prefix("attn", prefix),
)
self.block_sparse_moe = Grok1MoE(
config=config,
layer_id=layer_id,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=getattr(
config,
"moe_intermediate_size",
getattr(config, "intermediate_size", None),
),
quant_config=quant_config,
reduce_results=True,
use_presharded_weights=load_presharded_moe,
inplace=True,
no_combine=False, # just a suggestion to not combine topk
)
split_gate_up = not getattr(config, "merge_gate_up", True)
if self.num_experts > 0:
self.block_sparse_moe = Grok1MoE(
config=config,
layer_id=layer_id,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=getattr(
config,
"moe_intermediate_size",
getattr(config, "intermediate_size", None),
),
quant_config=quant_config,
reduce_results=not self.residual_moe,
use_presharded_weights=load_presharded_moe,
inplace=False, # not self.residual_moe,
no_combine=False, # self.residual_moe, # just a suggestion to not combine topk
prefix=add_prefix("block_sparse_moe", prefix),
)
if self.residual_moe:
self.mlp = Grok1MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
reduce_results=False,
use_presharded_weights=load_presharded_mlp,
layer_id=layer_id,
split_gate_up=split_gate_up,
)
else:
raise NotImplementedError()
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ffn = self.block_sparse_moe
if self.num_experts > 0:
if self.residual_moe:
# NOTE: self.block_sparse_moe modifies the input in-place,
# so we have to call it later. Be aware of any possible related errors.
if get_tensor_model_parallel_world_size() > 1:
self.ffn = lambda x: tensor_model_parallel_all_reduce(
self.moe_with_rmoe(x)
)
else:
self.ffn = self.moe_with_rmoe
else:
self.ffn = self.block_sparse_moe
else:
raise NotImplementedError()
def forward(
self,
@@ -364,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
residual: Optional[torch.Tensor] = None,
deferred_norm: Optional[RMSNorm] = None,
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
hidden_states_original = hidden_states
residual_original = residual
# Self Attention
if deferred_norm is not None:
assert residual is not None
@@ -386,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
hidden_states,
)
if residual_original is not None:
dispose_tensor(residual_original)
dispose_flag = False
if residual is not hidden_states_original:
dispose_flag = True
dispose_tensor(hidden_states_original)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
@@ -403,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
self.post_attn_norm.variance_epsilon,
)
if not dispose_flag:
dispose_tensor(hidden_states_original)
# Fully Connected
hidden_states = self.ffn(hidden_states)
return hidden_states, residual, self.post_moe_norm # defer layernorm
def moe_with_rmoe(self, x):
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
mlp_result = self.mlp(x)
with torch.cuda.stream(self.alt_stream):
# moe should not be inplace because of stream race condition
moe_result = self.block_sparse_moe(x)
current_stream.wait_stream(self.alt_stream)
return (mlp_result + moe_result) / 1.4142135623730951
class Grok1Model(nn.Module):
def __init__(
@@ -417,6 +713,8 @@ class Grok1Model(nn.Module):
load_presharded_embedding: bool = False,
load_presharded_attn: bool = False,
load_presharded_mlp: bool = False,
replicate_embedding: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@@ -427,7 +725,11 @@ class Grok1Model(nn.Module):
config.vocab_size,
config.hidden_size,
use_presharded_weights=load_presharded_embedding,
enable_tp=not replicate_embedding,
prefix=add_prefix("embed_tokens", prefix),
)
self.alt_stream = torch.cuda.Stream()
self.layers = nn.ModuleList(
[
Grok1DecoderLayer(
@@ -437,6 +739,7 @@ class Grok1Model(nn.Module):
load_presharded_moe=load_presharded_moe,
load_presharded_attn=load_presharded_attn,
load_presharded_mlp=load_presharded_mlp,
alt_stream=self.alt_stream,
)
for i in range(config.num_hidden_layers)
]
@@ -506,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@@ -514,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
# Get presharded weights.
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
self.load_presharded_moe = (
self.config.num_local_experts > 0
getattr(config, "load_presharded_moe", True)
and self.config.num_local_experts > 0
and get_tensor_model_parallel_world_size() > 1
)
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
@@ -529,6 +834,11 @@ class Grok1ForCausalLM(nn.Module):
or self.load_presharded_embedding
)
default_replicate_lm_head = False
self.replicate_lm_head = getattr(
config, "replicate_lm_head", default_replicate_lm_head
)
if self.is_weights_presharded:
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -536,6 +846,7 @@ class Grok1ForCausalLM(nn.Module):
self.replicate_lm_head = getattr(
config, "replicate_lm_head", default_replicate_lm_head
)
self.replicate_embedding = getattr(config, "replicate_embedding", False)
self.model = Grok1Model(
config,
@@ -544,6 +855,8 @@ class Grok1ForCausalLM(nn.Module):
load_presharded_embedding=self.load_presharded_embedding,
load_presharded_attn=self.load_presharded_attn,
load_presharded_mlp=self.load_presharded_mlp,
replicate_embedding=self.replicate_embedding,
prefix=add_prefix("model", prefix),
)
lm_head_params_dtype = None
@@ -553,6 +866,7 @@ class Grok1ForCausalLM(nn.Module):
config.vocab_size,
bias=False,
params_dtype=lm_head_params_dtype,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
@@ -561,6 +875,7 @@ class Grok1ForCausalLM(nn.Module):
config.hidden_size,
use_presharded_weights=self.load_presharded_embedding,
params_dtype=lm_head_params_dtype,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
@@ -577,6 +892,7 @@ class Grok1ForCausalLM(nn.Module):
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
)
self.loaded_param_names = set()
def forward(
self,
@@ -596,11 +912,13 @@ class Grok1ForCausalLM(nn.Module):
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
num_experts: Optional[int] = None,
ignore_parent_name: bool = False,
check_hit_names: bool = True,
model_config: PretrainedConfig | None = None,
) -> dict[str, torch.Tensor]:
if num_experts is None:
num_experts = self.config.num_local_experts
if model_config is None:
model_config = self.config
stacked_params_mapping = []
stacked_params_mapping += [
# (param_name, shard_name, shard_id)
@@ -616,6 +934,7 @@ class Grok1ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
num_experts = model_config.num_local_experts
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
@@ -630,23 +949,26 @@ class Grok1ForCausalLM(nn.Module):
def load_weight_wrapper(
name: str, loaded_weight: torch.Tensor, *args, **kwargs
):
if ignore_parent_name:
name = name.split(".")[-1]
if name not in params_dict:
return
# Fuse constant multipliers into the weights
if "lm_head" in name:
loaded_weight = (
loaded_weight.to(torch.float32)
* self.config.output_multiplier_scale
* model_config.output_multiplier_scale
)
original_name = name
if ignore_parent_name:
name = name.split(".")[-1]
if name not in params_dict:
logger.info(f"Skipping {name=} in load_weights_wrapper")
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, *args, **kwargs)
hit_names.add(name)
self.loaded_param_names.add(original_name)
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
@@ -685,19 +1007,22 @@ class Grok1ForCausalLM(nn.Module):
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
if len(hit_names) > 5:
missing = all_names - hit_names
missing_exclude_scales = {x for x in missing if "scale" not in x}
logger.info(
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
)
if len(missing_exclude_scales) > 0:
raise ValueError(
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
if check_hit_names:
if len(hit_names) > 5:
missing = all_names - hit_names
missing_exclude_scales = {x for x in missing if "scale" not in x}
logger.info(
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
)
if len(missing_exclude_scales) > 0:
raise ValueError(
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
)
elif len(hit_names) == 0:
raise ValueError("load_weights failed because it did not hit any names.")
elif len(hit_names) == 0:
raise ValueError(
f"load_weights failed because it did not hit any names. {all_names=} {hit_names=}"
)
return hit_names
@@ -708,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module):
"moe_intermediate_size",
getattr(cfg, "intermediate_size", None),
)
num_experts = cfg.num_local_experts
residual_moe = getattr(cfg, "residual_moe", False)
if cfg.num_local_experts > 0:
num_experts = cfg.num_local_experts + (1 if residual_moe else 0)
else:
num_experts = 1
wq = (
cfg.num_hidden_layers