# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations from collections import defaultdict from collections.abc import Iterable import regex as re import torch import torch.nn as nn from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen3 import Qwen3Model from vllm.model_executor.models.utils import WeightsMapper WeightItem = tuple[str, torch.Tensor] _LAYER_RE = re.compile(r"^layers\.(\d+)\.(.+)$") class VoyageQwen3BidirectionalEmbedModel(Qwen3Model): """ Qwen3Model + Voyage embedding head + bidirectional attention. Checkpoint conventions (HF): - MLP: gate_proj + up_proj (unfused) - Attn: q_proj + k_proj + v_proj (unfused) - Linear head: linear.weight - Weights prefixed with "model." (e.g., model.layers.0...) vLLM Qwen3Model expects: - mlp.gate_up_proj (fused) - self_attn.qkv_proj (fused) - No "model." prefix We remap/fuse weights using generator pipeline and load directly (bypassing parent's stacked_params_mapping which would cause double-transformation like qkv_proj -> qkqkv_proj). """ hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Embedding head (hidden_size -> num_labels, bias=False) self.linear = nn.Linear( self.config.hidden_size, self.config.num_labels, bias=False, ) def forward(self, *args, **kwargs): out = super().forward(*args, **kwargs) return self.linear(out) def _fuse_qkv_proj(self, weights: Iterable[WeightItem]) -> Iterable[WeightItem]: """Fuse q_proj, k_proj, v_proj into qkv_proj.""" qkv_buf: dict[int, dict[str, torch.Tensor]] = defaultdict(dict) qkv_suffixes = { "self_attn.q_proj.weight": "q", "self_attn.k_proj.weight": "k", "self_attn.v_proj.weight": "v", } for name, tensor in weights: m = _LAYER_RE.match(name) if m and m.group(2) in qkv_suffixes: layer_idx = int(m.group(1)) qkv_buf[layer_idx][qkv_suffixes[m.group(2)]] = tensor else: yield name, tensor # Yield fused QKV weights for layer_idx in sorted(qkv_buf.keys()): parts = qkv_buf[layer_idx] if all(p in parts for p in ("q", "k", "v")): fused = torch.cat([parts["q"], parts["k"], parts["v"]], dim=0) yield f"layers.{layer_idx}.self_attn.qkv_proj.weight", fused elif parts: missing = [p for p in ("q", "k", "v") if p not in parts] raise ValueError(f"Layer {layer_idx} missing QKV parts: {missing}") def _fuse_gate_up_proj(self, weights: Iterable[WeightItem]) -> Iterable[WeightItem]: """Fuse gate_proj and up_proj into gate_up_proj.""" mlp_buf: dict[int, dict[str, torch.Tensor]] = defaultdict(dict) mlp_suffixes = { "mlp.gate_proj.weight": "gate", "mlp.up_proj.weight": "up", } for name, tensor in weights: m = _LAYER_RE.match(name) if m and m.group(2) in mlp_suffixes: layer_idx = int(m.group(1)) mlp_buf[layer_idx][mlp_suffixes[m.group(2)]] = tensor else: yield name, tensor # Yield fused gate_up weights for layer_idx in sorted(mlp_buf.keys()): parts = mlp_buf[layer_idx] if all(p in parts for p in ("gate", "up")): fused = torch.cat([parts["gate"], parts["up"]], dim=0) yield f"layers.{layer_idx}.mlp.gate_up_proj.weight", fused elif parts: missing = [p for p in ("gate", "up") if p not in parts] raise ValueError(f"Layer {layer_idx} missing MLP parts: {missing}") def load_weights(self, weights: Iterable[WeightItem]) -> set[str]: """Remap, fuse, and load weights using generator pipeline.""" # Chain weight transformations weights = self.hf_to_vllm_mapper.apply(weights) weights = self._fuse_qkv_proj(weights) weights = self._fuse_gate_up_proj(weights) # Load weights directly into model parameters # (bypass parent's stacked_params_mapping) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params