Update grok 1 model (#1095)

This commit is contained in:
Lianmin Zheng
2024-08-14 04:40:44 -07:00
committed by GitHub
parent fe5024325b
commit a59636bb5e
11 changed files with 814 additions and 514 deletions

View File

@@ -16,20 +16,17 @@ limitations under the License.
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
import warnings
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
@@ -37,7 +34,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
@@ -45,141 +41,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
from sglang.srt.layers.fused_moe import fused_moe
from sglang.srt.layers.fused_moe import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
use_fused = True
class Grok1MLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
)
self.w2 = ReplicatedLinear(
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
)
self.w3 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
)
self.act_fn = nn.GELU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class Grok1MoEUnfused(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}."
)
# Split experts equally between ranks
self.expert_indicies = np.array_split(
range(self.num_total_experts), self.tp_size
)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList(
[
(
Grok1MLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
if idx in self.expert_indicies
else None
)
for idx in range(self.num_total_experts)
]
)
self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
router_logits = 30 * F.tanh(router_logits / 30)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights = routing_weights.to(hidden_states.dtype)
hidden_dim = hidden_states.shape[1]
final_hidden_states = torch.zeros(
(hidden_states.shape[0], hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_total_experts
).permute(2, 1, 0)
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = (
expert_layer(current_state)
* routing_weights[top_x_list, idx_list, None]
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states)
class Grok1MoE(nn.Module):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
@@ -197,221 +65,42 @@ class Grok1MoE(nn.Module):
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(quant_config, Fp8Config)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
hidden_size,
num_experts,
bias=False,
params_dtype=self.params_dtype,
params_dtype=params_dtype,
quant_config=None,
)
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype,
)
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
)
self.w2_weight = nn.Parameter(
torch.empty(
self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype,
)
)
set_weight_attrs(
self.w13_weight,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_weight,
{
"weight_loader": self.weight_loader,
},
)
# Used for fp8.
self.w13_scale = None
self.w2_scale = None
self.a13_scale = None
self.a2_scale = None
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.w2_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(
self.w13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2_scale,
{
"weight_loader": self.weight_loader,
},
)
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
self.a13_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
self.a2_scale = nn.Parameter(
torch.zeros(self.num_total_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(
self.a13_scale,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.a2_scale,
{
"weight_loader": self.weight_loader,
},
)
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
pre_sharded: bool,
):
param_data = param.data
shard_size = self.intermediate_size
if pre_sharded:
# The weight is already sharded. Readl the full shard
shard = slice(None)
else:
tp_rank = get_tensor_model_parallel_rank()
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
shard, :
]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
param_data[expert_id] = loaded_weight
def process_weights_after_loading(self):
# Fp8 is the only case where we need to process after loading.
if not self.use_fp8:
return
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(
self.w13_weight.data, dtype=torch.float8_e4m3fn
)
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :]
)
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :]
)
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. "
)
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(
hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=False,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale,
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
router_logits = 30.0 * F.tanh(router_logits / 30.0)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class Grok1Attention(nn.Module):
@@ -478,6 +167,7 @@ class Grok1Attention(nn.Module):
layer_id=layer_id,
logit_cap=logit_cap,
)
# TODO(lianmin): load logit cap from config
def forward(
self,
@@ -502,7 +192,7 @@ class Grok1DecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = Grok1Attention(
hidden_size=self.hidden_size,
@@ -513,18 +203,13 @@ class Grok1DecoderLayer(nn.Module):
rope_theta=rope_theta,
quant_config=quant_config,
)
if use_fused:
self.block_sparse_moe = Grok1MoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
)
else:
self.block_sparse_moe = Grok1MoEUnfused(
config=config, quant_config=quant_config
)
self.block_sparse_moe = Grok1MoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -536,6 +221,7 @@ class Grok1DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
# Self Attention
hidden_states = (
self.post_attn_norm(
self.self_attn(
@@ -547,11 +233,11 @@ class Grok1DecoderLayer(nn.Module):
+ hidden_states
)
# Fully Connected
hidden_states = (
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
+ hidden_states
)
return hidden_states
@@ -593,7 +279,6 @@ class Grok1Model(nn.Module):
for i in range(len(self.layers)):
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
hidden_states = self.norm(hidden_states)
hidden_states.mul_(self.config.output_multiplier_scale)
return hidden_states
@@ -615,8 +300,8 @@ class Grok1ModelForCausalLM(nn.Module):
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
warnings.filterwarnings("ignore", category=FutureWarning)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
@@ -637,50 +322,17 @@ class Grok1ModelForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"),
]
if use_fused:
expert_params_mapping = (
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
(
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
+ [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale",
expert_id,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
)
else:
expert_params_mapping = []
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts,
)
params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
for name, loaded_weight in weights:
# print(get_tensor_model_parallel_rank(), name)
if "rotary_emb.inv_freq" in name:
continue
@@ -691,21 +343,25 @@ class Grok1ModelForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1,
)
@@ -714,6 +370,9 @@ class Grok1ModelForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
@@ -721,11 +380,6 @@ class Grok1ModelForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")