port fp8 mixtral (#460)
This commit is contained in:
@@ -69,20 +69,13 @@ class ModelRpcServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# For model end global settings
|
# For model end global settings
|
||||||
server_args_dict = {
|
|
||||||
"enable_flashinfer": server_args.enable_flashinfer,
|
|
||||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
mem_fraction_static=server_args.mem_fraction_static,
|
mem_fraction_static=server_args.mem_fraction_static,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
tp_size=server_args.tp_size,
|
tp_size=server_args.tp_size,
|
||||||
nccl_port=port_args.nccl_port,
|
nccl_port=port_args.nccl_port,
|
||||||
load_format=server_args.load_format,
|
server_args=server_args,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
server_args_dict=server_args_dict,
|
|
||||||
)
|
)
|
||||||
if is_multimodal_model(server_args.model_path):
|
if is_multimodal_model(server_args.model_path):
|
||||||
self.processor = get_processor(
|
self.processor = get_processor(
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from vllm.model_executor.models import ModelRegistry
|
|||||||
|
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
||||||
|
|
||||||
|
|
||||||
@@ -218,22 +219,23 @@ class ModelRunner:
|
|||||||
tp_rank,
|
tp_rank,
|
||||||
tp_size,
|
tp_size,
|
||||||
nccl_port,
|
nccl_port,
|
||||||
load_format="auto",
|
server_args: ServerArgs,
|
||||||
trust_remote_code=True,
|
|
||||||
server_args_dict: dict = {},
|
|
||||||
):
|
):
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.mem_fraction_static = mem_fraction_static
|
self.mem_fraction_static = mem_fraction_static
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.nccl_port = nccl_port
|
self.nccl_port = nccl_port
|
||||||
self.load_format = load_format
|
self.server_args = server_args
|
||||||
self.trust_remote_code = trust_remote_code
|
|
||||||
|
|
||||||
global global_server_args_dict
|
global global_server_args_dict
|
||||||
global_server_args_dict = server_args_dict
|
global_server_args_dict = {
|
||||||
|
"enable_flashinfer": server_args.enable_flashinfer,
|
||||||
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||||
|
}
|
||||||
|
|
||||||
# Init torch distributed
|
# Init torch distributed
|
||||||
|
logger.debug("Init torch begin.")
|
||||||
torch.cuda.set_device(self.tp_rank)
|
torch.cuda.set_device(self.tp_rank)
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
@@ -241,13 +243,15 @@ class ModelRunner:
|
|||||||
rank=self.tp_rank,
|
rank=self.tp_rank,
|
||||||
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||||
)
|
)
|
||||||
|
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||||
|
logger.debug("Init torch end.")
|
||||||
|
|
||||||
total_gpu_memory = get_available_gpu_memory(
|
total_gpu_memory = get_available_gpu_memory(
|
||||||
self.tp_rank, distributed=self.tp_size > 1
|
self.tp_rank, distributed=self.tp_size > 1
|
||||||
) * (1 << 30)
|
) * (1 << 30)
|
||||||
|
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
|
||||||
self.init_memory_pool(total_gpu_memory)
|
self.init_memory_pool(total_gpu_memory)
|
||||||
|
|
||||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||||
@@ -256,15 +260,15 @@ class ModelRunner:
|
|||||||
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
||||||
|
|
||||||
device_config = DeviceConfig()
|
device_config = DeviceConfig()
|
||||||
load_config = LoadConfig()
|
load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||||
vllm_model_config = VllmModelConfig(
|
vllm_model_config = VllmModelConfig(
|
||||||
model=self.model_config.path,
|
model=self.server_args.model_path,
|
||||||
|
quantization=self.server_args.quantization,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
tokenizer_mode=None,
|
tokenizer_mode=None,
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
trust_remote_code=self.server_args.trust_remote_code,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
seed=42,
|
seed=42,
|
||||||
revision=self.model_config.revision,
|
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
if self.model_config.model_overide_args is not None:
|
if self.model_config.model_overide_args is not None:
|
||||||
@@ -279,7 +283,7 @@ class ModelRunner:
|
|||||||
parallel_config=None,
|
parallel_config=None,
|
||||||
scheduler_config=None,
|
scheduler_config=None,
|
||||||
)
|
)
|
||||||
logger.info(f"Rank {self.tp_rank}: load weight end.")
|
logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}")
|
||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory):
|
def profile_max_num_token(self, total_gpu_memory):
|
||||||
available_gpu_memory = get_available_gpu_memory(
|
available_gpu_memory = get_available_gpu_memory(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||||
"""Inference-only Mixtral model."""
|
"""Inference-only Mixtral model."""
|
||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
@@ -8,11 +8,13 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -20,12 +22,15 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
@@ -33,106 +38,196 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class MixtralMLP(nn.Module):
|
|
||||||
|
class MixtralMoE(nn.Module):
|
||||||
|
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
||||||
|
across all ranks.
|
||||||
|
|
||||||
|
Each expert's weights are sharded across all ranks and a fused MoE
|
||||||
|
kernel is used for the forward pass, and finally we reduce the outputs
|
||||||
|
across ranks.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
) -> None:
|
tp_size: Optional[int] = None,
|
||||||
super().__init__()
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.ffn_dim = intermediate_size
|
|
||||||
self.hidden_dim = hidden_size
|
|
||||||
|
|
||||||
self.w1 = ReplicatedLinear(
|
|
||||||
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
|
||||||
)
|
|
||||||
self.w2 = ReplicatedLinear(
|
|
||||||
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
|
||||||
)
|
|
||||||
self.w3 = ReplicatedLinear(
|
|
||||||
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Use vllm's SiluAndMul
|
|
||||||
self.act_fn = nn.SiLU()
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
w1_out, _ = self.w1(hidden_states)
|
|
||||||
w1_out = self.act_fn(w1_out)
|
|
||||||
w3_out, _ = self.w3(hidden_states)
|
|
||||||
current_hidden_states = w1_out * w3_out
|
|
||||||
current_hidden_states, _ = self.w2(current_hidden_states)
|
|
||||||
return current_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralMoE(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: MixtralConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||||
self.rank = get_tensor_model_parallel_rank()
|
self.num_total_experts = num_experts
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.top_k = top_k
|
||||||
self.num_total_experts = config.num_local_experts
|
self.hidden_size = hidden_size
|
||||||
self.top_k = config.num_experts_per_tok
|
self.intermediate_size = intermediate_size // self.tp_size
|
||||||
if self.tp_size > self.num_total_experts:
|
self.quant_config = quant_config
|
||||||
raise ValueError(
|
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
|
||||||
f"the number of experts {self.num_total_experts}."
|
|
||||||
)
|
|
||||||
# Split experts equally between ranks
|
|
||||||
self.expert_indicies = np.array_split(
|
|
||||||
range(self.num_total_experts), self.tp_size
|
|
||||||
)[self.rank].tolist()
|
|
||||||
if not self.expert_indicies:
|
|
||||||
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
|
||||||
|
|
||||||
self.experts = nn.ModuleList(
|
# FIXME(pcmoritz): Make this more general to support different
|
||||||
[
|
# quantization schemes
|
||||||
(
|
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
||||||
MixtralMLP(
|
|
||||||
self.num_total_experts,
|
if params_dtype is None:
|
||||||
config.hidden_size,
|
params_dtype = torch.get_default_dtype()
|
||||||
config.intermediate_size,
|
self.params_dtype = params_dtype
|
||||||
quant_config=quant_config,
|
|
||||||
)
|
# Gate always runs at half / full precision for now.
|
||||||
if idx in self.expert_indicies
|
self.gate = ReplicatedLinear(self.hidden_size,
|
||||||
else None
|
self.num_total_experts,
|
||||||
)
|
bias=False,
|
||||||
for idx in range(self.num_total_experts)
|
params_dtype=self.params_dtype,
|
||||||
]
|
quant_config=None)
|
||||||
)
|
|
||||||
self.gate = ReplicatedLinear(
|
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
params_dtype = torch.float8_e4m3fn
|
||||||
)
|
|
||||||
|
self.w13_weight = nn.Parameter(
|
||||||
|
torch.empty(self.num_total_experts,
|
||||||
|
2 * self.intermediate_size,
|
||||||
|
self.hidden_size,
|
||||||
|
dtype=params_dtype))
|
||||||
|
self.w2_weight = nn.Parameter(
|
||||||
|
torch.empty(self.num_total_experts,
|
||||||
|
self.hidden_size,
|
||||||
|
self.intermediate_size,
|
||||||
|
dtype=params_dtype))
|
||||||
|
|
||||||
|
set_weight_attrs(self.w13_weight, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
set_weight_attrs(self.w2_weight, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Used for fp8.
|
||||||
|
self.w13_scale = None
|
||||||
|
self.w2_scale = None
|
||||||
|
self.a13_scale = None
|
||||||
|
self.a2_scale = None
|
||||||
|
|
||||||
|
if self.use_fp8:
|
||||||
|
# WEIGHT_SCALE (for fp8)
|
||||||
|
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
|
# process_weights_after_loading()
|
||||||
|
if quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
set_weight_attrs(self.w13_scale, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
set_weight_attrs(self.w2_scale, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
|
||||||
|
# ACT_SCALE (for fp8)
|
||||||
|
if quant_config.activation_scheme == "static":
|
||||||
|
if not quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
raise ValueError(
|
||||||
|
"Found static activation scheme for checkpoint that "
|
||||||
|
"was not serialized fp8.")
|
||||||
|
self.a13_scale = nn.Parameter(torch.zeros(
|
||||||
|
self.num_total_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
self.a2_scale = nn.Parameter(torch.zeros(
|
||||||
|
self.num_total_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
set_weight_attrs(self.a13_scale, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
set_weight_attrs(self.a2_scale, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
|
||||||
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str, expert_id: int):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
param_data = param.data
|
||||||
|
shard_size = self.intermediate_size
|
||||||
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||||
|
if weight_name.endswith("w1.weight"):
|
||||||
|
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||||
|
if weight_name.endswith("w3.weight"):
|
||||||
|
param_data[expert_id,
|
||||||
|
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||||
|
if weight_name.endswith("w2.weight"):
|
||||||
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
|
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
def process_weights_after_loading(self):
|
||||||
|
# Fp8 is the only case where we need to process after loading.
|
||||||
|
if not self.use_fp8:
|
||||||
|
return
|
||||||
|
|
||||||
|
# If checkpoint is fp16, quantize here.
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
w13_weight = torch.empty_like(self.w13_weight.data,
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
w2_weight = torch.empty_like(self.w2_weight.data,
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
for expert in range(self.num_total_experts):
|
||||||
|
w13_weight[expert, :, :], self.w13_scale[
|
||||||
|
expert] = ops.scaled_fp8_quant(
|
||||||
|
self.w13_weight.data[expert, :, :])
|
||||||
|
w2_weight[expert, :, :], self.w2_scale[
|
||||||
|
expert] = ops.scaled_fp8_quant(
|
||||||
|
self.w2_weight.data[expert, :, :])
|
||||||
|
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
||||||
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
|
# If checkpoint is fp8 + static, cleanup act_scales.
|
||||||
|
# Since state_dict has an act_scale per expert but our kernels
|
||||||
|
# are passed one act_scale shared across all experts.
|
||||||
|
elif self.quant_config.activation_scheme == "static":
|
||||||
|
if self.a13_scale is None or self.a2_scale is None:
|
||||||
|
raise ValueError(
|
||||||
|
"QuantConfig has static quantization, but found "
|
||||||
|
"activation scales are None.")
|
||||||
|
|
||||||
|
if (not all_close_1d(self.a13_scale)
|
||||||
|
or not all_close_1d(self.a2_scale)):
|
||||||
|
print_warning_once(
|
||||||
|
"Found act_scales that are not equal for fp8 MoE layer. "
|
||||||
|
"Using the maximum across experts for each layer. ")
|
||||||
|
|
||||||
|
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
final_hidden_states = fused_moe(hidden_states,
|
||||||
|
self.w13_weight,
|
||||||
|
self.w2_weight,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8=self.use_fp8,
|
||||||
|
w1_scale=self.w13_scale,
|
||||||
|
w2_scale=self.w2_scale,
|
||||||
|
a1_scale=self.a13_scale,
|
||||||
|
a2_scale=self.a2_scale)
|
||||||
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
if self.tp_size > 1:
|
||||||
routing_weights, selected_experts = torch.topk(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
routing_weights, self.top_k, dim=-1
|
final_hidden_states)
|
||||||
)
|
|
||||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
final_hidden_states = None
|
return final_hidden_states.view(num_tokens, hidden_size)
|
||||||
for expert_idx in self.expert_indicies:
|
|
||||||
expert_layer = self.experts[expert_idx]
|
|
||||||
expert_mask = selected_experts == expert_idx
|
|
||||||
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
|
|
||||||
if final_hidden_states is None:
|
|
||||||
final_hidden_states = current_hidden_states
|
|
||||||
else:
|
|
||||||
final_hidden_states.add_(current_hidden_states)
|
|
||||||
|
|
||||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralAttention(nn.Module):
|
class MixtralAttention(nn.Module):
|
||||||
@@ -234,7 +329,12 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
sliding_window=config.sliding_window,
|
sliding_window=config.sliding_window,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
self.block_sparse_moe = MixtralMoE(
|
||||||
|
num_experts=config.num_local_experts,
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
quant_config=quant_config)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
@@ -342,11 +442,35 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
expert_params_mapping = [
|
||||||
|
# These are the weight scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id)
|
||||||
|
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
||||||
|
for expert_id in range(self.config.num_local_experts)
|
||||||
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
|
] + [
|
||||||
|
# These are the weights for the experts
|
||||||
|
# (param_name, weight_name, expert_id)
|
||||||
|
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
||||||
|
for expert_id in range(self.config.num_local_experts)
|
||||||
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
|
] + [
|
||||||
|
# These are the activation scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id)
|
||||||
|
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||||
|
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
||||||
|
for expert_id in range(self.config.num_local_experts)
|
||||||
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
@@ -358,15 +482,30 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
for param_name, weight_name, expert_id in expert_params_mapping:
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
# Skip experts that are not assigned to this worker.
|
name = name.replace(weight_name, param_name)
|
||||||
if "block_sparse_moe.experts." in name and name not in params_dict:
|
param = params_dict[name]
|
||||||
continue
|
weight_loader = param.weight_loader
|
||||||
param = params_dict[name]
|
weight_loader(param,
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
loaded_weight,
|
||||||
weight_loader(param, loaded_weight)
|
weight_name,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def all_close_1d(x: torch.Tensor) -> bool:
|
||||||
|
assert len(x.shape) == 1
|
||||||
|
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
EntryClass = MixtralForCausalLM
|
EntryClass = MixtralForCausalLM
|
||||||
|
|||||||
371
python/sglang/srt/models/mixtral_quant.py
Normal file
371
python/sglang/srt/models/mixtral_quant.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
# Adapted from
|
||||||
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
|
||||||
|
"""Inference-only Mixtral model."""
|
||||||
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers import MixtralConfig
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.ffn_dim = intermediate_size
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
|
||||||
|
self.w1 = ReplicatedLinear(
|
||||||
|
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
||||||
|
)
|
||||||
|
self.w2 = ReplicatedLinear(
|
||||||
|
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
||||||
|
)
|
||||||
|
self.w3 = ReplicatedLinear(
|
||||||
|
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Use vllm's SiluAndMul
|
||||||
|
self.act_fn = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
w1_out, _ = self.w1(hidden_states)
|
||||||
|
w1_out = self.act_fn(w1_out)
|
||||||
|
w3_out, _ = self.w3(hidden_states)
|
||||||
|
current_hidden_states = w1_out * w3_out
|
||||||
|
current_hidden_states, _ = self.w2(current_hidden_states)
|
||||||
|
return current_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralMoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MixtralConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.rank = get_tensor_model_parallel_rank()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.num_total_experts = config.num_local_experts
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
if self.tp_size > self.num_total_experts:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
|
f"the number of experts {self.num_total_experts}."
|
||||||
|
)
|
||||||
|
# Split experts equally between ranks
|
||||||
|
self.expert_indicies = np.array_split(
|
||||||
|
range(self.num_total_experts), self.tp_size
|
||||||
|
)[self.rank].tolist()
|
||||||
|
if not self.expert_indicies:
|
||||||
|
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
||||||
|
|
||||||
|
self.experts = nn.ModuleList(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
MixtralMLP(
|
||||||
|
self.num_total_experts,
|
||||||
|
config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
if idx in self.expert_indicies
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
for idx in range(self.num_total_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.gate = ReplicatedLinear(
|
||||||
|
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
routing_weights, selected_experts = torch.topk(
|
||||||
|
routing_weights, self.top_k, dim=-1
|
||||||
|
)
|
||||||
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
final_hidden_states = None
|
||||||
|
for expert_idx in self.expert_indicies:
|
||||||
|
expert_layer = self.experts[expert_idx]
|
||||||
|
expert_mask = selected_experts == expert_idx
|
||||||
|
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
|
||||||
|
if final_hidden_states is None:
|
||||||
|
final_hidden_states = current_hidden_states
|
||||||
|
else:
|
||||||
|
final_hidden_states.add_(current_hidden_states)
|
||||||
|
|
||||||
|
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
layer_id: int = 0,
|
||||||
|
max_position: int = 4096 * 32,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=int(self.rope_theta),
|
||||||
|
is_neox_style=True,
|
||||||
|
)
|
||||||
|
self.attn = RadixAttention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v, input_metadata)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralDecoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MixtralConfig,
|
||||||
|
layer_id: int = 0,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
# Requires transformers > 4.32.0
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
self.self_attn = MixtralAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
max_position=config.max_position_embeddings,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
layer_id=layer_id,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
sliding_window=config.sliding_window,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.block_sparse_moe(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MixtralConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MixtralDecoderLayer(config, i, quant_config=quant_config)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if input_embeds is None:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
hidden_states = input_embeds
|
||||||
|
residual = None
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions, hidden_states, input_metadata, residual
|
||||||
|
)
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QuantMixtralForCausalLM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MixtralConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = MixtralModel(config, quant_config=quant_config)
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Skip experts that are not assigned to this worker.
|
||||||
|
if "block_sparse_moe.experts." in name and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = QuantMixtralForCausalLM
|
||||||
@@ -15,6 +15,7 @@ class ServerArgs:
|
|||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = True
|
||||||
context_length: Optional[int] = None
|
context_length: Optional[int] = None
|
||||||
|
quantization: Optional[str] = None
|
||||||
|
|
||||||
# Port
|
# Port
|
||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
@@ -135,6 +136,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.context_length,
|
default=ServerArgs.context_length,
|
||||||
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantization",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.quantization,
|
||||||
|
help="The quantization method.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
type=float,
|
type=float,
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
|
|||||||
"which may cause useless memory allocation for torch CUDA context.",
|
"which may cause useless memory allocation for torch CUDA context.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
||||||
|
|
||||||
if distributed:
|
if distributed:
|
||||||
|
|||||||
Reference in New Issue
Block a user