From 0fafc5606b0dc205518002dc2058e7b9a8d5019a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 11:46:35 -0700 Subject: [PATCH] port fp8 mixtral (#460) --- .../sglang/srt/managers/router/model_rpc.py | 9 +- .../srt/managers/router/model_runner.py | 28 +- python/sglang/srt/models/mixtral.py | 335 +++++++++++----- python/sglang/srt/models/mixtral_quant.py | 371 ++++++++++++++++++ python/sglang/srt/server_args.py | 7 + python/sglang/srt/utils.py | 1 + 6 files changed, 633 insertions(+), 118 deletions(-) create mode 100644 python/sglang/srt/models/mixtral_quant.py diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index d5a029f51..2873ef4c5 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -69,20 +69,13 @@ class ModelRpcServer: ) # For model end global settings - server_args_dict = { - "enable_flashinfer": server_args.enable_flashinfer, - "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, - } - self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, tp_rank=tp_rank, tp_size=server_args.tp_size, nccl_port=port_args.nccl_port, - load_format=server_args.load_format, - trust_remote_code=server_args.trust_remote_code, - server_args_dict=server_args_dict, + server_args=server_args, ) if is_multimodal_model(server_args.model_path): self.processor = get_processor( diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index a74b1d10f..cea08f5da 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -17,6 +17,7 @@ from vllm.model_executor.models import ModelRegistry from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model @@ -218,22 +219,23 @@ class ModelRunner: tp_rank, tp_size, nccl_port, - load_format="auto", - trust_remote_code=True, - server_args_dict: dict = {}, + server_args: ServerArgs, ): self.model_config = model_config self.mem_fraction_static = mem_fraction_static self.tp_rank = tp_rank self.tp_size = tp_size self.nccl_port = nccl_port - self.load_format = load_format - self.trust_remote_code = trust_remote_code + self.server_args = server_args global global_server_args_dict - global_server_args_dict = server_args_dict + global_server_args_dict = { + "enable_flashinfer": server_args.enable_flashinfer, + "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, + } # Init torch distributed + logger.debug("Init torch begin.") torch.cuda.set_device(self.tp_rank) torch.distributed.init_process_group( backend="nccl", @@ -241,13 +243,15 @@ class ModelRunner: rank=self.tp_rank, init_method=f"tcp://127.0.0.1:{self.nccl_port}", ) - initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + logger.debug("Init torch end.") total_gpu_memory = get_available_gpu_memory( self.tp_rank, distributed=self.tp_size > 1 ) * (1 << 30) + # logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB") self.load_model() + # logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB") self.init_memory_pool(total_gpu_memory) self.is_multimodal_model = is_multimodal_model(self.model_config) @@ -256,15 +260,15 @@ class ModelRunner: logger.info(f"Rank {self.tp_rank}: load weight begin.") device_config = DeviceConfig() - load_config = LoadConfig() + load_config = LoadConfig(load_format=self.server_args.load_format) vllm_model_config = VllmModelConfig( - model=self.model_config.path, + model=self.server_args.model_path, + quantization=self.server_args.quantization, tokenizer=None, tokenizer_mode=None, - trust_remote_code=self.model_config.trust_remote_code, + trust_remote_code=self.server_args.trust_remote_code, dtype=torch.float16, seed=42, - revision=self.model_config.revision, skip_tokenizer_init=True, ) if self.model_config.model_overide_args is not None: @@ -279,7 +283,7 @@ class ModelRunner: parallel_config=None, scheduler_config=None, ) - logger.info(f"Rank {self.tp_rank}: load weight end.") + logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}") def profile_max_num_token(self, total_gpu_memory): available_gpu_memory = get_available_gpu_memory( diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 94f0ed393..cfe4ab6f8 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -1,5 +1,5 @@ # Adapted from -# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1 +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Mixtral model.""" from typing import Iterable, Optional, Tuple @@ -8,11 +8,13 @@ import torch import torch.nn.functional as F from torch import nn from transformers import MixtralConfig +from vllm import _custom_ops as ops from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, @@ -20,12 +22,15 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import print_warning_once from sglang.srt.layers.logits_processor import LogitsProcessor @@ -33,106 +38,196 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata -class MixtralMLP(nn.Module): + +class MixtralMoE(nn.Module): + """A tensor-parallel MoE implementation for Mixtral that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + def __init__( self, num_experts: int, + top_k: int, hidden_size: int, intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - self.w2 = ReplicatedLinear( - self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config - ) - self.w3 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralMoE(nn.Module): - def __init__( - self, - config: MixtralConfig, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}." - ) - # Split experts equally between ranks - self.expert_indicies = np.array_split( - range(self.num_total_experts), self.tp_size - )[self.rank].tolist() - if not self.expert_indicies: - raise ValueError(f"Rank {self.rank} has no experts assigned to it.") + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + self.quant_config = quant_config - self.experts = nn.ModuleList( - [ - ( - MixtralMLP( - self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config, - ) - if idx in self.expert_indicies - else None - ) - for idx in range(self.num_total_experts) - ] - ) - self.gate = ReplicatedLinear( - config.hidden_size, self.num_total_experts, bias=False, quant_config=None - ) + # FIXME(pcmoritz): Make this more general to support different + # quantization schemes + self.use_fp8 = isinstance(quant_config, Fp8Config) + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None) + + if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + self.w13_weight = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + dtype=params_dtype)) + self.w2_weight = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + dtype=params_dtype)) + + set_weight_attrs(self.w13_weight, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2_weight, { + "weight_loader": self.weight_loader, + }) + + # Used for fp8. + self.w13_scale = None + self.w2_scale = None + self.a13_scale = None + self.a2_scale = None + + if self.use_fp8: + # WEIGHT_SCALE (for fp8) + self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(self.w13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2_scale, { + "weight_loader": self.weight_loader, + }) + + # ACT_SCALE (for fp8) + if quant_config.activation_scheme == "static": + if not quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + self.a13_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + self.a2_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + + set_weight_attrs(self.a13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.a2_scale, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if "act_scale" in weight_name or "weight_scale" in weight_name: + param_data[expert_id] = loaded_weight + + def process_weights_after_loading(self): + # Fp8 is the only case where we need to process after loading. + if not self.use_fp8: + return + + # If checkpoint is fp16, quantize here. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(self.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(self.w2_weight.data, + dtype=torch.float8_e4m3fn) + for expert in range(self.num_total_experts): + w13_weight[expert, :, :], self.w13_scale[ + expert] = ops.scaled_fp8_quant( + self.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], self.w2_scale[ + expert] = ops.scaled_fp8_quant( + self.w2_weight.data[expert, :, :]) + self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) + self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) + + # If checkpoint is fp8 + static, cleanup act_scales. + # Since state_dict has an act_scale per expert but our kernels + # are passed one act_scale shared across all experts. + elif self.quant_config.activation_scheme == "static": + if self.a13_scale is None or self.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + + if (not all_close_1d(self.a13_scale) + or not all_close_1d(self.a2_scale)): + print_warning_once( + "Found act_scales that are not equal for fp8 MoE layer. " + "Using the maximum across experts for each layer. ") + + self.a13_scale = nn.Parameter(self.a13_scale.max(), + requires_grad=False) + self.a2_scale = nn.Parameter(self.a2_scale.max(), + requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.w13_weight, + self.w2_weight, + router_logits, + self.top_k, + renormalize=True, + inplace=True, + use_fp8=self.use_fp8, + w1_scale=self.w13_scale, + w2_scale=self.w2_scale, + a1_scale=self.a13_scale, + a2_scale=self.a2_scale) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = selected_experts == expert_idx - expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_(expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) class MixtralAttention(nn.Module): @@ -234,7 +329,12 @@ class MixtralDecoderLayer(nn.Module): sliding_window=config.sliding_window, quant_config=quant_config, ) - self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) + self.block_sparse_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -342,11 +442,35 @@ class MixtralForCausalLM(nn.Module): ("qkv_proj", "v_proj", "v"), ] + expert_params_mapping = [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id) + ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", + f"experts.{expert_id}.{weight_name}.act_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for param_name, weight_name, shard_id in stacked_params_mapping: + + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -358,15 +482,30 @@ class MixtralForCausalLM(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if "block_sparse_moe.experts." in name and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py new file mode 100644 index 000000000..f60b4c277 --- /dev/null +++ b/python/sglang/srt/models/mixtral_quant.py @@ -0,0 +1,371 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1 +"""Inference-only Mixtral model.""" +from typing import Iterable, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import MixtralConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.router.model_runner import InputMetadata + + +class MixtralMLP(nn.Module): + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear( + self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config + ) + self.w2 = ReplicatedLinear( + self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config + ) + self.w3 = ReplicatedLinear( + self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config + ) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralMoE(nn.Module): + def __init__( + self, + config: MixtralConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}." + ) + # Split experts equally between ranks + self.expert_indicies = np.array_split( + range(self.num_total_experts), self.tp_size + )[self.rank].tolist() + if not self.expert_indicies: + raise ValueError(f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList( + [ + ( + MixtralMLP( + self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + ) + if idx in self.expert_indicies + else None + ) + for idx in range(self.num_total_experts) + ] + ) + self.gate = ReplicatedLinear( + config.hidden_size, self.num_total_experts, bias=False, quant_config=None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = selected_experts == expert_idx + expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_(expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states) + + +class MixtralAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + sliding_window: Optional[int] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MixtralDecoderLayer(nn.Module): + def __init__( + self, + config: MixtralConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MixtralAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + sliding_window=config.sliding_window, + quant_config=quant_config, + ) + self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +class MixtralModel(nn.Module): + def __init__( + self, + config: MixtralConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + MixtralDecoderLayer(config, i, quant_config=quant_config) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, input_metadata, residual + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class QuantMixtralForCausalLM(nn.Module): + def __init__( + self, + config: MixtralConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = MixtralModel(config, quant_config=quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if "block_sparse_moe.experts." in name and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = QuantMixtralForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 65608af89..061340ffa 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -15,6 +15,7 @@ class ServerArgs: chat_template: Optional[str] = None trust_remote_code: bool = True context_length: Optional[int] = None + quantization: Optional[str] = None # Port host: str = "127.0.0.1" @@ -135,6 +136,12 @@ class ServerArgs: default=ServerArgs.context_length, help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", ) + parser.add_argument( + "--quantization", + type=str, + default=ServerArgs.quantization, + help="The quantization method.", + ) parser.add_argument( "--mem-fraction-static", type=float, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f8187ad2a..981e82152 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -106,6 +106,7 @@ def get_available_gpu_memory(gpu_id, distributed=True): "which may cause useless memory allocation for torch CUDA context.", ) + torch.cuda.empty_cache() free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) if distributed: