From a6ae3af15e84f8d86d89e24a04493d45815e01d6 Mon Sep 17 00:00:00 2001 From: ryang <38470282+ryang-max@users.noreply.github.com> Date: Fri, 23 May 2025 05:14:49 +0800 Subject: [PATCH] Support XiaomiMiMo inference with mtp (#6059) --- docs/backend/speculative_decoding.ipynb | 54 +++++ python/sglang/srt/configs/model_config.py | 3 + .../sglang/srt/model_executor/model_runner.py | 15 +- .../srt/models/{xiaomi_mimo.py => mimo.py} | 0 python/sglang/srt/models/mimo_mtp.py | 220 ++++++++++++++++++ test/srt/models/test_mtp_models.py | 58 +++++ 6 files changed, 344 insertions(+), 6 deletions(-) rename python/sglang/srt/models/{xiaomi_mimo.py => mimo.py} (100%) create mode 100644 python/sglang/srt/models/mimo_mtp.py create mode 100644 test/srt/models/test_mtp_models.py diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index 7b68d1d26..bff10ff26 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -283,6 +283,60 @@ "terminate_process(server_process)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi Token Prediction\n", + "\n", + "We support [MTP(Multi-Token Prediction)](https://arxiv.org/pdf/2404.19737) in SGLang by using speculative decoding. We use Xiaomi/MiMo-7B-RL model as example here (deepseek mtp usage refer to [deepseek doc](../references/deepseek.md#multi-token-prediction))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server_process, port = launch_server_cmd(\n", + " \"\"\"\n", + " python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-7B-RL --host 0.0.0.0 --trust-remote-code \\\n", + " --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \\\n", + " --mem-fraction 0.5\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "\n", + "url = f\"http://localhost:{port}/v1/chat/completions\"\n", + "\n", + "data = {\n", + " \"model\": \"XiaomiMiMo/MiMo-7B-RL\",\n", + " \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}],\n", + "}\n", + "\n", + "response = requests.post(url, json=data)\n", + "print_highlight(response.json())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 043ef7a0f..7d3e4131e 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -73,6 +73,7 @@ class ModelConfig: model_override_args=self.model_override_args, **kwargs, ) + self.hf_text_config = get_hf_text_config(self.hf_config) self.attention_chunk_size = getattr( self.hf_text_config, "attention_chunk_size", None @@ -97,6 +98,8 @@ class ModelConfig: ): self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM": + self.hf_config.architectures[0] = "MiMoMTP" # Check model type self.is_generation = is_generation_model( self.hf_config.architectures, is_embedding diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4b4c61a23..8aed6399f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -782,12 +782,15 @@ class ModelRunner: distributed=get_world_group().world_size > 1, cpu_group=get_world_group().cpu_group, ) - if self.use_mla_backend: - num_layers = ( - self.model_config.num_hidden_layers - if not self.is_draft_worker - else self.model_config.hf_config.num_nextn_predict_layers + if self.is_draft_worker: + num_layers = getattr( + self.model_config.hf_config, + "num_nextn_predict_layers", + self.num_effective_layers, ) + else: + num_layers = self.num_effective_layers + if self.use_mla_backend: # FIXME: pipeline parallelism is not compatible with mla backend assert self.pp_size == 1 cell_size = ( @@ -799,7 +802,7 @@ class ModelRunner: cell_size = ( self.model_config.get_num_kv_heads(get_attention_tp_size()) * self.model_config.head_dim - * self.num_effective_layers + * num_layers * 2 * torch._utils._element_size(self.kv_cache_dtype) ) diff --git a/python/sglang/srt/models/xiaomi_mimo.py b/python/sglang/srt/models/mimo.py similarity index 100% rename from python/sglang/srt/models/xiaomi_mimo.py rename to python/sglang/srt/models/mimo.py diff --git a/python/sglang/srt/models/mimo_mtp.py b/python/sglang/srt/models/mimo_mtp.py new file mode 100644 index 000000000..6c81d8d85 --- /dev/null +++ b/python/sglang/srt/models/mimo_mtp.py @@ -0,0 +1,220 @@ +# Adapted from https://github.com/vllm-project/vllm/pull/17433/files and deepseek_nextn.py + +from functools import partial +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.mimo import MiMoForCausalLM +from sglang.srt.models.qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2MLP, + Qwen2Model, +) +from sglang.srt.utils import add_prefix + + +class MiMoMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_proj = nn.Linear( + config.hidden_size * 2, config.hidden_size, bias=False + ) + self.mtp_block = Qwen2DecoderLayer( + config=config, quant_config=quant_config, prefix=prefix + ) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + # masking inputs at position 0, as not needed by MTP + hidden_states[positions == 0] = 0 + + hidden_states = self.input_proj( + torch.cat( + ( + self.hidden_layernorm(forward_batch.spec_info.hidden_states), + self.token_layernorm(hidden_states), + ), + dim=-1, + ) + ) + + hidden_states, residual = self.mtp_block( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + residual=None, + ) + hidden_states = residual + hidden_states + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class MiMoMTP(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + + self.model = MiMoMultiTokenPredictorLayer( + config, + prefix, + quant_config, + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + name = self.map_model_name_to_mtp_param_name(name) + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mtp_block" not in name: + break + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if "mtp_block" not in name and ( + "embed_tokens" not in name + and "lm_head" not in name + and "token_layernorm" not in name + and "hidden_layernorm" not in name + and "input_proj" not in name + and "final_layernorm" not in name + ): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def map_model_name_to_mtp_param_name(self, name: str) -> str: + import re + + name_without_prefix = [ + "token_layernorm", + "hidden_layernorm", + "input_proj", + "final_layernorm", + ] + pattern = r"model.mtp_layers.(\d+)." + group = re.match(pattern, name) + if group is not None: + for sub_name in name_without_prefix: + if sub_name in name: + name = name.replace(group.group(), "model.") + return name + name = name.replace(group.group(), "model.mtp_block.") + return name + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +EntryClass = MiMoMTP diff --git a/test/srt/models/test_mtp_models.py b/test/srt/models/test_mtp_models.py new file mode 100644 index 000000000..49b53c1e4 --- /dev/null +++ b/test/srt/models/test_mtp_models.py @@ -0,0 +1,58 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestMiMoMTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "XiaomiMiMo/MiMo-7B-RL" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--mem-fraction-static", + "0.5", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.7) + + +if __name__ == "__main__": + unittest.main()