Update grok 1 model (#1095)
This commit is contained in:
@@ -88,6 +88,9 @@ def main(args):
|
|||||||
for i in range(len(states)):
|
for i in range(len(states)):
|
||||||
preds.append(get_answer_value(states[i]["answer"]))
|
preds.append(get_answer_value(states[i]["answer"]))
|
||||||
|
|
||||||
|
# print(f"{preds=}")
|
||||||
|
# print(f"{labels=}")
|
||||||
|
|
||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
acc = np.mean(np.array(preds) == np.array(labels))
|
acc = np.mean(np.array(preds) == np.array(labels))
|
||||||
invalid = np.mean(np.array(preds) == INVALID)
|
invalid = np.mean(np.array(preds) == INVALID)
|
||||||
|
|||||||
@@ -221,6 +221,7 @@ def correctness_test(
|
|||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
||||||
|
rank_print(f"{input_ids=}")
|
||||||
|
|
||||||
if bench_args.cut_len > 0:
|
if bench_args.cut_len > 0:
|
||||||
# Prefill
|
# Prefill
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ limitations under the License.
|
|||||||
"""Fused operators for activation layers."""
|
"""Fused operators for activation layers."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from flashinfer.activation import silu_and_mul
|
from flashinfer.activation import silu_and_mul
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|||||||
1
python/sglang/srt/layers/fused_moe/__init__.py
Normal file
1
python/sglang/srt/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
|
||||||
@@ -1,20 +1,5 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1
|
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
|
||||||
"""Fused MoE kernel."""
|
"""Fused MoE kernel."""
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@@ -373,6 +359,31 @@ def get_default_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def try_get_optimal_moe_config(
|
||||||
|
w1_shape: Tuple[int, ...],
|
||||||
|
w2_shape: Tuple[int, ...],
|
||||||
|
top_k: int,
|
||||||
|
dtype: Optional[str],
|
||||||
|
M: int,
|
||||||
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
if override_config:
|
||||||
|
config = override_config
|
||||||
|
else:
|
||||||
|
# First try to load optimal config from the file
|
||||||
|
E, _, N = w2_shape
|
||||||
|
configs = get_moe_configs(E, N, dtype)
|
||||||
|
|
||||||
|
if configs:
|
||||||
|
# If an optimal configuration map has been found, look up the
|
||||||
|
# optimal config
|
||||||
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||||
|
else:
|
||||||
|
# Else use the default config
|
||||||
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def fused_topk(
|
def fused_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@@ -403,6 +414,41 @@ def fused_topk(
|
|||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
# This is used by the Deepseek-V2 model
|
||||||
|
def grouped_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_expert_group: int = 0,
|
||||||
|
topk_group: int = 0,
|
||||||
|
):
|
||||||
|
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
|
num_token = scores.shape[0]
|
||||||
|
group_scores = (
|
||||||
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
|
) # [n, n_group]
|
||||||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||||
|
1
|
||||||
|
] # [n, top_k_group]
|
||||||
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||||
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1)
|
||||||
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||||
|
.reshape(num_token, -1)
|
||||||
|
) # [n, e]
|
||||||
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
def fused_experts(
|
def fused_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@@ -425,25 +471,24 @@ def fused_experts(
|
|||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
M, _ = hidden_states.shape
|
num_tokens, _ = hidden_states.shape
|
||||||
E, N, _ = w1.shape
|
E, N, _ = w1.shape
|
||||||
|
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||||
|
# https://github.com/vllm-project/vllm/issues/5938
|
||||||
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
|
M = min(num_tokens, CHUNK_SIZE)
|
||||||
|
|
||||||
if override_config:
|
get_config_func = functools.partial(
|
||||||
config = override_config
|
try_get_optimal_moe_config,
|
||||||
else:
|
w1.shape,
|
||||||
# First try to load optimal config from the file
|
w2.shape,
|
||||||
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
topk_ids.shape[1],
|
||||||
|
"float8" if use_fp8 else None,
|
||||||
if configs:
|
override_config=override_config,
|
||||||
# If an optimal configuration map has been found, look up the
|
|
||||||
# optimal config
|
|
||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
||||||
else:
|
|
||||||
# Else use the default config
|
|
||||||
config = get_default_config(
|
|
||||||
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config = get_config_func(M)
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty(
|
intermediate_cache1 = torch.empty(
|
||||||
(M, topk_ids.shape[1], N),
|
(M, topk_ids.shape[1], N),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
@@ -460,19 +505,49 @@ def fused_experts(
|
|||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
||||||
topk_ids, config["BLOCK_SIZE_M"], E
|
|
||||||
)
|
|
||||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||||
|
|
||||||
|
if inplace:
|
||||||
|
out_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
out_hidden_states = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||||
|
begin_chunk_idx, end_chunk_idx = (
|
||||||
|
chunk * CHUNK_SIZE,
|
||||||
|
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
||||||
|
)
|
||||||
|
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||||
|
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||||
|
|
||||||
|
if tokens_in_chunk == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||||
|
# Adjust the intermediate cache size and config for the last
|
||||||
|
# chunk. Note that in most cases we only have one chunk
|
||||||
|
# so the cache size and config are already set correctly and
|
||||||
|
# do not need to be adjusted.
|
||||||
|
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||||
|
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
||||||
|
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||||
|
config = get_config_func(tokens_in_chunk)
|
||||||
|
|
||||||
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||||
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||||
|
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||||
|
curr_topk_ids, config["BLOCK_SIZE_M"], E
|
||||||
|
)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
invoke_fused_moe_kernel(
|
||||||
hidden_states,
|
curr_hidden_states,
|
||||||
w1,
|
w1,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
a1_scale,
|
a1_scale,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
topk_weights,
|
curr_topk_weights,
|
||||||
topk_ids,
|
curr_topk_ids,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
@@ -491,8 +566,8 @@ def fused_experts(
|
|||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
topk_weights,
|
curr_topk_weights,
|
||||||
topk_ids,
|
curr_topk_ids,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
@@ -503,13 +578,12 @@ def fused_experts(
|
|||||||
use_fp8=use_fp8,
|
use_fp8=use_fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if inplace:
|
torch.sum(
|
||||||
return torch.sum(
|
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
dim=1,
|
dim=1,
|
||||||
out=hidden_states,
|
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
)
|
)
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
return out_hidden_states
|
||||||
|
|
||||||
|
|
||||||
def fused_moe(
|
def fused_moe(
|
||||||
@@ -521,6 +595,9 @@ def fused_moe(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -543,6 +620,10 @@ def fused_moe(
|
|||||||
Defaults to False.
|
Defaults to False.
|
||||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||||
for the kernel configuration.
|
for the kernel configuration.
|
||||||
|
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||||
|
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||||
|
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||||
|
note: Deepseekv2 model uses grouped_topk
|
||||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
products for w1 and w2. Defaults to False.
|
products for w1 and w2. Defaults to False.
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
@@ -556,12 +637,18 @@ def fused_moe(
|
|||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
|
|
||||||
if hasattr(ops, "topk_softmax"):
|
if use_grouped_topk:
|
||||||
topk_weights, topk_ids = fused_topk(
|
assert num_expert_group is not None and topk_group is not None
|
||||||
hidden_states, gating_output, topk, renormalize
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
hidden_states,
|
||||||
|
gating_output,
|
||||||
|
topk,
|
||||||
|
renormalize,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_weights, topk_ids = fused_topk_v0_4_3(
|
topk_weights, topk_ids = fused_topk(
|
||||||
hidden_states, gating_output, topk, renormalize
|
hidden_states, gating_output, topk, renormalize
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -579,33 +666,3 @@ def fused_moe(
|
|||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fused_topk_v0_4_3(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
):
|
|
||||||
import vllm._moe_C as moe_kernels
|
|
||||||
|
|
||||||
M, _ = hidden_states.shape
|
|
||||||
|
|
||||||
topk_weights = torch.empty(
|
|
||||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
|
||||||
token_expert_indicies = torch.empty(
|
|
||||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
moe_kernels.topk_softmax(
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
token_expert_indicies,
|
|
||||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
|
||||||
)
|
|
||||||
del token_expert_indicies # Not used. Will be used in the future.
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
587
python/sglang/srt/layers/fused_moe/layer.py
Normal file
587
python/sglang/srt/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,587 @@
|
|||||||
|
# Adapted from
|
||||||
|
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.forward(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
router_logits,
|
||||||
|
top_k,
|
||||||
|
renormalize,
|
||||||
|
use_grouped_topk,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
num_expert_group: Optional[int],
|
||||||
|
topk_group: Optional[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.fused_moe.fused_moe import fused_moe
|
||||||
|
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
router_logits,
|
||||||
|
top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_cpu(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
||||||
|
|
||||||
|
def forward_tpu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
num_expert_group: Optional[int],
|
||||||
|
topk_group: Optional[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
||||||
|
|
||||||
|
assert not use_grouped_topk
|
||||||
|
assert num_expert_group is None
|
||||||
|
assert topk_group is None
|
||||||
|
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoE(torch.nn.Module):
|
||||||
|
"""FusedMoE layer for MoE models.
|
||||||
|
|
||||||
|
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
||||||
|
w13) and RowParallelLinear weights (down_proj/ w2).
|
||||||
|
|
||||||
|
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
||||||
|
copy that naming convention here and handle any remapping in the
|
||||||
|
load_weights function in each model implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_experts: Number of experts in the model
|
||||||
|
top_k: Number of experts selected for each token
|
||||||
|
hidden_size: Input hidden state size of the transformer
|
||||||
|
intermediate_size: Intermediate size of the experts
|
||||||
|
params_dtype: Data type for the parameters.
|
||||||
|
reduce_results: Whether to all all_reduce on the output of the layer
|
||||||
|
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
||||||
|
quant_config: Quantization configure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
reduce_results: bool = False,
|
||||||
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
self.tp_size = (
|
||||||
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||||
|
)
|
||||||
|
self.top_k = top_k
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
|
self.reduce_results = reduce_results
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.use_grouped_topk = use_grouped_topk
|
||||||
|
if self.use_grouped_topk:
|
||||||
|
assert num_expert_group is not None and topk_group is not None
|
||||||
|
self.num_expert_group = num_expert_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
|
||||||
|
if quant_config is None:
|
||||||
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
|
UnquantizedFusedMoEMethod()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if isinstance(quant_config, Fp8Config):
|
||||||
|
self.quant_method = Fp8MoEMethod(quant_config)
|
||||||
|
else:
|
||||||
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
self.quant_method.create_weights(
|
||||||
|
layer=self,
|
||||||
|
num_experts=num_experts,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=self.intermediate_size_per_partition,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
weight_loader=self.weight_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
def weight_loader(
|
||||||
|
self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
shard_id: int,
|
||||||
|
expert_id: int,
|
||||||
|
pre_sharded: bool,
|
||||||
|
):
|
||||||
|
param_data = param.data
|
||||||
|
|
||||||
|
# Input scales can be loaded directly and should be equal.
|
||||||
|
if "input_scale" in weight_name:
|
||||||
|
if (
|
||||||
|
param_data[expert_id] != 1
|
||||||
|
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"input_scales of w1 and w3 of a layer "
|
||||||
|
f"must be equal. But got {param_data[expert_id]} "
|
||||||
|
f"vs. {loaded_weight}"
|
||||||
|
)
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
# Weight scales
|
||||||
|
elif "weight_scale" in weight_name:
|
||||||
|
# If we are in merged column case (gate_up_proj)
|
||||||
|
# shard_id 0 == gate_proj / w1
|
||||||
|
# shard_id 2 == up_proj / w3
|
||||||
|
if shard_id == 0 or shard_id == 2:
|
||||||
|
# We have to keep the weight scales of w1 and w3 because
|
||||||
|
# we need to re-quantize w1/w3 weights after weight loading.
|
||||||
|
idx = 0 if shard_id == 0 else 1
|
||||||
|
param_data[expert_id][idx] = loaded_weight
|
||||||
|
# If we are in the row parallel case (down_proj)
|
||||||
|
# shard_id 1 == down_proj / w2
|
||||||
|
else:
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
# Weights
|
||||||
|
else:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_size = self.intermediate_size_per_partition
|
||||||
|
if pre_sharded:
|
||||||
|
shard = slice(None)
|
||||||
|
else:
|
||||||
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||||
|
|
||||||
|
# w1, gate_proj case: Load into first shard of w13.
|
||||||
|
if shard_id == 0:
|
||||||
|
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||||
|
# w3, up_proj case: Load into second shard of w13.
|
||||||
|
elif shard_id == 2:
|
||||||
|
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
||||||
|
shard, :
|
||||||
|
]
|
||||||
|
# w2, down_proj case: Load into only shard of w2.
|
||||||
|
elif shard_id == 1:
|
||||||
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
final_hidden_states = self.quant_method.apply(
|
||||||
|
self,
|
||||||
|
x=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=self.top_k,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
num_expert_group=self.num_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_expert_params_mapping(
|
||||||
|
cls,
|
||||||
|
ckpt_gate_proj_name: str,
|
||||||
|
ckpt_down_proj_name: str,
|
||||||
|
ckpt_up_proj_name: str,
|
||||||
|
num_experts: int,
|
||||||
|
) -> List[Tuple[str, str, int, int]]:
|
||||||
|
|
||||||
|
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||||
|
gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name]
|
||||||
|
|
||||||
|
return (
|
||||||
|
[
|
||||||
|
# These are the weight scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"experts.w13_scale"
|
||||||
|
if weight_name in gate_up
|
||||||
|
else "experts.w2_scale"
|
||||||
|
),
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight_scale",
|
||||||
|
expert_id,
|
||||||
|
shard_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(num_experts)
|
||||||
|
for shard_id, weight_name in enumerate(gate_down_up)
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
# These are the weights for the experts
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"experts.w13_weight"
|
||||||
|
if weight_name in gate_up
|
||||||
|
else "experts.w2_weight"
|
||||||
|
),
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight",
|
||||||
|
expert_id,
|
||||||
|
shard_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(num_experts)
|
||||||
|
for shard_id, weight_name in enumerate(gate_down_up)
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
# These are the weight scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"experts.a13_scale"
|
||||||
|
if weight_name in gate_up
|
||||||
|
else "experts.a2_scale"
|
||||||
|
),
|
||||||
|
f"experts.{expert_id}.{weight_name}.input_scale",
|
||||||
|
expert_id,
|
||||||
|
shard_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(num_experts)
|
||||||
|
for shard_id, weight_name in enumerate(gate_down_up)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Module
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
all_close_1d,
|
||||||
|
per_tensor_dequantize,
|
||||||
|
)
|
||||||
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||||
|
"""MoE method for FP8.
|
||||||
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
|
dynamic/static activation scale.
|
||||||
|
|
||||||
|
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||||
|
activation scaling. The weight scaling factor will be initialized after
|
||||||
|
the model weights are loaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: Fp8Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
# They will be combined to a single scale after weight loading.
|
||||||
|
w13_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_scale", w13_scale)
|
||||||
|
|
||||||
|
w2_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_scale", w2_scale)
|
||||||
|
|
||||||
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
|
# process_weights_after_loading()
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# INPUT_SCALES
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
raise ValueError(
|
||||||
|
"Found static activation scheme for checkpoint that "
|
||||||
|
"was not serialized fp8."
|
||||||
|
)
|
||||||
|
|
||||||
|
a13_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.register_parameter("a13_scale", a13_scale)
|
||||||
|
set_weight_attrs(a13_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
a2_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.register_parameter("a2_scale", a2_scale)
|
||||||
|
set_weight_attrs(a2_scale, extra_weight_attrs)
|
||||||
|
else:
|
||||||
|
layer.a13_scale = None
|
||||||
|
layer.a2_scale = None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
|
||||||
|
# If checkpoint is fp16, quantize in place.
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
w13_weight = torch.empty_like(
|
||||||
|
layer.w13_weight.data, dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
w2_weight = torch.empty_like(
|
||||||
|
layer.w2_weight.data, dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-initialize w13_scale because we directly quantize
|
||||||
|
# merged w13 weights and generate a single scaling factor.
|
||||||
|
layer.w13_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
for expert in range(layer.num_experts):
|
||||||
|
w13_weight[expert, :, :], layer.w13_scale[expert] = (
|
||||||
|
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
|
w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant(
|
||||||
|
layer.w2_weight.data[expert, :, :]
|
||||||
|
)
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If checkpoint is fp8, we need to handle that the
|
||||||
|
# MoE kernels require single activation scale and single weight
|
||||||
|
# scale for w13 per expert.
|
||||||
|
else:
|
||||||
|
# Fp8 moe kernels require a single activation scale.
|
||||||
|
# We take the max of all the scales in case they differ.
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
if layer.a13_scale is None or layer.a2_scale is None:
|
||||||
|
raise ValueError(
|
||||||
|
"QuantConfig has static quantization, but found "
|
||||||
|
"activation scales are None."
|
||||||
|
)
|
||||||
|
if not all_close_1d(layer.a13_scale) or not all_close_1d(
|
||||||
|
layer.a2_scale
|
||||||
|
):
|
||||||
|
print_warning_once(
|
||||||
|
"Found input_scales that are not equal for "
|
||||||
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
|
"for each layer. "
|
||||||
|
)
|
||||||
|
layer.a13_scale = torch.nn.Parameter(
|
||||||
|
layer.a13_scale.max(), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.a2_scale = torch.nn.Parameter(
|
||||||
|
layer.a2_scale.max(), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
|
# We take the max then dequant and requant each expert.
|
||||||
|
assert layer.w13_scale is not None
|
||||||
|
shard_size = layer.intermediate_size_per_partition
|
||||||
|
max_w13_scales = layer.w13_scale.max(dim=1).values
|
||||||
|
for expert_id in range(layer.num_experts):
|
||||||
|
start = 0
|
||||||
|
for shard_id in range(2):
|
||||||
|
dq_weight = per_tensor_dequantize(
|
||||||
|
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||||
|
layer.w13_scale[expert_id][shard_id],
|
||||||
|
)
|
||||||
|
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||||
|
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||||
|
)
|
||||||
|
start += shard_size
|
||||||
|
|
||||||
|
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
from sglang.srt.layers.fused_moe.fused_moe import fused_moe
|
||||||
|
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
router_logits,
|
||||||
|
top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8=True,
|
||||||
|
w1_scale=layer.w13_scale,
|
||||||
|
w2_scale=layer.w2_scale,
|
||||||
|
a1_scale=layer.a13_scale,
|
||||||
|
a2_scale=layer.a2_scale,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
)
|
||||||
@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
last_logits = last_logits[:, : self.config.vocab_size].float()
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
if hasattr(self.config, "final_logit_softcapping"):
|
if hasattr(self.config, "final_logit_softcapping"):
|
||||||
last_logits /= self.config.final_logit_softcapping
|
last_logits.div_(self.config.final_logit_softcapping)
|
||||||
last_logits = torch.tanh(last_logits)
|
last_logits = torch.tanh(last_logits)
|
||||||
last_logits *= self.config.final_logit_softcapping
|
last_logits.mul_(self.config.final_logit_softcapping)
|
||||||
|
|
||||||
# Return only last_logits if logprob is not requested
|
# Return only last_logits if logprob is not requested
|
||||||
if not logits_metadata.return_logprob:
|
if not logits_metadata.return_logprob:
|
||||||
@@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
if hasattr(self.config, "final_logit_softcapping"):
|
if hasattr(self.config, "final_logit_softcapping"):
|
||||||
all_logits /= self.config.final_logit_softcapping
|
all_logits.div_(self.config.final_logit_softcapping)
|
||||||
all_logits = torch.tanh(all_logits)
|
all_logits = torch.tanh(all_logits)
|
||||||
all_logits *= self.config.final_logit_softcapping
|
all_logits.mul_(self.config.final_logit_softcapping)
|
||||||
|
|
||||||
all_logprobs = all_logits
|
all_logprobs = all_logits
|
||||||
del all_logits, hidden_states
|
del all_logits, hidden_states
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
is_generation_model,
|
is_generation_model,
|
||||||
is_llama3_405b_fp8,
|
is_llama3_405b_fp8_head_16,
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
monkey_patch_vllm_dummy_weight_loader,
|
monkey_patch_vllm_dummy_weight_loader,
|
||||||
monkey_patch_vllm_p2p_access_check,
|
monkey_patch_vllm_p2p_access_check,
|
||||||
@@ -158,7 +158,7 @@ class ModelRunner:
|
|||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
|
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
||||||
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
||||||
self.model_config.hf_config.num_key_value_heads = 8
|
self.model_config.hf_config.num_key_value_heads = 8
|
||||||
vllm_model_config.hf_config.num_key_value_heads = 8
|
vllm_model_config.hf_config.num_key_value_heads = 8
|
||||||
|
|||||||
@@ -16,20 +16,17 @@ limitations under the License.
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||||
"""Inference-only Grok1 model."""
|
"""Inference-only Grok1 model."""
|
||||||
|
import warnings
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import tqdm
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -37,7 +34,6 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
@@ -45,141 +41,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe import fused_moe
|
from sglang.srt.layers.fused_moe import FusedMoE
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
use_fused = True
|
|
||||||
|
|
||||||
|
|
||||||
class Grok1MLP(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_experts: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.ffn_dim = intermediate_size
|
|
||||||
self.hidden_dim = hidden_size
|
|
||||||
|
|
||||||
self.w1 = ReplicatedLinear(
|
|
||||||
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
|
||||||
)
|
|
||||||
self.w2 = ReplicatedLinear(
|
|
||||||
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
|
||||||
)
|
|
||||||
self.w3 = ReplicatedLinear(
|
|
||||||
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
|
||||||
)
|
|
||||||
|
|
||||||
self.act_fn = nn.GELU()
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
w1_out, _ = self.w1(hidden_states)
|
|
||||||
w1_out = self.act_fn(w1_out)
|
|
||||||
w3_out, _ = self.w3(hidden_states)
|
|
||||||
current_hidden_states = w1_out * w3_out
|
|
||||||
current_hidden_states, _ = self.w2(current_hidden_states)
|
|
||||||
return current_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class Grok1MoEUnfused(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PretrainedConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.rank = get_tensor_model_parallel_rank()
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.num_total_experts = config.num_local_experts
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
if self.tp_size > self.num_total_experts:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
|
||||||
f"the number of experts {self.num_total_experts}."
|
|
||||||
)
|
|
||||||
# Split experts equally between ranks
|
|
||||||
self.expert_indicies = np.array_split(
|
|
||||||
range(self.num_total_experts), self.tp_size
|
|
||||||
)[self.rank].tolist()
|
|
||||||
if not self.expert_indicies:
|
|
||||||
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
|
||||||
|
|
||||||
self.experts = nn.ModuleList(
|
|
||||||
[
|
|
||||||
(
|
|
||||||
Grok1MLP(
|
|
||||||
self.num_total_experts,
|
|
||||||
config.hidden_size,
|
|
||||||
config.intermediate_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
)
|
|
||||||
if idx in self.expert_indicies
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
for idx in range(self.num_total_experts)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.gate = ReplicatedLinear(
|
|
||||||
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
router_logits, _ = self.gate(hidden_states)
|
|
||||||
router_logits = 30 * F.tanh(router_logits / 30)
|
|
||||||
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
||||||
routing_weights, selected_experts = torch.topk(
|
|
||||||
routing_weights, self.top_k, dim=-1
|
|
||||||
)
|
|
||||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
||||||
hidden_dim = hidden_states.shape[1]
|
|
||||||
|
|
||||||
final_hidden_states = torch.zeros(
|
|
||||||
(hidden_states.shape[0], hidden_dim),
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device=hidden_states.device,
|
|
||||||
)
|
|
||||||
expert_mask = torch.nn.functional.one_hot(
|
|
||||||
selected_experts, num_classes=self.num_total_experts
|
|
||||||
).permute(2, 1, 0)
|
|
||||||
|
|
||||||
for expert_idx in self.expert_indicies:
|
|
||||||
expert_layer = self.experts[expert_idx]
|
|
||||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
||||||
|
|
||||||
if top_x.shape[0] == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# in torch it is faster to index using lists than torch tensors
|
|
||||||
top_x_list = top_x.tolist()
|
|
||||||
idx_list = idx.tolist()
|
|
||||||
|
|
||||||
# Index the correct hidden states and compute the expert hidden state for
|
|
||||||
# the current expert. We need to make sure to multiply the output hidden
|
|
||||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
||||||
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
|
||||||
current_hidden_states = (
|
|
||||||
expert_layer(current_state)
|
|
||||||
* routing_weights[top_x_list, idx_list, None]
|
|
||||||
)
|
|
||||||
|
|
||||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
||||||
# the `top_x` tensor here.
|
|
||||||
final_hidden_states.index_add_(0, top_x, current_hidden_states)
|
|
||||||
|
|
||||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
|
||||||
|
|
||||||
|
|
||||||
class Grok1MoE(nn.Module):
|
class Grok1MoE(nn.Module):
|
||||||
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
|
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
|
||||||
@@ -197,221 +65,42 @@ class Grok1MoE(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
tp_size: Optional[int] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
|
||||||
self.num_total_experts = num_experts
|
|
||||||
self.top_k = top_k
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size // self.tp_size
|
|
||||||
self.quant_config = quant_config
|
|
||||||
|
|
||||||
# FIXME(pcmoritz): Make this more general to support different
|
|
||||||
# quantization schemes
|
|
||||||
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
self.params_dtype = params_dtype
|
|
||||||
|
|
||||||
# Gate always runs at half / full precision for now.
|
# Gate always runs at half / full precision for now.
|
||||||
self.gate = ReplicatedLinear(
|
self.gate = ReplicatedLinear(
|
||||||
self.hidden_size,
|
hidden_size,
|
||||||
self.num_total_experts,
|
num_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
params_dtype=self.params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=None,
|
quant_config=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
self.experts = FusedMoE(
|
||||||
params_dtype = torch.float8_e4m3fn
|
num_experts=num_experts,
|
||||||
|
top_k=top_k,
|
||||||
self.w13_weight = nn.Parameter(
|
hidden_size=hidden_size,
|
||||||
torch.empty(
|
intermediate_size=intermediate_size,
|
||||||
self.num_total_experts,
|
params_dtype=params_dtype,
|
||||||
2 * self.intermediate_size,
|
reduce_results=True,
|
||||||
self.hidden_size,
|
renormalize=False,
|
||||||
dtype=params_dtype,
|
quant_config=quant_config,
|
||||||
|
tp_size=tp_size,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
self.w2_weight = nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
self.num_total_experts,
|
|
||||||
self.hidden_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
set_weight_attrs(
|
|
||||||
self.w13_weight,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
set_weight_attrs(
|
|
||||||
self.w2_weight,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Used for fp8.
|
|
||||||
self.w13_scale = None
|
|
||||||
self.w2_scale = None
|
|
||||||
self.a13_scale = None
|
|
||||||
self.a2_scale = None
|
|
||||||
|
|
||||||
if self.use_fp8:
|
|
||||||
# WEIGHT_SCALE (for fp8)
|
|
||||||
self.w13_scale = nn.Parameter(
|
|
||||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
self.w2_scale = nn.Parameter(
|
|
||||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
|
||||||
# process_weights_after_loading()
|
|
||||||
if quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
set_weight_attrs(
|
|
||||||
self.w13_scale,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
set_weight_attrs(
|
|
||||||
self.w2_scale,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# ACT_SCALE (for fp8)
|
|
||||||
if quant_config.activation_scheme == "static":
|
|
||||||
if not quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
raise ValueError(
|
|
||||||
"Found static activation scheme for checkpoint that "
|
|
||||||
"was not serialized fp8."
|
|
||||||
)
|
|
||||||
self.a13_scale = nn.Parameter(
|
|
||||||
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
self.a2_scale = nn.Parameter(
|
|
||||||
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
set_weight_attrs(
|
|
||||||
self.a13_scale,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
set_weight_attrs(
|
|
||||||
self.a2_scale,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def weight_loader(
|
|
||||||
self,
|
|
||||||
param: nn.Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
weight_name: str,
|
|
||||||
expert_id: int,
|
|
||||||
pre_sharded: bool,
|
|
||||||
):
|
|
||||||
param_data = param.data
|
|
||||||
shard_size = self.intermediate_size
|
|
||||||
if pre_sharded:
|
|
||||||
# The weight is already sharded. Readl the full shard
|
|
||||||
shard = slice(None)
|
|
||||||
else:
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
|
||||||
if weight_name.endswith("w1.weight"):
|
|
||||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
|
||||||
if weight_name.endswith("w3.weight"):
|
|
||||||
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
|
||||||
shard, :
|
|
||||||
]
|
|
||||||
if weight_name.endswith("w2.weight"):
|
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
|
||||||
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
|
|
||||||
def process_weights_after_loading(self):
|
|
||||||
# Fp8 is the only case where we need to process after loading.
|
|
||||||
if not self.use_fp8:
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize here.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
w13_weight = torch.empty_like(
|
|
||||||
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
|
||||||
for expert in range(self.num_total_experts):
|
|
||||||
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
|
||||||
self.w13_weight.data[expert, :, :]
|
|
||||||
)
|
|
||||||
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
|
||||||
self.w2_weight.data[expert, :, :]
|
|
||||||
)
|
|
||||||
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
|
||||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
|
|
||||||
# If checkpoint is fp8 + static, cleanup act_scales.
|
|
||||||
# Since state_dict has an act_scale per expert but our kernels
|
|
||||||
# are passed one act_scale shared across all experts.
|
|
||||||
elif self.quant_config.activation_scheme == "static":
|
|
||||||
if self.a13_scale is None or self.a2_scale is None:
|
|
||||||
raise ValueError(
|
|
||||||
"QuantConfig has static quantization, but found "
|
|
||||||
"activation scales are None."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
|
||||||
print_warning_once(
|
|
||||||
"Found act_scales that are not equal for fp8 MoE layer. "
|
|
||||||
"Using the maximum across experts for each layer. "
|
|
||||||
)
|
|
||||||
|
|
||||||
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
|
||||||
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(
|
router_logits = 30.0 * F.tanh(router_logits / 30.0)
|
||||||
hidden_states,
|
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||||
self.w13_weight,
|
return final_hidden_states.view(orig_shape)
|
||||||
self.w2_weight,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=False,
|
|
||||||
inplace=True,
|
|
||||||
use_fp8=self.use_fp8,
|
|
||||||
w1_scale=self.w13_scale,
|
|
||||||
w2_scale=self.w2_scale,
|
|
||||||
a1_scale=self.a13_scale,
|
|
||||||
a2_scale=self.a2_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.tp_size > 1:
|
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_size)
|
|
||||||
|
|
||||||
|
|
||||||
class Grok1Attention(nn.Module):
|
class Grok1Attention(nn.Module):
|
||||||
@@ -478,6 +167,7 @@ class Grok1Attention(nn.Module):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
)
|
)
|
||||||
|
# TODO(lianmin): load logit cap from config
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -502,7 +192,7 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
# Requires transformers > 4.32.0
|
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
self.self_attn = Grok1Attention(
|
self.self_attn = Grok1Attention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@@ -513,7 +203,6 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if use_fused:
|
|
||||||
self.block_sparse_moe = Grok1MoE(
|
self.block_sparse_moe = Grok1MoE(
|
||||||
num_experts=config.num_local_experts,
|
num_experts=config.num_local_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
@@ -521,10 +210,6 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.block_sparse_moe = Grok1MoEUnfused(
|
|
||||||
config=config, quant_config=quant_config
|
|
||||||
)
|
|
||||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@@ -536,6 +221,7 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
hidden_states = (
|
hidden_states = (
|
||||||
self.post_attn_norm(
|
self.post_attn_norm(
|
||||||
self.self_attn(
|
self.self_attn(
|
||||||
@@ -547,11 +233,11 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
+ hidden_states
|
+ hidden_states
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
hidden_states = (
|
hidden_states = (
|
||||||
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
|
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
|
||||||
+ hidden_states
|
+ hidden_states
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -593,7 +279,6 @@ class Grok1Model(nn.Module):
|
|||||||
|
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
hidden_states.mul_(self.config.output_multiplier_scale)
|
hidden_states.mul_(self.config.output_multiplier_scale)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -615,8 +300,8 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@@ -637,50 +322,17 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
]
|
||||||
|
|
||||||
if use_fused:
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
expert_params_mapping = (
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
[
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
# These are the weight scales for the experts
|
ckpt_gate_proj_name="w1",
|
||||||
# (param_name, weight_name, expert_id)
|
ckpt_down_proj_name="w2",
|
||||||
(
|
ckpt_up_proj_name="w3",
|
||||||
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
num_experts=self.config.num_local_experts,
|
||||||
f"experts.{expert_id}.{weight_name}.weight_scale",
|
|
||||||
expert_id,
|
|
||||||
)
|
)
|
||||||
for expert_id in range(self.config.num_local_experts)
|
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
|
||||||
]
|
|
||||||
+ [
|
|
||||||
# These are the weights for the experts
|
|
||||||
# (param_name, weight_name, expert_id)
|
|
||||||
(
|
|
||||||
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
|
||||||
f"experts.{expert_id}.{weight_name}.weight",
|
|
||||||
expert_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(self.config.num_local_experts)
|
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
|
||||||
]
|
|
||||||
+ [
|
|
||||||
# These are the activation scales for the experts
|
|
||||||
# (param_name, weight_name, expert_id)
|
|
||||||
(
|
|
||||||
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
|
||||||
f"experts.{expert_id}.{weight_name}.act_scale",
|
|
||||||
expert_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(self.config.num_local_experts)
|
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
expert_params_mapping = []
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
if get_tensor_model_parallel_rank() == 0:
|
|
||||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
# print(get_tensor_model_parallel_rank(), name)
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -691,21 +343,25 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for param_name, weight_name, expert_id in expert_params_mapping:
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(
|
weight_loader(
|
||||||
param,
|
param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name,
|
weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
expert_id=expert_id,
|
expert_id=expert_id,
|
||||||
pre_sharded=get_tensor_model_parallel_world_size() > 1,
|
pre_sharded=get_tensor_model_parallel_world_size() > 1,
|
||||||
)
|
)
|
||||||
@@ -714,6 +370,9 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(
|
weight_loader = getattr(
|
||||||
param, "weight_loader", default_weight_loader
|
param, "weight_loader", default_weight_loader
|
||||||
@@ -721,11 +380,6 @@ class Grok1ModelForCausalLM(nn.Module):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
def all_close_1d(x: torch.Tensor) -> bool:
|
|
||||||
assert len(x.shape) == 1
|
|
||||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
|
||||||
|
|
||||||
|
|
||||||
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
|
|||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE,
|
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from triton.runtime.cache import (
|
from triton.runtime.cache import (
|
||||||
FileCacheManager,
|
FileCacheManager,
|
||||||
@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
|
|||||||
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
|
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||||
|
|
||||||
|
|
||||||
def is_llama3_405b_fp8(model_config):
|
def is_llama3_405b_fp8_head_16(model_config):
|
||||||
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
|
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
|
||||||
if (
|
if (
|
||||||
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
||||||
|
|||||||
Reference in New Issue
Block a user