Remove fused_moe_grok (#2223)
This commit is contained in:
2
3rdparty/amd/tuning/benchmark_moe_rocm.py
vendored
2
3rdparty/amd/tuning/benchmark_moe_rocm.py
vendored
@@ -10,7 +10,7 @@ import triton.language as tl
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe, get_config_file_name
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe, get_config_file_name
|
||||||
|
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||||
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
from sglang.srt.layers.fused_moe_grok.layer import FusedMoE, FusedMoEMethodBase
|
|
||||||
@@ -1,692 +0,0 @@
|
|||||||
# Adapted from
|
|
||||||
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
|
|
||||||
"""Fused MoE kernel."""
|
|
||||||
import functools
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def fused_moe_kernel(
|
|
||||||
# Pointers to matrices
|
|
||||||
a_ptr,
|
|
||||||
b_ptr,
|
|
||||||
c_ptr,
|
|
||||||
a_scale_ptr,
|
|
||||||
b_scale_ptr,
|
|
||||||
topk_weights_ptr,
|
|
||||||
sorted_token_ids_ptr,
|
|
||||||
expert_ids_ptr,
|
|
||||||
num_tokens_post_padded_ptr,
|
|
||||||
# Matrix dimensions
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
EM,
|
|
||||||
num_valid_tokens,
|
|
||||||
# The stride variables represent how much to increase the ptr by when
|
|
||||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
|
||||||
# how much to increase `a_ptr` by to get the element one row down
|
|
||||||
# (A has M rows).
|
|
||||||
stride_am,
|
|
||||||
stride_ak,
|
|
||||||
stride_be,
|
|
||||||
stride_bk,
|
|
||||||
stride_bn,
|
|
||||||
stride_cm,
|
|
||||||
stride_cn,
|
|
||||||
# Meta-parameters
|
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
|
||||||
GROUP_SIZE_M: tl.constexpr,
|
|
||||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
|
||||||
top_k: tl.constexpr,
|
|
||||||
compute_type: tl.constexpr,
|
|
||||||
use_fp8: tl.constexpr,
|
|
||||||
even_Ks: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
|
||||||
token and expert matrices.
|
|
||||||
|
|
||||||
Key Parameters:
|
|
||||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
|
||||||
be any shape representing batches and K is the feature dimension of
|
|
||||||
each token.
|
|
||||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
|
||||||
the number of experts, K is the input feature dimension, and N is
|
|
||||||
the output feature dimension.
|
|
||||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
|
||||||
total number of tokens post padding, topk is the number of times
|
|
||||||
each token is repeated, and N is the output feature dimension.
|
|
||||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
|
||||||
repeated topk times and arranged by the expert index they are
|
|
||||||
assigned to.
|
|
||||||
- expert_ids: A tensor containing the indices of the expert for each
|
|
||||||
block. It determines which expert matrix from B should be used for
|
|
||||||
each block in A.
|
|
||||||
This kernel performs the multiplication of a token by its corresponding
|
|
||||||
expert matrix as determined by `expert_ids`. The sorting of
|
|
||||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
|
||||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
|
||||||
multiplication across different blocks processed by the same expert.
|
|
||||||
"""
|
|
||||||
# -----------------------------------------------------------
|
|
||||||
# Map program ids `pid` to the block of C it should compute.
|
|
||||||
# This is done in a grouped ordering to promote L2 data reuse.
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
|
||||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
||||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
||||||
group_id = pid // num_pid_in_group
|
|
||||||
first_pid_m = group_id * GROUP_SIZE_M
|
|
||||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
||||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
||||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
||||||
|
|
||||||
# ----------------------------------------------------------
|
|
||||||
# Create pointers for the first blocks of A and B.
|
|
||||||
# We will advance this pointer as we move in the K direction
|
|
||||||
# and accumulate
|
|
||||||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
|
||||||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
|
||||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
|
||||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
|
||||||
return
|
|
||||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
||||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
|
||||||
token_mask = offs_token < num_valid_tokens
|
|
||||||
|
|
||||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
||||||
a_ptrs = a_ptr + (
|
|
||||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
|
||||||
)
|
|
||||||
|
|
||||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
|
||||||
b_ptrs = (
|
|
||||||
b_ptr
|
|
||||||
+ off_experts * stride_be
|
|
||||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_fp8:
|
|
||||||
a_scale = tl.load(a_scale_ptr)
|
|
||||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
|
||||||
# Iterate to compute a block of the C matrix.
|
|
||||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
|
||||||
# of fp32 values for higher accuracy.
|
|
||||||
# `accumulator` will be converted back to fp16 after the loop.
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
||||||
# Load the next block of A and B, generate a mask by checking the
|
|
||||||
# K dimension.
|
|
||||||
if even_Ks:
|
|
||||||
a = tl.load(
|
|
||||||
a_ptrs,
|
|
||||||
mask=token_mask[:, None],
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
b = tl.load(b_ptrs)
|
|
||||||
else:
|
|
||||||
a = tl.load(
|
|
||||||
a_ptrs,
|
|
||||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
||||||
|
|
||||||
# We accumulate along the K dimension.
|
|
||||||
if use_fp8:
|
|
||||||
accumulator = tl.dot(a, b, acc=accumulator)
|
|
||||||
else:
|
|
||||||
accumulator += tl.dot(a, b)
|
|
||||||
# Advance the ptrs to the next K block.
|
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
||||||
|
|
||||||
if MUL_ROUTED_WEIGHT:
|
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
|
||||||
|
|
||||||
if use_fp8:
|
|
||||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
|
||||||
else:
|
|
||||||
accumulator = accumulator.to(compute_type)
|
|
||||||
# -----------------------------------------------------------
|
|
||||||
# Write back the block of the output
|
|
||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
||||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
|
||||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
|
||||||
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Aligns the token distribution across experts to be compatible with block
|
|
||||||
size for matrix multiplication.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
|
||||||
top-k expert indices for each token.
|
|
||||||
- block_size: The block size used in block matrix multiplication.
|
|
||||||
- num_experts: The total number of experts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
|
||||||
to their allocated expert.
|
|
||||||
- expert_ids: A tensor indicating the assigned expert index for each block.
|
|
||||||
- num_tokens_post_padded: The total number of tokens after padding,
|
|
||||||
ensuring divisibility by block_size.
|
|
||||||
|
|
||||||
This function pads the number of tokens that each expert needs to process
|
|
||||||
so that it is divisible by block_size.
|
|
||||||
Padding ensures that during block matrix multiplication, the dimensions
|
|
||||||
align correctly.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
|
||||||
block_size = 4, and num_experts = 4:
|
|
||||||
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
|
||||||
with each expert needing to process 3 tokens.
|
|
||||||
- As block_size is 4, we pad 1 token for each expert.
|
|
||||||
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
|
||||||
- Then append padding tokens [12, 12, 12, 12] for each block.
|
|
||||||
- After sorting by expert index, we obtain token_ids
|
|
||||||
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
|
||||||
Tokens 12 are non-existent (padding) and are ignored in
|
|
||||||
the subsequent matrix multiplication.
|
|
||||||
- The padding ensures that the total number of tokens is now divisible
|
|
||||||
by block_size for proper block matrix operations.
|
|
||||||
"""
|
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
|
||||||
sorted_ids = torch.empty(
|
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
|
||||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
|
||||||
expert_ids = torch.empty(
|
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
|
||||||
ops.moe_align_block_size(
|
|
||||||
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
|
||||||
)
|
|
||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
|
||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_kernel(
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
C: torch.Tensor,
|
|
||||||
A_scale: Optional[torch.Tensor],
|
|
||||||
B_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
sorted_token_ids: torch.Tensor,
|
|
||||||
expert_ids: torch.Tensor,
|
|
||||||
num_tokens_post_padded: torch.Tensor,
|
|
||||||
mul_routed_weight: bool,
|
|
||||||
top_k: int,
|
|
||||||
config: Dict[str, Any],
|
|
||||||
compute_type: tl.dtype,
|
|
||||||
use_fp8: bool,
|
|
||||||
) -> None:
|
|
||||||
assert topk_weights.stride(1) == 1
|
|
||||||
assert sorted_token_ids.stride(0) == 1
|
|
||||||
|
|
||||||
padded_size = padding_size
|
|
||||||
if not use_fp8:
|
|
||||||
assert A_scale is None
|
|
||||||
assert B_scale is None
|
|
||||||
# MOE_PADDING FP8 only
|
|
||||||
padded_size = 0
|
|
||||||
else:
|
|
||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
|
||||||
assert B_scale is not None
|
|
||||||
|
|
||||||
grid = lambda META: (
|
|
||||||
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
|
||||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
K = B.shape[2] - padded_size
|
|
||||||
if K % config["BLOCK_SIZE_K"] == 0:
|
|
||||||
even_ks = True
|
|
||||||
else:
|
|
||||||
even_ks = False
|
|
||||||
|
|
||||||
fused_moe_kernel[grid](
|
|
||||||
A,
|
|
||||||
B,
|
|
||||||
C,
|
|
||||||
A_scale,
|
|
||||||
B_scale,
|
|
||||||
topk_weights,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
B.shape[1],
|
|
||||||
B.shape[2] - padded_size,
|
|
||||||
sorted_token_ids.shape[0],
|
|
||||||
topk_ids.numel(),
|
|
||||||
A.stride(0),
|
|
||||||
A.stride(1),
|
|
||||||
B.stride(0),
|
|
||||||
B.stride(2),
|
|
||||||
B.stride(1),
|
|
||||||
C.stride(1),
|
|
||||||
C.stride(2),
|
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
|
||||||
top_k=top_k,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
even_Ks=even_ks,
|
|
||||||
**config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
|
||||||
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
|
||||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
|
||||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
|
||||||
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
|
||||||
"""
|
|
||||||
Return optimized configurations for the fused MoE kernel.
|
|
||||||
|
|
||||||
The return value will be a dictionary that maps an irregular grid of
|
|
||||||
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
|
||||||
kernel on a given batch size bs, the closest batch size in the grid should
|
|
||||||
be picked and the associated configuration chosen to invoke the kernel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# First look up if an optimized configuration is available in the configs
|
|
||||||
# directory
|
|
||||||
json_file_name = get_config_file_name(E, N, dtype)
|
|
||||||
|
|
||||||
config_file_path = os.path.join(
|
|
||||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
|
||||||
)
|
|
||||||
if os.path.exists(config_file_path):
|
|
||||||
with open(config_file_path) as f:
|
|
||||||
logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
|
||||||
# If a configuration has been found, return it
|
|
||||||
return {int(key): val for key, val in json.load(f).items()}
|
|
||||||
|
|
||||||
# If no optimized configuration is available, we will use the default
|
|
||||||
# configuration
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_config(
|
|
||||||
M: int,
|
|
||||||
E: int,
|
|
||||||
N: int,
|
|
||||||
K: int,
|
|
||||||
topk: int,
|
|
||||||
dtype: Optional[str],
|
|
||||||
) -> Dict[str, int]:
|
|
||||||
if dtype == "float8":
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 128,
|
|
||||||
"BLOCK_SIZE_N": 256,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 8,
|
|
||||||
"num_stages": 4,
|
|
||||||
}
|
|
||||||
if M <= E:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 4,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
}
|
|
||||||
if M <= E:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 32,
|
|
||||||
"BLOCK_SIZE_K": 64,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def try_get_optimal_moe_config(
|
|
||||||
w1_shape: Tuple[int, ...],
|
|
||||||
w2_shape: Tuple[int, ...],
|
|
||||||
top_k: int,
|
|
||||||
dtype: Optional[str],
|
|
||||||
M: int,
|
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
):
|
|
||||||
if override_config:
|
|
||||||
config = override_config
|
|
||||||
else:
|
|
||||||
# First try to load optimal config from the file
|
|
||||||
E, _, N = w2_shape
|
|
||||||
configs = get_moe_configs(E, N, dtype)
|
|
||||||
|
|
||||||
if configs:
|
|
||||||
# If an optimal configuration map has been found, look up the
|
|
||||||
# optimal config
|
|
||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
||||||
else:
|
|
||||||
# Else use the default config
|
|
||||||
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def fused_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
):
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
||||||
|
|
||||||
M, _ = hidden_states.shape
|
|
||||||
|
|
||||||
topk_weights = torch.empty(
|
|
||||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
|
||||||
token_expert_indicies = torch.empty(
|
|
||||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
ops.topk_softmax(
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
token_expert_indicies,
|
|
||||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
|
||||||
)
|
|
||||||
del token_expert_indicies # Not used. Will be used in the future.
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
# This is used by the Deepseek-V2 model
|
|
||||||
def grouped_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
num_expert_group: int = 0,
|
|
||||||
topk_group: int = 0,
|
|
||||||
):
|
|
||||||
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
||||||
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
|
||||||
num_token = scores.shape[0]
|
|
||||||
group_scores = (
|
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
||||||
) # [n, n_group]
|
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
||||||
1
|
|
||||||
] # [n, top_k_group]
|
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
||||||
score_mask = (
|
|
||||||
group_mask.unsqueeze(-1)
|
|
||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
||||||
.reshape(num_token, -1)
|
|
||||||
) # [n, e]
|
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
def fused_experts(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
inplace: bool = False,
|
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
use_fp8: bool = False,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
padded_size = padding_size
|
|
||||||
if not use_fp8:
|
|
||||||
# MOE_PADDING FP8 only
|
|
||||||
padded_size = 0
|
|
||||||
# Check constraints.
|
|
||||||
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
|
||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
|
||||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
|
||||||
|
|
||||||
num_tokens, _ = hidden_states.shape
|
|
||||||
E, N, _ = w1.shape
|
|
||||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
|
||||||
# https://github.com/vllm-project/vllm/issues/5938
|
|
||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
|
||||||
M = min(num_tokens, CHUNK_SIZE)
|
|
||||||
|
|
||||||
get_config_func = functools.partial(
|
|
||||||
try_get_optimal_moe_config,
|
|
||||||
w1.shape,
|
|
||||||
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
|
|
||||||
topk_ids.shape[1],
|
|
||||||
"float8" if use_fp8 else None,
|
|
||||||
override_config=override_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = get_config_func(M)
|
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], N),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache2 = torch.empty(
|
|
||||||
(M * topk_ids.shape[1], N // 2),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache3 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], w2.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
|
||||||
|
|
||||||
if inplace:
|
|
||||||
out_hidden_states = hidden_states
|
|
||||||
else:
|
|
||||||
out_hidden_states = torch.empty_like(hidden_states)
|
|
||||||
|
|
||||||
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
|
||||||
begin_chunk_idx, end_chunk_idx = (
|
|
||||||
chunk * CHUNK_SIZE,
|
|
||||||
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
|
||||||
)
|
|
||||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
|
||||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
|
||||||
|
|
||||||
if tokens_in_chunk == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
|
||||||
# Adjust the intermediate cache size and config for the last
|
|
||||||
# chunk. Note that in most cases we only have one chunk
|
|
||||||
# so the cache size and config are already set correctly and
|
|
||||||
# do not need to be adjusted.
|
|
||||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
|
||||||
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
|
||||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
|
||||||
config = get_config_func(tokens_in_chunk)
|
|
||||||
|
|
||||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
|
||||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
||||||
curr_topk_ids, config["BLOCK_SIZE_M"], E
|
|
||||||
)
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
curr_hidden_states,
|
|
||||||
w1,
|
|
||||||
intermediate_cache1,
|
|
||||||
a1_scale,
|
|
||||||
w1_scale,
|
|
||||||
curr_topk_weights,
|
|
||||||
curr_topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
False,
|
|
||||||
topk_ids.shape[1],
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
intermediate_cache2,
|
|
||||||
w2,
|
|
||||||
intermediate_cache3,
|
|
||||||
a2_scale,
|
|
||||||
w2_scale,
|
|
||||||
curr_topk_weights,
|
|
||||||
curr_topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
True,
|
|
||||||
1,
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.sum(
|
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
|
||||||
dim=1,
|
|
||||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
|
||||||
)
|
|
||||||
return out_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def fused_moe(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
inplace: bool = False,
|
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
use_fp8: bool = False,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
|
||||||
weights, w1 and w2, and top-k gating mechanism.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
|
||||||
- w1 (torch.Tensor): The first set of expert weights.
|
|
||||||
- w2 (torch.Tensor): The second set of expert weights.
|
|
||||||
- gating_output (torch.Tensor): The output of the gating operation
|
|
||||||
(before softmax).
|
|
||||||
- topk (int): The number of top-k experts to select.
|
|
||||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
|
||||||
- inplace (bool): If True, perform the operation in-place.
|
|
||||||
Defaults to False.
|
|
||||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
|
||||||
for the kernel configuration.
|
|
||||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
|
||||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
|
||||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
|
||||||
note: Deepseekv2 model uses grouped_topk
|
|
||||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
|
||||||
products for w1 and w2. Defaults to False.
|
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
||||||
w1.
|
|
||||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
||||||
w2.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
|
||||||
"""
|
|
||||||
# Check constraints.
|
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
|
||||||
|
|
||||||
if use_grouped_topk:
|
|
||||||
assert num_expert_group is not None and topk_group is not None
|
|
||||||
topk_weights, topk_ids = grouped_topk(
|
|
||||||
hidden_states,
|
|
||||||
gating_output,
|
|
||||||
topk,
|
|
||||||
renormalize,
|
|
||||||
num_expert_group,
|
|
||||||
topk_group,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = fused_topk(
|
|
||||||
hidden_states, gating_output, topk, renormalize
|
|
||||||
)
|
|
||||||
|
|
||||||
return fused_experts(
|
|
||||||
hidden_states,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
inplace=inplace,
|
|
||||||
override_config=override_config,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
)
|
|
||||||
@@ -1,630 +0,0 @@
|
|||||||
# Adapted from
|
|
||||||
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
|
|
||||||
import os
|
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_reduce,
|
|
||||||
)
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
|
||||||
QuantizationConfig,
|
|
||||||
QuantizeMethodBase,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_grok.fused_moe import padding_size
|
|
||||||
from sglang.srt.utils import is_hip
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
num_experts: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool = True,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
||||||
"""MoE method without quantization."""
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
num_experts: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
|
|
||||||
# Fused gate_up_proj (column parallel)
|
|
||||||
w13_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# down_proj (row parallel)
|
|
||||||
w2_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool = True,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return self.forward(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
router_logits,
|
|
||||||
top_k,
|
|
||||||
renormalize,
|
|
||||||
use_grouped_topk,
|
|
||||||
num_expert_group,
|
|
||||||
topk_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_cuda(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
num_expert_group: Optional[int],
|
|
||||||
topk_group: Optional[int],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
|
|
||||||
|
|
||||||
return fused_moe(
|
|
||||||
x,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
router_logits,
|
|
||||||
top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
inplace=True,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
topk_group=topk_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_cpu(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
|
||||||
|
|
||||||
def forward_tpu(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
num_expert_group: Optional[int],
|
|
||||||
topk_group: Optional[int],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
|
||||||
"""FusedMoE layer for MoE models.
|
|
||||||
|
|
||||||
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
|
||||||
w13) and RowParallelLinear weights (down_proj/ w2).
|
|
||||||
|
|
||||||
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
|
||||||
copy that naming convention here and handle any remapping in the
|
|
||||||
load_weights function in each model implementation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_experts: Number of experts in the model
|
|
||||||
top_k: Number of experts selected for each token
|
|
||||||
hidden_size: Input hidden state size of the transformer
|
|
||||||
intermediate_size: Intermediate size of the experts
|
|
||||||
params_dtype: Data type for the parameters.
|
|
||||||
reduce_results: Whether to all all_reduce on the output of the layer
|
|
||||||
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
|
||||||
quant_config: Quantization configure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_experts: int,
|
|
||||||
top_k: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
reduce_results: bool = False,
|
|
||||||
renormalize: bool = True,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
tp_size: Optional[int] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
self.tp_size = (
|
|
||||||
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
|
||||||
)
|
|
||||||
self.top_k = top_k
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
|
||||||
self.reduce_results = reduce_results
|
|
||||||
self.renormalize = renormalize
|
|
||||||
self.use_grouped_topk = use_grouped_topk
|
|
||||||
if self.use_grouped_topk:
|
|
||||||
assert num_expert_group is not None and topk_group is not None
|
|
||||||
self.num_expert_group = num_expert_group
|
|
||||||
self.topk_group = topk_group
|
|
||||||
|
|
||||||
if quant_config is None:
|
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
|
||||||
UnquantizedFusedMoEMethod()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if isinstance(quant_config, Fp8Config):
|
|
||||||
self.quant_method = Fp8MoEMethod(quant_config)
|
|
||||||
else:
|
|
||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
|
||||||
assert self.quant_method is not None
|
|
||||||
|
|
||||||
self.quant_method.create_weights(
|
|
||||||
layer=self,
|
|
||||||
num_experts=num_experts,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
intermediate_size=self.intermediate_size_per_partition,
|
|
||||||
params_dtype=params_dtype,
|
|
||||||
weight_loader=self.weight_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
def weight_loader(
|
|
||||||
self,
|
|
||||||
param: torch.nn.Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
weight_name: str,
|
|
||||||
shard_id: int,
|
|
||||||
expert_id: int,
|
|
||||||
use_presharded_weights: bool = False,
|
|
||||||
):
|
|
||||||
param_data = param.data
|
|
||||||
|
|
||||||
# Input scales can be loaded directly and should be equal.
|
|
||||||
if "input_scale" in weight_name:
|
|
||||||
if (
|
|
||||||
param_data[expert_id] != 1
|
|
||||||
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"input_scales of w1 and w3 of a layer "
|
|
||||||
f"must be equal. But got {param_data[expert_id]} "
|
|
||||||
f"vs. {loaded_weight}"
|
|
||||||
)
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
# Weight scales
|
|
||||||
elif "weight_scale" in weight_name:
|
|
||||||
# If we are in merged column case (gate_up_proj)
|
|
||||||
# shard_id 0 == gate_proj / w1
|
|
||||||
# shard_id 2 == up_proj / w3
|
|
||||||
if shard_id == 0 or shard_id == 2:
|
|
||||||
# We have to keep the weight scales of w1 and w3 because
|
|
||||||
# we need to re-quantize w1/w3 weights after weight loading.
|
|
||||||
idx = 0 if shard_id == 0 else 1
|
|
||||||
param_data[expert_id][idx] = loaded_weight
|
|
||||||
# If we are in the row parallel case (down_proj)
|
|
||||||
# shard_id 1 == down_proj / w2
|
|
||||||
else:
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
# Weights
|
|
||||||
else:
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
shard_size = self.intermediate_size_per_partition
|
|
||||||
if use_presharded_weights:
|
|
||||||
shard = slice(None)
|
|
||||||
else:
|
|
||||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
|
||||||
|
|
||||||
# w1, gate_proj case: Load into first shard of w13.
|
|
||||||
if shard_id == 0:
|
|
||||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
|
||||||
# w3, up_proj case: Load into second shard of w13.
|
|
||||||
elif shard_id == 2:
|
|
||||||
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
|
||||||
shard, :
|
|
||||||
]
|
|
||||||
# w2, down_proj case: Load into only shard of w2.
|
|
||||||
elif shard_id == 1:
|
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
|
||||||
assert self.quant_method is not None
|
|
||||||
|
|
||||||
# Matrix multiply.
|
|
||||||
final_hidden_states = self.quant_method.apply(
|
|
||||||
self,
|
|
||||||
x=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=self.top_k,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
|
||||||
num_expert_group=self.num_expert_group,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_expert_params_mapping(
|
|
||||||
cls,
|
|
||||||
ckpt_gate_proj_name: str,
|
|
||||||
ckpt_down_proj_name: str,
|
|
||||||
ckpt_up_proj_name: str,
|
|
||||||
num_experts: int,
|
|
||||||
) -> List[Tuple[str, str, int, int]]:
|
|
||||||
|
|
||||||
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
|
|
||||||
gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name]
|
|
||||||
|
|
||||||
return (
|
|
||||||
[
|
|
||||||
# These are the weight scales for the experts
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
(
|
|
||||||
(
|
|
||||||
"experts.w13_scale"
|
|
||||||
if weight_name in gate_up
|
|
||||||
else "experts.w2_scale"
|
|
||||||
),
|
|
||||||
f"experts.{expert_id}.{weight_name}.weight_scale",
|
|
||||||
expert_id,
|
|
||||||
shard_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(num_experts)
|
|
||||||
for shard_id, weight_name in enumerate(gate_down_up)
|
|
||||||
]
|
|
||||||
+ [
|
|
||||||
# These are the weights for the experts
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
(
|
|
||||||
(
|
|
||||||
"experts.w13_weight"
|
|
||||||
if weight_name in gate_up
|
|
||||||
else "experts.w2_weight"
|
|
||||||
),
|
|
||||||
f"experts.{expert_id}.{weight_name}.weight",
|
|
||||||
expert_id,
|
|
||||||
shard_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(num_experts)
|
|
||||||
for shard_id, weight_name in enumerate(gate_down_up)
|
|
||||||
]
|
|
||||||
+ [
|
|
||||||
# These are the weight scales for the experts
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
(
|
|
||||||
(
|
|
||||||
"experts.a13_scale"
|
|
||||||
if weight_name in gate_up
|
|
||||||
else "experts.a2_scale"
|
|
||||||
),
|
|
||||||
f"experts.{expert_id}.{weight_name}.input_scale",
|
|
||||||
expert_id,
|
|
||||||
shard_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(num_experts)
|
|
||||||
for shard_id, weight_name in enumerate(gate_down_up)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import Module
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
||||||
all_close_1d,
|
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
|
||||||
per_tensor_dequantize,
|
|
||||||
)
|
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
|
||||||
"""MoE method for FP8.
|
|
||||||
Supports loading FP8 checkpoints with static weight scale and
|
|
||||||
dynamic/static activation scale.
|
|
||||||
|
|
||||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
|
||||||
activation scaling. The weight scaling factor will be initialized after
|
|
||||||
the model weights are loaded.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
quant_config: The quantization config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
|
||||||
self.quant_config = quant_config
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: Module,
|
|
||||||
num_experts: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
params_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
# WEIGHTS
|
|
||||||
w13_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# WEIGHT_SCALES
|
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
|
||||||
# They will be combined to a single scale after weight loading.
|
|
||||||
w13_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_scale", w13_scale)
|
|
||||||
|
|
||||||
w2_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_scale", w2_scale)
|
|
||||||
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
|
||||||
# process_weights_after_loading()
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
|
||||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
# INPUT_SCALES
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
raise ValueError(
|
|
||||||
"Found static activation scheme for checkpoint that "
|
|
||||||
"was not serialized fp8."
|
|
||||||
)
|
|
||||||
|
|
||||||
a13_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.register_parameter("a13_scale", a13_scale)
|
|
||||||
set_weight_attrs(a13_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
a2_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.register_parameter("a2_scale", a2_scale)
|
|
||||||
set_weight_attrs(a2_scale, extra_weight_attrs)
|
|
||||||
else:
|
|
||||||
layer.a13_scale = None
|
|
||||||
layer.a2_scale = None
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
|
||||||
|
|
||||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
|
||||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
|
||||||
|
|
||||||
# Re-initialize w13_scale because we directly quantize
|
|
||||||
# merged w13 weights and generate a single scaling factor.
|
|
||||||
layer.w13_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(
|
|
||||||
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
for expert in range(layer.num_experts):
|
|
||||||
w13_weight[expert, :, :], layer.w13_scale[expert] = (
|
|
||||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant(
|
|
||||||
layer.w2_weight.data[expert, :, :]
|
|
||||||
)
|
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
|
|
||||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
|
||||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
|
||||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
|
||||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp8, we need to handle that the
|
|
||||||
# MoE kernels require single activation scale and single weight
|
|
||||||
# scale for w13 per expert.
|
|
||||||
else:
|
|
||||||
# Fp8 moe kernels require a single activation scale.
|
|
||||||
# We take the max of all the scales in case they differ.
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if layer.a13_scale is None or layer.a2_scale is None:
|
|
||||||
raise ValueError(
|
|
||||||
"QuantConfig has static quantization, but found "
|
|
||||||
"activation scales are None."
|
|
||||||
)
|
|
||||||
if not all_close_1d(layer.a13_scale) or not all_close_1d(
|
|
||||||
layer.a2_scale
|
|
||||||
):
|
|
||||||
print_warning_once(
|
|
||||||
"Found input_scales that are not equal for "
|
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
|
||||||
"for each layer. "
|
|
||||||
)
|
|
||||||
layer.a13_scale = torch.nn.Parameter(
|
|
||||||
layer.a13_scale.max(), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.a2_scale = torch.nn.Parameter(
|
|
||||||
layer.a2_scale.max(), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
|
||||||
if is_hip():
|
|
||||||
# Normalize the weights and scales
|
|
||||||
w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
layer.w13_weight, layer.w13_scale, layer.a13_scale
|
|
||||||
)
|
|
||||||
w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
layer.w2_weight, layer.w2_scale, layer.a2_scale
|
|
||||||
)
|
|
||||||
# Reset the parameters
|
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
|
||||||
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
|
|
||||||
if a13_scale is not None:
|
|
||||||
layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
|
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
|
|
||||||
if a2_scale is not None:
|
|
||||||
layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
|
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
|
||||||
# We take the max then dequant and requant each expert.
|
|
||||||
assert layer.w13_scale is not None
|
|
||||||
shard_size = layer.intermediate_size_per_partition
|
|
||||||
max_w13_scales = layer.w13_scale.max(dim=1).values
|
|
||||||
for expert_id in range(layer.num_experts):
|
|
||||||
start = 0
|
|
||||||
for shard_id in range(2):
|
|
||||||
dq_weight = per_tensor_dequantize(
|
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
|
||||||
layer.w13_scale[expert_id][shard_id],
|
|
||||||
)
|
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
|
||||||
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
|
||||||
)
|
|
||||||
start += shard_size
|
|
||||||
|
|
||||||
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
|
||||||
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
|
||||||
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
|
||||||
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
|
||||||
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool = True,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
|
|
||||||
|
|
||||||
return fused_moe(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
router_logits,
|
|
||||||
top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
inplace=True,
|
|
||||||
use_fp8=True,
|
|
||||||
w1_scale=layer.w13_scale,
|
|
||||||
w2_scale=layer.w2_scale,
|
|
||||||
a1_scale=layer.a13_scale,
|
|
||||||
a2_scale=layer.a2_scale,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
topk_group=topk_group,
|
|
||||||
)
|
|
||||||
@@ -16,22 +16,17 @@
|
|||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||||
"""Inference-only Grok1 model."""
|
"""Inference-only Grok1 model."""
|
||||||
|
|
||||||
import warnings
|
from typing import Iterable, Optional, Tuple
|
||||||
from typing import Iterable, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_grok import FusedMoE
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -41,10 +36,12 @@ from sglang.srt.layers.linear import (
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
|
||||||
@@ -293,17 +290,11 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||||
self.model = Grok1Model(config, quant_config=quant_config)
|
self.model = Grok1Model(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
|
||||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
|
||||||
|
|
||||||
self.use_presharded_weights = True
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
if self.use_presharded_weights:
|
|
||||||
extra_kwargs = {
|
|
||||||
"use_presharded_weights": self.use_presharded_weights
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
extra_kwargs = {}
|
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(
|
weight_loader(
|
||||||
param,
|
param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name,
|
name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id,
|
expert_id=expert_id,
|
||||||
**extra_kwargs,
|
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
# Skip loading kv_scale from ckpts towards new design.
|
||||||
|
if name.endswith(".kv_scale") and name not in params_dict:
|
||||||
|
continue
|
||||||
if name is None:
|
if name is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||||
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_presharded_weights(
|
|
||||||
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
|
||||||
) -> Tuple[str, List[str], bool]:
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
|
|
||||||
if get_tensor_model_parallel_world_size() == 1:
|
|
||||||
return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
|
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
allow_patterns = [f"*-{tp_rank:03d}.bin"]
|
|
||||||
|
|
||||||
hf_folder = model_name_or_path
|
|
||||||
|
|
||||||
hf_weights_files: List[str] = []
|
|
||||||
for pattern in allow_patterns:
|
|
||||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
|
||||||
use_safetensors = False
|
|
||||||
|
|
||||||
return hf_folder, hf_weights_files, use_safetensors
|
|
||||||
|
|
||||||
|
|
||||||
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
||||||
|
|||||||
Reference in New Issue
Block a user