Update grok.py and tiktoken tokenizer (#9532)
This commit is contained in:
@@ -16,7 +16,6 @@
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||
"""Inference-only Grok1 model."""
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.elementwise import (
|
||||
experts_combine_triton,
|
||||
fused_dual_residual_rmsnorm,
|
||||
fused_rmsnorm,
|
||||
gelu_and_mul_triton,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.rotary_embedding import (
|
||||
RotaryEmbedding,
|
||||
_yarn_find_correction_range,
|
||||
_yarn_get_mscale,
|
||||
get_rope,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import dump_to_file
|
||||
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Dump tensors for debugging
|
||||
debug_tensor_dump_output_folder = None
|
||||
debug_tensor_dump_prefill_only = False
|
||||
# Skip all the other tensor dumps, only dump the target logits
|
||||
debug_tensor_dump_only_target_logprobs = False
|
||||
debug_tensor_dump_inject = False
|
||||
debug_tensor_dump_layers = None
|
||||
debug_tensor_dump_test = False
|
||||
|
||||
|
||||
class Grok1MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
reduce_results=True,
|
||||
use_presharded_weights: bool = False,
|
||||
split_gate_up: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
reduce_results=reduce_results,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.act_fn = GeluAndMul(approximate="tanh")
|
||||
self.layer_id = layer_id
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x, _ = gelu_and_mul_triton(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Grok1MoE(nn.Module):
|
||||
@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
reduce_results=True,
|
||||
reduce_results: bool = True,
|
||||
use_presharded_weights: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -145,6 +204,135 @@ class Grok1MoE(nn.Module):
|
||||
return self.experts(hidden_states, topk_output)
|
||||
|
||||
|
||||
def _yarn_linear_ramp_mask(
|
||||
low: float, high: float, dim: int, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
if low == high:
|
||||
low -= 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
def get_rope_scaling(config):
|
||||
rope_type = getattr(config, "rope_type", None)
|
||||
if rope_type:
|
||||
original_max_position_embeddings = getattr(
|
||||
config, "original_max_position_embeddings", None
|
||||
)
|
||||
scaling_factor = getattr(config, "scaling_factor", None)
|
||||
extrapolation_factor = getattr(config, "extrapolation_factor", 1.0)
|
||||
attn_factor = getattr(config, "attn_factor", 1.0)
|
||||
beta_fast = getattr(config, "beta_fast", 32)
|
||||
beta_slow = getattr(config, "beta_slow", 1)
|
||||
rope_scaling = {
|
||||
"extra_method": rope_type,
|
||||
"max_position_embeddings": original_max_position_embeddings,
|
||||
"scaling_factor": scaling_factor,
|
||||
"extrapolation_factor": extrapolation_factor,
|
||||
"attn_factor": attn_factor,
|
||||
"beta_fast": beta_fast,
|
||||
"beta_slow": beta_slow,
|
||||
"dtype": torch.float,
|
||||
}
|
||||
return rope_scaling
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extra_method: str = "yarn_log",
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extra_method = extra_method
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base ** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
||||
)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
self.rotary_dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (
|
||||
1
|
||||
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
||||
) * self.extrapolation_factor
|
||||
if self.extra_method in ["original"]:
|
||||
inv_freq = inv_freq_extrapolation
|
||||
elif self.extra_method in ["yarn", "yarn_linear"]:
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||
+ inv_freq_extrapolation * inv_freq_mask
|
||||
)
|
||||
elif self.extra_method == "yarn_log":
|
||||
inv_freq = torch.exp(
|
||||
torch.log(inv_freq_extrapolation) * inv_freq_mask
|
||||
+ torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask)
|
||||
)
|
||||
elif self.extra_method == "theta_scale":
|
||||
exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
||||
theta_scale_exponent = self.base ** (
|
||||
math.log(
|
||||
self.max_position_embeddings * self.scaling_factor / (2 * math.pi)
|
||||
)
|
||||
/ math.log(self.max_position_embeddings / (2 * math.pi))
|
||||
)
|
||||
inv_freq = torch.tensor(
|
||||
1.0 / (theta_scale_exponent ** (exponents / self.rotary_dim)),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown extrapolation method: {self.extra_method}")
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(
|
||||
self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
|
||||
)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
# cos = freqs.cos() * self.mscale
|
||||
# sin = freqs.sin() * self.mscale
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
class Grok1Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -157,7 +345,9 @@ class Grok1Attention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
load_presharded_attn: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -183,7 +373,9 @@ class Grok1Attention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
rope_scaling = get_rope_scaling(config)
|
||||
self.load_presharded_attn = load_presharded_attn
|
||||
self.alt_stream = alt_stream or torch.cuda.Stream()
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -195,6 +387,7 @@ class Grok1Attention(nn.Module):
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
load_presharded_attn=self.load_presharded_attn,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
@@ -205,6 +398,7 @@ class Grok1Attention(nn.Module):
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
use_presharded_weights=self.load_presharded_attn,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -214,7 +408,37 @@ class Grok1Attention(nn.Module):
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", False)
|
||||
|
||||
if rope_scaling is not None:
|
||||
self.rotary_emb = ScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
rotary_dim=(
|
||||
self.head_dim
|
||||
if not self.rope_rotate_half_dims
|
||||
else self.head_dim // 2
|
||||
),
|
||||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
**rope_scaling,
|
||||
)
|
||||
pos_encoding_mode = "NONE"
|
||||
else:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=(
|
||||
self.head_dim
|
||||
if not self.rope_rotate_half_dims
|
||||
else self.head_dim // 2
|
||||
),
|
||||
max_position=max_position,
|
||||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
pos_encoding_mode = "NONE"
|
||||
|
||||
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
|
||||
logit_capping_method = getattr(config, "attn_logit_softcapping_method", "tanh")
|
||||
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
@@ -224,7 +448,11 @@ class Grok1Attention(nn.Module):
|
||||
layer_id=layer_id,
|
||||
logit_cap=logit_cap,
|
||||
quant_config=quant_config,
|
||||
pos_encoding_mode=pos_encoding_mode,
|
||||
logit_capping_method=logit_capping_method,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
self.attn.xai_temperature_len = getattr(self.config, "attn_temperature_len", -1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -256,6 +484,8 @@ class Grok1Attention(nn.Module):
|
||||
)
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
@@ -288,6 +518,7 @@ class Grok1Attention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
del q, k, v, qkv
|
||||
|
||||
if debug_tensor_dump_output_folder:
|
||||
dump_to_file(
|
||||
@@ -312,49 +543,89 @@ class Grok1DecoderLayer(nn.Module):
|
||||
load_presharded_moe: bool = False,
|
||||
load_presharded_attn: bool = False,
|
||||
load_presharded_mlp: bool = False,
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
skip_moe: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.residual_moe = getattr(config, "residual_moe", False)
|
||||
self.layer_id = layer_id
|
||||
self.alt_stream = alt_stream or torch.cuda.Stream()
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = Grok1Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
max_position=(
|
||||
config.context_len
|
||||
if hasattr(config, "context_len")
|
||||
else config.max_position_embeddings
|
||||
),
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
alt_stream=self.alt_stream,
|
||||
load_presharded_attn=load_presharded_attn,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
self.block_sparse_moe = Grok1MoE(
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=getattr(
|
||||
config,
|
||||
"moe_intermediate_size",
|
||||
getattr(config, "intermediate_size", None),
|
||||
),
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
use_presharded_weights=load_presharded_moe,
|
||||
inplace=True,
|
||||
no_combine=False, # just a suggestion to not combine topk
|
||||
)
|
||||
|
||||
split_gate_up = not getattr(config, "merge_gate_up", True)
|
||||
if self.num_experts > 0:
|
||||
self.block_sparse_moe = Grok1MoE(
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=getattr(
|
||||
config,
|
||||
"moe_intermediate_size",
|
||||
getattr(config, "intermediate_size", None),
|
||||
),
|
||||
quant_config=quant_config,
|
||||
reduce_results=not self.residual_moe,
|
||||
use_presharded_weights=load_presharded_moe,
|
||||
inplace=False, # not self.residual_moe,
|
||||
no_combine=False, # self.residual_moe, # just a suggestion to not combine topk
|
||||
prefix=add_prefix("block_sparse_moe", prefix),
|
||||
)
|
||||
if self.residual_moe:
|
||||
self.mlp = Grok1MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
use_presharded_weights=load_presharded_mlp,
|
||||
layer_id=layer_id,
|
||||
split_gate_up=split_gate_up,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.ffn = self.block_sparse_moe
|
||||
if self.num_experts > 0:
|
||||
if self.residual_moe:
|
||||
# NOTE: self.block_sparse_moe modifies the input in-place,
|
||||
# so we have to call it later. Be aware of any possible related errors.
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
self.ffn = lambda x: tensor_model_parallel_all_reduce(
|
||||
self.moe_with_rmoe(x)
|
||||
)
|
||||
else:
|
||||
self.ffn = self.moe_with_rmoe
|
||||
else:
|
||||
self.ffn = self.block_sparse_moe
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -364,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
deferred_norm: Optional[RMSNorm] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
|
||||
|
||||
hidden_states_original = hidden_states
|
||||
residual_original = residual
|
||||
|
||||
# Self Attention
|
||||
if deferred_norm is not None:
|
||||
assert residual is not None
|
||||
@@ -386,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
if residual_original is not None:
|
||||
dispose_tensor(residual_original)
|
||||
|
||||
dispose_flag = False
|
||||
if residual is not hidden_states_original:
|
||||
dispose_flag = True
|
||||
dispose_tensor(hidden_states_original)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
@@ -403,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
|
||||
self.post_attn_norm.variance_epsilon,
|
||||
)
|
||||
|
||||
if not dispose_flag:
|
||||
dispose_tensor(hidden_states_original)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.ffn(hidden_states)
|
||||
return hidden_states, residual, self.post_moe_norm # defer layernorm
|
||||
|
||||
def moe_with_rmoe(self, x):
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
mlp_result = self.mlp(x)
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# moe should not be inplace because of stream race condition
|
||||
moe_result = self.block_sparse_moe(x)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
return (mlp_result + moe_result) / 1.4142135623730951
|
||||
|
||||
|
||||
class Grok1Model(nn.Module):
|
||||
def __init__(
|
||||
@@ -417,6 +713,8 @@ class Grok1Model(nn.Module):
|
||||
load_presharded_embedding: bool = False,
|
||||
load_presharded_attn: bool = False,
|
||||
load_presharded_mlp: bool = False,
|
||||
replicate_embedding: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -427,7 +725,11 @@ class Grok1Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
use_presharded_weights=load_presharded_embedding,
|
||||
enable_tp=not replicate_embedding,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
|
||||
self.alt_stream = torch.cuda.Stream()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Grok1DecoderLayer(
|
||||
@@ -437,6 +739,7 @@ class Grok1Model(nn.Module):
|
||||
load_presharded_moe=load_presharded_moe,
|
||||
load_presharded_attn=load_presharded_attn,
|
||||
load_presharded_mlp=load_presharded_mlp,
|
||||
alt_stream=self.alt_stream,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -506,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -514,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
|
||||
# Get presharded weights.
|
||||
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
|
||||
self.load_presharded_moe = (
|
||||
self.config.num_local_experts > 0
|
||||
getattr(config, "load_presharded_moe", True)
|
||||
and self.config.num_local_experts > 0
|
||||
and get_tensor_model_parallel_world_size() > 1
|
||||
)
|
||||
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
|
||||
@@ -529,6 +834,11 @@ class Grok1ForCausalLM(nn.Module):
|
||||
or self.load_presharded_embedding
|
||||
)
|
||||
|
||||
default_replicate_lm_head = False
|
||||
self.replicate_lm_head = getattr(
|
||||
config, "replicate_lm_head", default_replicate_lm_head
|
||||
)
|
||||
|
||||
if self.is_weights_presharded:
|
||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||
|
||||
@@ -536,6 +846,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
self.replicate_lm_head = getattr(
|
||||
config, "replicate_lm_head", default_replicate_lm_head
|
||||
)
|
||||
self.replicate_embedding = getattr(config, "replicate_embedding", False)
|
||||
|
||||
self.model = Grok1Model(
|
||||
config,
|
||||
@@ -544,6 +855,8 @@ class Grok1ForCausalLM(nn.Module):
|
||||
load_presharded_embedding=self.load_presharded_embedding,
|
||||
load_presharded_attn=self.load_presharded_attn,
|
||||
load_presharded_mlp=self.load_presharded_mlp,
|
||||
replicate_embedding=self.replicate_embedding,
|
||||
prefix=add_prefix("model", prefix),
|
||||
)
|
||||
|
||||
lm_head_params_dtype = None
|
||||
@@ -553,6 +866,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
params_dtype=lm_head_params_dtype,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
else:
|
||||
@@ -561,6 +875,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
use_presharded_weights=self.load_presharded_embedding,
|
||||
params_dtype=lm_head_params_dtype,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -577,6 +892,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
|
||||
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
|
||||
)
|
||||
self.loaded_param_names = set()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -596,11 +912,13 @@ class Grok1ForCausalLM(nn.Module):
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
num_experts: Optional[int] = None,
|
||||
ignore_parent_name: bool = False,
|
||||
check_hit_names: bool = True,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if num_experts is None:
|
||||
num_experts = self.config.num_local_experts
|
||||
if model_config is None:
|
||||
model_config = self.config
|
||||
|
||||
stacked_params_mapping = []
|
||||
stacked_params_mapping += [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@@ -616,6 +934,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
num_experts = model_config.num_local_experts
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
@@ -630,23 +949,26 @@ class Grok1ForCausalLM(nn.Module):
|
||||
def load_weight_wrapper(
|
||||
name: str, loaded_weight: torch.Tensor, *args, **kwargs
|
||||
):
|
||||
if ignore_parent_name:
|
||||
name = name.split(".")[-1]
|
||||
|
||||
if name not in params_dict:
|
||||
return
|
||||
|
||||
# Fuse constant multipliers into the weights
|
||||
if "lm_head" in name:
|
||||
loaded_weight = (
|
||||
loaded_weight.to(torch.float32)
|
||||
* self.config.output_multiplier_scale
|
||||
* model_config.output_multiplier_scale
|
||||
)
|
||||
|
||||
original_name = name
|
||||
if ignore_parent_name:
|
||||
name = name.split(".")[-1]
|
||||
|
||||
if name not in params_dict:
|
||||
logger.info(f"Skipping {name=} in load_weights_wrapper")
|
||||
return
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight, *args, **kwargs)
|
||||
hit_names.add(name)
|
||||
self.loaded_param_names.add(original_name)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
@@ -685,19 +1007,22 @@ class Grok1ForCausalLM(nn.Module):
|
||||
|
||||
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
|
||||
|
||||
if len(hit_names) > 5:
|
||||
missing = all_names - hit_names
|
||||
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
||||
logger.info(
|
||||
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
|
||||
)
|
||||
if len(missing_exclude_scales) > 0:
|
||||
raise ValueError(
|
||||
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
||||
if check_hit_names:
|
||||
if len(hit_names) > 5:
|
||||
missing = all_names - hit_names
|
||||
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
||||
logger.info(
|
||||
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
|
||||
)
|
||||
if len(missing_exclude_scales) > 0:
|
||||
raise ValueError(
|
||||
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
||||
)
|
||||
|
||||
elif len(hit_names) == 0:
|
||||
raise ValueError("load_weights failed because it did not hit any names.")
|
||||
elif len(hit_names) == 0:
|
||||
raise ValueError(
|
||||
f"load_weights failed because it did not hit any names. {all_names=} {hit_names=}"
|
||||
)
|
||||
|
||||
return hit_names
|
||||
|
||||
@@ -708,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module):
|
||||
"moe_intermediate_size",
|
||||
getattr(cfg, "intermediate_size", None),
|
||||
)
|
||||
num_experts = cfg.num_local_experts
|
||||
residual_moe = getattr(cfg, "residual_moe", False)
|
||||
if cfg.num_local_experts > 0:
|
||||
num_experts = cfg.num_local_experts + (1 if residual_moe else 0)
|
||||
else:
|
||||
num_experts = 1
|
||||
|
||||
wq = (
|
||||
cfg.num_hidden_layers
|
||||
|
||||
Reference in New Issue
Block a user