Update the mixtral to use the better FusedMoE layer (#1081)
This commit is contained in:
@@ -5,7 +5,7 @@ To support a new model in SGLang, you only need to add a single file under [SGLa
|
|||||||
Another valuable resource is the [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models). vLLM has extensive coverage of models, and SGLang has reused vLLM for most parts of the model implementations. This similarity makes it easy to port many models from vLLM to SGLang.
|
Another valuable resource is the [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models). vLLM has extensive coverage of models, and SGLang has reused vLLM for most parts of the model implementations. This similarity makes it easy to port many models from vLLM to SGLang.
|
||||||
|
|
||||||
To port a model from vLLM to SGLang, you can compare these two files [SGLang LLaMA Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py) and [vLLM LLaMA Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of PagedAttention with RadixAttention. The other parts are almost identical. Specifically,
|
To port a model from vLLM to SGLang, you can compare these two files [SGLang LLaMA Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py) and [vLLM LLaMA Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of PagedAttention with RadixAttention. The other parts are almost identical. Specifically,
|
||||||
- Replace vllm's `Attention` with `RadixAttention`.
|
- Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`.
|
||||||
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
|
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
|
||||||
- Remove `Sample`.
|
- Remove `Sample`.
|
||||||
- Change `forward()` functions, and add `input_metadata`.
|
- Change `forward()` functions, and add `input_metadata`.
|
||||||
|
|||||||
@@ -18,34 +18,25 @@ limitations under the License.
|
|||||||
"""Inference-only Mixtral model."""
|
"""Inference-only Mixtral model."""
|
||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
get_tensor_model_parallel_rank,
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_reduce,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
@@ -69,216 +60,44 @@ class MixtralMoE(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
tp_size: Optional[int] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
|
||||||
self.num_total_experts = num_experts
|
|
||||||
self.top_k = top_k
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size // self.tp_size
|
|
||||||
self.quant_config = quant_config
|
|
||||||
|
|
||||||
# FIXME(pcmoritz): Make this more general to support different
|
|
||||||
# quantization schemes
|
|
||||||
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
self.params_dtype = params_dtype
|
|
||||||
|
|
||||||
# Gate always runs at half / full precision for now.
|
# Gate always runs at half / full precision for now.
|
||||||
self.gate = ReplicatedLinear(
|
self.gate = ReplicatedLinear(
|
||||||
self.hidden_size,
|
hidden_size,
|
||||||
self.num_total_experts,
|
num_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
params_dtype=self.params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=None,
|
quant_config=None,
|
||||||
|
prefix=f"{prefix}.gate",
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
self.experts = FusedMoE(
|
||||||
params_dtype = torch.float8_e4m3fn
|
num_experts=num_experts,
|
||||||
|
top_k=top_k,
|
||||||
self.w13_weight = nn.Parameter(
|
hidden_size=hidden_size,
|
||||||
torch.empty(
|
intermediate_size=intermediate_size,
|
||||||
self.num_total_experts,
|
params_dtype=params_dtype,
|
||||||
2 * self.intermediate_size,
|
reduce_results=True,
|
||||||
self.hidden_size,
|
renormalize=True,
|
||||||
dtype=params_dtype,
|
quant_config=quant_config,
|
||||||
)
|
tp_size=tp_size,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
)
|
)
|
||||||
self.w2_weight = nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
self.num_total_experts,
|
|
||||||
self.hidden_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
set_weight_attrs(
|
|
||||||
self.w13_weight,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
set_weight_attrs(
|
|
||||||
self.w2_weight,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Used for fp8.
|
|
||||||
self.w13_scale = None
|
|
||||||
self.w2_scale = None
|
|
||||||
self.a13_scale = None
|
|
||||||
self.a2_scale = None
|
|
||||||
|
|
||||||
if self.use_fp8:
|
|
||||||
# WEIGHT_SCALE (for fp8)
|
|
||||||
self.w13_scale = nn.Parameter(
|
|
||||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
self.w2_scale = nn.Parameter(
|
|
||||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
|
||||||
# process_weights_after_loading()
|
|
||||||
if quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
set_weight_attrs(
|
|
||||||
self.w13_scale,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
set_weight_attrs(
|
|
||||||
self.w2_scale,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# ACT_SCALE (for fp8)
|
|
||||||
if quant_config.activation_scheme == "static":
|
|
||||||
if not quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
raise ValueError(
|
|
||||||
"Found static activation scheme for checkpoint that "
|
|
||||||
"was not serialized fp8."
|
|
||||||
)
|
|
||||||
self.a13_scale = nn.Parameter(
|
|
||||||
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
self.a2_scale = nn.Parameter(
|
|
||||||
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
set_weight_attrs(
|
|
||||||
self.a13_scale,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
set_weight_attrs(
|
|
||||||
self.a2_scale,
|
|
||||||
{
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def weight_loader(
|
|
||||||
self,
|
|
||||||
param: nn.Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
weight_name: str,
|
|
||||||
expert_id: int,
|
|
||||||
):
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
param_data = param.data
|
|
||||||
shard_size = self.intermediate_size
|
|
||||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
|
||||||
if weight_name.endswith("w1.weight"):
|
|
||||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
|
||||||
if weight_name.endswith("w3.weight"):
|
|
||||||
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
|
||||||
shard, :
|
|
||||||
]
|
|
||||||
if weight_name.endswith("w2.weight"):
|
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
|
||||||
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
|
|
||||||
def process_weights_after_loading(self):
|
|
||||||
# Fp8 is the only case where we need to process after loading.
|
|
||||||
if not self.use_fp8:
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize here.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
w13_weight = torch.empty_like(
|
|
||||||
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
|
||||||
for expert in range(self.num_total_experts):
|
|
||||||
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
|
||||||
self.w13_weight.data[expert, :, :]
|
|
||||||
)
|
|
||||||
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
|
||||||
self.w2_weight.data[expert, :, :]
|
|
||||||
)
|
|
||||||
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
|
||||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
|
|
||||||
# If checkpoint is fp8 + static, cleanup act_scales.
|
|
||||||
# Since state_dict has an act_scale per expert but our kernels
|
|
||||||
# are passed one act_scale shared across all experts.
|
|
||||||
elif self.quant_config.activation_scheme == "static":
|
|
||||||
if self.a13_scale is None or self.a2_scale is None:
|
|
||||||
raise ValueError(
|
|
||||||
"QuantConfig has static quantization, but found "
|
|
||||||
"activation scales are None."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
|
||||||
print_warning_once(
|
|
||||||
"Found act_scales that are not equal for fp8 MoE layer. "
|
|
||||||
"Using the maximum across experts for each layer. "
|
|
||||||
)
|
|
||||||
|
|
||||||
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
|
||||||
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(
|
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||||
hidden_states,
|
return final_hidden_states.view(orig_shape)
|
||||||
self.w13_weight,
|
|
||||||
self.w2_weight,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
use_fp8=self.use_fp8,
|
|
||||||
w1_scale=self.w13_scale,
|
|
||||||
w2_scale=self.w2_scale,
|
|
||||||
a1_scale=self.a13_scale,
|
|
||||||
a2_scale=self.a2_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.tp_size > 1:
|
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_size)
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralAttention(nn.Module):
|
class MixtralAttention(nn.Module):
|
||||||
@@ -291,7 +110,7 @@ class MixtralAttention(nn.Module):
|
|||||||
max_position: int = 4096 * 32,
|
max_position: int = 4096 * 32,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
sliding_window: Optional[int] = None,
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -314,7 +133,6 @@ class MixtralAttention(nn.Module):
|
|||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.sliding_window = sliding_window
|
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -323,12 +141,14 @@ class MixtralAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -365,6 +185,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -377,8 +198,8 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
self.block_sparse_moe = MixtralMoE(
|
self.block_sparse_moe = MixtralMoE(
|
||||||
num_experts=config.num_local_experts,
|
num_experts=config.num_local_experts,
|
||||||
@@ -386,6 +207,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.block_sparse_moe",
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
@@ -422,6 +244,7 @@ class MixtralModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
@@ -431,10 +254,11 @@ class MixtralModel(nn.Module):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
# config.num_hidden_layers=16
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MixtralDecoderLayer(config, i, quant_config=quant_config)
|
MixtralDecoderLayer(
|
||||||
|
config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
|
||||||
|
)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -462,6 +286,7 @@ class MixtralModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MixtralForCausalLM(nn.Module):
|
class MixtralForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MixtralConfig,
|
config: MixtralConfig,
|
||||||
@@ -471,11 +296,10 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = MixtralModel(config, quant_config=quant_config)
|
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@@ -496,40 +320,13 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = (
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
[
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
# These are the weight scales for the experts
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
# (param_name, weight_name, expert_id)
|
ckpt_gate_proj_name="w1",
|
||||||
(
|
ckpt_down_proj_name="w2",
|
||||||
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
ckpt_up_proj_name="w3",
|
||||||
f"experts.{expert_id}.{weight_name}.weight_scale",
|
num_experts=self.config.num_local_experts,
|
||||||
expert_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(self.config.num_local_experts)
|
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
|
||||||
]
|
|
||||||
+ [
|
|
||||||
# These are the weights for the experts
|
|
||||||
# (param_name, weight_name, expert_id)
|
|
||||||
(
|
|
||||||
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
|
||||||
f"experts.{expert_id}.{weight_name}.weight",
|
|
||||||
expert_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(self.config.num_local_experts)
|
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
|
||||||
]
|
|
||||||
+ [
|
|
||||||
# These are the activation scales for the experts
|
|
||||||
# (param_name, weight_name, expert_id)
|
|
||||||
(
|
|
||||||
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
|
||||||
f"experts.{expert_id}.{weight_name}.act_scale",
|
|
||||||
expert_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(self.config.num_local_experts)
|
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
@@ -544,25 +341,35 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for param_name, weight_name, expert_id in expert_params_mapping:
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(
|
weight_loader(
|
||||||
param, loaded_weight, weight_name, expert_id=expert_id
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(
|
weight_loader = getattr(
|
||||||
param, "weight_loader", default_weight_loader
|
param, "weight_loader", default_weight_loader
|
||||||
@@ -570,9 +377,4 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
def all_close_1d(x: torch.Tensor) -> bool:
|
|
||||||
assert len(x.shape) == 1
|
|
||||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
|
||||||
|
|
||||||
|
|
||||||
EntryClass = MixtralForCausalLM
|
EntryClass = MixtralForCausalLM
|
||||||
|
|||||||
@@ -160,7 +160,6 @@ class MixtralAttention(nn.Module):
|
|||||||
max_position: int = 4096 * 32,
|
max_position: int = 4096 * 32,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
sliding_window: Optional[int] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -183,7 +182,6 @@ class MixtralAttention(nn.Module):
|
|||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.sliding_window = sliding_window
|
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -246,7 +244,6 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class TestServingThroughput(unittest.TestCase):
|
|||||||
|
|
||||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||||
# A100 (PCIE) performance
|
# A100 (PCIE) performance
|
||||||
assert res["output_throughput"] > 950
|
assert res["output_throughput"] > 940
|
||||||
|
|
||||||
def test_default_with_chunked_prefill(self):
|
def test_default_with_chunked_prefill(self):
|
||||||
res = self.run_test(
|
res = self.run_test(
|
||||||
|
|||||||
Reference in New Issue
Block a user