diff --git a/3rdparty/amd/tuning/benchmark_moe_rocm.py b/3rdparty/amd/tuning/benchmark_moe_rocm.py index 9b30d8d02..a3f26e8e5 100644 --- a/3rdparty/amd/tuning/benchmark_moe_rocm.py +++ b/3rdparty/amd/tuning/benchmark_moe_rocm.py @@ -10,7 +10,7 @@ import triton.language as tl from tqdm import tqdm from transformers import AutoConfig -from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe, get_config_file_name +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe, get_config_file_name padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 diff --git a/python/sglang/srt/layers/fused_moe_grok/__init__.py b/python/sglang/srt/layers/fused_moe_grok/__init__.py deleted file mode 100644 index c915c960d..000000000 --- a/python/sglang/srt/layers/fused_moe_grok/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from sglang.srt.layers.fused_moe_grok.layer import FusedMoE, FusedMoEMethodBase diff --git a/python/sglang/srt/layers/fused_moe_grok/fused_moe.py b/python/sglang/srt/layers/fused_moe_grok/fused_moe.py deleted file mode 100644 index 4d1c98c23..000000000 --- a/python/sglang/srt/layers/fused_moe_grok/fused_moe.py +++ /dev/null @@ -1,692 +0,0 @@ -# Adapted from -# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe -"""Fused MoE kernel.""" -import functools -import json -import os -from typing import Any, Dict, Optional, Tuple - -import torch -import triton -import triton.language as tl -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.logger import init_logger - -logger = init_logger(__name__) -padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 - - -@triton.jit -def fused_moe_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - a_scale_ptr, - b_scale_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - use_fp8: tl.constexpr, - even_Ks: tl.constexpr, -): - """ - Implements the fused computation for a Mixture of Experts (MOE) using - token and expert matrices. - - Key Parameters: - - A: The input tensor representing tokens with shape (*, K), where '*' can - be any shape representing batches and K is the feature dimension of - each token. - - B: The stacked MOE weight tensor with shape (E, N, K), where E is - the number of experts, K is the input feature dimension, and N is - the output feature dimension. - - C: The output cache tensor with shape (M, topk, N), where M is the - total number of tokens post padding, topk is the number of times - each token is repeated, and N is the output feature dimension. - - sorted_token_ids: A tensor containing the sorted indices of tokens, - repeated topk times and arranged by the expert index they are - assigned to. - - expert_ids: A tensor containing the indices of the expert for each - block. It determines which expert matrix from B should be used for - each block in A. - This kernel performs the multiplication of a token by its corresponding - expert matrix as determined by `expert_ids`. The sorting of - `sorted_token_ids` by expert index and padding ensures divisibility by - BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix - multiplication across different blocks processed by the same expert. - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) - if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: - return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) - token_mask = offs_token < num_valid_tokens - - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak - ) - - off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = ( - b_ptr - + off_experts * stride_be - + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - ) - - if use_fp8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the - # K dimension. - if even_Ks: - a = tl.load( - a_ptrs, - mask=token_mask[:, None], - other=0.0, - ) - b = tl.load(b_ptrs) - else: - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - - # We accumulate along the K dimension. - if use_fp8: - accumulator = tl.dot(a, b, acc=accumulator) - else: - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) - accumulator = accumulator * moe_weight[:, None] - - if use_fp8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) - else: - accumulator = accumulator.to(compute_type) - # ----------------------------------------------------------- - # Write back the block of the output - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] - c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, num_experts: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Aligns the token distribution across experts to be compatible with block - size for matrix multiplication. - - Parameters: - - topk_ids: A tensor of shape [total_tokens, top_k] representing the - top-k expert indices for each token. - - block_size: The block size used in block matrix multiplication. - - num_experts: The total number of experts. - - Returns: - - sorted_token_ids: A tensor containing the sorted token indices according - to their allocated expert. - - expert_ids: A tensor indicating the assigned expert index for each block. - - num_tokens_post_padded: The total number of tokens after padding, - ensuring divisibility by block_size. - - This function pads the number of tokens that each expert needs to process - so that it is divisible by block_size. - Padding ensures that during block matrix multiplication, the dimensions - align correctly. - - Example: - Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], - block_size = 4, and num_experts = 4: - - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, - with each expert needing to process 3 tokens. - - As block_size is 4, we pad 1 token for each expert. - - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - - Then append padding tokens [12, 12, 12, 12] for each block. - - After sorting by expert index, we obtain token_ids - [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. - Tokens 12 are non-existent (padding) and are ignored in - the subsequent matrix multiplication. - - The padding ensures that the total number of tokens is now divisible - by block_size for proper block matrix operations. - """ - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty( - (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device - ) - sorted_ids.fill_(topk_ids.numel()) - max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty( - (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) - return sorted_ids, expert_ids, num_tokens_post_pad - - -def invoke_fused_moe_kernel( - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: Dict[str, Any], - compute_type: tl.dtype, - use_fp8: bool, -) -> None: - assert topk_weights.stride(1) == 1 - assert sorted_token_ids.stride(0) == 1 - - padded_size = padding_size - if not use_fp8: - assert A_scale is None - assert B_scale is None - # MOE_PADDING FP8 only - padded_size = 0 - else: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) - assert B_scale is not None - - grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), - ) - - K = B.shape[2] - padded_size - if K % config["BLOCK_SIZE_K"] == 0: - even_ks = True - else: - even_ks = False - - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2] - padded_size, - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8=use_fp8, - even_Ks=even_ks, - **config, - ) - - -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: - device_name = torch.cuda.get_device_name().replace(" ", "_") - dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" - - -@functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: - """ - Return optimized configurations for the fused MoE kernel. - - The return value will be a dictionary that maps an irregular grid of - batch sizes to configurations of the fused_moe kernel. To evaluate the - kernel on a given batch size bs, the closest batch size in the grid should - be picked and the associated configuration chosen to invoke the kernel. - """ - - # First look up if an optimized configuration is available in the configs - # directory - json_file_name = get_config_file_name(E, N, dtype) - - config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name - ) - if os.path.exists(config_file_path): - with open(config_file_path) as f: - logger.info("Using configuration from %s for MoE layer.", config_file_path) - # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} - - # If no optimized configuration is available, we will use the default - # configuration - return None - - -def get_default_config( - M: int, - E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], -) -> Dict[str, int]: - if dtype == "float8": - config = { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4, - } - if M <= E: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4, - } - else: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - if M <= E: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } - return config - - -def try_get_optimal_moe_config( - w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, Any]] = None, -): - if override_config: - config = override_config - else: - # First try to load optimal config from the file - E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) - - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) - return config - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - ops.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids - - -# This is used by the Deepseek-V2 model -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -): - - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids - - -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -): - padded_size = padding_size - if not use_fp8: - # MOE_PADDING FP8 only - padded_size = 0 - # Check constraints. - assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch" - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape - # We execute the fused_moe kernel in chunks to circumvent this issue: - # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - M = min(num_tokens, CHUNK_SIZE) - - get_config_func = functools.partial( - try_get_optimal_moe_config, - w1.shape, - (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - ) - - config = get_config_func(M) - - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache3 = torch.empty( - (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - - if inplace: - out_hidden_states = hidden_states - else: - out_hidden_states = torch.empty_like(hidden_states) - - for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = ( - chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, num_tokens), - ) - curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape - - if tokens_in_chunk == 0: - break - - if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - # Adjust the intermediate cache size and config for the last - # chunk. Note that in most cases we only have one chunk - # so the cache size and config are already set correctly and - # do not need to be adjusted. - intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] - intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] - config = get_config_func(tokens_in_chunk) - - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - curr_topk_ids, config["BLOCK_SIZE_M"], E - ) - - invoke_fused_moe_kernel( - curr_hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - curr_topk_weights, - curr_topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - curr_topk_weights, - curr_topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) - return out_hidden_states - - -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - num_expert_group: Optional[int]: additional parameter for grouped_topk - - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk - note: Deepseekv2 model uses grouped_topk - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states, - gating_output, - topk, - renormalize, - num_expert_group, - topk_group, - ) - else: - topk_weights, topk_ids = fused_topk( - hidden_states, gating_output, topk, renormalize - ) - - return fused_experts( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - override_config=override_config, - use_fp8=use_fp8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - ) diff --git a/python/sglang/srt/layers/fused_moe_grok/layer.py b/python/sglang/srt/layers/fused_moe_grok/layer.py deleted file mode 100644 index 89cc33d11..000000000 --- a/python/sglang/srt/layers/fused_moe_grok/layer.py +++ /dev/null @@ -1,630 +0,0 @@ -# Adapted from -# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe -import os -from abc import abstractmethod -from typing import List, Optional, Tuple - -import torch -import torch.nn.functional as F -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) -from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) -from vllm.model_executor.layers.quantization.fp8 import Fp8Config -from vllm.model_executor.utils import set_weight_attrs - -from sglang.srt.layers.fused_moe_grok.fused_moe import padding_size -from sglang.srt.utils import is_hip - -logger = init_logger(__name__) - - -class FusedMoEMethodBase(QuantizeMethodBase): - - @abstractmethod - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - raise NotImplementedError - - @abstractmethod - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - ) -> torch.Tensor: - raise NotImplementedError - - -class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): - """MoE method without quantization.""" - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - - # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - # down_proj (row parallel) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, hidden_size, intermediate_size, dtype=params_dtype - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - ) -> torch.Tensor: - return self.forward( - x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize, - use_grouped_topk, - num_expert_group, - topk_group, - ) - - def forward_cuda( - self, - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - num_expert_group: Optional[int], - topk_group: Optional[int], - ) -> torch.Tensor: - from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe - - return fused_moe( - x, - w1, - w2, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) - - def forward_cpu(self, *args, **kwargs): - raise NotImplementedError("The CPU backend currently does not support MoE.") - - def forward_tpu( - self, - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - num_expert_group: Optional[int], - topk_group: Optional[int], - ) -> torch.Tensor: - raise NotImplementedError("The TPU backend currently does not support MoE.") - - -class FusedMoE(torch.nn.Module): - """FusedMoE layer for MoE models. - - This layer contains both MergedColumnParallel weights (gate_up_proj / - w13) and RowParallelLinear weights (down_proj/ w2). - - Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We - copy that naming convention here and handle any remapping in the - load_weights function in each model implementation. - - Args: - num_experts: Number of experts in the model - top_k: Number of experts selected for each token - hidden_size: Input hidden state size of the transformer - intermediate_size: Intermediate size of the experts - params_dtype: Data type for the parameters. - reduce_results: Whether to all all_reduce on the output of the layer - renomalize: Whether to renormalize the logits in the fused_moe kernel - quant_config: Quantization configure. - """ - - def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = False, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = "", - ): - super().__init__() - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - self.tp_size = ( - tp_size if tp_size is not None else get_tensor_model_parallel_world_size() - ) - self.top_k = top_k - self.num_experts = num_experts - self.intermediate_size_per_partition = intermediate_size // self.tp_size - self.reduce_results = reduce_results - self.renormalize = renormalize - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.topk_group = topk_group - - if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod() - ) - else: - if isinstance(quant_config, Fp8Config): - self.quant_method = Fp8MoEMethod(quant_config) - else: - self.quant_method = quant_config.get_quant_method(self, prefix) - assert self.quant_method is not None - - self.quant_method.create_weights( - layer=self, - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=self.intermediate_size_per_partition, - params_dtype=params_dtype, - weight_loader=self.weight_loader, - ) - - def weight_loader( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: int, - expert_id: int, - use_presharded_weights: bool = False, - ): - param_data = param.data - - # Input scales can be loaded directly and should be equal. - if "input_scale" in weight_name: - if ( - param_data[expert_id] != 1 - and (param_data[expert_id] - loaded_weight).abs() > 1e-5 - ): - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param_data[expert_id]} " - f"vs. {loaded_weight}" - ) - param_data[expert_id] = loaded_weight - # Weight scales - elif "weight_scale" in weight_name: - # If we are in merged column case (gate_up_proj) - # shard_id 0 == gate_proj / w1 - # shard_id 2 == up_proj / w3 - if shard_id == 0 or shard_id == 2: - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == 0 else 1 - param_data[expert_id][idx] = loaded_weight - # If we are in the row parallel case (down_proj) - # shard_id 1 == down_proj / w2 - else: - param_data[expert_id] = loaded_weight - # Weights - else: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.intermediate_size_per_partition - if use_presharded_weights: - shard = slice(None) - else: - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - - # w1, gate_proj case: Load into first shard of w13. - if shard_id == 0: - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - # w3, up_proj case: Load into second shard of w13. - elif shard_id == 2: - param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ - shard, : - ] - # w2, down_proj case: Load into only shard of w2. - elif shard_id == 1: - param_data[expert_id, :, :] = loaded_weight[:, shard] - else: - raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") - - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - assert self.quant_method is not None - - # Matrix multiply. - final_hidden_states = self.quant_method.apply( - self, - x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - num_expert_group=self.num_expert_group, - topk_group=self.topk_group, - ) - - if self.reduce_results and self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states - - @classmethod - def make_expert_params_mapping( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int, - ) -> List[Tuple[str, str, int, int]]: - - gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] - gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name] - - return ( - [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ( - ( - "experts.w13_scale" - if weight_name in gate_up - else "experts.w2_scale" - ), - f"experts.{expert_id}.{weight_name}.weight_scale", - expert_id, - shard_id, - ) - for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) - ] - + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id, shard_id) - ( - ( - "experts.w13_weight" - if weight_name in gate_up - else "experts.w2_weight" - ), - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - shard_id, - ) - for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) - ] - + [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ( - ( - "experts.a13_scale" - if weight_name in gate_up - else "experts.a2_scale" - ), - f"experts.{expert_id}.{weight_name}.input_scale", - expert_id, - shard_id, - ) - for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) - ] - ) - - -import torch -from torch.nn import Module -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, - normalize_e4m3fn_to_e4m3fnuz, - per_tensor_dequantize, -) -from vllm.utils import print_warning_once - - -class Fp8MoEMethod(FusedMoEMethodBase): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. - - Also supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. - - Args: - quant_config: The quantization config. - """ - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - - def create_weights( - self, - layer: Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, hidden_size, intermediate_size, dtype=params_dtype - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_scale", w13_scale) - - w2_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_scale", w2_scale) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - - a13_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("a13_scale", a13_scale) - set_weight_attrs(a13_scale, extra_weight_attrs) - - a2_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("a2_scale", a2_scale) - set_weight_attrs(a2_scale, extra_weight_attrs) - else: - layer.a13_scale = None - layer.a2_scale = None - - def process_weights_after_loading(self, layer: Module) -> None: - - # If checkpoint is fp16 or bfloat16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: - # If ROCm, use float8_e4m3fnuz instead (MI300x HW) - fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn - w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter( - torch.ones( - layer.num_experts, dtype=torch.float32, device=w13_weight.device - ), - requires_grad=False, - ) - for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :] - ) - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - - # If ROCm, apply weight padding (min. Mem channel contention) only if set - if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - return - - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. - else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if self.quant_config.activation_scheme == "static": - if layer.a13_scale is None or layer.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - if not all_close_1d(layer.a13_scale) or not all_close_1d( - layer.a2_scale - ): - print_warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. " - ) - layer.a13_scale = torch.nn.Parameter( - layer.a13_scale.max(), requires_grad=False - ) - layer.a2_scale = torch.nn.Parameter( - layer.a2_scale.max(), requires_grad=False - ) - - # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): - # Normalize the weights and scales - w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_scale, layer.a13_scale - ) - w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_scale, layer.a2_scale - ) - # Reset the parameters - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False) - if a13_scale is not None: - layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False) - if a2_scale is not None: - layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_scale.max(dim=1).values - for expert_id in range(layer.num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_scale[expert_id][shard_id], - ) - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - ) - start += shard_size - - layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) - # If ROCm, apply weight padding (min. Mem channel contention) only if set - if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - return - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - ) -> torch.Tensor: - - from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe - - return fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) diff --git a/python/sglang/srt/layers/fused_moe_grok/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_grok/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json rename to python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json diff --git a/python/sglang/srt/layers/fused_moe_grok/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_grok/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json rename to python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index f8326c72d..1e49eb59a 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -16,22 +16,17 @@ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" -import warnings -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.layers.fused_moe_grok import FusedMoE +from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -41,10 +36,12 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -293,17 +290,11 @@ class Grok1ForCausalLM(nn.Module): super().__init__() self.config = config self.quant_config = quant_config + self.torchao_config = global_server_args_dict["torchao_config"] self.model = Grok1Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - # Monkey patch _prepare_weights to load pre-sharded weights - setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) - - self.use_presharded_weights = True - - warnings.filterwarnings("ignore", category=FutureWarning) - def forward( self, input_ids: torch.Tensor, @@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module): continue name = name.replace(weight_name, param_name) - if self.use_presharded_weights: - extra_kwargs = { - "use_presharded_weights": self.use_presharded_weights - } - else: - extra_kwargs = {} - param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id, - **extra_kwargs, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading kv_scale from ckpts towards new design. + if name.endswith(".kv_scale") and name not in params_dict: + continue if name is None: continue @@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) - -old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") - - -def _prepare_presharded_weights( - self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool -) -> Tuple[str, List[str], bool]: - import glob - import os - - if get_tensor_model_parallel_world_size() == 1: - return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt) - - tp_rank = get_tensor_model_parallel_rank() - allow_patterns = [f"*-{tp_rank:03d}.bin"] - - hf_folder = model_name_or_path - - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) - use_safetensors = False - - return hf_folder, hf_weights_files, use_safetensors + apply_torchao_config_(self, params_dict, set(["proj.weight"])) class Grok1ModelForCausalLM(Grok1ForCausalLM):