Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -112,6 +112,42 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
model_config.pooler_config.seq_pooling_type = pooling_type
class LlamaNemotronVLConfig(VerifyAndUpdateConfig):
"""Config handler for LlamaNemotronVL embedding models."""
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import SequencePoolingType
hf_config = model_config.hf_config
# Set bidirectional attention on the language model config
hf_config.is_causal = False
if hasattr(hf_config, "llm_config"):
hf_config.llm_config.is_causal = False
if hasattr(hf_config, "vision_config"):
hf_config.patch_size = hf_config.vision_config.patch_size
# Set up pooling type
pooling_type_map: dict[str, SequencePoolingType] = {
"avg": "MEAN",
"cls": "CLS",
"last": "LAST",
}
# Get pooling type from config (check both top-level and llm_config)
pooling = getattr(hf_config, "pooling", None)
if pooling is None and hasattr(hf_config, "llm_config"):
pooling = getattr(hf_config.llm_config, "pooling", "avg")
pooling_type = pooling_type_map.get(pooling)
if pooling_type is None:
raise ValueError(f"pool_type {pooling!r} not supported")
model_config.pooler_config.seq_pooling_type = pooling_type
class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
@@ -177,7 +213,7 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.py",
max_model_len_before,
model_config.max_model_len,
)
@@ -293,6 +329,14 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
}
class Ernie4_5_VLMoeForConditionalGenerationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
# Ernie4.5-VL conditionally executes text/vision MoE branches, so
# fast_moe_cold_start can silently produce incorrect execution order.
vllm_config.compilation_config.fast_moe_cold_start = False
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
@@ -553,7 +597,7 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
if cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
if cache_config.cache_dtype == "auto" or cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
@@ -619,11 +663,14 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Gemma3TextModel": Gemma3TextModelConfig,
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
"LlamaNemotronVLModel": LlamaNemotronVLConfig,
"LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig,
"NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"XLMRobertaModel": JinaRobertaModelConfig,
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,

View File

@@ -18,6 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionConfig
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -263,9 +264,13 @@ class Block(nn.Module):
return x
class RelPosAttention(nn.Module):
# --8<-- [start:rel_pos_attention]
@PluggableLayer.register("rel_pos_attention")
class RelPosAttention(PluggableLayer):
"""Multi-head Attention block with relative position embeddings."""
# --8<-- [end:rel_pos_attention]
def __init__(
self,
dim: int,

View File

@@ -32,6 +32,7 @@ from .deepseek_v2 import (
DeepseekV2MoE,
get_spec_layer_idx_from_weight_name,
)
from .interfaces import SupportsPP
from .utils import maybe_prefix
logger = init_logger(__name__)
@@ -180,7 +181,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
@support_torch_compile
class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
@@ -415,8 +416,141 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
# Validate that weights were loaded for each expected MTP layer.
loaded_layers: set[int] = set()
for param_name in loaded_params:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
if spec_layer is not None:
loaded_layers.add(spec_layer)
for layer_idx in range(
self.model.mtp_start_layer_idx,
self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
):
if layer_idx not in loaded_layers:
raise ValueError(
f"MTP speculative decoding layer {layer_idx} weights "
f"missing from checkpoint. The checkpoint may have "
f"been quantized without including the MTP layers. "
f"Use a checkpoint that includes MTP layer weights, "
f"or disable speculative decoding."
)
# Post-load optimization: fuse q_a_proj and kv_a_proj_with_mqa
# into a single GEMM, then monkey-patch forward to forward_opt.
# Same logic as DeepseekV2ForCausalLM.load_weights.
opt_support_quant_method = [
"GGUFLinearMethod", "UnquantizedLinearMethod",
"CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod",
]
def inject_layer(layer, quant_method, is_mla):
logger.info(
"DeepSeekMTP optimization: fused q_a_proj and kv_a_proj_with_mqa for layer '%s' (quant_method=%s, is_mla=%s). Forward replaced with forward_opt.",
layer.__class__.__name__, quant_method, is_mla)
q_lora_rank = getattr(layer, "q_lora_rank", None)
if quant_method in ["UnquantizedLinearMethod",
"CompressedTensorsW8A8Int8"]:
if q_lora_rank is not None:
layer.q_a_proj.weight.data = torch.cat(
[layer.q_a_proj.weight,
layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_a_proj, "weight_scale"):
layer.q_a_proj.weight_scale.data = torch.cat(
[layer.q_a_proj.weight_scale,
layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
elif not is_mla:
layer.q_proj.weight.data = torch.cat(
[layer.q_proj.weight,
layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_proj, "weight_scale"):
layer.q_proj.weight_scale.data = torch.cat(
[layer.q_proj.weight_scale,
layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
else:
return
del layer.kv_a_proj_with_mqa.weight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
elif quant_method == "GGUFLinearMethod":
pass
elif quant_method == "AWQMarlinLinearMethod":
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
assert dtype == torch.int32
if q_lora_rank is not None:
layer.q_a_proj.qweight.data = torch.cat(
[layer.q_a_proj.qweight,
layer.kv_a_proj_with_mqa.qweight], dim=1)
layer.q_a_proj.scales.data = torch.cat(
[layer.q_a_proj.scales,
layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_a_proj.qzeros.data = torch.cat(
[layer.q_a_proj.qzeros,
layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
elif not is_mla:
layer.q_proj.weight.data = torch.cat(
[layer.q_proj.weight,
layer.kv_a_proj_with_mqa.weight], dim=1)
layer.q_proj.scales.data = torch.cat(
[layer.q_proj.scales,
layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_proj.qzeros.data = torch.cat(
[layer.q_proj.qzeros,
layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
else:
return
del layer.kv_a_proj_with_mqa.qweight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
for _, layer in self.model.named_modules():
if layer.__class__.__name__ in [
"DeepseekV2Attention", "DeepseekV2MLAAttention"
]:
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
quant_method = (
layer.kv_a_proj_with_mqa.scheme.__class__.__name__)
else:
quant_method = (
layer.kv_a_proj_with_mqa
.quant_method.__class__.__name__)
if quant_method not in opt_support_quant_method:
break
inject_layer(
layer, quant_method,
is_mla=(layer.__class__.__name__
== "DeepseekV2MLAAttention"))
# Check if all parameters have been loaded
all_params = set(params_dict.keys())
not_loaded = all_params - loaded_params
if not_loaded:
logger.warning(
"DeepSeekMTP weight loading: %d parameters were NOT loaded.\n%s",
len(not_loaded),
"\n".join(sorted(not_loaded)),
)
else:
logger.info(
"DeepSeekMTP weight loading: All %d parameters loaded successfully.",
len(all_params),
)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.

View File

@@ -47,7 +47,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -75,7 +75,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionBackend
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
)
@@ -89,6 +91,7 @@ from .utils import (
make_layers,
maybe_prefix,
)
import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
@@ -221,73 +224,6 @@ class DeepseekV2MLP(nn.Module):
return x
class DeepSeekV2Gate(ReplicatedLinear):
def __init__(
self,
hidden_size: int,
n_experts: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
assert quant_config is None
super().__init__(
hidden_size,
n_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)
# Unquantized only, will be called "weight".
assert hasattr(self, "weight")
is_hopper_or_blackwell = current_platform.is_device_capability(
(9, 0)
) or current_platform.is_device_capability_family(100)
SUPPORTED_NUM_EXPERTS = [256, 384]
SUPPORTED_HIDDEN_SIZES = [7168]
self.allow_dsv3_router_gemm = (
current_platform.is_cuda()
and is_hopper_or_blackwell
and n_experts in SUPPORTED_NUM_EXPERTS
and hidden_size in SUPPORTED_HIDDEN_SIZES
)
self._out_dtype: torch.dtype | None = None
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
"""
Set out dtype for the router logits. This is needed after
__init__, b/c we need to check if the trtllm kernel is
selected before we decide between bf16 and fp32.
"""
if self._out_dtype is not None:
raise ValueError("out_dtype has already been set")
else:
self._out_dtype = out_dtype
@property
def out_dtype(self) -> torch.dtype:
if self._out_dtype is None:
raise ValueError("out_dtype has not been set yet")
return self._out_dtype
def forward(
self,
x: torch.Tensor,
) -> tuple[torch.Tensor, None]:
"""
Use specialized GEMM for low batch size for DSV3 and KIMI.
"""
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
return ops.dsv3_router_gemm(
hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
), None
else:
return super().forward(x)
class DeepseekV2MoE(nn.Module):
def __init__(
self,
@@ -316,23 +252,12 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
# self.gate = DeepSeekV2Gate(
# config.hidden_size,
# config.n_routed_experts,
# quant_config=None,
# prefix=f"{prefix}.gate",
# )
self.gate = ReplicatedLinear(
self.gate = GateLinear(
config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
if getattr(config, "topk_method", None) == "noaux_tc":
# self.gate.e_score_correction_bias = nn.Parameter(
# torch.empty(config.n_routed_experts, dtype=torch.float32)
# )
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts)
)
@@ -401,12 +326,12 @@ class DeepseekV2MoE(nn.Module):
else None,
)
# # NOTE(rob): this is a hack until we finish off the PR for
# # merging TRTLLM kernels into the MK framework. Then we can
# # query the MonolithicMK for the expected router logits.
# self.gate.set_out_dtype(
# torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
# )
# NOTE(rob): this is a hack until we finish off the PR for
# merging TRTLLM kernels into the MK framework. Then we can
# query the MonolithicMK for the expected router logits.
self.gate.set_out_dtype(
torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
@@ -443,11 +368,12 @@ class DeepseekV2MoE(nn.Module):
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
@@ -596,7 +522,7 @@ class DeepseekV2Attention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
@@ -605,23 +531,20 @@ class DeepseekV2Attention(nn.Module):
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
q = self.q_proj(hidden_states)[0]
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
kv_a = self.kv_a_layernorm(kv_a)
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
@@ -671,7 +594,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
def get_attn_backend(self) -> AttentionBackend:
return DeepseekV32IndexerBackend
class Indexer(nn.Module):
def __init__(
@@ -727,8 +650,8 @@ class Indexer(nn.Module):
# where we store value in fp8 and scale in fp32
# per self.quant_block_size element
self.k_cache = DeepseekV32IndexerCache(
head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
dtype=torch.uint8,
head_dim=self.head_dim,
dtype=torch.bfloat16,
prefix=f"{prefix}.k_cache",
cache_config=cache_config,
)
@@ -776,23 +699,61 @@ class Indexer(nn.Module):
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(
q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt is not None,
)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
# q = q.view(-1, self.head_dim)
# q_fp8, q_scale = per_token_group_quant_fp8(
# q,
# self.quant_block_size,
# column_major_scales=False,
# use_ue8m0=self.scale_fmt is not None,
# )
# q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
# q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = (
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5
)
weights = weights.squeeze(-1)
return self.indexer_op(hidden_states, q_fp8, k, weights)
return self.indexer_op(hidden_states, q, k, weights)
def _min_latency_fused_qkv_a_proj_impl(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
"""
Dynamically run min-latency gemm if num_tokens <= 16.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
num_tokens = input_.shape[0]
if 0 < num_tokens <= 16:
output = torch.empty(
num_tokens,
weight.shape[0],
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(output, input_, weight.T)
return output
else:
return torch.nn.functional.linear(input_, weight)
def _min_latency_fused_qkv_a_proj_fake(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return input_.new_empty(input_.shape[0], weight.shape[0])
direct_register_custom_op(
op_name="min_latency_fused_qkv_a_proj",
op_func=_min_latency_fused_qkv_a_proj_impl,
mutates_args=[],
fake_impl=_min_latency_fused_qkv_a_proj_fake,
)
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
@@ -830,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
num_tokens = input_.shape[0]
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
output = torch.empty(
num_tokens,
2112,
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(
output,
input_,
self.weight.T,
)
if self._use_min_latency_gemm:
output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
@@ -898,47 +848,35 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
# self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj(
# self.hidden_size,
# [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
# quant_config=quant_config,
# prefix=f"{prefix}.fused_qkv_a_proj",
# )
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True,
)
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
)
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
if self.q_lora_rank is not None:
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
self.q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -1005,9 +943,7 @@ class DeepseekV2MLAAttention(nn.Module):
kv_b_proj=self.kv_b_proj,
rotary_emb=self.rotary_emb,
o_proj=self.o_proj,
fused_qkv_a_proj=self.fused_qkv_a_proj
if self.q_lora_rank is not None
else None,
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
if self.q_lora_rank is None
else None,
@@ -1346,14 +1282,14 @@ class DeepseekV2ForCausalLM(
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
# quant_method for relevant layers during initialization.
self.fuse_qkv_a_proj = (
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
)
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
# self.fuse_qkv_a_proj = (
# hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
# )
# if self.fuse_qkv_a_proj:
# self.packed_modules_mapping["fused_qkv_a_proj"] = [
# "q_a_proj",
# "kv_a_proj_with_mqa",
# ]
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
@@ -1385,19 +1321,19 @@ class DeepseekV2ForCausalLM(
self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
@@ -1441,10 +1377,10 @@ class DeepseekV2ForCausalLM(
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
mla_params_mapping = [
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
# mla_params_mapping = [
# ("fused_qkv_a_proj", "q_a_proj", 0),
# ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
# ]
mha_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
@@ -1452,8 +1388,8 @@ class DeepseekV2ForCausalLM(
]
if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping)
else:
stacked_params_mapping.extend(mla_params_mapping)
# else:
# stacked_params_mapping.extend(mla_params_mapping)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
@@ -1474,168 +1410,232 @@ class DeepseekV2ForCausalLM(
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# if go with fusion option, then update name
if (
param_name == "fused_qkv_a_proj"
) and name_mapped not in params_dict:
continue
else:
name = name_mapped
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
try:
if "rotary_emb.inv_freq" in name:
continue
if is_pp_missing_parameter(name, self):
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
# (e.g. ...mlp.shared_experts.gate_proj.weight).
# For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = (
1
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
else 0
)
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}"
)
chunk_size = total // num_chunks
for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if is_fusion_moe_shared_experts_layer:
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
if loaded_weight.ndim == 1:
weight_to_load = loaded_weight[chunk_slice]
elif split_dim == 0:
weight_to_load = loaded_weight[chunk_slice, :]
else:
weight_to_load = loaded_weight[:, chunk_slice]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)
continue
name_mapped = name.replace(weight_name, param_name)
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# if go with fusion option, then update name
if (
param_name == "fused_qkv_a_proj"
) and name_mapped not in params_dict:
continue
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
name = name_mapped
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
# (e.g. ...mlp.shared_experts.gate_proj.weight).
# For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = (
1
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
else 0
)
weight_loader(param, loaded_weight)
if name is not None and not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}"
)
chunk_size = total // num_chunks
for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight
if is_fusion_moe_shared_experts_layer:
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
if loaded_weight.ndim == 1:
weight_to_load = loaded_weight[chunk_slice]
elif split_dim == 0:
weight_to_load = loaded_weight[chunk_slice, :]
else:
weight_to_load = loaded_weight[:, chunk_slice]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if name is not None and not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
except:
pass
opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"]
# add your opt here..
def inject_layer(layer, quant_method, is_mla):
q_lora_rank = getattr(layer, "q_lora_rank", None)
if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]:
if q_lora_rank is not None:
layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_a_proj, "weight_scale"):
layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
elif not is_mla:
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_proj, "weight_scale"):
layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
else:
return
del layer.kv_a_proj_with_mqa.weight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
elif quant_method == "GGUFLinearMethod":
pass
elif quant_method == "AWQMarlinLinearMethod":
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
assert dtype == torch.int32
if layer.q_lora_rank is not None:
layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1)
layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
elif not is_mla:
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1)
layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
else:
return
del layer.kv_a_proj_with_mqa.qweight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
else:
pass
for _, layer in self.model.named_modules():
if layer.__class__.__name__ in ["DeepseekV2Attention","DeepseekV2MLAAttention"]:
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__
else:
quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__
if quant_method not in opt_support_quant_method:
break
inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "DeepseekV2MLAAttention")
return loaded_params

View File

@@ -164,7 +164,7 @@ class Ernie4_5_MoeMoE(nn.Module):
config.hidden_size,
config.moe_num_experts,
bias=False,
params_dtype=torch.float32,
# params_dtype=torch.float32,
quant_config=None,
prefix=f"{prefix}.gate",
)
@@ -209,7 +209,7 @@ class Ernie4_5_MoeMoE(nn.Module):
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
@@ -429,7 +429,8 @@ class Ernie4_5_MoeModel(nn.Module):
self.num_redundant_experts = eplb_config.num_redundant_experts
if get_pp_group().is_first_rank:
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
@@ -653,11 +654,11 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
else:
self.lm_head = PPMissingLayer()
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors

View File

@@ -829,16 +829,31 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
if self.ctx.model_config.trust_remote_code:
# Defined in HF Hub repo
min_pixels_key = "min_pixels"
max_pixels_key = "max_pixels"
else:
# Defined in Transformers library (requires v5.0 or above)
min_pixels_key = "shortest_edge"
max_pixels_key = "longest_edge"
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
size = image_processor.size
if override_size := mm_kwargs.get("size"):
size = size | override_size
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
size = size | {min_pixels_key: override_min_pixels}
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
size = size | {max_pixels_key: override_max_pixels}
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * spatial_conv_size,
min_pixels=size["min_pixels"],
max_pixels=size["max_pixels"],
min_pixels=size[min_pixels_key],
max_pixels=size[max_pixels_key],
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:

View File

@@ -0,0 +1,394 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Hidden States Extractor Model.
This model extracts and caches hidden states from the target model
without performing actual token generation. It's used with the
extract_hidden_states speculative decoding method.
"""
from collections.abc import Iterable
from typing import ClassVar
import torch
import torch.nn as nn
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.attention.attention import set_default_quant_scales
from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.utils import maybe_prefix
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
is_quantized_kv_cache,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheSpec,
MLAAttentionSpec,
)
########## Custom Ops ########
def unified_kv_cache_update(
to_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
"""
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
if layer_slot_mapping is not None:
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
)
attn_layer.impl.do_kv_cache_update(
attn_layer,
to_cache,
kv_cache,
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
@maybe_transfer_kv_layer
def dummy_attention(layer_name, _placeholder):
# Note: layer_name arg required by @maybe_transfer_kv_layer
return _placeholder
def basic_cache(
to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
slot_mapping: torch.Tensor, # shape: [seq_len]
):
num_blocks, block_size, num_heads, head_size = kv_cache.shape
token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size)
token_kv_cache[slot_mapping] = to_cache
######### CacheOnlyAttentionBackend ########
class CacheOnlyAttentionBackend(AttentionBackend):
"""Attention backend that only caches KV without computing attention."""
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "CACHE_ONLY_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
return attn_type == AttentionType.DECODER
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@staticmethod
def get_impl_cls() -> type["CacheOnlyAttentionImpl"]:
return CacheOnlyAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
# We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
# We also don't use a k/v (2) dim
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]:
return CacheOnlyAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
class CacheOnlyAttentionMetadata:
def __init__(self, slot_mapping: torch.Tensor):
self.slot_mapping = slot_mapping
class CacheOnlyAttentionMetadataBuilder(
AttentionMetadataBuilder[CacheOnlyAttentionMetadata]
):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> CacheOnlyAttentionMetadata:
use_cascade = common_prefix_len > 0
if use_cascade:
raise NotImplementedError(
"Cascade attention not supported by CacheOnlyAttention"
)
causal = common_attn_metadata.causal
if not causal:
raise NotImplementedError(
"Non-causal attention not supported by CacheOnlyAttention"
)
return CacheOnlyAttentionMetadata(
slot_mapping=common_attn_metadata.slot_mapping,
)
class CacheOnlyAttentionImpl(AttentionImpl):
"""Attention implementation that only caches KV states."""
def __init__(
self,
num_heads: int,
head_size: int,
kv_cache_dtype: str,
kv_cache_torch_dtype: torch.dtype,
attn_type: AttentionType = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.kv_cache_dtype = kv_cache_dtype
self.kv_cache_torch_dtype = kv_cache_torch_dtype
if attn_type != AttentionType.DECODER:
raise NotImplementedError(f"Unsupported attention type: {attn_type}")
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError("Quantized KV cache not supported")
self.num_queries_per_kv = 1
def do_kv_cache_update(
self,
layer,
to_cache,
kv_cache,
slot_mapping,
):
assert to_cache.dtype == self.kv_cache_torch_dtype, (
f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}"
)
assert kv_cache.dtype == self.kv_cache_torch_dtype, (
f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}"
)
basic_cache(to_cache, kv_cache, slot_mapping)
def forward(self, *args, **kwargs):
# Empty implementation of abstract method
pass
############## CacheOnlyAttentionLayer (replaces Attention) ############
class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
"""Attention layer that only caches key/value states without computing attention."""
def __init__(
self,
num_heads: int,
head_size: int,
cache_config: CacheConfig | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.layer_name = prefix
vllm_config = get_current_vllm_config()
# KV cache configuration
cache_config = cache_config or vllm_config.cache_config
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
self.block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
self.block_size = 16
assert kv_cache_dtype in ["auto", "bfloat16", "float16"], (
"CacheOnlyAttentionLayer doesn't currently support quantized kv cache but"
f"kv cache dtype was set to {kv_cache_dtype}"
)
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
# Initialize KV cache quantization attributes
set_default_quant_scales(self, register_buffer=True)
# Attention backend
self.attn_backend = CacheOnlyAttentionBackend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
num_heads,
head_size,
kv_cache_dtype,
self.kv_cache_torch_dtype,
attn_type,
)
assert not self.attn_backend.forward_includes_kv_cache_update, (
"KV cache update should be independent of forward"
)
# Placeholder KV cache (replaced by bind_kv_cache)
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Register in compilation context
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(self, to_cache: torch.Tensor) -> torch.Tensor:
"""Cache hidden states as KV pairs without computing attention.
Args:
to_cache: The tensor to insert into the kv cache.
shape [num_tokens, num_heads, head_size]
Returns:
Dummy output tensor (not used)
"""
# Note: we set num_heads to num_hidden_layers and
# head_size to hidden_size for hidden states storage
output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype)
# Note: dummy_out is used to force torch.compile to preserve ordering between
# cache update and attention op (which triggers kv_connector transfer)
dummy_out = unified_kv_cache_update(to_cache, self.layer_name)
# Triggers kv_connector transfer via decorator
_ = dummy_attention(self.layer_name, dummy_out)
return output
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Note: we use MLAAttentionSpec here to because it will
# produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
# whereas FullAttentionSpec will add an additional factor of 2
return MLAAttentionSpec(
block_size=self.block_size,
num_kv_heads=self.num_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)
############ ExtractHiddenStatesModel definition ##########
class ExtractHiddenStatesModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.target_num_hidden_layers = (
vllm_config.model_config.get_total_num_hidden_layers()
)
self.num_hidden_states = len(
getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", [])
)
cache_config = vllm_config.cache_config
# Create a single cache-only attention layer
# Note: We set num_heads <- self.num_hidden_states
# and head_size <- hidden_size so that we can insert
# the hidden states directly into the cache without
# reshaping
self.cache_only_layers = nn.ModuleDict(
{
str(self.target_num_hidden_layers): CacheOnlyAttentionLayer(
num_heads=self.num_hidden_states,
head_size=self.hidden_size,
cache_config=cache_config,
prefix=maybe_prefix(
prefix, f"cache_only_layers.{self.target_num_hidden_layers}"
),
)
}
)
def forward(self, hidden_states: torch.Tensor) -> None:
"""Process and cache hidden states.
Args:
hidden_states: Hidden states from target model
shape: [num_tokens, num_hidden_states, hidden_size]
Returns:
Tuple of (dummy_output, dummy_output) - both unused
"""
# Call dummy attention layer to cache hidden states
# Output is ignored - we only care about the KV cache side effects
_ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""No weights to load for this dummy model."""
return set()

View File

@@ -0,0 +1,829 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, cast
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import (
BatchFeature,
Qwen2Config,
)
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import (
ReplicatedLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.models.whisper_utils import (
ISO639_1_SUPPORTED_LANGS,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.processors.fireredasr2_processor import (
FireRedASR2FeatureExtractor,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsTranscription,
_require_is_multimodal,
)
from .qwen2 import Qwen2ForCausalLM
from .utils import (
AutoWeightsLoader,
WeightsMapper,
_merge_multimodal_embeddings,
maybe_prefix,
)
logger = init_logger(__name__)
class FireRedASR2AudioInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- nmb: Number of mel bins
- t: Time frames (M)
"""
input_features: Annotated[
list[torch.Tensor] | None,
TensorShape("b", "nmb", "t"),
]
speech_lengths: Annotated[
list[torch.Tensor] | None,
TensorShape("b"),
]
fake_token_lengths: Annotated[
list[torch.Tensor] | None,
TensorShape("b"),
]
class Swish(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class Conv2dSubsampling(nn.Module):
def __init__(self, idim: int, d_model: int, out_channels: int = 32):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, out_channels, 3, 2),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 2),
nn.ReLU(),
)
subsample_idim = ((idim - 1) // 2 - 1) // 2
self.out = ReplicatedLinear(
input_size=out_channels * subsample_idim,
output_size=d_model,
bias=True,
)
self.subsampling = 4
left_context = right_context = 3 # both exclude currect frame
self.context = left_context + 1 + right_context # 7
def forward(
self, x: torch.Tensor, x_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = x.unsqueeze(1)
x = self.conv(x)
N, C, T, D = x.size()
x, _ = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
mask = x_mask[:, :, :-2:2][:, :, :-2:2]
input_lengths = mask[:, -1, :].sum(dim=-1)
return x, input_lengths, mask
class RelPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float()
* -(torch.log(torch.tensor(10000.0)).item() / d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
self.pe = torch.cat([pe_positive, pe_negative], dim=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Tmax = 2 * max_len - 1
Tmax, T = self.pe.size(1), x.size(1)
pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
return pos_emb
class ConformerFeedForward(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.pre_layer_norm = nn.LayerNorm(d_model)
self.linear_expand = ReplicatedLinear(
input_size=d_model,
output_size=d_model * 4,
bias=True,
)
self.nonlinear = Swish()
self.linear_project = ReplicatedLinear(
input_size=d_model * 4,
output_size=d_model,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.pre_layer_norm(x)
x, _ = self.linear_expand(x)
x = self.nonlinear(x)
x, _ = self.linear_project(x)
output = x + residual
return output
class EncoderMultiHeadAttention(nn.Module):
def __init__(self, n_head: int, d_model: int):
super().__init__()
assert d_model % n_head == 0
self.n_head = n_head
self.d_k = d_model // n_head
self.d_v = self.d_k
self.w_qs = ReplicatedLinear(
input_size=d_model, output_size=n_head * self.d_k, bias=False
)
self.w_ks = ReplicatedLinear(
input_size=d_model, output_size=n_head * self.d_k, bias=False
)
self.w_vs = ReplicatedLinear(
input_size=d_model, output_size=n_head * self.d_v, bias=False
)
self.layer_norm_q = nn.LayerNorm(d_model)
self.layer_norm_k = nn.LayerNorm(d_model)
self.layer_norm_v = nn.LayerNorm(d_model)
self.fc = ReplicatedLinear(
input_size=n_head * self.d_v, output_size=d_model, bias=False
)
def forward_qkv(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
q = self.layer_norm_q(q)
k = self.layer_norm_k(k)
v = self.layer_norm_v(v)
q = self.w_qs(q)[0].view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k)[0].view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v)[0].view(sz_b, len_v, n_head, d_v)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
return q, k, v
def forward_output(
self, output: torch.Tensor, residual: torch.Tensor, sz_b: int, len_q: int
) -> torch.Tensor:
output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
fc_out, _ = self.fc(output)
output = fc_out
output = output + residual
return output
def forward_attention(
self, attn: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
if mask is not None:
mask = mask.unsqueeze(1)
mask = mask.eq(0)
attn = attn.masked_fill(mask, -float("inf"))
attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
else:
attn = torch.softmax(attn, dim=-1)
d_attn = attn
output = torch.matmul(d_attn, v)
return output, attn
class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
def __init__(self, n_head: int, d_model: int):
super().__init__(n_head, d_model)
d_k = d_model // n_head
self.scale = 1.0 / (d_k**0.5)
self.linear_pos = ReplicatedLinear(
input_size=d_model, output_size=n_head * d_k, bias=False
)
self.pos_bias_u = nn.Parameter(torch.empty([n_head, d_k]))
self.pos_bias_v = nn.Parameter(torch.empty([n_head, d_k]))
def _rel_shift(self, x):
N, H, T1, T2 = x.size()
zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(N, H, T2 + 1, T1)
x = x_padded[:, :, 1:].view_as(x)
x = x[:, :, :, : x.size(-1) // 2 + 1]
return x
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pos_emb: torch.Tensor,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
sz_b, len_q = q.size(0), q.size(1)
residual = q
q, k, v = self.forward_qkv(q, k, v)
q = q.transpose(1, 2)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb)[0].view(n_batch_pos, -1, self.n_head, self.d_k)
p = p.transpose(1, 2)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self._rel_shift(matrix_bd)
attn_scores = matrix_ac + matrix_bd
attn_scores.mul_(self.scale)
output, attn = self.forward_attention(attn_scores, v, mask=mask)
output = self.forward_output(output, residual, sz_b, len_q)
return output, attn
class ConformerConvolution(nn.Module):
def __init__(self, d_model: int, kernel_size: int = 33):
super().__init__()
assert kernel_size % 2 == 1
self.pre_layer_norm = nn.LayerNorm(d_model)
self.pointwise_conv1 = nn.Conv1d(
d_model, d_model * 4, kernel_size=1, bias=False
)
self.padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
d_model * 2,
d_model * 2,
kernel_size,
stride=1,
padding=self.padding,
groups=d_model * 2,
bias=False,
)
self.batch_norm = nn.LayerNorm(d_model * 2)
self.swish = Swish()
self.pointwise_conv2 = nn.Conv1d(
d_model * 2, d_model, kernel_size=1, bias=False
)
def forward(
self, x: torch.Tensor, mask: torch.Tensor | None = None
) -> torch.Tensor:
residual = x
out = self.pre_layer_norm(x)
out = out.transpose(1, 2)
if mask is not None:
out.masked_fill_(mask.ne(1), 0.0)
out = self.pointwise_conv1(out)
out = F.glu(out, dim=1)
out = self.depthwise_conv(out)
out = out.transpose(1, 2)
out = self.swish(self.batch_norm(out))
out = out.transpose(1, 2)
out = self.pointwise_conv2(out)
if mask is not None:
out.masked_fill_(mask.ne(1), 0.0)
out = out.transpose(1, 2)
return out + residual
class RelPosEmbConformerBlock(nn.Module):
def __init__(self, d_model, n_head, kernel_size=33):
super().__init__()
self.ffn1 = ConformerFeedForward(d_model)
self.mhsa = RelPosMultiHeadAttention(n_head, d_model)
self.conv = ConformerConvolution(d_model, kernel_size)
self.ffn2 = ConformerFeedForward(d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
pos_emb: torch.Tensor,
slf_attn_mask: torch.Tensor | None = None,
pad_mask: torch.Tensor | None = None,
) -> torch.Tensor:
out = 0.5 * x + 0.5 * self.ffn1(x)
out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
out = self.conv(out, pad_mask)
out = 0.5 * out + 0.5 * self.ffn2(out)
out = self.layer_norm(out)
return out
class ConformerEncoder(nn.Module):
def __init__(
self,
idim: int,
n_layers_enc: int,
n_head: int,
d_model: int,
kernel_size: int = 33,
pe_maxlen: int = 5000,
):
super().__init__()
self.odim = d_model
self.input_preprocessor = Conv2dSubsampling(idim, d_model)
self.positional_encoding = RelPositionalEncoding(d_model)
self.layer_stack = nn.ModuleList()
for _ in range(n_layers_enc):
block = RelPosEmbConformerBlock(d_model, n_head, kernel_size)
self.layer_stack.append(block)
def forward(
self, padded_input: torch.Tensor, input_lengths: torch.Tensor, pad: bool = True
):
if pad:
padded_input = F.pad(
padded_input,
(0, 0, 0, self.input_preprocessor.context - 1),
"constant",
0.0,
)
src_mask = self.padding_position_is_0(padded_input, input_lengths)
embed_output, input_lengths, src_mask = self.input_preprocessor(
padded_input, src_mask
)
enc_output = embed_output
pos_emb = self.positional_encoding(embed_output)
enc_outputs = []
for enc_layer in self.layer_stack:
enc_output = enc_layer(
enc_output, pos_emb, slf_attn_mask=src_mask, pad_mask=src_mask
)
enc_outputs.append(enc_output)
return enc_output, input_lengths, src_mask
def padding_position_is_0(
self, padded_input: torch.Tensor, input_lengths: torch.Tensor
) -> torch.Tensor:
N, T = padded_input.size()[:2]
mask = torch.ones((N, T)).to(padded_input.device)
for i in range(N):
mask[i, input_lengths[i] :] = 0
mask = mask.unsqueeze(dim=1)
return mask.to(torch.uint8)
class FireRedASR2Adapter(nn.Module):
def __init__(self, encoder_dim: int, llm_dim: int, downsample_rate: int = 2):
super().__init__()
self.ds = downsample_rate
self.linear1 = ReplicatedLinear(
input_size=encoder_dim * downsample_rate,
output_size=llm_dim,
bias=True,
)
self.relu = _ACTIVATION_REGISTRY["relu"]
self.linear2 = ReplicatedLinear(
input_size=llm_dim,
output_size=llm_dim,
bias=True,
)
def forward(self, x, x_lens):
batch_size, seq_len, feat_dim = x.size()
num_frames_to_discard = seq_len % self.ds
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, seq_len // self.ds, feat_dim * self.ds)
x, _ = self.linear1(x)
x = self.relu(x)
x, _ = self.linear2(x)
new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds
return x, new_x_lens
class FireRedASR2Encoder(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
):
super().__init__()
self.audio_encoder = ConformerEncoder(
**vllm_config.model_config.hf_config.audio_encoder_conf
)
class FireRedASR2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.encoder = FireRedASR2Encoder(
vllm_config=vllm_config,
)
encoder_dim = self.encoder.audio_encoder.odim
llm_dim = vllm_config.model_config.hf_config.hidden_size
self.encoder_projector = FireRedASR2Adapter(
encoder_dim,
llm_dim,
vllm_config.model_config.hf_config.encoder_downsample_rate,
)
self.decoder = Qwen2ForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder")
)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
return decoder_outputs
def get_encoder_outputs(
self,
speech: torch.Tensor | list[torch.Tensor] | None,
speech_lengths: torch.Tensor | list[torch.Tensor] | None,
) -> torch.Tensor | None:
encoder_outs, enc_lengths, enc_mask = self.encoder.audio_encoder(
speech, speech_lengths
)
speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths)
return speech_features
class FireRedASR2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> Qwen2Config:
return self.ctx.get_hf_config(Qwen2Config)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1}
def get_feature_extractor(self, **kwargs: object) -> FireRedASR2FeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, FireRedASR2FeatureExtractor)
return feature_extractor
def get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.get_target_channels(),
)
def get_target_channels(self) -> int:
return 1
class FireRedASR2DummyInputsBuilder(BaseDummyInputsBuilder[FireRedASR2ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
return "<|AUDIO|>" * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio")
ret = {
"audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides
)
}
return ret
class FireRedASR2MultiModalProcessor(
BaseMultiModalProcessor[FireRedASR2ProcessingInfo]
):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
input_features=MultiModalFieldConfig.batched("audio"),
speech_lengths=MultiModalFieldConfig.batched("audio"),
fake_token_lengths=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_token_id = vocab[audio_token]
out_mm_data = out_mm_kwargs.get_data()
fake_token_lengths = out_mm_data.get("fake_token_lengths")
if fake_token_lengths is None:
audio_output_lengths = []
else:
assert isinstance(fake_token_lengths, torch.Tensor)
audio_output_lengths = fake_token_lengths.tolist()
def get_replacement_fireredasr2_audio(item_idx: int):
num_features = audio_output_lengths[item_idx]
audio_tokens = [audio_token_id] * int(num_features)
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [
PromptReplacement(
modality="audio",
target=[audio_token_id],
replacement=get_replacement_fireredasr2_audio,
)
]
@MULTIMODAL_REGISTRY.register_processor(
FireRedASR2MultiModalProcessor,
info=FireRedASR2ProcessingInfo,
dummy_inputs=FireRedASR2DummyInputsBuilder,
)
class FireRedASR2ForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal
):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"llm.": "model.decoder.",
"encoder.": "model.encoder.audio_encoder.",
"encoder_projector.": "model.encoder_projector.",
"net.0": "pre_layer_norm",
"net.1": "linear_expand",
"net.4": "linear_project",
}
)
supports_transcription_only = True
supports_segment_timestamp = True
supported_languages = ISO639_1_SUPPORTED_LANGS
@classmethod
def validate_language(cls, language: str | None) -> str | None:
if language is None:
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
logger.warning(
"Defaulting to language='en'. If you wish to transcribe "
"audio in a different language, pass the `language` field "
"in the TranscriptionRequest."
)
language = "en"
return super().validate_language(language)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the fireredasr2 prompt"
)
prompt_str = "<|im_start|>user\n<|AUDIO|>请转写音频为文字<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
prompt = {
"prompt": prompt_str,
"multi_modal_data": {
"audio": (audio, stt_config.sample_rate),
},
}
return cast(PromptType, prompt)
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
processor = cached_processor_from_config(model_config)
return SpeechToTextConfig(
max_audio_clip_s=processor.feature_extractor.chunk_length,
sample_rate=processor.feature_extractor.sampling_rate,
)
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
processor = cached_processor_from_config(model_config)
hop_length = processor.feature_extractor.hop_length
assert hop_length is not None
return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.dtype = vllm_config.model_config.dtype
self.model = FireRedASR2Model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
decoder_outputs = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
return decoder_outputs
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
speech = audio_input["input_features"]
speech_lengths = audio_input["speech_lengths"].to(torch.int32)
enc_output = self.model.get_encoder_outputs(
speech=speech, speech_lengths=speech_lengths
)
return enc_output
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.model.decoder.embed_input_ids(input_ids)
ret = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=_require_is_multimodal(is_multimodal),
)
return ret
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> FireRedASR2AudioInputs:
input_features = kwargs.pop("input_features", None)
speech_lengths = kwargs.pop("speech_lengths", None)
fake_token_lengths = kwargs.pop("fake_token_lengths", None)
return FireRedASR2AudioInputs(
input_features=input_features,
speech_lengths=speech_lengths,
fake_token_lengths=fake_token_lengths,
)
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.model.decoder.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self, skip_prefixes=["model.encoder.audio_encoder.positional_encoding.pe"]
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@@ -13,7 +13,6 @@ positions via `inputs_embeds`, while `position_ids` (RoPE) remains standard 1D.
from __future__ import annotations
import os
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any
@@ -924,53 +923,6 @@ class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, Suppor
f"sequence of Tensors (got {type(speech_attention_mask)})"
)
debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1"
if debug:
print(
f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} "
f"speech_attention_mask={tuple(speech_attention_mask.shape)}",
flush=True,
)
attn_impl = getattr(
self.continuous_audio_tower.config, "_attn_implementation", None
)
print(
f"[FunAudioChat] audio_attn_impl={attn_impl}",
flush=True,
)
if hasattr(self.continuous_audio_tower, "conv1"):
conv1_w = self.continuous_audio_tower.conv1.weight
print(
f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}",
flush=True,
)
try:
attn0 = self.continuous_audio_tower.layers[0].self_attn
q_norm = float(attn0.q_proj.weight.norm().item())
k_norm = float(attn0.k_proj.weight.norm().item())
v_norm = float(attn0.v_proj.weight.norm().item())
o_norm = float(attn0.out_proj.weight.norm().item())
print(
f"[FunAudioChat] attn0_q_norm={q_norm:.6g} "
f"k_norm={k_norm:.6g} "
f"v_norm={v_norm:.6g} "
f"o_norm={o_norm:.6g}",
flush=True,
)
except Exception:
pass
if isinstance(input_features, torch.Tensor):
print(
f"[FunAudioChat] input_features={tuple(input_features.shape)}",
flush=True,
)
if isinstance(feature_attention_mask, torch.Tensor):
print(
"[FunAudioChat] feature_attention_mask="
f"{tuple(feature_attention_mask.shape)}",
flush=True,
)
group_size = int(self.audio_tower.group_size)
speech_maxlen = int(speech_ids.shape[-1])
@@ -1019,38 +971,6 @@ class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, Suppor
embeds = tuple(
audio_features[i, : int(length)] for i, length in enumerate(lengths)
)
if debug:
embed_lens = [int(t.shape[0]) for t in embeds]
print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True)
if embeds:
t0 = embeds[0]
print(
f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} "
f"nan={bool(torch.isnan(t0).any())} "
f"norm={float(t0.norm().item()):.6g}",
flush=True,
)
dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "")
if (
dump_path
and speech_ids.shape[0] == 1
and len(embeds) == 1
and embed_lens[0] > 10
):
if not os.path.exists(dump_path):
np.save(dump_path, embeds[0].detach().float().cpu().numpy())
print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True)
cont_path = dump_path.replace(".npy", "_cont.npy")
if continuous_audio_features is not None and not os.path.exists(
cont_path
):
np.save(
cont_path,
continuous_audio_features.detach().float().cpu().numpy(),
)
print(
f"[FunAudioChat] dumped continuous to {cont_path}", flush=True
)
return embeds
def forward(

View File

@@ -409,7 +409,7 @@ class Gemma3nAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
q = q.unflatten(-1, (self.num_heads, self.head_dim))
q = self.q_norm(q)
q = q.flatten(-2, -1)

View File

@@ -110,7 +110,12 @@ class Glm4MoeMLP(nn.Module):
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
if self.down_proj.quant_method.__class__.__name__ != "UnquantizedLinearMethod" and x.shape[-1] != self.down_proj.weight.shape[0]:
padding = self.down_proj.weight.shape[0] - x.shape[-1]
x_align = torch.nn.functional.pad(x, (0, padding), mode='constant', value=0)
else:
x_align = x
x, _ = self.down_proj(x_align)
return x
@@ -144,11 +149,10 @@ class Glm4MoE(nn.Module):
config.hidden_size,
config.n_routed_experts,
bias=False,
# dtype=torch.float32,
dtype=torch.bfloat16,
)
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts)
)
torch.empty(config.n_routed_experts, dtype=torch.bfloat16))
# Load balancing settings.
vllm_config = get_current_vllm_config()
@@ -205,8 +209,7 @@ class Glm4MoE(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
# router_logits = self.gate(hidden_states.to(dtype=torch.float32))
router_logits = self.gate(hidden_states)
router_logits = self.gate(hidden_states.to(dtype=torch.bfloat16))
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
@@ -312,6 +315,9 @@ class Glm4MoeAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
if self.use_qk_norm:
q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(
q.shape

View File

@@ -127,12 +127,10 @@ class Glm4MoeLiteDecoderLayer(nn.Module):
v_head_dim = getattr(config, "v_head_dim", 0)
kv_lora_rank = getattr(config, "kv_lora_rank", 0)
# if model_config.use_mla:
# attn_cls = Glm4MoeLiteMLAAttention
# else:
# attn_cls = Glm4MoeLiteAttention
attn_cls = Glm4MoeLiteAttention
if model_config.use_mla:
attn_cls = Glm4MoeLiteMLAAttention
else:
attn_cls = Glm4MoeLiteAttention
self.self_attn = attn_cls(
vllm_config=vllm_config,
@@ -306,7 +304,7 @@ class Glm4MoeLiteModel(nn.Module):
),
}
)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
@@ -318,6 +316,120 @@ class Glm4MoeLiteModel(nn.Module):
num_experts=self.config.n_routed_experts,
)
class Glm4MoeLiteForCausalLM(
nn.Module, SupportsPP, SupportsLoRA, Glm4LiteMixtureOfExperts
):
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
self.use_mha = config.model_type == "deepseek" or all(
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
)
if self.use_mha:
self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
# `packed_modules_mapping` needs to be modified before
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
# quant_method for relevant layers during initialization.
self.fuse_qkv_a_proj = (
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
)
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
self.model = Glm4MoeLiteModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
# Set MoE hyperparameters
self.num_moe_layers = (
self.config.num_hidden_layers - self.config.first_k_dense_replace
)
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = []
self.num_expert_groups = getattr(self.config, "n_group", 1)
self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Glm4MoeLiteDecoderLayer)
if isinstance(layer.mlp, Glm4MoeLite):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
@@ -327,12 +439,13 @@ class Glm4MoeLiteModel(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
mla_params_mapping = [
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
mha_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
stacked_params_mapping.extend(mla_params_mapping)
if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
@@ -510,128 +623,71 @@ class Glm4MoeLiteModel(nn.Module):
weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"]
# add your opt here..
def inject_layer(layer, quant_method, is_mla):
q_lora_rank = getattr(layer, "q_lora_rank", None)
if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]:
if q_lora_rank is not None:
layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_a_proj, "weight_scale"):
layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
elif not is_mla:
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_proj, "weight_scale"):
layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
else:
return
del layer.kv_a_proj_with_mqa.weight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
elif quant_method == "GGUFLinearMethod":
pass
elif quant_method == "AWQMarlinLinearMethod":
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
assert dtype == torch.int32
if layer.q_lora_rank is not None:
layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1)
layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
elif not is_mla:
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1)
layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
else:
return
del layer.kv_a_proj_with_mqa.qweight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
else:
pass
for _, layer in self.model.named_modules():
if layer.__class__.__name__ in ["Glm4MoeLiteAttention","Glm4MoeLiteMLAAttention"]:
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__
else:
quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__
if quant_method not in opt_support_quant_method:
break
inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "Glm4MoeLiteMLAAttention")
return loaded_params
class Glm4MoeLiteForCausalLM(
nn.Module, SupportsPP, SupportsLoRA, Glm4LiteMixtureOfExperts
):
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
self.use_mha = config.model_type == "deepseek" or all(
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
)
if self.use_mha:
self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
# `packed_modules_mapping` needs to be modified before
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
# quant_method for relevant layers during initialization.
self.fuse_qkv_a_proj = (
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
)
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
self.model = Glm4MoeLiteModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
# Set MoE hyperparameters
self.num_moe_layers = (
self.config.num_hidden_layers - self.config.first_k_dense_replace
)
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = []
self.num_expert_groups = getattr(self.config, "n_group", 1)
self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Glm4MoeLiteDecoderLayer)
if isinstance(layer.mlp, Glm4MoeLite):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_spec_layer_idx_from_weight_name(
config: "Glm4MoeLiteConfig", weight_name: str
) -> int | None:

View File

@@ -7,7 +7,7 @@ import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig
import vllm.envs as envs
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
@@ -23,7 +23,11 @@ from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
@@ -42,6 +46,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionType
from vllm.model_executor.model_loader import padding_weight_loader
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
@@ -107,7 +112,6 @@ class OAIAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.num_attention_heads * self.head_dim,
output_size=self.hidden_size,
@@ -165,7 +169,14 @@ class MLPBlock(torch.nn.Module):
self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=True,
quant_config=None,
prefix=f"{prefix}.router",
return_bias=False,
)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(
num_experts=config.num_local_experts,
@@ -969,8 +980,18 @@ class GptOssModel(nn.Module):
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
def handle_weight(name, weight, param_name, permute_dims=None, slice_dims=None, contiguous=True):
"""Helper function to handle weight loading with optional slicing and permutation."""
param = params_dict[param_name]
if slice_dims:
weight = weight[slice_dims]
if permute_dims:
weight = weight.permute(*permute_dims)
if contiguous:
weight = weight.contiguous()
padding_weight_loader(param, weight)
loaded_params.add(param_name)
use_ep = self.parallel_config.enable_expert_parallel
@@ -986,91 +1007,71 @@ class GptOssModel(nn.Module):
intermediate_size = self.config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
pack_factor = 2 if envs.VLLM_W8A8_MOE_USE_W4A8 else 1
w4a8_flag = envs.VLLM_W8A8_MOE_USE_W4A8
gemm_format = envs.VLLM_W8A8_FORMAT
for name, weight in weights:
# Skip layers on other devices.
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ".w13_weight" in name:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
if ".experts.w13_weight" in name and "scale" not in name and "bias" not in name:
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end))
permute_dims = None if gemm_format == "NN" else (0, 2, 1)
handle_weight(name, weight, name, permute_dims=permute_dims, slice_dims=slice_dims)
elif ".experts.w2_weight" in name and "scale" not in name and "bias" not in name:
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(tp_rank_start // pack_factor, tp_rank_end // pack_factor), slice(None))
permute_dims = None if gemm_format == "NN" else (0, 2, 1)
handle_weight(name, weight, name, permute_dims=permute_dims, slice_dims=slice_dims)
elif ".experts.gate_up_proj_scale" in name:
new_name = name.replace("gate_up_proj_scale", "w13_weight_scale")
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end))
permute_dims = None if w4a8_flag else (0, 2, 1)
handle_weight(name, weight, new_name, permute_dims=permute_dims, slice_dims=slice_dims, contiguous=w4a8_flag)
elif ".experts.down_proj_scale" in name:
new_name = name.replace("down_proj_scale", "w2_weight_scale")
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else None
permute_dims = None if w4a8_flag else (0, 2, 1)
handle_weight(name, weight, new_name, permute_dims=permute_dims, slice_dims=slice_dims, contiguous=w4a8_flag)
elif ".experts.w13_bias" in name:
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end))
handle_weight(name, weight, name, slice_dims=slice_dims, contiguous=False)
elif ".experts.w2_bias" in name:
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w2_weight" in name:
# Handle MLP down projection weights
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w2_bias" in name:
# Handle MLP down projection bias
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[name]
param.copy_(weight)
elif tp_rank != 0:
weight.zero_()
param.data.copy_(weight)
loaded_params.add(name)
continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
elif ("q_proj" in name or "k_proj" in name or "v_proj" in name):
shard_id = ("q" if "q_proj" in name else "k" if "k_proj" in name else "v")
name = name.replace("self_attn", "attn")
param_name = name.replace(f"{shard_id}_proj", "qkv_proj")
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, weight, loaded_shard_id=shard_id)
loaded_params.add(param_name)
else:
# Handle all other weights with potential renaming
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight)
loaded_params.add(name)
loaded_params.add(name)
return loaded_params
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

View File

@@ -636,7 +636,13 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
spatial_merge_size = vision_config.spatial_merge_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
size = image_processor.size
if override_size := mm_kwargs.get("size"):
size = size | override_size
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
size = size | {"shortest_edge": override_min_pixels}
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
size = size | {"longest_edge": override_max_pixels}
if do_resize:
resized_height, resized_width = smart_resize(

View File

@@ -49,7 +49,6 @@ from .utils import (
)
from .vision import get_vision_encoder_info
EOT = "<|endofturn|>"
IMAGE_TOKEN: str = "<|dummy3|>"
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"

View File

@@ -551,7 +551,7 @@ def process_vision_for_patches(
`(num_images, height, width, channels)` for a batch. Channels are
expected to be RGB.
patch_size (`int`):
Edge length of square patches; implictly controls resize grid granularity.
Edge length of square patches; implicitly controls resize grid granularity.
max_num_patches (`int`):
Maximum number of patches allowed after resizing.
min_num_patches (`int`, *optional*):

View File

@@ -1021,7 +1021,13 @@ class KeyeProcessingInfo(BaseProcessingInfo):
temporal_patch_size = 1
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
size = image_processor.size
if override_size := mm_kwargs.get("size"):
size = size | override_size
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
size = size | {"min_pixels": override_min_pixels}
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
size = size | {"max_pixels": override_max_pixels}
if do_resize:
resized_height, resized_width = smart_resize(

View File

@@ -654,4 +654,4 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
return loader.load_weights(weights)

View File

@@ -230,4 +230,4 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
}
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)

View File

@@ -88,7 +88,7 @@ class MiniMaxM2MoE(nn.Module):
self.use_routing_bias = getattr(config, "use_routing_bias", False)
if self.use_routing_bias:
self.e_score_correction_bias = nn.Parameter(
torch.empty(config.num_local_experts, dtype=torch.float32)
torch.empty(config.num_local_experts, dtype=torch.get_default_dtype())
)
self.e_score_correction_bias.weight_loader = (
MiniMaxM2MoE.ebias_weight_loader
@@ -107,13 +107,14 @@ class MiniMaxM2MoE(nn.Module):
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
router_logits_dtype=torch.float32,
)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=False,
# params_dtype=torch.float32,
params_dtype=torch.float32,
quant_config=None,
prefix=f"{prefix}.gate",
)
@@ -121,7 +122,6 @@ class MiniMaxM2MoE(nn.Module):
@staticmethod
def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
# param.data.copy_(loaded_weight.to(torch.float32))
param.data.copy_(loaded_weight)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -129,10 +129,9 @@ class MiniMaxM2MoE(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
# router_logits, _ = self.gate(hidden_states.to(torch.float32))
router_logits, _ = self.gate(hidden_states)
router_logits, _ = self.gate(hidden_states.to(torch.float32))
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
hidden_states=hidden_states, router_logits=router_logits.to(hidden_states.dtype)
)
final_hidden_states = final_hidden_states
if self.tp_size > 1:

View File

@@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import (
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.parakeet import ParakeetExtractor, ProjectedParakeet
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
from vllm.model_executor.models.utils import (
init_vllm_registered_model,
@@ -55,12 +56,14 @@ from vllm.multimodal.evs import (
compute_retention_mask,
)
from vllm.multimodal.inputs import (
AudioItem,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
@@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
class NanoNemotronVLAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- b: Number of audio clips
- t: Audio feature length
- f: Feature size (mel bins)
"""
type: Literal["audio_features"] = "audio_features"
input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")]
feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")]
audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")]
MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<image>"
AUDIO_START = "<so_start>"
AUDIO_END = "<so_end>"
AUDIO_CONTEXT = "<so_embedding>"
# Profiling
# MAX_FRAMES = 16
@@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self.video_token = video_token
self.video_pruning_rate = video_pruning_rate
self.audio_extractor: ParakeetExtractor | None = None
raw_sound_config = getattr(config, "sound_config", None)
if raw_sound_config is not None:
self.audio_extractor = ParakeetExtractor(raw_sound_config)
# Pre-tokenize special tokens for video processing
# to avoid repeated tokenization
self._img_start_token_ids = tokenizer.encode(
@@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text = [t.replace("<video>", video_repl_text, 1) for t in text]
return text, video_inputs
def _preprocess_audio(
self,
text: list[str],
audios: list[npt.NDArray],
):
if len(audios) == 0:
return text, {}
assert self.audio_extractor is not None
extractor = self.audio_extractor
parts = [x for x in re.split(f"({re.escape(AUDIO_CONTEXT)})", text[0]) if x]
token_count = parts.count(AUDIO_CONTEXT)
if token_count != len(audios):
raise ValueError(
"Number of audio tokens in text does not match the number "
f"of audios (tokens={token_count}, audios={len(audios)})."
)
audio_index = 0
for idx, part in enumerate(parts):
if part == AUDIO_CONTEXT:
audio_repl = self.get_audio_repl(audios[audio_index])
parts[idx] = audio_repl.full
audio_index += 1
text = ["".join(parts)]
audio_inputs = extractor(
audios,
sampling_rate=extractor.sampling_rate,
return_tensors="pt",
)
input_audio_features = audio_inputs.input_features
feature_attention_mask = audio_inputs.attention_mask
audio_feature_lengths = feature_attention_mask.sum(dim=1)
audio_inputs = {
"input_audio_features": input_audio_features,
"feature_attention_mask": feature_attention_mask,
"audio_feature_lengths": audio_feature_lengths,
}
return text, audio_inputs
def __call__(
self,
text: str | list[str] | None = None,
images: Image.Image | list[Image.Image] | None = None,
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
audios: AudioItem | list[AudioItem] | None = None,
return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None,
) -> BatchFeature:
@@ -964,8 +1034,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
if max_num_tiles is None:
max_num_tiles = self.max_num_tiles
text, images, videos = [
self._make_batch_input(x) for x in (text, images, videos)
text, images, videos, audios = [
self._make_batch_input(x) for x in (text, images, videos, audios)
]
text, image_inputs = self._preprocess_image(
@@ -980,17 +1050,22 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
max_num_tiles=1,
)
text, audio_inputs = self._preprocess_audio(
text=text,
audios=audios,
)
text_inputs = self.tokenizer(text, add_special_tokens=False)
combined_inputs = {**text_inputs, **video_inputs, **audio_inputs}
if self.dynamic_tiler is None:
batch = BatchFeature(
{**text_inputs, **video_inputs, **image_inputs},
{**combined_inputs, **image_inputs},
tensor_type=return_tensors,
)
else:
batch = BatchFeature(
{**text_inputs, **video_inputs}, tensor_type=return_tensors
)
batch = BatchFeature(combined_inputs, tensor_type=return_tensors)
# allow images to be exempt from the BatchFeature validation:
# We will .stack() them in _parse_and_validate_image_input
batch.update(image_inputs)
@@ -1006,6 +1081,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def get_audio_repl(
self,
audio: npt.NDArray,
) -> PromptUpdateDetails[str]:
assert self.audio_extractor is not None
num_tokens = self.audio_extractor.audio_token_count(len(audio))
repl_full = f"{AUDIO_START}{AUDIO_CONTEXT * num_tokens}{AUDIO_END}"
return PromptUpdateDetails.select_text(repl_full, AUDIO_CONTEXT)
@classmethod
def get_video_repl(
cls,
@@ -1147,15 +1231,28 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
def supports_video(self):
return self.get_hf_processor().supports_video
@property
def audio_extractor(self) -> ParakeetExtractor | None:
return self.get_hf_processor().audio_extractor
def get_data_parser(self):
target_sr = None
target_channels = None
if extractor := self.audio_extractor:
target_sr = extractor.sampling_rate
target_channels = 1
return MultiModalDataParser(
video_needs_metadata=True,
target_sr=target_sr,
target_channels=target_channels,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self):
video_limit = {"video": None} if self.supports_video else {}
return {**super().get_supported_mm_limits(), **video_limit}
audio_limit = {"audio": None} if self.audio_extractor is not None else {}
return {**super().get_supported_mm_limits(), **video_limit, **audio_limit}
def get_video_token(self) -> str | None:
return IMG_CONTEXT
@@ -1304,7 +1401,16 @@ class NanoNemotronVLMultiModalProcessor(
else:
video_fields = {}
return image_fields | video_fields
if self.info.audio_extractor is not None:
audio_fields = dict(
input_audio_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
)
else:
audio_fields = {}
return image_fields | video_fields | audio_fields
def _get_prompt_updates(
self,
@@ -1373,6 +1479,20 @@ class NanoNemotronVLMultiModalProcessor(
),
]
def get_audio_replacement(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
return hf_processor.get_audio_repl(audios.get(item_idx))
if self.info.audio_extractor is not None:
prompt_repl = [
*prompt_repl,
PromptReplacement(
modality="audio",
target=AUDIO_CONTEXT,
replacement=get_audio_replacement,
),
]
return prompt_repl
@@ -1422,8 +1542,13 @@ class NanoNemotronVLDummyInputsBuilder(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_videos = mm_counts.get("video", 0)
num_audios = mm_counts.get("audio", 0)
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
return (
super().get_dummy_text(mm_counts)
+ "<video>" * num_videos
+ AUDIO_CONTEXT * num_audios
)
def _get_dummy_videos(
self,
@@ -1482,7 +1607,25 @@ class NanoNemotronVLDummyInputsBuilder(
}
else:
dummy_video = {}
return {**dummy_image, **dummy_video}
if extractor := self.info.audio_extractor:
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
tokens_per_audio = max(1, seq_len // max(num_audios, 1))
max_audio_num_samples = MAX_AUDIO_LEN_S * extractor.sampling_rate
calculated_max_audio_num_samples = extractor.audio_length(tokens_per_audio)
audio_len = min(max_audio_num_samples, calculated_max_audio_num_samples)
dummy_audio = {
"audio": self._get_dummy_audios(
length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
}
else:
dummy_audio = {}
return {**dummy_image, **dummy_video, **dummy_audio}
@MULTIMODAL_REGISTRY.register_processor(
@@ -1499,12 +1642,15 @@ class NemotronH_Nano_VL_V2(
return "<image>"
if modality.startswith("video"):
return "<video>"
if modality.startswith("audio"):
return AUDIO_CONTEXT
return None
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
model_config = vllm_config.model_config
config = model_config.hf_config
multimodal_config = model_config.multimodal_config
image_size = config.force_image_size
patch_size = config.patch_size
self.patch_size = patch_size
@@ -1523,10 +1669,12 @@ class NemotronH_Nano_VL_V2(
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_tower_model(vllm_config, {"image", "video"}):
llm_dtype = self.language_model.config.dtype
assert isinstance(llm_dtype, torch.dtype)
self.llm_dtype = llm_dtype
with self._mark_tower_model(vllm_config, {"image", "video", "audio"}):
self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.dtype
llm_dtype
)
# Construct the vision projection.
@@ -1547,14 +1695,26 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
)
self.mlp1 = mlp1.to(self.language_model.config.dtype)
self.mlp1 = mlp1.to(llm_dtype)
self.sound_encoder: ProjectedParakeet | None = None
if getattr(config, "sound_config", None) is not None:
logger.info_once(
"Found sound config, initializing sound encoder for Nemotron AVLM",
scope="global",
)
self.sound_encoder = ProjectedParakeet(
config.sound_config,
dtype=llm_dtype,
llm_hidden_size=llm_hidden_size,
max_model_len=model_config.max_model_len,
)
self.config = config
self.model_config = vllm_config.model_config
# Pre-tokenize special tokens for video processing
# to avoid repeated tokenization
tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
tokenizer = cached_tokenizer_from_config(model_config)
self._img_start_token_ids = tokenizer.encode(
IMG_START, add_special_tokens=False
)
@@ -1566,7 +1726,10 @@ class NemotronH_Nano_VL_V2(
config
)
if self.dynamic_resolution:
logger.info("Dynamic resolution is enabled for NanoNemotronVLProcessor")
logger.info_once(
"Dynamic resolution is enabled for NanoNemotronVLProcessor",
scope="global",
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
@@ -1780,6 +1943,51 @@ class NemotronH_Nano_VL_V2(
return final_video_embeddings
def _process_audio_input(
self, audio_input: NanoNemotronVLAudioFeatureInputs
) -> tuple[torch.Tensor, ...]:
assert self.sound_encoder is not None
input_audio_features = audio_input.input_audio_features
feature_attention_mask = audio_input.feature_attention_mask
target_device = next(self.sound_encoder.parameters()).device
# When cross-request batching combines audio clips with different
# time dimensions, _reduce_data returns a list instead of a stacked
# tensor. Pad to the max time dim and stack; the attention mask
# already marks valid positions so zero-padding is safe.
if isinstance(input_audio_features, list):
feature_sizes = [f.shape[-2] for f in input_audio_features]
max_t = max(feature_sizes)
padded_feats = [
torch.nn.functional.pad(feat, (0, 0, 0, max_t - feat_size))
for feat, feat_size in zip(
input_audio_features, feature_sizes, strict=True
)
]
padded_masks = [
torch.nn.functional.pad(mask, (0, max_t - mask.shape[-1]))
for mask in feature_attention_mask
]
input_audio_features = torch.stack(padded_feats)
feature_attention_mask = torch.stack(padded_masks)
input_audio_features = input_audio_features.to(
dtype=self.llm_dtype, device=target_device
)
feature_attention_mask = feature_attention_mask.to(device=target_device)
sound_embeds = self.sound_encoder(input_audio_features, feature_attention_mask)
valid_input_lens = feature_attention_mask.sum(dim=1)
valid_output_lens = self.sound_encoder.encoder._get_subsampling_output_length(
valid_input_lens
)
truncated_embeds = []
for i in range(sound_embeds.shape[0]):
valid_len = valid_output_lens[i].item()
truncated_embeds.append(sound_embeds[i, :valid_len])
return tuple(truncated_embeds)
def _create_final_video_embeddings(
self,
video_embeddings: torch.Tensor,
@@ -1887,6 +2095,18 @@ class NemotronH_Nano_VL_V2(
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
if (
input_key
in (
"input_audio_features",
"feature_attention_mask",
"audio_feature_lengths",
)
and "audios" not in modalities
):
modalities["audios"] = NanoNemotronVLAudioFeatureInputs(
**kwargs, validate=False
)
return modalities
@@ -1917,6 +2137,10 @@ class NemotronH_Nano_VL_V2(
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += tuple(video_embeddings)
if modality == "audios":
audio_input = modalities["audios"]
audio_embeddings = self._process_audio_input(audio_input)
multimodal_embeddings += tuple(audio_embeddings)
return multimodal_embeddings
@@ -1947,8 +2171,8 @@ class NemotronH_Nano_VL_V2(
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="mlp1",
tower_model="vision_model",
connector=["mlp1", "sound_encoder.projection"],
tower_model=["vision_model", "sound_encoder.encoder"],
)
def compute_logits(
@@ -1969,9 +2193,13 @@ class NemotronH_Nano_VL_V2(
def is_vision_weights(name: str) -> bool:
return name.startswith("vision_model.radio_model.")
def is_sound_weights(name: str) -> bool:
return name.startswith("sound")
# Separate weights by component
llm_weights = []
vision_weights = []
sound_weights = []
for name, w in weights:
if is_llm(name):
@@ -1987,107 +2215,15 @@ class NemotronH_Nano_VL_V2(
# Convert: vision_model.radio_model.* → radio_model.*
hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix
vision_weights.append((hf_key, w))
elif is_sound_weights(name):
assert self.sound_encoder is not None
sound_weights.append((name, w))
self.language_model.load_weights(llm_weights)
self.vision_model.load_weights(vision_weights)
def print_architecture(self, detailed: bool = True, save_to_file: str = None):
"""
Print model architecture with parameter names, shapes, and sizes.
Args:
detailed: If True, show detailed parameter breakdown
save_to_file: If provided, save output to this file path
"""
import sys
from io import StringIO
# Capture output if saving to file
original_stdout = sys.stdout
if save_to_file:
sys.stdout = StringIO()
try:
print("=" * 100)
print("NemotronH_Nano_VL_V2 Model Architecture")
print("=" * 100)
total_params = 0
param_groups = {
"language_model": [],
"vision_model": [],
"mlp1": [],
"other": [],
}
for name, param in self.named_parameters():
param_size = param.numel()
total_params += param_size
# Group parameters by main component
if name.startswith("language_model"):
param_groups["language_model"].append(
(name, param.shape, param_size, param.dtype)
)
elif name.startswith("vision_model"):
param_groups["vision_model"].append(
(name, param.shape, param_size, param.dtype)
)
elif name.startswith("mlp1"):
param_groups["mlp1"].append(
(name, param.shape, param_size, param.dtype)
)
else:
param_groups["other"].append(
(name, param.shape, param_size, param.dtype)
)
if detailed:
print(
f"{name:<70} | Shape: {str(param.shape):<25} | "
f"Size: {param_size:>12,} | Dtype: {param.dtype}"
)
print("=" * 100)
print("Summary by Component:")
print("-" * 60)
for component, params in param_groups.items():
if params: # Only show components that have parameters
component_total = sum(size for _, _, size, _ in params)
percentage = (
(component_total / total_params) * 100
if total_params > 0
else 0
)
print(
f"{component:<20} | Parameters: {len(params):>4} | "
f"Total Size: {component_total:>15,} | "
f"{percentage:>6.2f}%"
)
print("-" * 60)
print(f"{'Total Parameters':<20} | {total_params:>15,}")
# Estimate memory usage (assuming bfloat16 = 2 bytes per parameter)
memory_mb = total_params * 2 / (1024**2)
memory_gb = memory_mb / 1024
print(f"{'Est. Memory (MB)':<20} | {memory_mb:>15.2f}")
print(f"{'Est. Memory (GB)':<20} | {memory_gb:>15.2f}")
print("=" * 100)
# Save to file if requested
if save_to_file:
output = sys.stdout.getvalue()
sys.stdout = original_stdout
with open(save_to_file, "w") as f:
f.write(output)
print(f"Architecture saved to: {save_to_file}")
print(output) # Also print to console
finally:
if save_to_file and sys.stdout != original_stdout:
sys.stdout = original_stdout
if self.sound_encoder is not None:
assert len(sound_weights) > 0
self.sound_encoder.load_weights(sound_weights)
def get_vit_model_from_radio_config(self, hf_config):
hf_config_vision = hf_config.vision_config

View File

@@ -34,7 +34,7 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
GateLinear,
SharedFusedMoE,
activation_without_mul,
)
@@ -148,13 +148,11 @@ class NemotronHMoE(nn.Module):
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
router_logits_dtype = torch.float32
self.gate = ReplicatedLinear(
self.gate = GateLinear(
config.hidden_size,
config.n_routed_experts,
bias=False,
params_dtype=router_logits_dtype,
quant_config=None,
out_dtype=torch.float32,
force_fp32_compute=True,
prefix=f"{prefix}.gate",
)
@@ -232,7 +230,6 @@ class NemotronHMoE(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
router_logits_dtype=router_logits_dtype,
routed_input_transform=self.fc1_latent_proj,
)
@@ -244,7 +241,7 @@ class NemotronHMoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
router_logits, _ = self.gate(hidden_states)
# SharedFusedMoE handles:
# - shared experts (with original hidden_states)
@@ -298,6 +295,11 @@ class NemotronHMLPDecoderLayer(nn.Module):
hybrid_override_pattern = config.hybrid_override_pattern
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
# Get per-layer config for heterogeneous models if exist
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
layer_config = get_layer_config(layer_idx) if get_layer_config else config
config = layer_config
if isinstance(config.intermediate_size, list):
if len(config.intermediate_size) == 1:
intermediate_size = config.intermediate_size[0]
@@ -670,7 +672,7 @@ class NemotronHModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if self.has_moe:
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
# - FusedMoe.w1 (aka gate_proj) should be up_proj since that's
# what the activation is applied to
# - FusedMoe.w3 (aka up_proj) should be ignored since we're

View File

@@ -7,6 +7,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import math
from abc import ABC
from collections.abc import Iterable
@@ -18,6 +19,8 @@ from transformers import AutoModel, PretrainedConfig
from transformers.image_processing_utils_fast import BaseImageProcessorFast
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.internvl import (
@@ -30,24 +33,29 @@ from vllm.model_executor.models.internvl import (
InternVLProcessor,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.siglip import SiglipVisionModel
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_image_processor_from_config
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
from .interfaces import (
MultiModalEmbeddings,
SupportsCrossEncoding,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<image>"
from .interfaces_base import VllmModelForPooling
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
def build_transform(input_size: int):
@@ -183,10 +191,12 @@ def image_to_pixel_values_nemotron_vl(
min_num: int,
max_num: int,
use_thumbnail: bool,
transform: T.Compose | None = None,
) -> torch.Tensor:
target_ratios = get_nemotron_vl_target_ratios(min_num, max_num)
transform = build_transform(input_size=input_size)
if transform is None:
transform = build_transform(input_size=input_size)
images = dynamic_preprocess_nemotron_vl(
image,
@@ -200,11 +210,15 @@ def image_to_pixel_values_nemotron_vl(
class NemotronVLProcessor(InternVLProcessor):
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<image>"
def __init__(
self,
config: PretrainedConfig,
tokenizer: TokenizerLike,
image_processor: BaseImageProcessorFast,
image_processor: BaseImageProcessorFast | None = None,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
@@ -236,11 +250,18 @@ class NemotronVLProcessor(InternVLProcessor):
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail: bool = self.image_processor.use_thumbnail
if image_processor is not None:
self.use_thumbnail = image_processor.use_thumbnail
else:
self.use_thumbnail = getattr(config, "use_thumbnail", True)
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
return self.tokenizer.get_vocab()[self.IMG_CONTEXT]
def _get_transform(self) -> T.Compose:
return build_transform(input_size=self.image_size)
def get_num_image_tokens(
self,
@@ -283,10 +304,26 @@ class NemotronVLProcessor(InternVLProcessor):
min_num=min_num,
max_num=max_num,
use_thumbnail=self.use_thumbnail,
transform=self._get_transform(),
)
for image in images
]
def _replace_image_tokens(
self,
text: list[str],
pixel_values_lst: list[torch.Tensor],
) -> list[str]:
"""Replace <image> placeholders with image tokens."""
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
# Use temporary placeholder to avoid replacing tokens we just inserted
NVL_IMAGE_CONTEXT = image_repl.full.replace("<image>", "<NVL_IMG_CONTEXT>")
text = [t.replace("<image>", NVL_IMAGE_CONTEXT, 1) for t in text]
return [t.replace("<NVL_IMG_CONTEXT>", self.IMG_CONTEXT) for t in text]
def _preprocess_image(
self,
text: list[str],
@@ -311,15 +348,7 @@ class NemotronVLProcessor(InternVLProcessor):
),
}
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
NVL_IMAGE_CONTEXT = image_repl.full.replace(
"<image>", "<NVL_IMG_CONTEXT>"
)
text = [t.replace("<image>", NVL_IMAGE_CONTEXT, 1) for t in text]
text = [t.replace("<NVL_IMG_CONTEXT>", IMG_CONTEXT) for t in text]
text = self._replace_image_tokens(text, pixel_values_lst)
return text, image_inputs
def get_image_repl(
@@ -327,10 +356,10 @@ class NemotronVLProcessor(InternVLProcessor):
feature_size: int,
num_patches: int | None,
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
repl_features = self.IMG_CONTEXT * feature_size
repl_full = self.IMG_START + repl_features + self.IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
return PromptUpdateDetails.select_text(repl_full, self.IMG_CONTEXT)
class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
@@ -396,7 +425,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
hf_config=config.get_text_config(),
prefix=maybe_prefix(prefix, "language_model"),
)
@@ -413,7 +442,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
text_config = config.get_text_config()
llm_quant_config = getattr(text_config, "quantization_config", None)
if (not quant_config.modules_to_not_convert) and (
llm_quant_config is not None
@@ -429,10 +458,17 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
):
return AutoModel.from_config(config.vision_config, trust_remote_code=True)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vit_hidden_size
vision_projection_hidden_size = config.projector_hidden_size
llm_hidden_size = config.text_config.hidden_size
def _init_mlp1(
self,
config: PretrainedConfig,
vit_hidden_size: int | None = None,
vision_projection_hidden_size: int | None = None,
) -> nn.Module:
if vit_hidden_size is None:
vit_hidden_size = config.vit_hidden_size
if vision_projection_hidden_size is None:
vision_projection_hidden_size = config.projector_hidden_size
llm_hidden_size = config.get_text_config().hidden_size
return nn.Sequential(
nn.LayerNorm(
@@ -465,10 +501,18 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
x = x.permute(0, 2, 1, 3).contiguous()
return x
def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Call vision model and return embeddings.
Override this method in subclasses to handle different vision model
interfaces (e.g., SigLIP vs C-RADIO).
"""
vit_embeds = self.vision_model(x=pixel_values).features
return vit_embeds.to(dtype=torch.bfloat16)
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
# https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/modeling.py#L177
vit_embeds = self.vision_model(x=pixel_values).features
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
vit_embeds = self._call_vision_model(pixel_values)
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
@@ -523,15 +567,16 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"]
hidden_size = self.config.get_text_config().hidden_size
# Only one image in the current batch
if len(num_patches) == 1:
return (image_embeds.view(-1, self.config.text_config.hidden_size),)
return (image_embeds.view(-1, hidden_size),)
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
image_embeds = image_embeds.view(-1, hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in num_patches
]
@@ -643,3 +688,255 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
connector="mlp1",
tower_model="vision_model",
)
# --------------------------------------------------------
# LlamaNemotronVL Embedding Model (nvidia/llama-nemotron-embed-vl-1b-v2)
# Extends LlamaNemotronVLChatModel for embedding/pooling tasks:
# - SigLIP vision encoder (instead of C-RADIO)
# - Bidirectional (non-causal) LLaMA language model
# - Pooler output instead of generative logits
# --------------------------------------------------------
# SigLIP normalization constants
SIGLIP_MEAN = (0.5, 0.5, 0.5)
SIGLIP_STD = (0.5, 0.5, 0.5)
def build_siglip_transform(input_size: int):
"""Build transform for SigLIP vision encoder with normalization.
Extends the base transform from nemotron_vl with SigLIP-specific normalization.
"""
base_transform = build_transform(input_size=input_size)
return T.Compose(
[
base_transform,
T.Normalize(mean=SIGLIP_MEAN, std=SIGLIP_STD),
]
)
class LlamaNemotronVLEmbedProcessor(NemotronVLProcessor):
"""
Processor for LlamaNemotronVL embedding model.
Inherits from NemotronVLProcessor and specializes it for embedding tasks:
- Uses SigLIP transform with normalization instead of base transform
- Uses different image context token (<IMG_CONTEXT> vs <image>)
"""
IMG_CONTEXT = "<IMG_CONTEXT>"
def __init__(
self,
config: PretrainedConfig,
tokenizer: TokenizerLike,
processor_config: dict,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
dynamic_image_size: bool | None = None,
) -> None:
if min_dynamic_patch is None:
min_dynamic_patch = processor_config.get(
"min_input_tiles",
getattr(config, "min_dynamic_patch", 1),
)
if max_dynamic_patch is None:
max_dynamic_patch = processor_config.get(
"max_input_tiles",
getattr(config, "max_dynamic_patch", 1),
)
if dynamic_image_size is None:
dynamic_image_size = processor_config.get(
"dynamic_image_size",
getattr(config, "dynamic_image_size", True),
)
super().__init__(
config=config,
tokenizer=tokenizer,
image_processor=None,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
def _get_transform(self) -> T.Compose:
"""Override to add SigLIP normalization."""
return build_siglip_transform(input_size=self.image_size)
def _replace_image_tokens(
self,
text: list[str],
pixel_values_lst: list[torch.Tensor],
) -> list[str]:
"""Override with simpler token replacement for embedding model.
No temporary placeholder needed because IMG_CONTEXT is <IMG_CONTEXT>,
not <image>, so there's no collision risk.
"""
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
text = [t.replace("<image>", image_repl.full, 1) for t in text]
return text
class LlamaNemotronVLEmbedProcessingInfo(NemotronVLProcessingInfo):
"""Processing info for LlamaNemotronVL embedding model."""
def get_hf_processor(self, **kwargs: object) -> LlamaNemotronVLEmbedProcessor:
"""Override to create embedding-specific processor without image_processor."""
model_config = self.ctx.model_config
processor_config = {}
if model_config.model is not None:
processor_config = (
get_hf_file_to_dict(
"processor_config.json",
model_config.model,
model_config.revision,
)
or {}
)
return self.ctx.init_processor(
LlamaNemotronVLEmbedProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
processor_config=processor_config,
**kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
BaseInternVLMultiModalProcessor[LlamaNemotronVLEmbedProcessingInfo],
info=LlamaNemotronVLEmbedProcessingInfo,
dummy_inputs=BaseInternVLDummyInputsBuilder[LlamaNemotronVLEmbedProcessingInfo],
)
class LlamaNemotronVLForEmbedding(LlamaNemotronVLChatModel, VllmModelForPooling):
"""
LlamaNemotronVL model for embeddings.
Inherits from LlamaNemotronVLChatModel and specializes it for embedding tasks:
- Uses SigLIP vision encoder instead of C-RADIO
- Uses bidirectional LLaMA (via llm_config) instead of causal LLaMA
- Adds pooler for embedding output instead of generating logits
"""
is_pooling_model = True
# Weight mapping from checkpoint format to vLLM format
# Different from parent class due to different vision model structure
weight_mapper = WeightsMapper(
orig_to_new_prefix={
# Language model mapping
"language_model.layers.": "language_model.model.layers.",
"language_model.embed_tokens.": "language_model.model.embed_tokens.",
"language_model.norm.": "language_model.model.norm.",
# Vision model mapping (SiglipVisionModel has nested vision_model)
"vision_model.encoder.": "vision_model.vision_model.encoder.",
"vision_model.embeddings.": "vision_model.vision_model.embeddings.",
"vision_model.post_layernorm.": "vision_model.vision_model.post_layernorm.",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
# Override: get img_context_token_id from config (parent sets None)
self.img_context_token_id = getattr(config, "img_context_token_id", None)
# Initialize pooler for embedding output
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_embedding(pooler_config)
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config,
*,
prefix: str,
) -> nn.Module:
"""Override to use SigLIP instead of C-RADIO."""
return SiglipVisionModel(
config.vision_config,
quant_config=quant_config,
prefix=prefix,
use_head=False,
)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
"""Override to use different MLP structure for embedding model."""
return super()._init_mlp1(
config,
vit_hidden_size=config.vision_config.hidden_size,
vision_projection_hidden_size=config.get_text_config().hidden_size,
)
def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Override to handle SigLIP interface."""
return self.vision_model(pixel_values)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Override to use different weight mapping for SigLIP."""
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.weight_mapper)
class LlamaNemotronVLForSequenceClassification(
LlamaNemotronVLForEmbedding, SupportsCrossEncoding
):
"""LlamaNemotronVL model variant for sequence classification / reranking."""
# Reranker checkpoint places base model weights under `model.*`,
# while `score.*` remains at the top level.
weight_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) | (
LlamaNemotronVLForEmbedding.weight_mapper
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
text_config = vllm_config.model_config.hf_config.get_text_config()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
self.score = ReplicatedLinear(
model_config.get_hidden_size(),
text_config.num_labels,
bias=False,
params_dtype=model_config.head_dtype,
quant_config=quant_config,
return_bias=False,
prefix=maybe_prefix(prefix, "score"),
)
pooler_config = model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_weights = super().load_weights(weights)
# reranker checkpoint omits the inner LM seq-cls head
# (`language_model.score.*`). It is unused by this outer model, but
# the default loader expects all parameters to be initialized.
for name, param in self.named_parameters():
if not name.startswith("language_model.score.") or name in loaded_weights:
continue
if name.endswith(".weight"):
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
elif name.endswith(".bias"):
torch.nn.init.zeros_(param)
else:
torch.nn.init.normal_(param, mean=0.0, std=0.02)
loaded_weights.add(name)
return loaded_weights

View File

@@ -43,12 +43,9 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
IMAGE_TOKEN = "<image>"
IMAGE_PLACEHOLDER_ID = 151669
VIDEO_TOKEN = "<video>"
VIDEO_PLACEHOLDER_ID = 151670
INDICATOR_IDS = [151672, 151673, 151674, 151675]
IMAGE_PAD_TOKEN_ID = 151655
THINK_END_TOKEN_ID = 151668
class Ovis2_5ImagePatchInputs(TensorSchema):

View File

@@ -155,15 +155,30 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
if self.ctx.model_config.trust_remote_code:
# Defined in HF Hub repo
min_pixels_key = "min_pixels"
max_pixels_key = "max_pixels"
else:
# Defined in Transformers library (requires v5.0 or above)
min_pixels_key = "shortest_edge"
max_pixels_key = "longest_edge"
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
size = image_processor.size
if override_size := mm_kwargs.get("size"):
size = size | override_size
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
size = size | {min_pixels_key: override_min_pixels}
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
size = size | {max_pixels_key: override_max_pixels}
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=size["min_pixels"],
max_pixels=size["max_pixels"],
min_pixels=size[min_pixels_key],
max_pixels=size[max_pixels_key],
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Modules below used for the audio encoder component in: models/nano_nemotron_vl.py
"""
from collections.abc import Iterable
from dataclasses import asdict
import numpy as np
import torch
import torch.nn as nn
from transformers import ParakeetEncoder as HFParakeetEncoder
from transformers import ParakeetFeatureExtractor, PretrainedConfig
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig
class ParakeetProjection(nn.Module):
def __init__(self, config: ParakeetConfig) -> None:
super().__init__()
sound_hidden_size = config.hidden_size
proj_hidden_size = config.projection_hidden_size
llm_hidden_size = config.llm_hidden_size
bias = config.projection_bias
self.norm = nn.LayerNorm(sound_hidden_size, eps=config.projection_eps)
self.linear1 = nn.Linear(sound_hidden_size, proj_hidden_size, bias=bias)
self.activation = ReLUSquaredActivation()
self.linear2 = nn.Linear(proj_hidden_size, llm_hidden_size, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
hidden_states = self.linear1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.linear2(hidden_states)
return hidden_states
class ProjectedParakeet(nn.Module):
def __init__(
self,
config: PretrainedConfig,
*,
dtype: torch.dtype,
llm_hidden_size: int,
max_model_len: int,
) -> None:
super().__init__()
self.config = ParakeetConfig.from_hf_config(
config, llm_hidden_size=llm_hidden_size, max_model_len=max_model_len
)
self.encoder = HFParakeetEncoder(self.config)
self.encoder = self.encoder.to(dtype)
self.projection = ParakeetProjection(self.config)
self.projection = self.projection.to(dtype)
def forward(
self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
outputs = self.encoder(
input_features=input_features, attention_mask=attention_mask
)
outputs = outputs.last_hidden_state
outputs = outputs.to(dtype=torch.bfloat16)
outputs = self.projection(outputs)
return outputs
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()
params_dict = dict(self.named_parameters())
buffers_dict = dict(self.named_buffers())
if isinstance(weights, dict):
weights_list = list(weights.items())
else:
weights_list = list(weights)
for name, weight in weights_list:
if name.startswith("sound_encoder.encoder.feature_extractor."):
# Feature extractor buffers are handled outside the encoder.
continue
if name.startswith("sound_encoder."):
target_name = name[len("sound_encoder.") :]
elif name.startswith("sound_projection."):
target_name = f"projection.{name[len('sound_projection.') :]}"
else:
continue
target = params_dict.get(target_name)
if target is None:
target = buffers_dict.get(target_name)
if target is None:
raise ValueError(f"Unknown weight: {name}")
weight_loader = getattr(target, "weight_loader", default_weight_loader)
with torch.no_grad():
weight_loader(target, weight)
loaded_params.add(target_name)
return loaded_params
class ParakeetExtractor(ParakeetFeatureExtractor):
def __init__(self, config: PretrainedConfig) -> None:
self.config = ExtractorConfig.from_hf_config(config)
super().__init__(**asdict(self.config))
self._clip_target_samples = int(
round(self.config.clip_duration_s * self.sampling_rate)
)
self._tail_min_samples = int(
round(self.config.clip_min_duration_s * self.sampling_rate)
)
def _normalize_audio_length(self, audio_len: int) -> int:
# Match mcore's compute_params() logic for clip/minduration handling.
target_len = max(audio_len, self._tail_min_samples)
tail_remainder = target_len % self._clip_target_samples
if 0 < tail_remainder < self._tail_min_samples:
padding = self._tail_min_samples - tail_remainder
target_len += padding
assert isinstance(target_len, int)
return target_len
def audio_token_count(self, audio_len: int) -> int:
audio_len = self._normalize_audio_length(audio_len)
num_frames = audio_len // self.hop_length
n_tokens = HFParakeetEncoder._get_subsampling_output_length(
self, torch.tensor([num_frames], dtype=torch.float)
)
return max(1, n_tokens.item())
def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs):
padded = []
for p in raw_speech:
assert p.ndim == 1
audio_len = int(p.shape[0])
target_len = self._normalize_audio_length(audio_len)
p = np.pad(p, (0, target_len - audio_len))
padded.append(p)
return super().__call__(padded, *args, **kwargs)
def audio_length(self, audio_tokens: int) -> int:
return int(audio_tokens * self.config.subsampling_factor * self.hop_length)

View File

@@ -221,7 +221,8 @@ def sparsemixer(scores, jitter_eps=0.01):
multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
multiplier = multiplier.to(torch.float32)
selected_experts = selected_experts.to(torch.int32)
return (
multiplier,
selected_experts,

View File

@@ -83,15 +83,16 @@ from .vision import (
resolve_visual_encoder_outputs,
)
import ixformer.inference.functions as ixf
try:
# Note: vLLM does not install xformers by default.
from xformers import ops as xops
if current_platform.is_cuda() and current_platform.has_device_capability(100):
if current_platform.is_cuda():
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = False
else:
USE_XFORMERS_OPS = True
else:
USE_XFORMERS_OPS = False
except ImportError:
USE_XFORMERS_OPS = False
@@ -698,23 +699,21 @@ class Attention(nn.Module):
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
if USE_XFORMERS_OPS:
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
v = v.reshape(batch * patches, self.n_heads, self.head_dim)
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
q = q.view(batch * patches, self.n_heads, self.head_dim)
k = k.view(batch * patches, self.n_heads, self.head_dim)
out = ixf.ixinfer_flash_attn_unpad(q,k,v, mask.q_seqinfo.seqstart.to(q.device), mask.k_seqinfo.seqstart.to(q.device), mask.q_seqinfo.max_seqlen, mask.k_seqinfo.max_seqlen)
# out = memory_efficient_attention(q, k, v, attn_bias=mask)
else:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
out = out.transpose(1, 2)
assert False, "xformers failed !"
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out)
class TransformerBlock(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()

View File

@@ -292,7 +292,7 @@ class QWenBaseModel(nn.Module):
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.wte(input_ids)

View File

@@ -122,8 +122,17 @@ def check_interleaved_audio_video(
"""
Check if video and audio positions are interleaved in the multimodal region.
Returns:
True if video and audio tokens are interleaved, False otherwise.
Returns True only for the use_audio_in_video=True case, where video and
audio tokens alternate within a single contiguous region with no gaps.
A simple range-overlap check produces false positives when multiple
non-interleaved requests are batched together: audio tokens from request N
fall between video tokens from request N and request N+1, making the
global ranges overlap even though each individual request is non-interleaved.
To distinguish true interleaving from this batching artefact we require
that every position in the combined [first_VA, last_VA] range is occupied
by either a video or an audio token (no text/image gaps).
"""
if num_video == 0 or num_audio == 0:
return False
@@ -131,10 +140,22 @@ def check_interleaved_audio_video(
video_pos = is_video.nonzero(as_tuple=True)[0]
audio_pos = is_audio.nonzero(as_tuple=True)[0]
return (
# Quick range-overlap pre-check (necessary but not sufficient).
if not (
video_pos[0].item() < audio_pos[-1].item()
and audio_pos[0].item() < video_pos[-1].item()
)
):
return False
# Density check: for true use_audio_in_video interleaving every position
# in the combined span is a video or audio token. Batched non-interleaved
# requests have text/image tokens between the per-request V and A blocks.
# combined_start/end encompass all V/A tokens, so num_video + num_audio
# equals the number of V/A tokens in range; compare directly to span size.
combined_start = min(video_pos[0].item(), audio_pos[0].item())
combined_end = max(video_pos[-1].item(), audio_pos[-1].item())
total_in_range = combined_end - combined_start + 1
return (num_video + num_audio) == total_in_range
def merge_interleaved_embeddings(
@@ -332,6 +353,39 @@ class Qwen2_5OmniThinkerProcessingInfo(
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None, "image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int] | None:
mm_counts = mm_counts or {}
requested_modalities = {m for m, c in mm_counts.items() if c > 0}
mm_max_tokens: dict[str, int] = {}
if requested_modalities & {"image", "video"}:
vl_tokens = Qwen2_5_VLProcessingInfo.get_mm_max_tokens_per_item(
self,
seq_len=seq_len,
mm_counts=mm_counts,
)
mm_max_tokens.update(
{
m: vl_tokens[m]
for m in ["image", "video"]
if m in requested_modalities
}
)
if "audio" in requested_modalities:
audio_tokens = Qwen2AudioProcessingInfo.get_mm_max_tokens_per_item(
self,
seq_len=seq_len,
mm_counts=mm_counts,
)
mm_max_tokens["audio"] = audio_tokens["audio"]
return mm_max_tokens
class Qwen2_5OmniThinkerDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]
@@ -1376,23 +1430,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
from .utils import _merge_multimodal_embeddings
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if len(multimodal_embeddings) == 0:
return inputs_embeds
# Check for audio-in-video: interleaved video and audio tokens
# in the multimodal region.
# in the multimodal region. Only use the interleaved path when
# needed; otherwise fall back to the default parent implementation.
video_token_id = self.config.video_token_index
audio_token_id = self.config.audio_token_index
@@ -1403,6 +1446,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
num_audio = is_audio.sum().item()
if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio):
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
return merge_interleaved_embeddings(
inputs_embeds,
multimodal_embeddings,
@@ -1413,9 +1462,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
num_audio,
)
# Default: standard merge (no interleaving)
return _merge_multimodal_embeddings(
inputs_embeds, multimodal_embeddings, is_multimodal
# Default: standard merge (no interleaving), same as parent class
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
- second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`.
- timestamps: List of timestamp values (in seconds) for each frame
after merging. Length equals the temporal dimension after merging.
"""
type: Literal["pixel_values_videos"]
@@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
TensorShape("nv"),
]
timestamps: list[list[float]] | None = None
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
"""
@@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
- second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`.
- timestamps: List of timestamp values (in seconds) for each frame
after merging. Length equals the temporal dimension after merging.
"""
type: Literal["video_embeds"]
@@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
torch.Tensor | None,
TensorShape("nv"),
] = None
timestamps: list[list[float]] | None = None
Qwen2_5_VLVideoInputs: TypeAlias = (
@@ -289,10 +296,11 @@ class Qwen2_5_VisionMLP(nn.Module):
disable_tp=use_data_parallel,
)
self.act_fn = act_fn
self.hidden_features = hidden_features
def forward(self, x: torch.Tensor):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.act_fn(gate_up, self.hidden_features)
x_down, _ = self.down_proj(x)
return x_down
@@ -357,6 +365,7 @@ class Qwen2_5_VisionAttention(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -398,6 +407,7 @@ class Qwen2_5_VisionAttention(nn.Module):
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
context_layer = einops.rearrange(
@@ -463,6 +473,7 @@ class Qwen2_5_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=None,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)

View File

@@ -179,6 +179,26 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
mm_counts = mm_counts or {}
if mm_counts.get("audio", 0) <= 0:
return {}
feature_extractor = self.get_feature_extractor()
chunk_length = min(feature_extractor.chunk_length, 30)
audio_len = int(chunk_length * feature_extractor.sampling_rate)
hop_length = feature_extractor.hop_length
max_mel_seq_len = audio_len // hop_length
input_lengths = torch.tensor([max_mel_seq_len], dtype=torch.long)
_, output_lengths = _get_feat_extract_output_lengths(input_lengths)
return {"audio": int(output_lengths.item())}
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:

View File

@@ -755,6 +755,7 @@ def _create_qwen2vl_field_factory(
"video", video_embed_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
timestamps=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)
return _qwen2vl_field_config
@@ -843,7 +844,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
temporal_patch_size = vision_config.temporal_patch_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
size = image_processor.size
if override_size := mm_kwargs.get("size"):
size = size | override_size
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
size = size | {"shortest_edge": override_min_pixels}
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
size = size | {"longest_edge": override_max_pixels}
if do_resize:
resized_height, resized_width = smart_resize(
@@ -930,7 +937,14 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_processor = self.get_image_processor()
mm_kwargs = self.ctx.get_merged_mm_kwargs({})
size = mm_kwargs.get("size", image_processor.size)
size = image_processor.size
if override_size := mm_kwargs.get("size"):
size = size | override_size
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
size = size | {"shortest_edge": override_min_pixels}
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
size = size | {"longest_edge": override_max_pixels}
max_pixels = size["longest_edge"]
unit = patch_size * merge_size

View File

@@ -51,7 +51,7 @@ from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix, reparse_quant_config
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
logger = init_logger(__name__)
@@ -142,7 +142,6 @@ class Qwen3Attention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.qk_norm = RMSNormQK(self.head_dim, self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
@@ -150,17 +149,19 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# # Add qk-norm
# Add qk-norm
# q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
# q_by_head = self.q_norm(q_by_head)
# q_by_head = self.q_norm.forward_native(q_by_head) # TODO(gyf) check why
# q = q_by_head.view(q.shape)
# k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
# k_by_head = self.k_norm(k_by_head)
# k_by_head = self.k_norm.forward_native(k_by_head)
# k = k_by_head.view(k.shape)
q_by_head = q.view(*q.shape[:-1], q.shape[-1]//self.head_dim, self.head_dim)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = k.view(*k.shape[:-1],
k.shape[-1] // self.head_dim,
self.head_dim)
out_q, out_k = self.qk_norm(
q_by_head,
k_by_head,
@@ -170,7 +171,6 @@ class Qwen3Attention(nn.Module):
q = out_q.view(q.shape)
k = out_k.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
@@ -201,8 +201,6 @@ class Qwen3DecoderLayer(nn.Module):
else:
attn_type = AttentionType.ENCODER_ONLY
quant_config = reparse_quant_config(prefix, quant_config)
self.self_attn = Qwen3Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -236,23 +234,25 @@ class Qwen3DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
ALL_DECODER_LAYER_TYPES = {
"attention": Qwen3DecoderLayer,
}

View File

@@ -274,7 +274,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
self.ffn_layer_scale = torch.nn.Parameter(
@@ -282,7 +281,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
@@ -444,11 +442,7 @@ class Qwen3_5Model(Qwen3NextModel):
# qwen3.5 no need to transpose
# loaded_weight = loaded_weight.transpose(-1, -2)
if "experts.gate_up_proj" in name:
if loaded_weight.shape[-2] != 1:
chunk_dim = -2
else:
chunk_dim = -1
loaded_weight = loaded_weight.chunk(2, dim=chunk_dim)
loaded_weight = loaded_weight.chunk(2, dim=-2)
success_w1 = self.load_fused_expert_weights(
name_mapped,
params_dict,
@@ -544,6 +538,7 @@ class Qwen3_5ForCausalLMBase(
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
@@ -633,6 +628,9 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
# Qwen3.5 does not support multimodal pruning (EVS).
supports_multimodal_pruning = False
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
@@ -648,10 +646,8 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
# Qwen3.5 does not support multimodal pruning (EVS).
self.is_multimodal_pruning_enabled = False
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(
@@ -698,6 +694,12 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
return inputs_embeds
def recompute_mrope_positions(self, *args, **kwargs):
raise NotImplementedError(
"Qwen3.5 does not support multimodal pruning (EVS). "
"recompute_mrope_positions should never be called."
)
def forward(
self,
input_ids: torch.Tensor,
@@ -856,10 +858,8 @@ class Qwen3_5MoeForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
# Qwen3.5 does not support multimodal pruning (EVS).
self.is_multimodal_pruning_enabled = False
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(

View File

@@ -339,7 +339,7 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@@ -77,7 +77,7 @@ from .utils import (
)
logger = init_logger(__name__)
import ixformer.inference.functions as ixf_ops
class Qwen3MoeMLP(nn.Module):
def __init__(
@@ -170,7 +170,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
quant_config=None,
prefix=f"{prefix}.gate",
)
@@ -338,13 +338,14 @@ class Qwen3MoeAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
out_q, out_k = ixf_ops.rms_norm_qk(q_by_head, k_by_head, self.q_norm.weight.data, self.k_norm.weight.data, self.q_norm.variance_epsilon)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q = out_q.view(q.shape)
k = out_k.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
@@ -379,6 +380,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
dual_chunk_attention_config=dual_chunk_attention_config,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsW8A8Int8
if hasattr(self.self_attn.qkv_proj, "scheme") and isinstance(self.self_attn.qkv_proj.scheme, CompressedTensorsW8A8Int8):
self.fused_norm_quant = True
else:
self.fused_norm_quant = False
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = (
@@ -409,12 +416,23 @@ class Qwen3MoeDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.fused_norm_quant:
origin_input = hidden_states
hidden_states_i8, residual, scale = ixf_ops.residual_rms_norm_dynamic_int8(
hidden_states, self.input_layernorm.weight.data, residual,
eps=self.input_layernorm.variance_epsilon,
)
hidden_states = (hidden_states_i8, scale, hidden_states.dtype)
if residual is None:
residual = origin_input
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,

View File

@@ -10,6 +10,7 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from vllm import envs
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CacheConfig,
@@ -34,7 +35,8 @@ from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
)
from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule,
fused_recurrent_gated_delta_rule_packed_decode,
fused_sigmoid_gating_delta_rule_update,
)
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
@@ -114,7 +116,7 @@ def fi_chunk_gated_delta_rule(
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
from flashinfer.gdn_prefill import (
@@ -153,21 +155,13 @@ def fi_chunk_gated_delta_rule(
class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None:
super().__init__()
# if current_platform.is_cuda() and current_platform.is_device_capability(90):
# logger.info_once(
# "Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
# )
# self._forward_method = self.forward_cuda
# else:
# logger.info_once(
# "Using FlashAttn GDN prefill kernel on CUDA compute capability 90"
# )
# self._forward_method = self.forward_native
logger.info_once(
"Using FlashAttn GDN prefill kernel on CUDA compute capability 90"
)
self._forward_method = self.forward_native
if current_platform.is_cuda() and current_platform.is_device_capability(90):
logger.info_once(
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
)
self._forward_method = self.forward_cuda
else:
self._forward_method = self.forward_native
def forward_cuda(
self,
@@ -178,10 +172,10 @@ class ChunkGatedDeltaRule(CustomOp):
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fi_chunk_gated_delta_rule(
return self.forward_native(
q=q,
k=k,
v=v,
@@ -202,7 +196,7 @@ class ChunkGatedDeltaRule(CustomOp):
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fla_chunk_gated_delta_rule(
@@ -420,6 +414,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
prefix=f"{prefix}.in_proj_qkvz",
)
# ba_proj doesn't support blockwise fp8 quantization.
# # in_proj_ba is defined as MergedColumnParallelLinear for
# compatibility with Qwen3_5.
self.in_proj_ba = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.num_v_heads] * 2,
@@ -469,7 +465,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
@@ -482,6 +477,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
)
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
self.enable_packed_recurrent_decode = (
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
@@ -631,6 +629,106 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
"""Warm up GDN prefill kernels during V1 profiling.
During V1 profile runs, ``_forward_core`` returns early because
``attn_metadata`` is ``None``, so the autotuned kernels used by
``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
``chunk_scaled_dot_kkt``) are never invoked. After profiling,
vLLM allocates KV cache using most of the remaining GPU memory.
When the first real inference triggers the autotuner it OOMs
because there is not enough memory left for benchmarking.
This method runs minimal forward passes through
``chunk_gated_delta_rule`` with small dummy tensors to force
autotuning while GPU memory is still plentiful. The autotuner
results are cached globally, so only the first layer incurs
actual benchmarking cost.
Most kernels use a fixed ``BT = chunk_size`` (64), but
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT``
is part of its autotune key, we run warmup passes with T = 16,
32, and 64 to cover all possible ``BT`` values.
The decode path uses ``fused_sigmoid_gating_delta_rule_update``
which has fixed kernel parameters (no autotuning), so only the
prefill (chunked) path needs warming up.
"""
if hasattr(self, "_prefill_kernels_warmed_up"):
return
self._prefill_kernels_warmed_up = True
device = mixed_qkv.device
dtype = mixed_qkv.dtype
num_k_heads = self.num_k_heads // self.tp_size
num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype()
# Run warmup for each possible BT value of chunk_fwd_kernel_o:
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
# Other kernels always use BT=chunk_size(64), so their autotune
# cache is populated on the first pass and reused thereafter.
for T in (16, 32, 64):
q = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
k = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
v = torch.randn(
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
)
# NOTE: g and beta must have the same dtypes as during
# inference, so we construct them with the same function
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
# inputs required by that function.
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
state = torch.zeros(
1,
num_v_heads,
self.head_v_dim,
self.head_k_dim,
device=device,
dtype=state_dtype,
)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
try:
self.chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=state,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
except Exception:
logger.warning(
"GDN prefill kernel warmup (T=%d) failed for "
"layer %s. First inference may OOM due to "
"autotuner.",
T,
self.prefix,
exc_info=True,
)
else:
logger.debug(
"GDN prefill kernel warmup (T=%d) completed for layer %s",
T,
self.prefix,
)
finally:
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
torch.cuda.empty_cache()
def _forward_core(
self,
mixed_qkv: torch.Tensor,
@@ -638,19 +736,34 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a: torch.Tensor,
core_attn_out: torch.Tensor,
):
"""
Core attention computation (called by custom op).
"""
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
# V1 profile run — warm up prefill kernels so that
# autotuning completes before KV cache allocation.
self._warmup_prefill_kernels(mixed_qkv)
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
if (
self.enable_packed_recurrent_decode
and attn_metadata.spec_sequence_masks is None
and attn_metadata.num_prefills == 0
and attn_metadata.num_decodes > 0
):
return self._forward_core_decode_non_spec(
mixed_qkv=mixed_qkv,
b=b,
a=a,
core_attn_out=core_attn_out,
attn_metadata=attn_metadata,
virtual_engine=forward_context.virtual_engine,
)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
@@ -738,41 +851,40 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_non_spec
)
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
g_spec = g
beta_spec = beta
g_non_spec = None
beta_non_spec = None
else:
g_spec = g.index_select(1, spec_token_indx)
beta_spec = beta.index_select(1, spec_token_indx)
if attn_metadata.num_prefills > 0:
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
else:
g_non_spec = g
beta_non_spec = beta
else:
g_spec = None
beta_spec = None
g_non_spec = g
beta_non_spec = beta
g_non_spec = None
beta_non_spec = None
# 2. Recurrent attention
# 2.1: Process the multi-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_spec,
k=key_spec,
v=value_spec,
g=g_spec,
beta=beta_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
core_attn_out_spec, last_recurrent_state = (
fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log,
a=a,
b=b,
dt_bias=self.dt_bias,
q=query_spec,
k=key_spec,
v=value_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[
: attn_metadata.num_spec_decodes + 1
],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
)
)
else:
core_attn_out_spec, last_recurrent_state = None, None
@@ -801,12 +913,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule(
fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log,
a=a,
b=b,
dt_bias=self.dt_bias,
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[
@@ -834,6 +948,55 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
def _forward_core_decode_non_spec(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
attn_metadata: GDNAttentionMetadata,
virtual_engine: int,
):
"""
Core attention computation with a packed non-spec decode fast path.
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
validate_data=False,
)
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv=mixed_qkv_non_spec,
a=a,
b=b,
A_log=self.A_log,
dt_bias=self.dt_bias,
scale=self.head_k_dim**-0.5,
initial_state=ssm_state,
out=out_buf,
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
use_qk_l2norm_in_kernel=True,
)
return
class Qwen3NextAttention(nn.Module):
def __init__(
@@ -961,7 +1124,7 @@ class Qwen3NextDecoderLayer(nn.Module):
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_text_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
@@ -1024,7 +1187,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
self.ffn_layer_scale = torch.nn.Parameter(
@@ -1032,7 +1194,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
@@ -1299,6 +1460,8 @@ class QwenNextMixtureOfExperts(MixtureOfExperts):
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
layer.mlp, Qwen3NextSparseMoeBlock
):
@@ -1334,6 +1497,8 @@ class Qwen3NextForCausalLM(
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj_qkvz": ["in_proj_qkvz"],
"in_proj_ba": ["in_proj_ba"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@@ -648,6 +648,7 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor | None, # Only used for Flash Attention
sequence_lengths: torch.Tensor | None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -655,6 +656,7 @@ class Qwen3_VisionBlock(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
x = x + self.mlp(self.norm2(x))
@@ -975,6 +977,20 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
# Recompute cu_seqlens in numpy from grid_thw to avoid GPU->CPU sync
grid_thw_np = grid_thw.cpu().numpy()
cu_seqlens_np = np.repeat(
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
).cumsum(axis=0, dtype=np.int32)
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
self.attn_backend, cu_seqlens_np
)
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes
@@ -985,6 +1001,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
if (
deepstack_visual_indexes is not None
@@ -1146,6 +1163,39 @@ class Qwen3OmniMoeThinkerProcessingInfo(
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None, "image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int] | None:
mm_counts = mm_counts or {}
requested_modalities = {m for m, c in mm_counts.items() if c > 0}
mm_max_tokens: dict[str, int] = {}
if requested_modalities & {"image", "video"}:
vl_tokens = Qwen2_5_VLProcessingInfo.get_mm_max_tokens_per_item(
self,
seq_len=seq_len,
mm_counts=mm_counts,
)
mm_max_tokens.update(
{
m: vl_tokens[m]
for m in ["image", "video"]
if m in requested_modalities
}
)
if "audio" in requested_modalities:
audio_tokens = Qwen2AudioProcessingInfo.get_mm_max_tokens_per_item(
self,
seq_len=seq_len,
mm_counts=mm_counts,
)
mm_max_tokens["audio"] = audio_tokens["audio"]
return mm_max_tokens
Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder
@@ -1904,15 +1954,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
num_audio,
)
# Default: standard merge (no interleaving)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
# Default: standard merge (no interleaving), same as parent class.
# multimodal_embeddings may have been updated above (deepstack
# main-scale). Use super() to stay consistent with the parent
# implementation and avoid issues seen in Qwen2.5-Omni (#34506).
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor | None,

File diff suppressed because it is too large Load Diff

View File

@@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights."""
from platform import architecture
import typing
from collections.abc import Callable, Iterable
from itertools import islice
@@ -45,6 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.tokenizers.registry import cached_tokenizer_from_config
from .interfaces import MixtureOfExperts
from .qwen3_moe import (
@@ -415,6 +417,7 @@ class Qwen3VLMoeForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
@@ -451,14 +454,14 @@ class Qwen3VLMoeForConditionalGeneration(
with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config.with_hf_config(config.text_config),
vllm_config=vllm_config.with_hf_config(config.text_config, architectures=["Qwen3MoeForCausalLM"]),
prefix=maybe_prefix(prefix, "language_model"),
)
if not get_pp_group().is_first_rank and hasattr(
config.vision_config, "deepstack_visual_indexes"
):
assert self.language_model.start_layer >= len(
assert self.language_model.model.start_layer >= len(
config.vision_config.deepstack_visual_indexes
), (
"start_layer should be greater than or equal to "

View File

@@ -6,11 +6,9 @@
# Copyright (c) Alibaba Cloud.
"""Inference-only Qwen-VL model compatible with HuggingFace weights."""
import copy
import math
import unicodedata
from collections.abc import Callable, Collection, Mapping, Sequence, Set
from functools import lru_cache, partial
from collections.abc import Callable, Mapping, Sequence
from functools import partial
from typing import Annotated, Literal, TypeAlias
import regex as re
@@ -436,60 +434,6 @@ class QwenVLModel(QWenModel):
)
@lru_cache(maxsize=1)
def _get_tokenizer_without_image_pad(
tokenizer: PreTrainedTokenizer,
) -> PreTrainedTokenizer:
"""
The logic of adding image pad tokens should only be applied in
[`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor],
so they are patched out here.
The definition of the wrapped tokenizer can be found here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
"""
new_tokenizer = copy.deepcopy(tokenizer)
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
def tokenize(
self,
text: str,
allowed_special: Set[str] | str = "all",
disallowed_special: Collection[str] | str = (),
**kwargs,
) -> list[bytes | str]:
text = unicodedata.normalize("NFC", text)
return [
self.decoder[t]
for t in self.tokenizer.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
]
def _decode(
self,
token_ids: int | list[int],
skip_special_tokens: bool = False,
errors: str | None = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
return self.tokenizer.decode(
token_ids,
errors=errors or self.errors,
)
TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad"
new_tokenizer.__class__ = TokenizerWithoutImagePad
return new_tokenizer
class QwenVLProcessor:
"""
This model doesn't define its own HF processor,
@@ -574,12 +518,6 @@ class QwenVLProcessor:
class QwenVLProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> PreTrainedTokenizer:
tokenizer = self.ctx.get_tokenizer()
assert isinstance(tokenizer, PreTrainedTokenizer)
return _get_tokenizer_without_image_pad(tokenizer)
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
return self.ctx.init_processor(
QwenVLProcessor,
@@ -730,6 +668,8 @@ class QwenVLForConditionalGeneration(
"w1",
],
}
embed_input_ids = SupportsMultiModal.embed_input_ids
embed_input_ids = SupportsMultiModal.embed_input_ids

View File

@@ -75,12 +75,14 @@ _TEXT_GENERATION_MODELS = {
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"AXK1ForCausalLM": ("AXK1", "AXK1ForCausalLM"),
# baichuan-7b, upper case 'C' in the class name
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
# baichuan-13b, lower case 'c' in the class name
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
"BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
"BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
@@ -259,6 +261,10 @@ _EMBEDDING_MODELS = {
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
"SiglipModel": ("siglip", "SiglipEmbeddingModel"),
"LlamaNemotronVLModel": (
"nemotron_vl",
"LlamaNemotronVLForEmbedding",
),
# Technically Terratorch models work on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
@@ -278,6 +284,10 @@ _CROSS_ENCODER_MODELS = {
"llama",
"LlamaBidirectionalForSequenceClassification",
),
"LlamaNemotronVLForSequenceClassification": (
"nemotron_vl",
"LlamaNemotronVLForSequenceClassification",
),
"ModernBertForSequenceClassification": (
"modernbert",
"ModernBertForSequenceClassification",
@@ -331,6 +341,10 @@ _MULTIMODAL_MODELS = {
"ernie45_vl",
"Ernie4_5_VLMoeForConditionalGeneration",
),
"FireRedASR2ForConditionalGeneration": (
"fireredasr2",
"FireRedASR2ForConditionalGeneration",
),
"FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"), # noqa: E501
"FunAudioChatForConditionalGeneration": (
"funaudiochat",
@@ -506,6 +520,7 @@ _MULTIMODAL_MODELS = {
}
_SPECULATIVE_DECODING_MODELS = {
"ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),

View File

@@ -346,7 +346,7 @@ class Step3TextModel(nn.Module):
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
["hidden_states","residual"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:

View File

@@ -69,7 +69,7 @@ class FP32ReplicatedLinear(ReplicatedLinear):
x: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
assert self.params_dtype == torch.float32
return super().forward(x.to(torch.float32))
return super().forward(x)
class Step3p5MLP(nn.Module):
@@ -148,6 +148,7 @@ class Step3p5Attention(nn.Module):
yarn_only_types: list = None,
swa_num_attention_heads: int | None = None,
partial_rotary_factor: float = 1.0,
total_layer_num: int = None,
):
super().__init__()
self.hidden_size = hidden_size
@@ -250,7 +251,12 @@ class Step3p5Attention(nn.Module):
prefix=f"{prefix}.attn",
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
)
extra_cache_para={
"total_num_kv_heads":self.total_num_kv_heads,
"total_layer_num":total_layer_num,
"layer_id":self.layer_idx,
},
)
self.max_position_embeddings = max_position
assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5
@@ -390,6 +396,7 @@ class FusedMoEBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
fused_shared_output=True,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -413,8 +420,8 @@ class FusedMoEBlock(nn.Module):
assert shared_output is None
if self.share_expert is not None:
assert shared_output is not None
final_hidden_states += shared_output
if shared_output is not None:
final_hidden_states += shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
@@ -484,6 +491,7 @@ class Step3p5DecoderLayer(nn.Module):
if partial_rotary_factors
else 1.0,
prefix=f"{prefix}.self_attn",
total_layer_num=vllm_config.model_config.hf_config.num_hidden_layers,
)
else:
raise ValueError(
@@ -535,25 +543,28 @@ class Step3p5DecoderLayer(nn.Module):
return self.tp_group.all_reduce(in1 + in2)
def forward(
self, positions: torch.Tensor, hidden_states: torch.Tensor
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if residual is None:
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states += residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
if self.use_moe:
ffn_output = self.moe(hidden_states)
else:
ffn_output = self.mlp(hidden_states)
hidden_states = ffn_output + residual
return hidden_states
return ffn_output, residual
@support_torch_compile
@@ -592,7 +603,7 @@ class Step3p5Model(nn.Module):
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
["hidden_states", "residual"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -610,20 +621,26 @@ class Step3p5Model(nn.Module):
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
"hidden_states": hidden_states,
"residual": residual
}
)
hidden_states += residual
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
@@ -891,7 +908,6 @@ class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.model.norm(hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
@@ -934,7 +950,26 @@ class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
all_params = set(dict(self.named_parameters()).keys())
missing = all_params - loaded
kv_cache_params = {p for p in missing
if p.endswith((".k_scale", ".k_zero_point",
".q_scale", ".q_zero_point",
".v_scale", ".v_zero_point"))}
real_missing = missing - kv_cache_params
if real_missing:
logger.warning(
"[DIAG-WEIGHTS] Missing %d params: %s",
len(real_missing), sorted(real_missing),
)
else:
logger.info(
"[DIAG-WEIGHTS] All params loaded. (%d kv-cache params skipped)",
len(kv_cache_params),
)
return loaded
def get_spec_layer_idx_from_weight_name(

View File

@@ -6,9 +6,11 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name
from .utils import maybe_prefix
@@ -40,9 +43,11 @@ class SharedHead(nn.Module):
return self.norm(hidden_states)
@support_torch_compile
class Step3p5AMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str,
) -> None:
@@ -51,8 +56,15 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
quant_config = vllm_config.quant_config
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.eh_proj = ReplicatedLinear(
config.hidden_size * 2,
config.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.eh_proj",
)
self.lm_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer(
vllm_config,
prefix=f"{prefix}.mtp_block",
@@ -64,9 +76,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
embed_tokens: VocabParallelEmbedding | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
if inputs_embeds is None:
assert embed_tokens is not None
inputs_embeds = embed_tokens(input_ids)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
@@ -74,7 +89,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)
hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states)
hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, residual=None)
hidden_states = hidden_states + residual
return hidden_states
@@ -92,8 +108,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict(
{
str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config,
f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
@@ -112,14 +128,13 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
self.embed_tokens,
current_step_idx,
)
@@ -131,7 +146,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
)
return logits
@@ -139,7 +154,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
return self.embed_tokens(input_ids)
class Step3p5MTP(nn.Module):
class Step3p5MTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
@@ -257,12 +272,15 @@ class Step3p5MTP(nn.Module):
name = name.replace(".transformer.", ".")
if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head")
name = name.replace("shared_head", "lm_head")
if "embed_tokens" in name:
assert (
hasattr(self.config, "num_nextn_predict_layers")
and self.config.num_nextn_predict_layers > 0
)
name = "model.embed_tokens.weight"
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
@@ -280,6 +298,31 @@ class Step3p5MTP(nn.Module):
and getattr(param, "requires_grad", False) is False
}
params_need_to_load -= optional_params
# Some attention quantization metadata may be absent in checkpoints
# depending on export path / runtime backend.
missing_params = params_need_to_load - loaded_params
skip_missing_suffixes = (
".attn.k_scale",
".attn.q_scale",
".attn.v_scale",
".attn.q_zero_point",
".attn.k_zero_point",
".attn.v_zero_point",
)
skipped_missing_params = sorted(
name for name in missing_params if name.endswith(skip_missing_suffixes)
)
if skipped_missing_params:
preview = tuple(skipped_missing_params[:10])
logger.warning_once(
"Step3p5MTP load_weights: skip %d missing optional attn quant "
"params with suffixes %s (showing first %d): %s",
len(skipped_missing_params),
skip_missing_suffixes,
len(preview),
preview,
)
params_need_to_load -= set(skipped_missing_params)
if params_need_to_load != loaded_params:
missing_params = list(params_need_to_load - loaded_params)
param_name_example = missing_params[0]

View File

@@ -1,500 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import torch.nn as nn
from transformers import SwinConfig
from transformers.models.swin.modeling_swin import SwinEmbeddings, SwinPatchMerging
from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer
from transformers.pytorch_utils import meshgrid
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
class SwinSelfAttention(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
num_heads: int,
window_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of "
f"attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = (
window_size
if isinstance(window_size, Iterable)
else (window_size, window_size)
)
self.scale = self.attention_head_size**-0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads
)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.relative_position_index = nn.Parameter(
relative_position_index, requires_grad=False
)
self.qkv = QKVParallelLinear(
hidden_size=dim,
head_size=self.attention_head_size,
total_num_heads=self.num_attention_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def _get_rel_pos_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
]
relative_position_bias = relative_position_bias.view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
return relative_position_bias.unsqueeze(0)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape
qkv_output, _ = self.qkv(hidden_states)
query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
query_layer = self.transpose_for_scores(query_layer)
attention_scores = self._get_rel_pos_bias()
if attention_mask is not None:
mask_shape = attention_mask.shape[0]
attention_mask_expanded = attention_mask.view(
1, mask_shape, 1, dim, dim
).expand(
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
)
attention_scores = attention_scores + attention_mask_expanded.unsqueeze(
1
).unsqueeze(0)
attention_scores = attention_scores.view(
-1, self.num_attention_heads, dim, dim
)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_scores,
dropout_p=0.0,
)
attention_probs = None
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
return outputs
class SwinSelfOutput(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.dense = RowParallelLinear(
input_size=dim,
output_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
return hidden_states
class SwinAttention(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
num_heads: int,
window_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.self = SwinSelfAttention(
config,
dim,
num_heads,
window_size,
quant_config=quant_config,
prefix=f"{prefix}.self",
)
self.output = SwinSelfOutput(
config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
)
self.pruned_heads = set()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class SwinIntermediate(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.dense = ColumnParallelLinear(
dim,
int(config.mlp_ratio * dim),
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.intermediate_act_fn = get_act_fn(config.hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class SwinOutput(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.dense = RowParallelLinear(
int(config.mlp_ratio * dim),
dim,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
return hidden_states
class SwinLayer(HFSwinLayer):
def __init__(
self,
config: SwinConfig,
dim: int,
input_resolution: int,
num_heads: int,
drop_path_rate: float = 0.0,
shift_size: int = 0,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
drop_path_rate=drop_path_rate,
shift_size=shift_size,
)
self.attention = SwinAttention(
config,
dim,
num_heads,
window_size=self.window_size,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
self.intermediate = SwinIntermediate(
config, dim, quant_config=quant_config, prefix=f"{prefix}.intermediate"
)
self.output = SwinOutput(
config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
)
class SwinStage(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
input_resolution: int,
depth: int,
num_heads: int,
drop_path: list[float],
downsample: SwinPatchMerging | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
SwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
drop_path_rate=drop_path[layer_idx],
shift_size=0 if (layer_idx % 2 == 0) else config.window_size // 2,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, norm_layer=nn.LayerNorm
)
else:
self.downsample = None
self.pointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_outputs = layer_module(
hidden_states,
input_dimensions,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(
hidden_states_before_downsampling, input_dimensions
)
else:
output_dimensions = (height, width, height, width)
stage_outputs = (
hidden_states,
hidden_states_before_downsampling,
output_dimensions,
)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
class SwinEncoder(nn.Module):
def __init__(
self,
config: SwinConfig,
grid_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [
x.item()
for x in torch.linspace(
0, config.drop_path_rate, sum(config.depths), device="cpu"
)
]
self.layers = nn.ModuleList(
[
SwinStage(
config=config,
dim=int(config.embed_dim * 2**layer_idx),
input_resolution=(
grid_size[0] // (2**layer_idx),
grid_size[1] // (2**layer_idx),
),
depth=config.depths[layer_idx],
num_heads=config.num_heads[layer_idx],
drop_path=dpr[
sum(config.depths[:layer_idx]) : sum(
config.depths[: layer_idx + 1]
)
],
downsample=SwinPatchMerging
if (layer_idx < self.num_layers - 1)
else None,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(self.num_layers)
]
)
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers):
layer_outputs = layer_module(
hidden_states,
input_dimensions,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[2]
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
return hidden_states
class SwinModel(nn.Module):
config_class: SwinConfig
def __init__(
self,
config: SwinConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = SwinEmbeddings(config)
self.encoder = SwinEncoder(
config,
self.embeddings.patch_grid,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
)
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
output_attentions=output_attentions,
)
return encoder_outputs
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv", "query", "q"),
("qkv", "key", "k"),
("qkv", "value", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@@ -27,6 +27,7 @@ import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.models.llama import (
LlamaDecoderLayer,
@@ -69,7 +70,7 @@ class TeleFLMForCausalLM(LlamaForCausalLM):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# mup
self.use_mup = self.config.use_mup
if self.use_mup:
if self.use_mup and get_pp_group().is_last_rank:
self.mup_scale_factor = self.config.mup_scale_factor
self.output_mult = self.config.output_mult / self.mup_scale_factor
logit_scale = self.output_mult

View File

@@ -300,14 +300,26 @@ class Base(
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
# Populate Eagle3 attrs
if (
isinstance(module, nn.ModuleList)
and len(module) == self.text_config.num_hidden_layers
):
# Populate Eagle3 attrs
self._target_class = type(child_module)
layer_name = qual_name.removeprefix("model.")
self._layer_names[int(child_name)] = layer_name
# MTP weights should not be loaded into the base model
num_hidden_layers = self.text_config.num_hidden_layers
names = (
"n_predict", # Override from SpeculativeConfig
"num_nextn_predict_layers", # Most models
"mtp_num_hidden_layers", # Qwen 3.5
)
n_predict = getattr_iter(self.text_config, names, 0)
for i in range(num_hidden_layers, num_hidden_layers + n_predict):
mtp_prefix = f"{prefix}.{i}."
if mtp_prefix not in self.ignore_unexpected_prefixes:
self.ignore_unexpected_prefixes.append(mtp_prefix)
# Replace modules as needed
if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name))

View File

@@ -218,7 +218,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
if "mm_token_type_ids" in processed_data
else "token_type_ids"
)
mm_token_type_ids = processed_data.pop(token_type_key)
mm_token_type_ids = processed_data.get(token_type_key)
# We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`.
@@ -353,6 +353,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
num_image_patches = kwargs.pop("num_image_patches")
kwargs.pop("token_type_ids", None) # used only in `forward`
kwargs.pop("mm_token_type_ids", None) # used only in `model.get_rope_index`
if pixel_values is not None:
# ROCm: Force math SDP backend for vision encoder to avoid accuracy issues
@@ -443,6 +444,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
{
"image_grid_thw",
"video_grid_thw",
"mm_token_type_ids",
"second_per_grid_ts",
"audio_feature_lengths",
"use_audio_in_video",
@@ -451,7 +453,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
if any(
v
for k, v in kwargs.items()
if k not in {"image_grid_thw", "video_grid_thw"}
if k not in {"image_grid_thw", "mm_token_type_ids"}
):
raise NotImplementedError(
"Transformers modeling backend only supports images."
@@ -459,6 +461,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
mm_token_type_ids = kwargs.get("mm_token_type_ids")
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw
@@ -467,10 +470,29 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
video_grid_thw
)
# In v4 `get_rope_index` doesn't have wildcard `kwargs`, and
# can't accept arbitrary args, even if its value is `None`
kwargs = {}
if mm_token_type_ids:
if not hasattr(self, "_get_rope_index_accepts_mm_token_type_ids"):
import inspect
sig = inspect.signature(self.model.get_rope_index)
params = sig.parameters
self._get_rope_index_accepts_mm_token_type_ids = (
"mm_token_type_ids" in params
or any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
)
)
if self._get_rope_index_accepts_mm_token_type_ids:
kwargs["mm_token_type_ids"] = torch.cat(mm_token_type_ids)
mrope_positions, mrope_position_delta = self.model.get_rope_index(
input_ids=torch.tensor(input_tokens).unsqueeze(0),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
**kwargs,
)
mrope_positions = mrope_positions[:, 0]

View File

@@ -311,8 +311,9 @@ class AutoWeightsLoader:
continue
named_parameters = module.named_parameters(recurse=True)
desc_param_keys = {
base_prefix + k for k, _ in module.named_parameters(recurse=True)
maybe_prefix(base_prefix, k) for k, _ in named_parameters
}
msg = (
f"There is no module or parameter named {prefix!r} "
@@ -874,16 +875,3 @@ def get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1
return feature_layer_index
from vllm.model_executor.layers.quantization.compressed_tensors.utils import should_ignore_layer
def reparse_quant_config(prefix: str, quant_config):
ignore = getattr(quant_config, "ignore", None)
if not ignore:
return quant_config
if should_ignore_layer(prefix, ignore):
return None
return quant_config