From d383e6616e5193d207a6f8b2d703066482e57a0f Mon Sep 17 00:00:00 2001 From: Shane A Date: Sun, 19 Oct 2025 23:59:16 -0700 Subject: [PATCH] [Model] Add Olmo 3 model support (#11396) --- docs/supported_models/generative_models.md | 1 + python/sglang/srt/configs/__init__.py | 2 + python/sglang/srt/configs/olmo3.py | 105 ++++++++++++++++++ python/sglang/srt/models/olmo2.py | 35 +++++- python/sglang/srt/server_args.py | 26 +++++ python/sglang/srt/utils/common.py | 1 + .../sglang/srt/utils/hf_transformers_utils.py | 2 + test/srt/models/test_generation_models.py | 1 + 8 files changed, 169 insertions(+), 4 deletions(-) create mode 100644 python/sglang/srt/configs/olmo3.py diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index 7733bedc8..fdb18b845 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -33,6 +33,7 @@ in the GitHub search bar. | **Gemma** (v1, v2, v3) | `google/gemma-3-1b-it` | Google’s family of efficient multilingual models (1B–27B); Gemma 3 offers a 128K context window, and its larger (4B+) variants support vision input. | | **Phi** (Phi-1.5, Phi-2, Phi-3, Phi-4, Phi-MoE series) | `microsoft/Phi-4-multimodal-instruct`, `microsoft/Phi-3.5-MoE-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-multimodal (5.6B) processes text, images, and speech, Phi-4-mini is a high-accuracy text model and Phi-3.5-MoE is a mixture-of-experts model. | | **MiniCPM** (v3, 4B) | `openbmb/MiniCPM3-4B` | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. | +| **OLMo** (2, 3) | `allenai/OLMo-2-1124-7B-Instruct` | Allen AI’s series of Open Language Models designed to enable the science of language models. | | **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. | | **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. | | **Command-R** (Cohere) | `CohereForAI/c4ai-command-r-v01` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. | diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index fb5a4d6d2..b3ae8b5c7 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -10,6 +10,7 @@ from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig from sglang.srt.configs.longcat_flash import LongcatFlashConfig from sglang.srt.configs.nemotron_h import NemotronHConfig +from sglang.srt.configs.olmo3 import Olmo3Config from sglang.srt.configs.qwen3_next import Qwen3NextConfig from sglang.srt.configs.step3_vl import ( Step3TextConfig, @@ -29,6 +30,7 @@ __all__ = [ "Step3VLConfig", "Step3TextConfig", "Step3VisionEncoderConfig", + "Olmo3Config", "Qwen3NextConfig", "DotsVLMConfig", "DotsOCRConfig", diff --git a/python/sglang/srt/configs/olmo3.py b/python/sglang/srt/configs/olmo3.py new file mode 100644 index 000000000..95e7c2537 --- /dev/null +++ b/python/sglang/srt/configs/olmo3.py @@ -0,0 +1,105 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Olmo3 model configuration""" + +import enum + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Olmo3LayerType(enum.Enum): + full_attention = "full_attention" + sliding_attention = "sliding_attention" + + +class Olmo3Config(PretrainedConfig): + + model_type = "olmo3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + sliding_window=4096, + layer_types=None, + **kwargs, + ): + # This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM + # in sglang. + if "architectures" not in kwargs: + kwargs["architectures"] = ["Olmo2ForCausalLM"] + elif "Olmo3ForCausalLM" in kwargs["architectures"]: + kwargs["architectures"].remove("Olmo3ForCausalLM") + kwargs["architectures"].append("Olmo2ForCausalLM") + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + rope_config_validation(self) + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + self.sliding_window = sliding_window + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" + for i in range(self.num_hidden_layers) + ] diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py index 75834e6fb..de5087f31 100644 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -48,6 +48,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, make_layers +# Aligned with HF's implementation, using sliding window inclusive with the last token +# SGLang assumes exclusive +def get_attention_sliding_window_size(config): + return config.sliding_window - 1 if hasattr(config, "sliding_window") else None + + class Olmo2Attention(nn.Module): """ This is the attention block where the output is computed as @@ -85,6 +91,8 @@ class Olmo2Attention(nn.Module): self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta @@ -104,12 +112,26 @@ class Olmo2Attention(nn.Module): eps=self.config.rms_norm_eps, ) self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - # Rotary embeddings. + + sliding_window = None + if ( + layer_types := getattr(self.config, "layer_types", None) + ) is not None and layer_types[layer_id] == "sliding_attention": + sliding_window = get_attention_sliding_window_size(self.config) + + # Rotary embeddings. Rope scaling is only applied on full attention + # layers. + self.rope_scaling = ( + self.config.rope_scaling + if sliding_window is None + else {"rope_type": "default"} + ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=self.max_position_embeddings, base=self.rope_theta, + rope_scaling=self.rope_scaling, ) self.scaling = self.head_dim**-0.5 self.attn = RadixAttention( @@ -118,6 +140,7 @@ class Olmo2Attention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + sliding_window_size=sliding_window, quant_config=quant_config, prefix=add_prefix("attn", prefix), ) @@ -152,7 +175,7 @@ class Olmo2Attention(nn.Module): forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, forward_batch) @@ -224,6 +247,7 @@ class Olmo2DecoderLayer(nn.Module): prefix: str = "", ): super().__init__() + self.layer_id = layer_id # Attention block. self.self_attn = Olmo2Attention( config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix) @@ -280,8 +304,8 @@ class Olmo2Model(nn.Module): self.layers = make_layers( config.num_hidden_layers, lambda idx, prefix: Olmo2DecoderLayer( - layer_id=idx, config=config, + layer_id=idx, quant_config=quant_config, prefix=prefix, ), @@ -294,7 +318,7 @@ class Olmo2Model(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, - input_embeds: torch.Tensor = None, + input_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. @@ -351,6 +375,9 @@ class Olmo2ForCausalLM(nn.Module): ) self.logits_processor = LogitsProcessor(config) + def get_attention_sliding_window_size(self): + return get_attention_sliding_window_size(self.config) + def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4dfb2660c..939e502fc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -36,6 +36,7 @@ from sglang.srt.utils import ( configure_ipv6, get_device, get_device_memory_capacity, + get_device_sm, is_cuda, is_flashinfer_available, is_hip, @@ -942,6 +943,31 @@ class ServerArgs: f"Disable hybrid SWA memory for {model_arch} as it is not yet supported." ) self.disable_hybrid_swa_memory = True + elif model_arch in ["Olmo2ForCausalLM"]: + # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with Olmo3 model. + logger.warning( + f"Disabling hybrid SWA memory for {model_arch} as it is not yet supported." + ) + self.disable_hybrid_swa_memory = True + + if self.attention_backend is None: + if is_cuda() and is_sm100_supported(): + self.attention_backend = "trtllm_mha" + elif is_cuda() and get_device_sm() >= 80: + self.attention_backend = "fa3" + else: + self.attention_backend = "triton" + + # Flashinfer appears to degrade performance when sliding window attention + # is used for the Olmo2 architecture. Olmo2 does not use sliding window attention + # but Olmo3 does. + assert ( + self.attention_backend != "flashinfer" + ), "FlashInfer backend can significantly degrade the performance of Olmo3 models." + + logger.info( + f"Using {self.attention_backend} as attention backend for {model_arch}." + ) if is_deepseek_nsa(hf_config): if ( diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 1c50e7fb9..2264168e2 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2530,6 +2530,7 @@ def is_fa3_default_architecture(hf_config): "Qwen2ForCausalLM", "Llama4ForConditionalGeneration", "LlamaForCausalLM", + "Olmo2ForCausalLM", "Gemma2ForCausalLM", "Gemma3ForConditionalGeneration", "Qwen3ForCausalLM", diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 527d6bd04..b20fcd605 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -47,6 +47,7 @@ from sglang.srt.configs import ( LongcatFlashConfig, MultiModalityConfig, NemotronHConfig, + Olmo3Config, Qwen3NextConfig, Step3VLConfig, ) @@ -64,6 +65,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { InternVLChatConfig.model_type: InternVLChatConfig, Step3VLConfig.model_type: Step3VLConfig, LongcatFlashConfig.model_type: LongcatFlashConfig, + Olmo3Config.model_type: Olmo3Config, Qwen3NextConfig.model_type: Qwen3NextConfig, FalconH1Config.model_type: FalconH1Config, DotsVLMConfig.model_type: DotsVLMConfig, diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index d6c576471..4aab2246e 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -61,6 +61,7 @@ ALL_MODELS = [ ModelCase("Qwen/Qwen2.5-14B-Instruct"), ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True), ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True), + ModelCase("shanearora/2025-sep-a-base-model"), ModelCase( "THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True ),