From a59636bb5e68f36308bb092674429d27c05cf125 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 14 Aug 2024 04:40:44 -0700 Subject: [PATCH] Update grok 1 model (#1095) --- benchmark/gsm8k/bench_sglang.py | 3 + python/sglang/bench_latency.py | 1 + python/sglang/srt/layers/activation.py | 1 - .../sglang/srt/layers/fused_moe/__init__.py | 1 + .../srt/layers/{ => fused_moe}/fused_moe.py | 275 ++++---- python/sglang/srt/layers/fused_moe/layer.py | 587 ++++++++++++++++++ python/sglang/srt/layers/logits_processor.py | 8 +- .../sglang/srt/model_executor/model_runner.py | 4 +- python/sglang/srt/models/grok.py | 444 ++----------- python/sglang/srt/models/mixtral.py | 1 - python/sglang/srt/utils.py | 3 +- 11 files changed, 814 insertions(+), 514 deletions(-) create mode 100644 python/sglang/srt/layers/fused_moe/__init__.py rename python/sglang/srt/layers/{ => fused_moe}/fused_moe.py (78%) create mode 100644 python/sglang/srt/layers/fused_moe/layer.py diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index 298ec11d7..652086f91 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -88,6 +88,9 @@ def main(args): for i in range(len(states)): preds.append(get_answer_value(states[i]["answer"])) + # print(f"{preds=}") + # print(f"{labels=}") + # Compute accuracy acc = np.mean(np.array(preds) == np.array(labels)) invalid = np.mean(np.array(preds) == INVALID) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index ee227849c..e500d30d1 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -221,6 +221,7 @@ def correctness_test( # Prepare inputs input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) + rank_print(f"{input_ids=}") if bench_args.cut_len > 0: # Prefill diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 64d391594..7cd8abb6f 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -14,7 +14,6 @@ limitations under the License. """Fused operators for activation layers.""" import torch -import torch.nn as nn import torch.nn.functional as F from flashinfer.activation import silu_and_mul from vllm.model_executor.custom_op import CustomOp diff --git a/python/sglang/srt/layers/fused_moe/__init__.py b/python/sglang/srt/layers/fused_moe/__init__.py new file mode 100644 index 000000000..5f7691c09 --- /dev/null +++ b/python/sglang/srt/layers/fused_moe/__init__.py @@ -0,0 +1 @@ +from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase diff --git a/python/sglang/srt/layers/fused_moe.py b/python/sglang/srt/layers/fused_moe/fused_moe.py similarity index 78% rename from python/sglang/srt/layers/fused_moe.py rename to python/sglang/srt/layers/fused_moe/fused_moe.py index c5630fa5d..717be5ce9 100644 --- a/python/sglang/srt/layers/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe/fused_moe.py @@ -1,20 +1,5 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - # Adapted from -# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1 +# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe """Fused MoE kernel.""" import functools import json @@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple import torch import triton import triton.language as tl +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -373,6 +359,31 @@ def get_default_config( return config +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, +): + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + configs = get_moe_configs(E, N, dtype) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) + return config + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -403,6 +414,41 @@ def fused_topk( return topk_weights, topk_ids +# This is used by the Deepseek-V2 model +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -425,24 +471,23 @@ def fused_experts( assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - M, _ = hidden_states.shape + num_tokens, _ = hidden_states.shape E, N, _ = w1.shape + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) - if override_config: - config = override_config - else: - # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + ) - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config( - M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None - ) + config = get_config_func(M) intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), @@ -460,56 +505,85 @@ def fused_experts( dtype=hidden_states.dtype, ) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - if inplace: - return torch.sum( + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) + + invoke_fused_moe_kernel( + curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + torch.sum( intermediate_cache3.view(*intermediate_cache3.shape), dim=1, - out=hidden_states, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + return out_hidden_states def fused_moe( @@ -521,6 +595,9 @@ def fused_moe( renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -543,6 +620,10 @@ def fused_moe( Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -556,12 +637,18 @@ def fused_moe( # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - if hasattr(ops, "topk_softmax"): - topk_weights, topk_ids = fused_topk( - hidden_states, gating_output, topk, renormalize + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, ) else: - topk_weights, topk_ids = fused_topk_v0_4_3( + topk_weights, topk_ids = fused_topk( hidden_states, gating_output, topk, renormalize ) @@ -579,33 +666,3 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, ) - - -def fused_topk_v0_4_3( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, -): - import vllm._moe_C as moe_kernels - - M, _ = hidden_states.shape - - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids diff --git a/python/sglang/srt/layers/fused_moe/layer.py b/python/sglang/srt/layers/fused_moe/layer.py new file mode 100644 index 000000000..0b17c14ff --- /dev/null +++ b/python/sglang/srt/layers/fused_moe/layer.py @@ -0,0 +1,587 @@ +# Adapted from +# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe +from abc import abstractmethod +from typing import List, Optional, Tuple + +import torch +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + return self.forward( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize, + use_grouped_topk, + num_expert_group, + topk_group, + ) + + def forward_cuda( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + ) -> torch.Tensor: + from sglang.srt.layers.fused_moe.fused_moe import fused_moe + + return fused_moe( + x, + w1, + w2, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + + def forward_cpu(self, *args, **kwargs): + raise NotImplementedError("The CPU backend currently does not support MoE.") + + def forward_tpu( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe + + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + return fused_moe(x, w1, w2, router_logits, top_k, renormalize) + + +class FusedMoE(torch.nn.Module): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + self.top_k = top_k + self.num_experts = num_experts + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = ( + UnquantizedFusedMoEMethod() + ) + else: + if isinstance(quant_config, Fp8Config): + self.quant_method = Fp8MoEMethod(quant_config) + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size_per_partition, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: int, + expert_id: int, + pre_sharded: bool, + ): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if ( + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + ): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}" + ) + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + # shard_id 0 == gate_proj / w1 + # shard_id 2 == up_proj / w3 + if shard_id == 0 or shard_id == 2: + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == 0 else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + # shard_id 1 == down_proj / w2 + else: + param_data[expert_id] = loaded_weight + # Weights + else: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.intermediate_size_per_partition + if pre_sharded: + shard = slice(None) + else: + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + + # w1, gate_proj case: Load into first shard of w13. + if shard_id == 0: + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + # w3, up_proj case: Load into second shard of w13. + elif shard_id == 2: + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] + # w2, down_proj case: Load into only shard of w2. + elif shard_id == 1: + param_data[expert_id, :, :] = loaded_weight[:, shard] + else: + raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + ) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, int]]: + + gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] + gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name] + + return ( + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_scale" + if weight_name in gate_up + else "experts.w2_scale" + ), + f"experts.{expert_id}.{weight_name}.weight_scale", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_weight" + if weight_name in gate_up + else "experts.w2_weight" + ), + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.a13_scale" + if weight_name in gate_up + else "experts.a2_scale" + ), + f"experts.{expert_id}.{weight_name}.input_scale", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + ) + + +import torch +from torch.nn import Module +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, + per_tensor_dequantize, +) +from vllm.utils import print_warning_once + + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_scale", w2_scale) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + a13_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like( + layer.w13_weight.data, dtype=torch.float8_e4m3fn + ) + w2_weight = torch.empty_like( + layer.w2_weight.data, dtype=torch.float8_e4m3fn + ) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts, dtype=torch.float32, device=w13_weight.device + ), + requires_grad=False, + ) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :] + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.a13_scale) or not all_close_1d( + layer.a2_scale + ): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + layer.a13_scale = torch.nn.Parameter( + layer.a13_scale.max(), requires_grad=False + ) + layer.a2_scale = torch.nn.Parameter( + layer.a2_scale.max(), requires_grad=False + ) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + + from sglang.srt.layers.fused_moe.fused_moe import fused_moe + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index cf5045fda..541fa0f15 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module): last_logits = last_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): - last_logits /= self.config.final_logit_softcapping + last_logits.div_(self.config.final_logit_softcapping) last_logits = torch.tanh(last_logits) - last_logits *= self.config.final_logit_softcapping + last_logits.mul_(self.config.final_logit_softcapping) # Return only last_logits if logprob is not requested if not logits_metadata.return_logprob: @@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module): all_logits = all_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): - all_logits /= self.config.final_logit_softcapping + all_logits.div_(self.config.final_logit_softcapping) all_logits = torch.tanh(all_logits) - all_logits *= self.config.final_logit_softcapping + all_logits.mul_(self.config.final_logit_softcapping) all_logprobs = all_logits del all_logits, hidden_states diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 34a40c7d7..9da284da6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, is_generation_model, - is_llama3_405b_fp8, + is_llama3_405b_fp8_head_16, is_multimodal_model, monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_p2p_access_check, @@ -158,7 +158,7 @@ class ModelRunner: skip_tokenizer_init=True, ) - if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8: + if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8: # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints self.model_config.hf_config.num_key_value_heads = 8 vllm_model_config.hf_config.num_key_value_heads = 8 diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 13d4330d4..eff746f1d 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -16,20 +16,17 @@ limitations under the License. # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" +import warnings from typing import Iterable, List, Optional, Tuple -import numpy as np import torch import torch.nn.functional as F -import tqdm from torch import nn from transformers import PretrainedConfig -from vllm import _custom_ops as ops from vllm.config import CacheConfig from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.linear import ( QKVParallelLinear, @@ -37,7 +34,6 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -45,141 +41,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ) from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import print_warning_once -from sglang.srt.layers.fused_moe import fused_moe +from sglang.srt.layers.fused_moe import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata -use_fused = True - - -class Grok1MLP(nn.Module): - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - self.w2 = ReplicatedLinear( - self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config - ) - self.w3 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - - self.act_fn = nn.GELU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class Grok1MoEUnfused(nn.Module): - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}." - ) - # Split experts equally between ranks - self.expert_indicies = np.array_split( - range(self.num_total_experts), self.tp_size - )[self.rank].tolist() - if not self.expert_indicies: - raise ValueError(f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList( - [ - ( - Grok1MLP( - self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config, - ) - if idx in self.expert_indicies - else None - ) - for idx in range(self.num_total_experts) - ] - ) - self.gate = ReplicatedLinear( - config.hidden_size, self.num_total_experts, bias=False, quant_config=None - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - router_logits, _ = self.gate(hidden_states) - router_logits = 30 * F.tanh(router_logits / 30) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - routing_weights = routing_weights.to(hidden_states.dtype) - hidden_dim = hidden_states.shape[1] - - final_hidden_states = torch.zeros( - (hidden_states.shape[0], hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.num_total_experts - ).permute(2, 1, 0) - - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - - if top_x.shape[0] == 0: - continue - - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = ( - expert_layer(current_state) - * routing_weights[top_x_list, idx_list, None] - ) - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states) - class Grok1MoE(nn.Module): """A tensor-parallel MoE implementation for Grok1 that shards each expert @@ -197,221 +65,42 @@ class Grok1MoE(nn.Module): hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, ): super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size - self.quant_config = quant_config - - # FIXME(pcmoritz): Make this more general to support different - # quantization schemes - self.use_fp8 = isinstance(quant_config, Fp8Config) - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear( - self.hidden_size, - self.num_total_experts, + hidden_size, + num_experts, bias=False, - params_dtype=self.params_dtype, + params_dtype=params_dtype, quant_config=None, ) - if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - self.w13_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - dtype=params_dtype, - ) + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, ) - self.w2_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - self.hidden_size, - self.intermediate_size, - dtype=params_dtype, - ) - ) - - set_weight_attrs( - self.w13_weight, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_weight, - { - "weight_loader": self.weight_loader, - }, - ) - - # Used for fp8. - self.w13_scale = None - self.w2_scale = None - self.a13_scale = None - self.a2_scale = None - - if self.use_fp8: - # WEIGHT_SCALE (for fp8) - self.w13_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.w2_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs( - self.w13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - # ACT_SCALE (for fp8) - if quant_config.activation_scheme == "static": - if not quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - self.a13_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.a2_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - set_weight_attrs( - self.a13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.a2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - expert_id: int, - pre_sharded: bool, - ): - param_data = param.data - shard_size = self.intermediate_size - if pre_sharded: - # The weight is already sharded. Readl the full shard - shard = slice(None) - else: - tp_rank = get_tensor_model_parallel_rank() - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ - shard, : - ] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - if "act_scale" in weight_name or "weight_scale" in weight_name: - param_data[expert_id] = loaded_weight - - def process_weights_after_loading(self): - # Fp8 is the only case where we need to process after loading. - if not self.use_fp8: - return - - # If checkpoint is fp16, quantize here. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like( - self.w13_weight.data, dtype=torch.float8_e4m3fn - ) - w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn) - for expert in range(self.num_total_experts): - w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant( - self.w13_weight.data[expert, :, :] - ) - w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant( - self.w2_weight.data[expert, :, :] - ) - self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) - self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) - - # If checkpoint is fp8 + static, cleanup act_scales. - # Since state_dict has an act_scale per expert but our kernels - # are passed one act_scale shared across all experts. - elif self.quant_config.activation_scheme == "static": - if self.a13_scale is None or self.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - - if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale): - print_warning_once( - "Found act_scales that are not equal for fp8 MoE layer. " - "Using the maximum across experts for each layer. " - ) - - self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False) - self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe( - hidden_states, - self.w13_weight, - self.w2_weight, - router_logits, - self.top_k, - renormalize=False, - inplace=True, - use_fp8=self.use_fp8, - w1_scale=self.w13_scale, - w2_scale=self.w2_scale, - a1_scale=self.a13_scale, - a2_scale=self.a2_scale, - ) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) + router_logits = 30.0 * F.tanh(router_logits / 30.0) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) class Grok1Attention(nn.Module): @@ -478,6 +167,7 @@ class Grok1Attention(nn.Module): layer_id=layer_id, logit_cap=logit_cap, ) + # TODO(lianmin): load logit cap from config def forward( self, @@ -502,7 +192,7 @@ class Grok1DecoderLayer(nn.Module): ) -> None: super().__init__() self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = Grok1Attention( hidden_size=self.hidden_size, @@ -513,18 +203,13 @@ class Grok1DecoderLayer(nn.Module): rope_theta=rope_theta, quant_config=quant_config, ) - if use_fused: - self.block_sparse_moe = Grok1MoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - ) - else: - self.block_sparse_moe = Grok1MoEUnfused( - config=config, quant_config=quant_config - ) + self.block_sparse_moe = Grok1MoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -536,6 +221,7 @@ class Grok1DecoderLayer(nn.Module): hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: + # Self Attention hidden_states = ( self.post_attn_norm( self.self_attn( @@ -547,11 +233,11 @@ class Grok1DecoderLayer(nn.Module): + hidden_states ) + # Fully Connected hidden_states = ( self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states ) - return hidden_states @@ -593,7 +279,6 @@ class Grok1Model(nn.Module): for i in range(len(self.layers)): hidden_states = self.layers[i](positions, hidden_states, input_metadata) - hidden_states = self.norm(hidden_states) hidden_states.mul_(self.config.output_multiplier_scale) return hidden_states @@ -615,8 +300,8 @@ class Grok1ModelForCausalLM(nn.Module): # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) + warnings.filterwarnings("ignore", category=FutureWarning) - @torch.no_grad() def forward( self, input_ids: torch.Tensor, @@ -637,50 +322,17 @@ class Grok1ModelForCausalLM(nn.Module): ("qkv_proj", "v_proj", "v"), ] - if use_fused: - expert_params_mapping = ( - [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ( - "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ( - "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ( - "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.act_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - ) - else: - expert_params_mapping = [] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) params_dict = dict(self.named_parameters()) - if get_tensor_model_parallel_rank() == 0: - weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4)) for name, loaded_weight in weights: - # print(get_tensor_model_parallel_rank(), name) if "rotary_emb.inv_freq" in name: continue @@ -691,21 +343,25 @@ class Grok1ModelForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, weight_name, + shard_id=shard_id, expert_id=expert_id, pre_sharded=get_tensor_model_parallel_world_size() > 1, ) @@ -714,6 +370,9 @@ class Grok1ModelForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name is None: + continue + param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader @@ -721,11 +380,6 @@ class Grok1ModelForCausalLM(nn.Module): weight_loader(param, loaded_weight) -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) - - old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index d11f6c951..45de85d87 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2d20881c8..9761c851a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -35,7 +35,6 @@ import torch import torch.distributed as dist from fastapi.responses import JSONResponse from packaging import version as pkg_version -from starlette.middleware.base import BaseHTTPMiddleware from torch.nn.parameter import Parameter from triton.runtime.cache import ( FileCacheManager, @@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535): logger.warn(f"Fail to set RLIMIT_NOFILE: {e}") -def is_llama3_405b_fp8(model_config): +def is_llama3_405b_fp8_head_16(model_config): """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads.""" if ( model_config.hf_config.architectures[0] == "LlamaForCausalLM"