diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index b590790be..a91262823 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -107,6 +107,7 @@ class FlashInferAttnBackend(AttentionBackend): if ( "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures + or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures ): global_config.flashinfer_workspace_size = 512 * 1024 * 1024 diff --git a/python/sglang/srt/models/xiaomi_mimo.py b/python/sglang/srt/models/xiaomi_mimo.py new file mode 100644 index 000000000..2a89e7706 --- /dev/null +++ b/python/sglang/srt/models/xiaomi_mimo.py @@ -0,0 +1,171 @@ +# Adapted from qwen2.py + +from functools import partial +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2Model +from sglang.srt.utils import add_prefix + +MiMoConfig = None + + +class MiMoModel(Qwen2Model): + def __init__( + self, + config: MiMoConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + quant_config=quant_config, + prefix=prefix, + decoder_layer_type=Qwen2DecoderLayer, + ) + + +class MiMoForCausalLM(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: MiMoConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = MiMoModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if ( + "rotary_emb.inv_freq" in name + or "projector" in name + or "mtp_layers" in name + ): + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + + +EntryClass = MiMoForCausalLM