diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 3e4cd2688..8cb91894e 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -749,6 +749,8 @@ multimodal_model_archs = [ "Qwen2AudioForConditionalGeneration", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration", + "Qwen3VLForConditionalGeneration", + "Qwen3VLMoeForConditionalGeneration", "KimiVLForConditionalGeneration", "InternVLChatModel", "InternS1ForConditionalGeneration", diff --git a/python/sglang/srt/configs/qwen3_vl.py b/python/sglang/srt/configs/qwen3_vl.py new file mode 100644 index 000000000..4a995c856 --- /dev/null +++ b/python/sglang/srt/configs/qwen3_vl.py @@ -0,0 +1,586 @@ +from typing import Optional, Union + +from transformers import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class Qwen3VLVisionConfig(PretrainedConfig): + model_type = "qwen3_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +class Qwen3VLTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3VLModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + head_dim (`int`, *optional*, defaults to 128): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig + + >>> # Initializing a Qwen3VL style configuration + >>> configuration = Qwen3VLTextConfig() + + >>> # Initializing a model from the Qwen3-VL-7B style configuration + >>> model = Qwen3VLTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_text" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig + + >>> # Initializing a Qwen3-VL style configuration + >>> configuration = Qwen3VLConfig() + + >>> # Initializing a model from the Qwen3-VL-4B style configuration + >>> model = Qwen3VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl" + sub_configs = { + "vision_config": Qwen3VLVisionConfig, + "text_config": Qwen3VLTextConfig, + } + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) + + +class Qwen3VLMoeTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 60): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + head_dim (`int`, *optional*): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3VLMoe style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen3VLMoe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + attention_bias=False, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=1408, + num_experts_per_tok=4, + num_experts=60, + norm_topk_prob=True, + mlp_only_layers=None, + rope_scaling=None, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.head_dim = head_dim or hidden_size // num_attention_heads + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLMoeVisionConfig(PretrainedConfig): + model_type = "qwen3_vl_moe" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +class Qwen3VLMoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3-VL-MOE style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe" + sub_configs = { + "vision_config": Qwen3VLMoeVisionConfig, + "text_config": Qwen3VLMoeTextConfig, + } + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) + + +__all__ = [ + "Qwen3VLMoeConfig", + "Qwen3VLMoeVisionConfig", + "Qwen3VLConfig", + "Qwen3VLVisionConfig", +] diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index eacf84c8a..fd69ea727 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1187,7 +1187,7 @@ class MRotaryEmbedding(RotaryEmbedding): time_tensor_long = time_tensor.long() t_index = time_tensor_long.flatten() - elif model_type == "qwen2_vl": + elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"): t_index = ( torch.arange(llm_grid_t) .view(-1, 1) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index f495904d5..41de295af 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -507,6 +507,7 @@ def embed_mm_inputs( Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] ] = None, placeholder_tokens: dict[Modality, List[int]] = None, + use_deepstack: bool = False, ) -> Optional[torch.Tensor]: """ Embed multimodal inputs and integrate them with text token embeddings. @@ -522,7 +523,7 @@ def embed_mm_inputs( Returns: Combined embedding tensor with multimodal content integrated """ - + other_info = {} if mm_inputs_list is None: return None @@ -532,7 +533,7 @@ def embed_mm_inputs( for mm_inputs in mm_inputs_list: item_flatten_list += [item for item in mm_inputs.mm_items if item is not None] - embeddings, masks = [], [] + embeddings, masks, deepstack_embeddings = [], [], [] # 2. Get multimodal embedding separately # Try get mm embedding if any for modality in Modality.all(): @@ -578,6 +579,12 @@ def embed_mm_inputs( extend_length=extend_seq_lens, items_offset_list=items_offsets, ) + + if use_deepstack and embedding is not None: + embedding, deepstack_embedding = ( + multimodal_model.separate_deepstack_embeds(embedding) + ) + deepstack_embeddings += [deepstack_embedding] embeddings += [embedding] masks += [mask] @@ -591,13 +598,37 @@ def embed_mm_inputs( inputs_embeds = input_embedding(input_ids) # 4. scatter embeddings into input embedding - for embedding, mask in zip(embeddings, masks): + + # deepstack embedding + if use_deepstack: + num_deepstack_embeddings = ( + len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0 + ) + deepstack_embedding_shape = inputs_embeds.shape[:-1] + ( + inputs_embeds.shape[-1] * num_deepstack_embeddings, + ) + + input_deepstack_embeds = torch.zeros( + deepstack_embedding_shape, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + ) + + other_info["input_deepstack_embeds"] = input_deepstack_embeds + + for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks): if embedding is None or mask is None: continue # in-place update indices = torch.where(mask.squeeze(dim=-1))[0] inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) - return inputs_embeds + + if use_deepstack: + input_deepstack_embeds[indices] = deepstack_embeddings[i].to( + inputs_embeds.device, inputs_embeds.dtype + ) + + return inputs_embeds, other_info def general_mm_embed_routine( @@ -609,6 +640,7 @@ def general_mm_embed_routine( Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] ] = None, placeholder_tokens: Optional[dict[Modality, List[int]]] = None, + use_deepstack: bool = False, **kwargs, ) -> torch.Tensor: """ @@ -620,6 +652,7 @@ def general_mm_embed_routine( language_model: Base language model to use data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function. placeholder_tokens: Token IDs for multimodal placeholders + use_deepstack: Whether to use deepstack embeddings **kwargs: Additional arguments passed to language model Returns: @@ -645,16 +678,20 @@ def general_mm_embed_routine( for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu) if forward_batch.mm_inputs[i] is not None ] - inputs_embeds = embed_mm_inputs( + inputs_embeds, other_info = embed_mm_inputs( mm_inputs_list=mm_inputs_list, extend_prefix_lens=extend_prefix_lens, extend_seq_lens=extend_seq_lens, input_ids=input_ids, - input_embedding=embed_tokens, multimodal_model=multimodal_model, + input_embedding=embed_tokens, data_embedding_func_mapping=data_embedding_funcs, placeholder_tokens=placeholder_tokens, + use_deepstack=use_deepstack, ) + # add for qwen3_vl deepstack + if use_deepstack: + kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"] # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models # just being defensive here forward_batch.mm_inputs = None diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py new file mode 100644 index 000000000..a87d21e78 --- /dev/null +++ b/python/sglang/srt/models/qwen3_vl.py @@ -0,0 +1,787 @@ +# Copyright 2025 Qwen Team +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Inference-only Qwen3-VL model compatible with HuggingFace weights.""" +import logging +from functools import lru_cache, partial +from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers.activations import ACT2FN +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionRotaryEmbedding, +) + +from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig +from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs +from sglang.srt.models.qwen3 import Qwen3Model +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + +# === Vision Encoder === # + + +class Qwen3_VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = True, + hidden_act="silu", + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear_fc1 = ColumnParallelLinear( + in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear_fc1", prefix), + ) + self.linear_fc2 = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear_fc2", prefix), + ) + self.act = ACT2FN[hidden_act] + + def forward(self, x: torch.Tensor): + x_fc1, _ = self.linear_fc1(x) + mlp_output, _ = self.linear_fc2(self.act(x_fc1)) + return mlp_output + + +class Qwen3VLVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d( + self.in_channels, + self.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim + ) + return hidden_states + + +class Qwen3_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + intermediate_dim: int, + hidden_act="silu", + norm_layer: Optional[Callable[[int], nn.Module]] = None, + attn_implementation: Optional[str] = "sdpa", + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + + if attn_implementation == "sdpa": + softmax_in_single_precision = False + qkv_backend = "sdpa" + flatten_batch = True + elif attn_implementation == "flash_attention_2": + softmax_in_single_precision = False + qkv_backend = "triton_attn" + flatten_batch = True + elif attn_implementation == "eager": + softmax_in_single_precision = True + qkv_backend = "sdpa" + flatten_batch = True + elif attn_implementation == "flash_attention_3": + softmax_in_single_precision = False + qkv_backend = "fa3" + flatten_batch = True + + self.attn = VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + use_qkv_parallel=True, + rotary_embed="normal", + proj_bias=True, + qkv_backend=qkv_backend, + softmax_in_single_precision=softmax_in_single_precision, + flatten_batch=flatten_batch, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + self.mlp = Qwen3_VisionMLP( + dim, + intermediate_dim, + hidden_act=hidden_act, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.norm1(x) + hidden_states = rearrange(hidden_states, "s b ... -> b s ...") + attn = self.attn( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + attn = rearrange(attn, "b s ... -> s b ...") + x = x + attn + norm2 = self.norm2(x) + mlp = self.mlp(norm2) + x = x + mlp + return x + + +class Qwen3_VisionPatchMerger(nn.Module): + + def __init__( + self, + dim: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + + self.use_postshuffle_norm = use_postshuffle_norm + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm = norm_layer( + self.hidden_size if use_postshuffle_norm else context_dim + ) + self.linear_fc1 = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=add_prefix("linear_fc1", prefix), + ) + self.act_fn = nn.GELU() + self.linear_fc2 = RowParallelLinear( + self.hidden_size, + dim, + bias=True, + quant_config=quant_config, + prefix=add_prefix("linear_fc2", prefix), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.hidden_size)) + else: + x = self.norm(x).view(-1, self.hidden_size) + + x_parallel, _ = self.linear_fc1(x) + x_parallel = self.act_fn(x_parallel) + out, _ = self.linear_fc2(x_parallel) + return out + + +class Qwen3_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen3VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.num_position_embeddings = vision_config.num_position_embeddings + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config) + self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + intermediate_dim=vision_config.intermediate_size, + hidden_act=vision_config.hidden_act, + norm_layer=norm_layer, + attn_implementation="flash_attention_3", + quant_config=quant_config, + prefix=add_prefix(f"blocks.{layer_idx}", prefix), + ) + for layer_idx in range(vision_config.depth) + ] + ) + self.merger = Qwen3_VisionPatchMerger( + dim=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=add_prefix("merger", prefix), + ) + + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3_VisionPatchMerger( + dim=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=add_prefix(f"deepstack_merger_list.{layer_idx}", prefix), + ) + for layer_idx in range(len(self.deepstack_visual_indexes)) + ] + ) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + num_grid_per_side = int(self.num_position_embeddings**0.5) + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + # TODO: use torch instand of np + for t, h, w in grid_thw: + h_idxs = np.linspace(0, num_grid_per_side - 1, h) + w_idxs = np.linspace(0, num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.astype(int) + w_idxs_floor = w_idxs.astype(int) + h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + idx_list[0].extend( + ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None]) + .flatten() + .tolist() + * t + ) + idx_list[1].extend( + ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None]) + .flatten() + .tolist() + * t + ) + idx_list[2].extend( + ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None]) + .flatten() + .tolist() + * t + ) + idx_list[3].extend( + ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None]) + .flatten() + .tolist() + * t + ) + + weight_list[0].extend( + ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t + ) + weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t) + weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t) + weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t) + + device = self.pos_embed.weight.device + dtype = self.pos_embed.weight.dtype + + p0 = ( + self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device)) + * torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None] + ) + p1 = ( + self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device)) + * torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None] + ) + p2 = ( + self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device)) + * torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None] + ) + p3 = ( + self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device)) + * torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None] + ) + + patch_pos_embeds = p0 + p1 + p2 + p3 + patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw]) + patch_pos_embeds_permute = [] + m_size = self.spatial_merge_size + for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): + pos_embed = ( + pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + x = x + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = x.size() + rotary_pos_emb = rotary_pos_emb.to(x.device) + + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + # compute cu_seqlens + cu_seqlens = torch.cat( + [ + torch.tensor([0], device=grid_thw.device), + (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0), + ] + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + x = x.unsqueeze(1) + + deepstack_feature_lists = [] + num_deepstack_captured = 0 + for layer_num, blk in enumerate(self.blocks): + x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[num_deepstack_captured]( + x + ) + deepstack_feature_lists.append(deepstack_feature) + num_deepstack_captured += 1 + x = self.merger(x) + hidden_states = torch.cat( + [x] + deepstack_feature_lists, dim=1 + ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +cached_get_processor = lru_cache(get_processor) + + +class Qwen3LLMModel(Qwen3Model): + + def __init__( + self, + *, + config: Qwen3VLConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__(config=config, quant_config=quant_config, prefix=prefix) + if not self.pp_group.is_first_rank: + assert self.start_layer >= len( + config.vision_config.deepstack_visual_indexes + ), "start_layer should be greater than or equal to len(deepstack_visual_indexes)" + + self.hidden_size = config.hidden_size + self.deepstack_embed_to_decoder_layer = range( + len(config.vision_config.deepstack_visual_indexes) + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + input_deepstack_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + else: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + aux_hidden_states = [] + for layer_idx, layer in enumerate( + self.layers[self.start_layer : self.end_layer] + ): + layer_idx = layer_idx + self.start_layer + if layer_idx in self.layers_to_capture: + aux_hidden_states.append( + hidden_states + residual if residual is not None else hidden_states + ) + + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + + # process deepstack + if ( + input_deepstack_embeds is not None + and layer_idx in self.deepstack_embed_to_decoder_layer + ): + sep = self.hidden_size * layer_idx + hidden_states = ( + hidden_states + + input_deepstack_embeds[:, sep : sep + self.hidden_size] + ) + + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + if hidden_states.shape[0] != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states + + +class Qwen3VLForConditionalGeneration(nn.Module): + def __init__( + self, + config: Qwen3VLConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. + # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. + quant_config=quant_config, + prefix=add_prefix("visual", prefix), + ) + + self.model = Qwen3LLMModel( + config=config, + quant_config=quant_config, + prefix=add_prefix("model", prefix), + ) + + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling + + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on + # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states + + # deepstack + self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes + self.num_deepstack_embeddings = len(self.deepstack_visual_indexes) + + @property + def use_deepstack(self) -> bool: + return hasattr(self, "deepstack_visual_indexes") + + def separate_deepstack_embeds(self, embedding): + assert ( + embedding.shape[-1] % (1 + self.num_deepstack_embeddings) == 0 + ), f"hidden_state of {embedding.shape} should be divisible by ({1 + self.num_deepstack_embeddings})" + + separate_index = self.config.hidden_size + input_embeds = embedding[:, :separate_index] + input_deepstack_embeds = embedding[:, separate_index:] + return input_embeds, input_deepstack_embeds + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # in qwen-vl, last dim is the same + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.visual.dtype + ) + image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) + assert pixel_values.dim() == 2, pixel_values.dim() + assert image_grid_thw.dim() == 2, image_grid_thw.dim() + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + + def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # in qwen-vl, last dim is the same + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.visual.dtype + ) + video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0) + assert pixel_values.dim() == 2, pixel_values.dim() + assert video_grid_thw.dim() == 2, video_grid_thw.dim() + video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw) + return video_embeds + + def get_input_embeddings(self): + return self.model.embed_tokens + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + get_embedding: bool = False, + ): + """Run forward pass for Qwen3-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + (Use input_metadata.mrope_positions to replace it) + """ + if self.is_mrope_enabled: + positions = forward_batch.mrope_positions + + if not ( + forward_batch.forward_mode.is_decode() + or not forward_batch.contains_image_inputs() + ): + if self.is_mrope_enabled: + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}" + ) + + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.model, + multimodal_model=self, + positions=positions, + use_deepstack=self.use_deepstack, + ) + + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ("gate_up_proj", "up_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "language_model" in name: + name = name.replace(r"model.language_model.", r"model.") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "visual" in name: + continue + name = name.replace(weight_name, param_name) + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "visual" in name: + # adapt to VisionAttention + name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") + name = name.replace(r"model.visual.", r"visual.") + + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + print(params_dict.keys()) + raise + + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = Qwen3VLForConditionalGeneration diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py new file mode 100644 index 000000000..a88059916 --- /dev/null +++ b/python/sglang/srt/models/qwen3_vl_moe.py @@ -0,0 +1,471 @@ +# Copyright 2025 Qwen Team +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Inference-only Qwen3-VL model compatible with HuggingFace weights.""" +import logging +from functools import lru_cache, partial +from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import BatchFeature +from transformers.activations import ACT2FN +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionRotaryEmbedding, +) + +from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig +from sglang.srt.distributed import ( + get_moe_expert_parallel_world_size, + get_pp_group, + get_tensor_model_parallel_rank, +) +from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + MultimodalDataItem, + MultimodalInputs, + global_server_args_dict, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from sglang.srt.models.qwen3_vl import ( + Qwen3_VisionTransformer, + Qwen3VLForConditionalGeneration, +) +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + +cached_get_processor = lru_cache(get_processor) + + +class Qwen3MoeLLMModel(Qwen3MoeModel): + def __init__( + self, + *, + config: Qwen3VLMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__(config=config, quant_config=quant_config, prefix=prefix) + + self.hidden_size = config.hidden_size + + def get_input_embeddings(self) -> nn.Embedding: + return self.embed_tokens + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # in qwen-vl, last dim is the same + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.visual.dtype + ) + image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) + assert pixel_values.dim() == 2, pixel_values.dim() + assert image_grid_thw.dim() == 2, image_grid_thw.dim() + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + input_deepstack_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + else: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + aux_hidden_states = [] + for layer_idx, layer in enumerate( + self.layers[self.start_layer : self.end_layer] + ): + layer_idx = layer_idx + self.start_layer + if layer_idx in self.layers_to_capture: + aux_hidden_states.append( + hidden_states + residual if residual is not None else hidden_states + ) + + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + + # process deepstack + if input_deepstack_embeds is not None and layer_idx in range(3): + sep = self.hidden_size * layer_idx + hidden_states = ( + hidden_states + + input_deepstack_embeds[:, sep : sep + self.hidden_size] + ) + + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + if hidden_states.shape[0] != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states + + +class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): + def __init__( + self, + *, + config: Qwen3VLMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super(Qwen3VLForConditionalGeneration, self).__init__() + self.config = config + + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. + # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. + quant_config=quant_config, + prefix=add_prefix("visual", prefix), + ) + + self.model = Qwen3MoeLLMModel( + config=config, + quant_config=quant_config, + prefix=add_prefix("model", prefix), + ) + + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling + + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + # deepstack + self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes + self.num_deepstack_embeddings = len(self.deepstack_visual_indexes) + + @property + def use_deepstack(self) -> bool: + return hasattr(self, "deepstack_visual_indexes") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + get_embedding: bool = False, + ): + """Run forward pass for Qwen3-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + (Use input_metadata.mrope_positions to replace it) + """ + if self.is_mrope_enabled: + positions = forward_batch.mrope_positions + + if not ( + forward_batch.forward_mode.is_decode() + or not forward_batch.contains_image_inputs() + ): + if self.is_mrope_enabled: + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}" + ) + + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.model, + multimodal_model=self, + positions=positions, + use_deepstack=self.use_deepstack, + ) + + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) + + def load_fused_expert_weights( + self, + name: str, + params_dict: dict, + loaded_weight: torch.Tensor, + shard_id: str, + num_experts: int, + ): + param = params_dict[name] + # weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + weight_loader = param.weight_loader + ep_rank = get_tensor_model_parallel_rank() + ep_size = get_moe_expert_parallel_world_size() + if ep_size == 1: + for expert_id in range(num_experts): + curr_expert_weight = loaded_weight[expert_id] + weight_loader( + param, + curr_expert_weight, + name, + shard_id, + expert_id, + ) + else: + experts_per_ep = num_experts // ep_size + start_expert = ep_rank * experts_per_ep + end_expert = ( + (ep_rank + 1) * experts_per_ep + if ep_rank != ep_size - 1 + else num_experts + ) + + for idx, expert_id in enumerate(range(start_expert, end_expert)): + curr_expert_weight = loaded_weight[expert_id] + weight_loader( + param, + curr_expert_weight, + name, + shard_id, + idx, + ) + return True + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ("gate_up_proj", "up_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = ( + ".bias", + "_bias", + ".k_scale", + "_k_scale", + ".v_scale", + "_v_scale", + ".weight_scale", + "_weight_scale", + ".input_scale", + "_input_scale", + ) + + is_fused_expert = False + fused_expert_params_mapping = [ + ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), + ("experts.w2_weight", "experts.down_proj", 0, "w2"), + ] + + num_experts = self.config.num_experts + + # Cache params_dict to avoid repeated expensive traversal of model parameters + if not hasattr(self, "_cached_params_dict"): + self._cached_params_dict = dict(self.named_parameters()) + params_dict = self._cached_params_dict + for name, loaded_weight in weights: + if "language_model" in name: + name = name.replace(r"model.language_model.", r"model.") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: + is_fused_expert = True + expert_params_mapping = fused_expert_params_mapping + + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if "visual" in name: + continue + + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + # [TODO] Skip layers that are on other devices (check if sglang has a similar function) + # if is_pp_missing_parameter(name, self): + # continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Track if this is an expert weight to enable early skipping + is_expert_weight = False + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + if "visual" in name: + continue + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + if is_fused_expert: + loaded_weight = loaded_weight.transpose(-1, -2) # no bias + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[0], + "w1", + num_experts, + ) + self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[1], + "w3", + num_experts, + ) + else: + self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight, + shard_id, + num_experts, + ) + else: + # Skip loading extra parameters for GPTQ/modelopt models. + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): + continue + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # # other available replicas. + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + ) + name = name_mapped + break + else: + if is_expert_weight: + # This is an expert weight but not mapped to this rank, skip all remaining processing + continue + if "visual" in name: + # adapt to VisionAttention + name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") + name = name.replace(r"model.visual.", r"visual.") + + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") + + # TODO mimic deepseek + # Lazy initialization of expert weights cache to avoid slowing down load_weights + # if not hasattr(self, "routed_experts_weights_of_layer"): + # self.routed_experts_weights_of_layer = { + # layer_id: self.model.layers[layer_id].mlp.get_moe_weights() + # for layer_id in range(self.start_layer, self.end_layer) + # if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock) + # } + + +EntryClass = Qwen3VLMoeForConditionalGeneration diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index facddfea5..ec5e574f4 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -12,6 +12,8 @@ from torchvision.transforms import InterpolationMode from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration +from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration +from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, ) @@ -209,7 +211,12 @@ async def preprocess_video( # Compatible with Qwen2VL and Qwen2_5VL class Qwen2_5VLImageProcessor(SGLangBaseProcessor): - models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration] + models = [ + Qwen2VLForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, + Qwen3VLForConditionalGeneration, + Qwen3VLMoeForConditionalGeneration, + ] def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs)