################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ # SPDX-License-Identifier: Apache-2.0 # Adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py # Copyright 2025 The vLLM team. # Copyright 2025 The Qwen Team. # Copyright 2025 The HuggingFace Inc. team. # All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3-VL MOE model compatible with HuggingFace weights.""" import typing from collections.abc import Iterable from typing import Callable, Optional, Union import torch from vllm.distributed import get_pp_group from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.qwen3_vl_moe import Qwen3MoeLLMModel from vllm.model_executor.models.utils import is_pp_missing_parameter from vllm.sequence import IntermediateTensors def Qwen3MoeLLMModel_forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, deepstack_input_embeds: Optional[IntermediateTensors] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] hidden_states = hidden_states.unsqueeze(0) residual = residual.unsqueeze(0) if residual is not None else None for layer_idx, layer in enumerate( self.layers[self.start_layer:self.end_layer]): layer_idx = layer_idx + self.start_layer hidden_states, residual = layer( positions, hidden_states, residual, ) if deepstack_input_embeds is not None and \ layer_idx in range(0, len(deepstack_input_embeds)): hidden_states = hidden_states + deepstack_input_embeds[ f"deepstack_input_embeds_{layer_idx}"].to( hidden_states.device).unsqueeze(0) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states.unsqueeze(0), "residual": residual.unsqueeze(0) if residual is not None else None }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states.squeeze(0) Qwen3MoeLLMModel.forward = Qwen3MoeLLMModel_forward def Qwen3MoeLLMModel_load_weights( self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] # Skip loading extra parameters for GPTQ/modelopt models. ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", ".v_scale", "_v_scale", ".weight_scale", "_weight_scale", ".input_scale", "_input_scale") params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() is_fused_expert = False fused_expert_params_mapping = [ ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), ("experts.w2_weight", "experts.down_proj", 0, "w2"), ] num_experts = self.config.num_experts for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if ("experts.gate_up_proj" in name or "experts.down_proj" in name): is_fused_expert = True expert_params_mapping = fused_expert_params_mapping # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith(ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue if name.endswith("scale"): # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) else: weight_loader(param, loaded_weight, shard_id) break else: is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True name_mapped = name.replace(weight_name, param_name) if is_pp_missing_parameter(name_mapped, self): continue if is_fused_expert: loaded_weight = loaded_weight.transpose(-1, -2) # no bias if "experts.gate_up_proj" in name: loaded_weight = loaded_weight.chunk(2, dim=-2) success_w1 = self.load_fused_expert_weights( name_mapped, params_dict, loaded_weight[0], "w1", num_experts) success_w3 = self.load_fused_expert_weights( name_mapped, params_dict, loaded_weight[1], "w3", num_experts) success = success_w1 and success_w3 else: # down_proj success = self.load_fused_expert_weights( name_mapped, params_dict, loaded_weight, shard_id, num_experts) else: # Skip loading extra parameters for GPTQ/modelopt models if name_mapped.endswith( ignore_suffixes ) and name_mapped not in params_dict: continue param = params_dict[name_mapped] # We should ask the weight loader to return success or # not here since otherwise we may skip experts with # other available replicas. weight_loader = typing.cast(Callable[..., bool], param.weight_loader) success = weight_loader(param, loaded_weight, name_mapped, shard_id=shard_id, expert_id=expert_id, return_success=True) if success: name = name_mapped break else: if is_expert_weight: # We've checked that this is an expert weight # However it's not mapped locally to this rank # So we simply skip it continue # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith(ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale") if remapped_kv_scale_name not in params_dict: # logger.warning_once( # "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 # name, # remapped_kv_scale_name, # ) continue else: name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if name == 'patch_embed.proj.weight': loaded_weight = loaded_weight.reshape( loaded_weight.shape[0], -1).contiguous() weight_loader(param, loaded_weight) if name.find("norm.weight") != -1: param.data = param.data.to(torch.float32) loaded_params.add(name) return loaded_params Qwen3MoeLLMModel.load_weights = Qwen3MoeLLMModel_load_weights