Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
1168
vllm/model_executor/models/AXK1.py
Normal file
1168
vllm/model_executor/models/AXK1.py
Normal file
File diff suppressed because it is too large
Load Diff
1246
vllm/model_executor/models/bailing_moe_linear.py
Normal file
1246
vllm/model_executor/models/bailing_moe_linear.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -112,6 +112,42 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
model_config.pooler_config.seq_pooling_type = pooling_type
|
||||
|
||||
|
||||
class LlamaNemotronVLConfig(VerifyAndUpdateConfig):
|
||||
"""Config handler for LlamaNemotronVL embedding models."""
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
from vllm.config.pooler import SequencePoolingType
|
||||
|
||||
hf_config = model_config.hf_config
|
||||
|
||||
# Set bidirectional attention on the language model config
|
||||
hf_config.is_causal = False
|
||||
if hasattr(hf_config, "llm_config"):
|
||||
hf_config.llm_config.is_causal = False
|
||||
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config.patch_size = hf_config.vision_config.patch_size
|
||||
|
||||
# Set up pooling type
|
||||
pooling_type_map: dict[str, SequencePoolingType] = {
|
||||
"avg": "MEAN",
|
||||
"cls": "CLS",
|
||||
"last": "LAST",
|
||||
}
|
||||
|
||||
# Get pooling type from config (check both top-level and llm_config)
|
||||
pooling = getattr(hf_config, "pooling", None)
|
||||
if pooling is None and hasattr(hf_config, "llm_config"):
|
||||
pooling = getattr(hf_config.llm_config, "pooling", "avg")
|
||||
|
||||
pooling_type = pooling_type_map.get(pooling)
|
||||
if pooling_type is None:
|
||||
raise ValueError(f"pool_type {pooling!r} not supported")
|
||||
|
||||
model_config.pooler_config.seq_pooling_type = pooling_type
|
||||
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
@@ -177,7 +213,7 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
"Nomic context extension is disabled. "
|
||||
"Changing max_model_len from %s to %s. "
|
||||
"To enable context extension, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.py",
|
||||
max_model_len_before,
|
||||
model_config.max_model_len,
|
||||
)
|
||||
@@ -293,6 +329,14 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
}
|
||||
|
||||
|
||||
class Ernie4_5_VLMoeForConditionalGenerationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
# Ernie4.5-VL conditionally executes text/vision MoE branches, so
|
||||
# fast_moe_cold_start can silently produce incorrect execution order.
|
||||
vllm_config.compilation_config.fast_moe_cold_start = False
|
||||
|
||||
|
||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
@@ -553,7 +597,7 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
|
||||
if cache_config.cache_dtype.startswith("fp8"):
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
|
||||
if cache_config.cache_dtype == "bfloat16":
|
||||
if cache_config.cache_dtype == "auto" or cache_config.cache_dtype == "bfloat16":
|
||||
cache_config.cache_dtype = "auto"
|
||||
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
|
||||
|
||||
@@ -619,11 +663,14 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
|
||||
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
|
||||
"LlamaNemotronVLModel": LlamaNemotronVLConfig,
|
||||
"LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig,
|
||||
"NomicBertModel": NomicBertModelConfig,
|
||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
|
||||
"Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
|
||||
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
|
||||
"XLMRobertaModel": JinaRobertaModelConfig,
|
||||
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
|
||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPVisionConfig
|
||||
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@@ -263,9 +264,13 @@ class Block(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class RelPosAttention(nn.Module):
|
||||
# --8<-- [start:rel_pos_attention]
|
||||
@PluggableLayer.register("rel_pos_attention")
|
||||
class RelPosAttention(PluggableLayer):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
# --8<-- [end:rel_pos_attention]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
|
||||
@@ -32,6 +32,7 @@ from .deepseek_v2 import (
|
||||
DeepseekV2MoE,
|
||||
get_spec_layer_idx_from_weight_name,
|
||||
)
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -180,7 +181,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
||||
class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
@@ -415,8 +416,141 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
||||
weight_loader(param, loaded_weight)
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
|
||||
# Validate that weights were loaded for each expected MTP layer.
|
||||
loaded_layers: set[int] = set()
|
||||
for param_name in loaded_params:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
|
||||
if spec_layer is not None:
|
||||
loaded_layers.add(spec_layer)
|
||||
for layer_idx in range(
|
||||
self.model.mtp_start_layer_idx,
|
||||
self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
|
||||
):
|
||||
if layer_idx not in loaded_layers:
|
||||
raise ValueError(
|
||||
f"MTP speculative decoding layer {layer_idx} weights "
|
||||
f"missing from checkpoint. The checkpoint may have "
|
||||
f"been quantized without including the MTP layers. "
|
||||
f"Use a checkpoint that includes MTP layer weights, "
|
||||
f"or disable speculative decoding."
|
||||
)
|
||||
|
||||
# Post-load optimization: fuse q_a_proj and kv_a_proj_with_mqa
|
||||
# into a single GEMM, then monkey-patch forward to forward_opt.
|
||||
# Same logic as DeepseekV2ForCausalLM.load_weights.
|
||||
opt_support_quant_method = [
|
||||
"GGUFLinearMethod", "UnquantizedLinearMethod",
|
||||
"CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod",
|
||||
]
|
||||
|
||||
def inject_layer(layer, quant_method, is_mla):
|
||||
logger.info(
|
||||
"DeepSeekMTP optimization: fused q_a_proj and kv_a_proj_with_mqa for layer '%s' (quant_method=%s, is_mla=%s). Forward replaced with forward_opt.",
|
||||
layer.__class__.__name__, quant_method, is_mla)
|
||||
q_lora_rank = getattr(layer, "q_lora_rank", None)
|
||||
if quant_method in ["UnquantizedLinearMethod",
|
||||
"CompressedTensorsW8A8Int8"]:
|
||||
if q_lora_rank is not None:
|
||||
layer.q_a_proj.weight.data = torch.cat(
|
||||
[layer.q_a_proj.weight,
|
||||
layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_a_proj, "weight_scale"):
|
||||
layer.q_a_proj.weight_scale.data = torch.cat(
|
||||
[layer.q_a_proj.weight_scale,
|
||||
layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat(
|
||||
[layer.q_proj.weight,
|
||||
layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_proj, "weight_scale"):
|
||||
layer.q_proj.weight_scale.data = torch.cat(
|
||||
[layer.q_proj.weight_scale,
|
||||
layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
else:
|
||||
return
|
||||
del layer.kv_a_proj_with_mqa.weight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
elif quant_method == "GGUFLinearMethod":
|
||||
pass
|
||||
elif quant_method == "AWQMarlinLinearMethod":
|
||||
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
|
||||
assert dtype == torch.int32
|
||||
if q_lora_rank is not None:
|
||||
layer.q_a_proj.qweight.data = torch.cat(
|
||||
[layer.q_a_proj.qweight,
|
||||
layer.kv_a_proj_with_mqa.qweight], dim=1)
|
||||
layer.q_a_proj.scales.data = torch.cat(
|
||||
[layer.q_a_proj.scales,
|
||||
layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_a_proj.qzeros.data = torch.cat(
|
||||
[layer.q_a_proj.qzeros,
|
||||
layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat(
|
||||
[layer.q_proj.weight,
|
||||
layer.kv_a_proj_with_mqa.weight], dim=1)
|
||||
layer.q_proj.scales.data = torch.cat(
|
||||
[layer.q_proj.scales,
|
||||
layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_proj.qzeros.data = torch.cat(
|
||||
[layer.q_proj.qzeros,
|
||||
layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
else:
|
||||
return
|
||||
del layer.kv_a_proj_with_mqa.qweight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
|
||||
for _, layer in self.model.named_modules():
|
||||
if layer.__class__.__name__ in [
|
||||
"DeepseekV2Attention", "DeepseekV2MLAAttention"
|
||||
]:
|
||||
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
|
||||
quant_method = (
|
||||
layer.kv_a_proj_with_mqa.scheme.__class__.__name__)
|
||||
else:
|
||||
quant_method = (
|
||||
layer.kv_a_proj_with_mqa
|
||||
.quant_method.__class__.__name__)
|
||||
if quant_method not in opt_support_quant_method:
|
||||
break
|
||||
inject_layer(
|
||||
layer, quant_method,
|
||||
is_mla=(layer.__class__.__name__
|
||||
== "DeepseekV2MLAAttention"))
|
||||
|
||||
# Check if all parameters have been loaded
|
||||
all_params = set(params_dict.keys())
|
||||
not_loaded = all_params - loaded_params
|
||||
if not_loaded:
|
||||
logger.warning(
|
||||
"DeepSeekMTP weight loading: %d parameters were NOT loaded.\n%s",
|
||||
len(not_loaded),
|
||||
"\n".join(sorted(not_loaded)),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"DeepSeekMTP weight loading: All %d parameters loaded successfully.",
|
||||
len(all_params),
|
||||
)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
|
||||
@@ -47,7 +47,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -75,7 +75,9 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerBackend,
|
||||
)
|
||||
@@ -89,6 +91,7 @@ from .utils import (
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -221,73 +224,6 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DeepSeekV2Gate(ReplicatedLinear):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
n_experts: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
assert quant_config is None
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
n_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
# Unquantized only, will be called "weight".
|
||||
assert hasattr(self, "weight")
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
(9, 0)
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
self.allow_dsv3_router_gemm = (
|
||||
current_platform.is_cuda()
|
||||
and is_hopper_or_blackwell
|
||||
and n_experts in SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
self._out_dtype: torch.dtype | None = None
|
||||
|
||||
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
|
||||
"""
|
||||
Set out dtype for the router logits. This is needed after
|
||||
__init__, b/c we need to check if the trtllm kernel is
|
||||
selected before we decide between bf16 and fp32.
|
||||
"""
|
||||
|
||||
if self._out_dtype is not None:
|
||||
raise ValueError("out_dtype has already been set")
|
||||
else:
|
||||
self._out_dtype = out_dtype
|
||||
|
||||
@property
|
||||
def out_dtype(self) -> torch.dtype:
|
||||
if self._out_dtype is None:
|
||||
raise ValueError("out_dtype has not been set yet")
|
||||
return self._out_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
"""
|
||||
Use specialized GEMM for low batch size for DSV3 and KIMI.
|
||||
"""
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
return ops.dsv3_router_gemm(
|
||||
hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
|
||||
), None
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -316,23 +252,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
|
||||
# self.gate = DeepSeekV2Gate(
|
||||
# config.hidden_size,
|
||||
# config.n_routed_experts,
|
||||
# quant_config=None,
|
||||
# prefix=f"{prefix}.gate",
|
||||
# )
|
||||
self.gate = ReplicatedLinear(
|
||||
self.gate = GateLinear(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
if getattr(config, "topk_method", None) == "noaux_tc":
|
||||
# self.gate.e_score_correction_bias = nn.Parameter(
|
||||
# torch.empty(config.n_routed_experts, dtype=torch.float32)
|
||||
# )
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts)
|
||||
)
|
||||
@@ -401,12 +326,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
else None,
|
||||
)
|
||||
|
||||
# # NOTE(rob): this is a hack until we finish off the PR for
|
||||
# # merging TRTLLM kernels into the MK framework. Then we can
|
||||
# # query the MonolithicMK for the expected router logits.
|
||||
# self.gate.set_out_dtype(
|
||||
# torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
|
||||
# )
|
||||
# NOTE(rob): this is a hack until we finish off the PR for
|
||||
# merging TRTLLM kernels into the MK framework. Then we can
|
||||
# query the MonolithicMK for the expected router logits.
|
||||
self.gate.set_out_dtype(
|
||||
torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
@@ -443,11 +368,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
shared_output *= 1.0 / self.routed_scaling_factor
|
||||
|
||||
|
||||
if self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
final_hidden_states += shared_output
|
||||
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0
|
||||
@@ -596,7 +522,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -605,23 +531,20 @@ class DeepseekV2Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = latent_cache.unsqueeze(1)
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
|
||||
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
|
||||
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
kv_a = self.kv_a_layernorm(kv_a)
|
||||
kv = self.kv_b_proj(kv_a)[0]
|
||||
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q[..., self.qk_nope_head_dim :] = q_pe
|
||||
k = torch.empty_like(q)
|
||||
@@ -671,7 +594,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
|
||||
def get_attn_backend(self) -> AttentionBackend:
|
||||
return DeepseekV32IndexerBackend
|
||||
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
def __init__(
|
||||
@@ -727,8 +650,8 @@ class Indexer(nn.Module):
|
||||
# where we store value in fp8 and scale in fp32
|
||||
# per self.quant_block_size element
|
||||
self.k_cache = DeepseekV32IndexerCache(
|
||||
head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
|
||||
dtype=torch.uint8,
|
||||
head_dim=self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
prefix=f"{prefix}.k_cache",
|
||||
cache_config=cache_config,
|
||||
)
|
||||
@@ -776,23 +699,61 @@ class Indexer(nn.Module):
|
||||
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
|
||||
|
||||
# we only quant q here since k quant is fused with cache insertion
|
||||
q = q.view(-1, self.head_dim)
|
||||
q_fp8, q_scale = per_token_group_quant_fp8(
|
||||
q,
|
||||
self.quant_block_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=self.scale_fmt is not None,
|
||||
)
|
||||
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
|
||||
q_scale = q_scale.view(-1, self.n_head, 1)
|
||||
# q = q.view(-1, self.head_dim)
|
||||
# q_fp8, q_scale = per_token_group_quant_fp8(
|
||||
# q,
|
||||
# self.quant_block_size,
|
||||
# column_major_scales=False,
|
||||
# use_ue8m0=self.scale_fmt is not None,
|
||||
# )
|
||||
# q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
|
||||
# q_scale = q_scale.view(-1, self.n_head, 1)
|
||||
|
||||
weights, _ = self.weights_proj(hidden_states)
|
||||
weights = (
|
||||
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
|
||||
weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5
|
||||
)
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
return self.indexer_op(hidden_states, q_fp8, k, weights)
|
||||
return self.indexer_op(hidden_states, q, k, weights)
|
||||
|
||||
|
||||
def _min_latency_fused_qkv_a_proj_impl(
|
||||
input_: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dynamically run min-latency gemm if num_tokens <= 16.
|
||||
This must be wrapped in a custom op because our torch.compile integration
|
||||
does not support runtime dispatching on num_tokens.
|
||||
"""
|
||||
num_tokens = input_.shape[0]
|
||||
if 0 < num_tokens <= 16:
|
||||
output = torch.empty(
|
||||
num_tokens,
|
||||
weight.shape[0],
|
||||
dtype=torch.bfloat16,
|
||||
device=input_.device,
|
||||
)
|
||||
ops.dsv3_fused_a_gemm(output, input_, weight.T)
|
||||
return output
|
||||
else:
|
||||
return torch.nn.functional.linear(input_, weight)
|
||||
|
||||
|
||||
def _min_latency_fused_qkv_a_proj_fake(
|
||||
input_: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return input_.new_empty(input_.shape[0], weight.shape[0])
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="min_latency_fused_qkv_a_proj",
|
||||
op_func=_min_latency_fused_qkv_a_proj_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_min_latency_fused_qkv_a_proj_fake,
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
|
||||
@@ -830,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
|
||||
self,
|
||||
input_,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
|
||||
num_tokens = input_.shape[0]
|
||||
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
|
||||
output = torch.empty(
|
||||
num_tokens,
|
||||
2112,
|
||||
dtype=torch.bfloat16,
|
||||
device=input_.device,
|
||||
)
|
||||
ops.dsv3_fused_a_gemm(
|
||||
output,
|
||||
input_,
|
||||
self.weight.T,
|
||||
)
|
||||
if self._use_min_latency_gemm:
|
||||
output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
@@ -898,47 +848,35 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
# self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj(
|
||||
# self.hidden_size,
|
||||
# [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.fused_qkv_a_proj",
|
||||
# )
|
||||
self.fused_qkv_a_proj = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_qkv_a_proj",
|
||||
disable_tp=True,
|
||||
)
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj")
|
||||
else:
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||
)
|
||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
self.q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj",
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
@@ -1005,9 +943,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
rotary_emb=self.rotary_emb,
|
||||
o_proj=self.o_proj,
|
||||
fused_qkv_a_proj=self.fused_qkv_a_proj
|
||||
if self.q_lora_rank is not None
|
||||
else None,
|
||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
|
||||
if self.q_lora_rank is None
|
||||
else None,
|
||||
@@ -1346,14 +1282,14 @@ class DeepseekV2ForCausalLM(
|
||||
# initializing DeepseekV2Model, as it is passed inplace to
|
||||
# quantization config init and may be used to select the
|
||||
# quant_method for relevant layers during initialization.
|
||||
self.fuse_qkv_a_proj = (
|
||||
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
||||
)
|
||||
if self.fuse_qkv_a_proj:
|
||||
self.packed_modules_mapping["fused_qkv_a_proj"] = [
|
||||
"q_a_proj",
|
||||
"kv_a_proj_with_mqa",
|
||||
]
|
||||
# self.fuse_qkv_a_proj = (
|
||||
# hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
||||
# )
|
||||
# if self.fuse_qkv_a_proj:
|
||||
# self.packed_modules_mapping["fused_qkv_a_proj"] = [
|
||||
# "q_a_proj",
|
||||
# "kv_a_proj_with_mqa",
|
||||
# ]
|
||||
|
||||
self.model = self.model_cls(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
@@ -1385,19 +1321,19 @@ class DeepseekV2ForCausalLM(
|
||||
self.moe_layers = []
|
||||
self.moe_mlp_layers = []
|
||||
example_moe = None
|
||||
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, DeepseekV2DecoderLayer)
|
||||
if isinstance(layer.mlp, DeepseekV2MoE):
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = layer.mlp
|
||||
self.moe_mlp_layers.append(layer.mlp)
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
@@ -1441,10 +1377,10 @@ class DeepseekV2ForCausalLM(
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
mla_params_mapping = [
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
]
|
||||
# mla_params_mapping = [
|
||||
# ("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
# ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
# ]
|
||||
mha_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
@@ -1452,8 +1388,8 @@ class DeepseekV2ForCausalLM(
|
||||
]
|
||||
if self.use_mha:
|
||||
stacked_params_mapping.extend(mha_params_mapping)
|
||||
else:
|
||||
stacked_params_mapping.extend(mla_params_mapping)
|
||||
# else:
|
||||
# stacked_params_mapping.extend(mla_params_mapping)
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
@@ -1474,168 +1410,232 @@ class DeepseekV2ForCausalLM(
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
is_fusion_moe_shared_experts_layer = (
|
||||
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# QKV fusion is optional, fall back to normal
|
||||
# weight loading if it's not enabled
|
||||
# if go with fusion option, then update name
|
||||
if (
|
||||
param_name == "fused_qkv_a_proj"
|
||||
) and name_mapped not in params_dict:
|
||||
continue
|
||||
else:
|
||||
name = name_mapped
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
try:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
|
||||
# Special handling: when AITER fusion_shared_experts is enabled,
|
||||
# checkpoints may provide a single widened shared_experts tensor
|
||||
# without explicit expert indices
|
||||
# (e.g. ...mlp.shared_experts.gate_proj.weight).
|
||||
# For models with multiple shared experts, split that tensor
|
||||
# evenly into per-shared-expert slices and load them into
|
||||
# appended expert slots mlp.experts.{n_routed_experts + j}.*
|
||||
# accordingly.
|
||||
num_chunks = 1
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
|
||||
# Determine split axis based on op type
|
||||
# gate/up: ColumnParallel → split along dim 0
|
||||
# down: RowParallel → split along dim 1
|
||||
split_dim = (
|
||||
1
|
||||
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
|
||||
else 0
|
||||
)
|
||||
total = loaded_weight.shape[split_dim]
|
||||
assert total % num_chunks == 0, (
|
||||
f"Shared expert weight dim {total} "
|
||||
f"not divisible by num_chunks {num_chunks}"
|
||||
)
|
||||
chunk_size = total // num_chunks
|
||||
|
||||
for j in range(num_chunks):
|
||||
chunk_name = name
|
||||
weight_to_load = loaded_weight
|
||||
is_fusion_moe_shared_experts_layer = (
|
||||
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
|
||||
if loaded_weight.ndim == 1:
|
||||
weight_to_load = loaded_weight[chunk_slice]
|
||||
elif split_dim == 0:
|
||||
weight_to_load = loaded_weight[chunk_slice, :]
|
||||
else:
|
||||
weight_to_load = loaded_weight[:, chunk_slice]
|
||||
# Synthesize an expert-style name so expert mapping
|
||||
# can route it
|
||||
chunk_name = name.replace(
|
||||
"mlp.shared_experts",
|
||||
f"mlp.experts.{self.config.n_routed_experts + j}",
|
||||
)
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# Use expert_params_mapping to locate the destination
|
||||
# param and delegate to its expert-aware weight_loader
|
||||
# with expert_id.
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in chunk_name:
|
||||
continue
|
||||
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = chunk_name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# other available replicas.
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
weight_to_load,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
name = name_mapped
|
||||
else:
|
||||
loaded_params.add(name_mapped)
|
||||
break
|
||||
# QKV fusion is optional, fall back to normal
|
||||
# weight loading if it's not enabled
|
||||
# if go with fusion option, then update name
|
||||
if (
|
||||
param_name == "fused_qkv_a_proj"
|
||||
) and name_mapped not in params_dict:
|
||||
continue
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
name = name_mapped
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
# Special handling: when AITER fusion_shared_experts is enabled,
|
||||
# checkpoints may provide a single widened shared_experts tensor
|
||||
# without explicit expert indices
|
||||
# (e.g. ...mlp.shared_experts.gate_proj.weight).
|
||||
# For models with multiple shared experts, split that tensor
|
||||
# evenly into per-shared-expert slices and load them into
|
||||
# appended expert slots mlp.experts.{n_routed_experts + j}.*
|
||||
# accordingly.
|
||||
num_chunks = 1
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
|
||||
# Determine split axis based on op type
|
||||
# gate/up: ColumnParallel → split along dim 0
|
||||
# down: RowParallel → split along dim 1
|
||||
split_dim = (
|
||||
1
|
||||
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
|
||||
else 0
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name is not None and not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
total = loaded_weight.shape[split_dim]
|
||||
assert total % num_chunks == 0, (
|
||||
f"Shared expert weight dim {total} "
|
||||
f"not divisible by num_chunks {num_chunks}"
|
||||
)
|
||||
chunk_size = total // num_chunks
|
||||
|
||||
for j in range(num_chunks):
|
||||
chunk_name = name
|
||||
weight_to_load = loaded_weight
|
||||
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
|
||||
if loaded_weight.ndim == 1:
|
||||
weight_to_load = loaded_weight[chunk_slice]
|
||||
elif split_dim == 0:
|
||||
weight_to_load = loaded_weight[chunk_slice, :]
|
||||
else:
|
||||
weight_to_load = loaded_weight[:, chunk_slice]
|
||||
# Synthesize an expert-style name so expert mapping
|
||||
# can route it
|
||||
chunk_name = name.replace(
|
||||
"mlp.shared_experts",
|
||||
f"mlp.experts.{self.config.n_routed_experts + j}",
|
||||
)
|
||||
|
||||
# Use expert_params_mapping to locate the destination
|
||||
# param and delegate to its expert-aware weight_loader
|
||||
# with expert_id.
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in chunk_name:
|
||||
continue
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = chunk_name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# other available replicas.
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
weight_to_load,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
name = name_mapped
|
||||
else:
|
||||
loaded_params.add(name_mapped)
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name is not None and not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
except:
|
||||
pass
|
||||
|
||||
opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"]
|
||||
# add your opt here..
|
||||
def inject_layer(layer, quant_method, is_mla):
|
||||
q_lora_rank = getattr(layer, "q_lora_rank", None)
|
||||
if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]:
|
||||
if q_lora_rank is not None:
|
||||
layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_a_proj, "weight_scale"):
|
||||
layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_proj, "weight_scale"):
|
||||
layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
else:
|
||||
return
|
||||
del layer.kv_a_proj_with_mqa.weight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
elif quant_method == "GGUFLinearMethod":
|
||||
pass
|
||||
elif quant_method == "AWQMarlinLinearMethod":
|
||||
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
|
||||
assert dtype == torch.int32
|
||||
if layer.q_lora_rank is not None:
|
||||
layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1)
|
||||
layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1)
|
||||
layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
else:
|
||||
return
|
||||
|
||||
del layer.kv_a_proj_with_mqa.qweight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
else:
|
||||
pass
|
||||
|
||||
for _, layer in self.model.named_modules():
|
||||
if layer.__class__.__name__ in ["DeepseekV2Attention","DeepseekV2MLAAttention"]:
|
||||
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
|
||||
quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__
|
||||
else:
|
||||
quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__
|
||||
if quant_method not in opt_support_quant_method:
|
||||
break
|
||||
|
||||
inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "DeepseekV2MLAAttention")
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
|
||||
@@ -164,7 +164,7 @@ class Ernie4_5_MoeMoE(nn.Module):
|
||||
config.hidden_size,
|
||||
config.moe_num_experts,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
# params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
@@ -209,7 +209,7 @@ class Ernie4_5_MoeMoE(nn.Module):
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
@@ -429,7 +429,8 @@ class Ernie4_5_MoeModel(nn.Module):
|
||||
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@@ -653,11 +654,11 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
|
||||
@@ -829,16 +829,31 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
|
||||
spatial_conv_size = hf_config.spatial_conv_size
|
||||
temporal_conv_size = hf_config.temporal_conv_size
|
||||
|
||||
if self.ctx.model_config.trust_remote_code:
|
||||
# Defined in HF Hub repo
|
||||
min_pixels_key = "min_pixels"
|
||||
max_pixels_key = "max_pixels"
|
||||
else:
|
||||
# Defined in Transformers library (requires v5.0 or above)
|
||||
min_pixels_key = "shortest_edge"
|
||||
max_pixels_key = "longest_edge"
|
||||
|
||||
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
|
||||
size = mm_kwargs.get("size", image_processor.size)
|
||||
size = image_processor.size
|
||||
if override_size := mm_kwargs.get("size"):
|
||||
size = size | override_size
|
||||
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
|
||||
size = size | {min_pixels_key: override_min_pixels}
|
||||
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
|
||||
size = size | {max_pixels_key: override_max_pixels}
|
||||
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=image_height,
|
||||
width=image_width,
|
||||
factor=patch_size * spatial_conv_size,
|
||||
min_pixels=size["min_pixels"],
|
||||
max_pixels=size["max_pixels"],
|
||||
min_pixels=size[min_pixels_key],
|
||||
max_pixels=size[max_pixels_key],
|
||||
)
|
||||
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
|
||||
else:
|
||||
|
||||
394
vllm/model_executor/models/extract_hidden_states.py
Normal file
394
vllm/model_executor/models/extract_hidden_states.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Hidden States Extractor Model.
|
||||
|
||||
This model extracts and caches hidden states from the target model
|
||||
without performing actual token generation. It's used with the
|
||||
extract_hidden_states speculative decoding method.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.attention.attention import set_default_quant_scales
|
||||
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
||||
maybe_transfer_kv_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheSpec,
|
||||
MLAAttentionSpec,
|
||||
)
|
||||
|
||||
########## Custom Ops ########
|
||||
|
||||
|
||||
def unified_kv_cache_update(
|
||||
to_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
)
|
||||
attn_layer.impl.do_kv_cache_update(
|
||||
attn_layer,
|
||||
to_cache,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def dummy_attention(layer_name, _placeholder):
|
||||
# Note: layer_name arg required by @maybe_transfer_kv_layer
|
||||
return _placeholder
|
||||
|
||||
|
||||
def basic_cache(
|
||||
to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
|
||||
kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
|
||||
slot_mapping: torch.Tensor, # shape: [seq_len]
|
||||
):
|
||||
num_blocks, block_size, num_heads, head_size = kv_cache.shape
|
||||
token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size)
|
||||
token_kv_cache[slot_mapping] = to_cache
|
||||
|
||||
|
||||
######### CacheOnlyAttentionBackend ########
|
||||
|
||||
|
||||
class CacheOnlyAttentionBackend(AttentionBackend):
|
||||
"""Attention backend that only caches KV without computing attention."""
|
||||
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CACHE_ONLY_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
return attn_type == AttentionType.DECODER
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CacheOnlyAttentionImpl"]:
|
||||
return CacheOnlyAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
# We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
|
||||
# We also don't use a k/v (2) dim
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]:
|
||||
return CacheOnlyAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
|
||||
class CacheOnlyAttentionMetadata:
|
||||
def __init__(self, slot_mapping: torch.Tensor):
|
||||
self.slot_mapping = slot_mapping
|
||||
|
||||
|
||||
class CacheOnlyAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[CacheOnlyAttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> CacheOnlyAttentionMetadata:
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
raise NotImplementedError(
|
||||
"Cascade attention not supported by CacheOnlyAttention"
|
||||
)
|
||||
causal = common_attn_metadata.causal
|
||||
if not causal:
|
||||
raise NotImplementedError(
|
||||
"Non-causal attention not supported by CacheOnlyAttention"
|
||||
)
|
||||
|
||||
return CacheOnlyAttentionMetadata(
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
)
|
||||
|
||||
|
||||
class CacheOnlyAttentionImpl(AttentionImpl):
|
||||
"""Attention implementation that only caches KV states."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_torch_dtype: torch.dtype,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_cache_torch_dtype = kv_cache_torch_dtype
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(f"Unsupported attention type: {attn_type}")
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("Quantized KV cache not supported")
|
||||
|
||||
self.num_queries_per_kv = 1
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer,
|
||||
to_cache,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
):
|
||||
assert to_cache.dtype == self.kv_cache_torch_dtype, (
|
||||
f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}"
|
||||
)
|
||||
assert kv_cache.dtype == self.kv_cache_torch_dtype, (
|
||||
f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}"
|
||||
)
|
||||
|
||||
basic_cache(to_cache, kv_cache, slot_mapping)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# Empty implementation of abstract method
|
||||
pass
|
||||
|
||||
|
||||
############## CacheOnlyAttentionLayer (replaces Attention) ############
|
||||
|
||||
|
||||
class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer that only caches key/value states without computing attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.layer_name = prefix
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# KV cache configuration
|
||||
cache_config = cache_config or vllm_config.cache_config
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
self.block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
self.block_size = 16
|
||||
|
||||
assert kv_cache_dtype in ["auto", "bfloat16", "float16"], (
|
||||
"CacheOnlyAttentionLayer doesn't currently support quantized kv cache but"
|
||||
f"kv cache dtype was set to {kv_cache_dtype}"
|
||||
)
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
set_default_quant_scales(self, register_buffer=True)
|
||||
|
||||
# Attention backend
|
||||
self.attn_backend = CacheOnlyAttentionBackend
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
self.kv_cache_torch_dtype,
|
||||
attn_type,
|
||||
)
|
||||
|
||||
assert not self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"KV cache update should be independent of forward"
|
||||
)
|
||||
|
||||
# Placeholder KV cache (replaced by bind_kv_cache)
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Register in compilation context
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def forward(self, to_cache: torch.Tensor) -> torch.Tensor:
|
||||
"""Cache hidden states as KV pairs without computing attention.
|
||||
|
||||
Args:
|
||||
to_cache: The tensor to insert into the kv cache.
|
||||
shape [num_tokens, num_heads, head_size]
|
||||
|
||||
Returns:
|
||||
Dummy output tensor (not used)
|
||||
"""
|
||||
# Note: we set num_heads to num_hidden_layers and
|
||||
# head_size to hidden_size for hidden states storage
|
||||
output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype)
|
||||
|
||||
# Note: dummy_out is used to force torch.compile to preserve ordering between
|
||||
# cache update and attention op (which triggers kv_connector transfer)
|
||||
dummy_out = unified_kv_cache_update(to_cache, self.layer_name)
|
||||
|
||||
# Triggers kv_connector transfer via decorator
|
||||
_ = dummy_attention(self.layer_name, dummy_out)
|
||||
|
||||
return output
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Note: we use MLAAttentionSpec here to because it will
|
||||
# produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
|
||||
# whereas FullAttentionSpec will add an additional factor of 2
|
||||
return MLAAttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=self.num_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
############ ExtractHiddenStatesModel definition ##########
|
||||
|
||||
|
||||
class ExtractHiddenStatesModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
||||
self.target_num_hidden_layers = (
|
||||
vllm_config.model_config.get_total_num_hidden_layers()
|
||||
)
|
||||
self.num_hidden_states = len(
|
||||
getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", [])
|
||||
)
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
# Create a single cache-only attention layer
|
||||
# Note: We set num_heads <- self.num_hidden_states
|
||||
# and head_size <- hidden_size so that we can insert
|
||||
# the hidden states directly into the cache without
|
||||
# reshaping
|
||||
self.cache_only_layers = nn.ModuleDict(
|
||||
{
|
||||
str(self.target_num_hidden_layers): CacheOnlyAttentionLayer(
|
||||
num_heads=self.num_hidden_states,
|
||||
head_size=self.hidden_size,
|
||||
cache_config=cache_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, f"cache_only_layers.{self.target_num_hidden_layers}"
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> None:
|
||||
"""Process and cache hidden states.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states from target model
|
||||
shape: [num_tokens, num_hidden_states, hidden_size]
|
||||
|
||||
Returns:
|
||||
Tuple of (dummy_output, dummy_output) - both unused
|
||||
"""
|
||||
|
||||
# Call dummy attention layer to cache hidden states
|
||||
# Output is ignored - we only care about the KV cache side effects
|
||||
_ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""No weights to load for this dummy model."""
|
||||
return set()
|
||||
829
vllm/model_executor/models/fireredasr2.py
Normal file
829
vllm/model_executor/models/fireredasr2.py
Normal file
@@ -0,0 +1,829 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
BatchFeature,
|
||||
Qwen2Config,
|
||||
)
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.models.whisper_utils import (
|
||||
ISO639_1_SUPPORTED_LANGS,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.transformers_utils.processors.fireredasr2_processor import (
|
||||
FireRedASR2FeatureExtractor,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsMultiModal,
|
||||
SupportsTranscription,
|
||||
_require_is_multimodal,
|
||||
)
|
||||
from .qwen2 import Qwen2ForCausalLM
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
_merge_multimodal_embeddings,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FireRedASR2AudioInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- nmb: Number of mel bins
|
||||
- t: Time frames (M)
|
||||
"""
|
||||
|
||||
input_features: Annotated[
|
||||
list[torch.Tensor] | None,
|
||||
TensorShape("b", "nmb", "t"),
|
||||
]
|
||||
speech_lengths: Annotated[
|
||||
list[torch.Tensor] | None,
|
||||
TensorShape("b"),
|
||||
]
|
||||
fake_token_lengths: Annotated[
|
||||
list[torch.Tensor] | None,
|
||||
TensorShape("b"),
|
||||
]
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
def __init__(self, idim: int, d_model: int, out_channels: int = 32):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(1, out_channels, 3, 2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(out_channels, out_channels, 3, 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
subsample_idim = ((idim - 1) // 2 - 1) // 2
|
||||
self.out = ReplicatedLinear(
|
||||
input_size=out_channels * subsample_idim,
|
||||
output_size=d_model,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.subsampling = 4
|
||||
left_context = right_context = 3 # both exclude currect frame
|
||||
self.context = left_context + 1 + right_context # 7
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_mask: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x = x.unsqueeze(1)
|
||||
x = self.conv(x)
|
||||
N, C, T, D = x.size()
|
||||
x, _ = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
|
||||
mask = x_mask[:, :, :-2:2][:, :, :-2:2]
|
||||
input_lengths = mask[:, -1, :].sum(dim=-1)
|
||||
return x, input_lengths, mask
|
||||
|
||||
|
||||
class RelPositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model: int, max_len: int = 5000):
|
||||
super().__init__()
|
||||
pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
|
||||
pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
|
||||
position = torch.arange(0, max_len).unsqueeze(1).float()
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, d_model, 2).float()
|
||||
* -(torch.log(torch.tensor(10000.0)).item() / d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
self.pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Tmax = 2 * max_len - 1
|
||||
Tmax, T = self.pe.size(1), x.size(1)
|
||||
pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
|
||||
return pos_emb
|
||||
|
||||
|
||||
class ConformerFeedForward(nn.Module):
|
||||
def __init__(self, d_model: int):
|
||||
super().__init__()
|
||||
self.pre_layer_norm = nn.LayerNorm(d_model)
|
||||
self.linear_expand = ReplicatedLinear(
|
||||
input_size=d_model,
|
||||
output_size=d_model * 4,
|
||||
bias=True,
|
||||
)
|
||||
self.nonlinear = Swish()
|
||||
self.linear_project = ReplicatedLinear(
|
||||
input_size=d_model * 4,
|
||||
output_size=d_model,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.pre_layer_norm(x)
|
||||
x, _ = self.linear_expand(x)
|
||||
x = self.nonlinear(x)
|
||||
x, _ = self.linear_project(x)
|
||||
output = x + residual
|
||||
return output
|
||||
|
||||
|
||||
class EncoderMultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_head: int, d_model: int):
|
||||
super().__init__()
|
||||
assert d_model % n_head == 0
|
||||
self.n_head = n_head
|
||||
self.d_k = d_model // n_head
|
||||
self.d_v = self.d_k
|
||||
|
||||
self.w_qs = ReplicatedLinear(
|
||||
input_size=d_model, output_size=n_head * self.d_k, bias=False
|
||||
)
|
||||
self.w_ks = ReplicatedLinear(
|
||||
input_size=d_model, output_size=n_head * self.d_k, bias=False
|
||||
)
|
||||
self.w_vs = ReplicatedLinear(
|
||||
input_size=d_model, output_size=n_head * self.d_v, bias=False
|
||||
)
|
||||
|
||||
self.layer_norm_q = nn.LayerNorm(d_model)
|
||||
self.layer_norm_k = nn.LayerNorm(d_model)
|
||||
self.layer_norm_v = nn.LayerNorm(d_model)
|
||||
|
||||
self.fc = ReplicatedLinear(
|
||||
input_size=n_head * self.d_v, output_size=d_model, bias=False
|
||||
)
|
||||
|
||||
def forward_qkv(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
||||
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
|
||||
|
||||
q = self.layer_norm_q(q)
|
||||
k = self.layer_norm_k(k)
|
||||
v = self.layer_norm_v(v)
|
||||
|
||||
q = self.w_qs(q)[0].view(sz_b, len_q, n_head, d_k)
|
||||
k = self.w_ks(k)[0].view(sz_b, len_k, n_head, d_k)
|
||||
v = self.w_vs(v)[0].view(sz_b, len_v, n_head, d_v)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
return q, k, v
|
||||
|
||||
def forward_output(
|
||||
self, output: torch.Tensor, residual: torch.Tensor, sz_b: int, len_q: int
|
||||
) -> torch.Tensor:
|
||||
output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
|
||||
fc_out, _ = self.fc(output)
|
||||
output = fc_out
|
||||
output = output + residual
|
||||
return output
|
||||
|
||||
def forward_attention(
|
||||
self, attn: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1)
|
||||
mask = mask.eq(0)
|
||||
attn = attn.masked_fill(mask, -float("inf"))
|
||||
attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
|
||||
else:
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
d_attn = attn
|
||||
output = torch.matmul(d_attn, v)
|
||||
|
||||
return output, attn
|
||||
|
||||
|
||||
class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
|
||||
def __init__(self, n_head: int, d_model: int):
|
||||
super().__init__(n_head, d_model)
|
||||
d_k = d_model // n_head
|
||||
self.scale = 1.0 / (d_k**0.5)
|
||||
self.linear_pos = ReplicatedLinear(
|
||||
input_size=d_model, output_size=n_head * d_k, bias=False
|
||||
)
|
||||
self.pos_bias_u = nn.Parameter(torch.empty([n_head, d_k]))
|
||||
self.pos_bias_v = nn.Parameter(torch.empty([n_head, d_k]))
|
||||
|
||||
def _rel_shift(self, x):
|
||||
N, H, T1, T2 = x.size()
|
||||
zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(N, H, T2 + 1, T1)
|
||||
x = x_padded[:, :, 1:].view_as(x)
|
||||
x = x[:, :, :, : x.size(-1) // 2 + 1]
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
sz_b, len_q = q.size(0), q.size(1)
|
||||
|
||||
residual = q
|
||||
q, k, v = self.forward_qkv(q, k, v)
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb)[0].view(n_batch_pos, -1, self.n_head, self.d_k)
|
||||
p = p.transpose(1, 2)
|
||||
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self._rel_shift(matrix_bd)
|
||||
|
||||
attn_scores = matrix_ac + matrix_bd
|
||||
attn_scores.mul_(self.scale)
|
||||
|
||||
output, attn = self.forward_attention(attn_scores, v, mask=mask)
|
||||
|
||||
output = self.forward_output(output, residual, sz_b, len_q)
|
||||
return output, attn
|
||||
|
||||
|
||||
class ConformerConvolution(nn.Module):
|
||||
def __init__(self, d_model: int, kernel_size: int = 33):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
self.pre_layer_norm = nn.LayerNorm(d_model)
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
d_model, d_model * 4, kernel_size=1, bias=False
|
||||
)
|
||||
self.padding = (kernel_size - 1) // 2
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
d_model * 2,
|
||||
d_model * 2,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=self.padding,
|
||||
groups=d_model * 2,
|
||||
bias=False,
|
||||
)
|
||||
self.batch_norm = nn.LayerNorm(d_model * 2)
|
||||
self.swish = Swish()
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
d_model * 2, d_model, kernel_size=1, bias=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
out = self.pre_layer_norm(x)
|
||||
out = out.transpose(1, 2)
|
||||
if mask is not None:
|
||||
out.masked_fill_(mask.ne(1), 0.0)
|
||||
out = self.pointwise_conv1(out)
|
||||
out = F.glu(out, dim=1)
|
||||
out = self.depthwise_conv(out)
|
||||
|
||||
out = out.transpose(1, 2)
|
||||
out = self.swish(self.batch_norm(out))
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = self.pointwise_conv2(out)
|
||||
if mask is not None:
|
||||
out.masked_fill_(mask.ne(1), 0.0)
|
||||
out = out.transpose(1, 2)
|
||||
return out + residual
|
||||
|
||||
|
||||
class RelPosEmbConformerBlock(nn.Module):
|
||||
def __init__(self, d_model, n_head, kernel_size=33):
|
||||
super().__init__()
|
||||
self.ffn1 = ConformerFeedForward(d_model)
|
||||
self.mhsa = RelPosMultiHeadAttention(n_head, d_model)
|
||||
self.conv = ConformerConvolution(d_model, kernel_size)
|
||||
self.ffn2 = ConformerFeedForward(d_model)
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
slf_attn_mask: torch.Tensor | None = None,
|
||||
pad_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
out = 0.5 * x + 0.5 * self.ffn1(x)
|
||||
out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
|
||||
out = self.conv(out, pad_mask)
|
||||
out = 0.5 * out + 0.5 * self.ffn2(out)
|
||||
out = self.layer_norm(out)
|
||||
return out
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
idim: int,
|
||||
n_layers_enc: int,
|
||||
n_head: int,
|
||||
d_model: int,
|
||||
kernel_size: int = 33,
|
||||
pe_maxlen: int = 5000,
|
||||
):
|
||||
super().__init__()
|
||||
self.odim = d_model
|
||||
|
||||
self.input_preprocessor = Conv2dSubsampling(idim, d_model)
|
||||
self.positional_encoding = RelPositionalEncoding(d_model)
|
||||
|
||||
self.layer_stack = nn.ModuleList()
|
||||
for _ in range(n_layers_enc):
|
||||
block = RelPosEmbConformerBlock(d_model, n_head, kernel_size)
|
||||
self.layer_stack.append(block)
|
||||
|
||||
def forward(
|
||||
self, padded_input: torch.Tensor, input_lengths: torch.Tensor, pad: bool = True
|
||||
):
|
||||
if pad:
|
||||
padded_input = F.pad(
|
||||
padded_input,
|
||||
(0, 0, 0, self.input_preprocessor.context - 1),
|
||||
"constant",
|
||||
0.0,
|
||||
)
|
||||
src_mask = self.padding_position_is_0(padded_input, input_lengths)
|
||||
|
||||
embed_output, input_lengths, src_mask = self.input_preprocessor(
|
||||
padded_input, src_mask
|
||||
)
|
||||
enc_output = embed_output
|
||||
|
||||
pos_emb = self.positional_encoding(embed_output)
|
||||
|
||||
enc_outputs = []
|
||||
for enc_layer in self.layer_stack:
|
||||
enc_output = enc_layer(
|
||||
enc_output, pos_emb, slf_attn_mask=src_mask, pad_mask=src_mask
|
||||
)
|
||||
enc_outputs.append(enc_output)
|
||||
|
||||
return enc_output, input_lengths, src_mask
|
||||
|
||||
def padding_position_is_0(
|
||||
self, padded_input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
N, T = padded_input.size()[:2]
|
||||
mask = torch.ones((N, T)).to(padded_input.device)
|
||||
for i in range(N):
|
||||
mask[i, input_lengths[i] :] = 0
|
||||
mask = mask.unsqueeze(dim=1)
|
||||
return mask.to(torch.uint8)
|
||||
|
||||
|
||||
class FireRedASR2Adapter(nn.Module):
|
||||
def __init__(self, encoder_dim: int, llm_dim: int, downsample_rate: int = 2):
|
||||
super().__init__()
|
||||
self.ds = downsample_rate
|
||||
self.linear1 = ReplicatedLinear(
|
||||
input_size=encoder_dim * downsample_rate,
|
||||
output_size=llm_dim,
|
||||
bias=True,
|
||||
)
|
||||
self.relu = _ACTIVATION_REGISTRY["relu"]
|
||||
self.linear2 = ReplicatedLinear(
|
||||
input_size=llm_dim,
|
||||
output_size=llm_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x, x_lens):
|
||||
batch_size, seq_len, feat_dim = x.size()
|
||||
num_frames_to_discard = seq_len % self.ds
|
||||
if num_frames_to_discard > 0:
|
||||
x = x[:, :-num_frames_to_discard, :]
|
||||
seq_len = x.size(1)
|
||||
|
||||
x = x.contiguous()
|
||||
x = x.view(batch_size, seq_len // self.ds, feat_dim * self.ds)
|
||||
|
||||
x, _ = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x, _ = self.linear2(x)
|
||||
|
||||
new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds
|
||||
return x, new_x_lens
|
||||
|
||||
|
||||
class FireRedASR2Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.audio_encoder = ConformerEncoder(
|
||||
**vllm_config.model_config.hf_config.audio_encoder_conf
|
||||
)
|
||||
|
||||
|
||||
class FireRedASR2Model(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.encoder = FireRedASR2Encoder(
|
||||
vllm_config=vllm_config,
|
||||
)
|
||||
encoder_dim = self.encoder.audio_encoder.odim
|
||||
llm_dim = vllm_config.model_config.hf_config.hidden_size
|
||||
self.encoder_projector = FireRedASR2Adapter(
|
||||
encoder_dim,
|
||||
llm_dim,
|
||||
vllm_config.model_config.hf_config.encoder_downsample_rate,
|
||||
)
|
||||
|
||||
self.decoder = Qwen2ForCausalLM(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder")
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def get_encoder_outputs(
|
||||
self,
|
||||
speech: torch.Tensor | list[torch.Tensor] | None,
|
||||
speech_lengths: torch.Tensor | list[torch.Tensor] | None,
|
||||
) -> torch.Tensor | None:
|
||||
encoder_outs, enc_lengths, enc_mask = self.encoder.audio_encoder(
|
||||
speech, speech_lengths
|
||||
)
|
||||
speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths)
|
||||
return speech_features
|
||||
|
||||
|
||||
class FireRedASR2ProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self) -> Qwen2Config:
|
||||
return self.ctx.get_hf_config(Qwen2Config)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": 1}
|
||||
|
||||
def get_feature_extractor(self, **kwargs: object) -> FireRedASR2FeatureExtractor:
|
||||
hf_processor = self.get_hf_processor(**kwargs)
|
||||
feature_extractor = hf_processor.feature_extractor # type: ignore
|
||||
assert isinstance(feature_extractor, FireRedASR2FeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
def get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
return MultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
target_channels=self.get_target_channels(),
|
||||
)
|
||||
|
||||
def get_target_channels(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
class FireRedASR2DummyInputsBuilder(BaseDummyInputsBuilder[FireRedASR2ProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
return "<|AUDIO|>" * num_audios
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
) -> MultiModalDataDict:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
audio_overrides = mm_options.get("audio")
|
||||
|
||||
ret = {
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len, num_audios=num_audios, overrides=audio_overrides
|
||||
)
|
||||
}
|
||||
return ret
|
||||
|
||||
|
||||
class FireRedASR2MultiModalProcessor(
|
||||
BaseMultiModalProcessor[FireRedASR2ProcessingInfo]
|
||||
):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if mm_data:
|
||||
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||
mm_data = dict(audio=mm_data.pop("audios"))
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
if "labels" in processed_outputs:
|
||||
processed_outputs["input_ids"] = processed_outputs.pop("labels")
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
speech_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
fake_token_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
|
||||
|
||||
audio_token_id = vocab[audio_token]
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
|
||||
fake_token_lengths = out_mm_data.get("fake_token_lengths")
|
||||
|
||||
if fake_token_lengths is None:
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
assert isinstance(fake_token_lengths, torch.Tensor)
|
||||
|
||||
audio_output_lengths = fake_token_lengths.tolist()
|
||||
|
||||
def get_replacement_fireredasr2_audio(item_idx: int):
|
||||
num_features = audio_output_lengths[item_idx]
|
||||
|
||||
audio_tokens = [audio_token_id] * int(num_features)
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
audio_tokens,
|
||||
embed_token_id=audio_token_id,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=[audio_token_id],
|
||||
replacement=get_replacement_fireredasr2_audio,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
FireRedASR2MultiModalProcessor,
|
||||
info=FireRedASR2ProcessingInfo,
|
||||
dummy_inputs=FireRedASR2DummyInputsBuilder,
|
||||
)
|
||||
class FireRedASR2ForConditionalGeneration(
|
||||
nn.Module, SupportsTranscription, SupportsMultiModal
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
],
|
||||
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"llm.": "model.decoder.",
|
||||
"encoder.": "model.encoder.audio_encoder.",
|
||||
"encoder_projector.": "model.encoder_projector.",
|
||||
"net.0": "pre_layer_norm",
|
||||
"net.1": "linear_expand",
|
||||
"net.4": "linear_project",
|
||||
}
|
||||
)
|
||||
|
||||
supports_transcription_only = True
|
||||
supports_segment_timestamp = True
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str | None) -> str | None:
|
||||
if language is None:
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||
logger.warning(
|
||||
"Defaulting to language='en'. If you wish to transcribe "
|
||||
"audio in a different language, pass the `language` field "
|
||||
"in the TranscriptionRequest."
|
||||
)
|
||||
language = "en"
|
||||
return super().validate_language(language)
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig, # not needed here
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: str | None,
|
||||
) -> PromptType:
|
||||
if language is None:
|
||||
raise ValueError(
|
||||
"Language must be specified when creating the fireredasr2 prompt"
|
||||
)
|
||||
|
||||
prompt_str = "<|im_start|>user\n<|AUDIO|>请转写音频为文字<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
|
||||
prompt = {
|
||||
"prompt": prompt_str,
|
||||
"multi_modal_data": {
|
||||
"audio": (audio, stt_config.sample_rate),
|
||||
},
|
||||
}
|
||||
return cast(PromptType, prompt)
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
) -> SpeechToTextConfig:
|
||||
processor = cached_processor_from_config(model_config)
|
||||
|
||||
return SpeechToTextConfig(
|
||||
max_audio_clip_s=processor.feature_extractor.chunk_length,
|
||||
sample_rate=processor.feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(
|
||||
cls,
|
||||
audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
) -> int | None:
|
||||
processor = cached_processor_from_config(model_config)
|
||||
hop_length = processor.feature_extractor.hop_length
|
||||
assert hop_length is not None
|
||||
return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.config = config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.model = FireRedASR2Model(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
decoder_outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
speech = audio_input["input_features"]
|
||||
speech_lengths = audio_input["speech_lengths"].to(torch.int32)
|
||||
enc_output = self.model.get_encoder_outputs(
|
||||
speech=speech, speech_lengths=speech_lengths
|
||||
)
|
||||
|
||||
return enc_output
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.decoder.embed_input_ids(input_ids)
|
||||
|
||||
ret = _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=_require_is_multimodal(is_multimodal),
|
||||
)
|
||||
return ret
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> FireRedASR2AudioInputs:
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
speech_lengths = kwargs.pop("speech_lengths", None)
|
||||
fake_token_lengths = kwargs.pop("fake_token_lengths", None)
|
||||
|
||||
return FireRedASR2AudioInputs(
|
||||
input_features=input_features,
|
||||
speech_lengths=speech_lengths,
|
||||
fake_token_lengths=fake_token_lengths,
|
||||
)
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.model.decoder.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self, skip_prefixes=["model.encoder.audio_encoder.positional_encoding.pe"]
|
||||
)
|
||||
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
@@ -13,7 +13,6 @@ positions via `inputs_embeds`, while `position_ids` (RoPE) remains standard 1D.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
@@ -924,53 +923,6 @@ class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, Suppor
|
||||
f"sequence of Tensors (got {type(speech_attention_mask)})"
|
||||
)
|
||||
|
||||
debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1"
|
||||
if debug:
|
||||
print(
|
||||
f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} "
|
||||
f"speech_attention_mask={tuple(speech_attention_mask.shape)}",
|
||||
flush=True,
|
||||
)
|
||||
attn_impl = getattr(
|
||||
self.continuous_audio_tower.config, "_attn_implementation", None
|
||||
)
|
||||
print(
|
||||
f"[FunAudioChat] audio_attn_impl={attn_impl}",
|
||||
flush=True,
|
||||
)
|
||||
if hasattr(self.continuous_audio_tower, "conv1"):
|
||||
conv1_w = self.continuous_audio_tower.conv1.weight
|
||||
print(
|
||||
f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}",
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
attn0 = self.continuous_audio_tower.layers[0].self_attn
|
||||
q_norm = float(attn0.q_proj.weight.norm().item())
|
||||
k_norm = float(attn0.k_proj.weight.norm().item())
|
||||
v_norm = float(attn0.v_proj.weight.norm().item())
|
||||
o_norm = float(attn0.out_proj.weight.norm().item())
|
||||
print(
|
||||
f"[FunAudioChat] attn0_q_norm={q_norm:.6g} "
|
||||
f"k_norm={k_norm:.6g} "
|
||||
f"v_norm={v_norm:.6g} "
|
||||
f"o_norm={o_norm:.6g}",
|
||||
flush=True,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if isinstance(input_features, torch.Tensor):
|
||||
print(
|
||||
f"[FunAudioChat] input_features={tuple(input_features.shape)}",
|
||||
flush=True,
|
||||
)
|
||||
if isinstance(feature_attention_mask, torch.Tensor):
|
||||
print(
|
||||
"[FunAudioChat] feature_attention_mask="
|
||||
f"{tuple(feature_attention_mask.shape)}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
group_size = int(self.audio_tower.group_size)
|
||||
speech_maxlen = int(speech_ids.shape[-1])
|
||||
|
||||
@@ -1019,38 +971,6 @@ class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, Suppor
|
||||
embeds = tuple(
|
||||
audio_features[i, : int(length)] for i, length in enumerate(lengths)
|
||||
)
|
||||
if debug:
|
||||
embed_lens = [int(t.shape[0]) for t in embeds]
|
||||
print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True)
|
||||
if embeds:
|
||||
t0 = embeds[0]
|
||||
print(
|
||||
f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} "
|
||||
f"nan={bool(torch.isnan(t0).any())} "
|
||||
f"norm={float(t0.norm().item()):.6g}",
|
||||
flush=True,
|
||||
)
|
||||
dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "")
|
||||
if (
|
||||
dump_path
|
||||
and speech_ids.shape[0] == 1
|
||||
and len(embeds) == 1
|
||||
and embed_lens[0] > 10
|
||||
):
|
||||
if not os.path.exists(dump_path):
|
||||
np.save(dump_path, embeds[0].detach().float().cpu().numpy())
|
||||
print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True)
|
||||
cont_path = dump_path.replace(".npy", "_cont.npy")
|
||||
if continuous_audio_features is not None and not os.path.exists(
|
||||
cont_path
|
||||
):
|
||||
np.save(
|
||||
cont_path,
|
||||
continuous_audio_features.detach().float().cpu().numpy(),
|
||||
)
|
||||
print(
|
||||
f"[FunAudioChat] dumped continuous to {cont_path}", flush=True
|
||||
)
|
||||
return embeds
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -409,7 +409,7 @@ class Gemma3nAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
||||
q = q.unflatten(-1, (self.num_heads, self.head_dim))
|
||||
q = self.q_norm(q)
|
||||
q = q.flatten(-2, -1)
|
||||
|
||||
@@ -110,7 +110,12 @@ class Glm4MoeMLP(nn.Module):
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
if self.down_proj.quant_method.__class__.__name__ != "UnquantizedLinearMethod" and x.shape[-1] != self.down_proj.weight.shape[0]:
|
||||
padding = self.down_proj.weight.shape[0] - x.shape[-1]
|
||||
x_align = torch.nn.functional.pad(x, (0, padding), mode='constant', value=0)
|
||||
else:
|
||||
x_align = x
|
||||
x, _ = self.down_proj(x_align)
|
||||
return x
|
||||
|
||||
|
||||
@@ -144,11 +149,10 @@ class Glm4MoE(nn.Module):
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
# dtype=torch.float32,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts)
|
||||
)
|
||||
torch.empty(config.n_routed_experts, dtype=torch.bfloat16))
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
@@ -205,8 +209,7 @@ class Glm4MoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
# router_logits = self.gate(hidden_states.to(dtype=torch.float32))
|
||||
router_logits = self.gate(hidden_states)
|
||||
router_logits = self.gate(hidden_states.to(dtype=torch.bfloat16))
|
||||
|
||||
fused_moe_out = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
@@ -312,6 +315,9 @@ class Glm4MoeAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
if self.use_qk_norm:
|
||||
q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(
|
||||
q.shape
|
||||
|
||||
@@ -127,12 +127,10 @@ class Glm4MoeLiteDecoderLayer(nn.Module):
|
||||
v_head_dim = getattr(config, "v_head_dim", 0)
|
||||
kv_lora_rank = getattr(config, "kv_lora_rank", 0)
|
||||
|
||||
# if model_config.use_mla:
|
||||
# attn_cls = Glm4MoeLiteMLAAttention
|
||||
# else:
|
||||
# attn_cls = Glm4MoeLiteAttention
|
||||
|
||||
attn_cls = Glm4MoeLiteAttention
|
||||
if model_config.use_mla:
|
||||
attn_cls = Glm4MoeLiteMLAAttention
|
||||
else:
|
||||
attn_cls = Glm4MoeLiteAttention
|
||||
|
||||
self.self_attn = attn_cls(
|
||||
vllm_config=vllm_config,
|
||||
@@ -306,7 +304,7 @@ class Glm4MoeLiteModel(nn.Module):
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
@@ -318,6 +316,120 @@ class Glm4MoeLiteModel(nn.Module):
|
||||
num_experts=self.config.n_routed_experts,
|
||||
)
|
||||
|
||||
|
||||
class Glm4MoeLiteForCausalLM(
|
||||
nn.Module, SupportsPP, SupportsLoRA, Glm4LiteMixtureOfExperts
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
|
||||
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
|
||||
self.use_mha = config.model_type == "deepseek" or all(
|
||||
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
|
||||
)
|
||||
|
||||
if self.use_mha:
|
||||
self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
|
||||
|
||||
# `packed_modules_mapping` needs to be modified before
|
||||
# initializing DeepseekV2Model, as it is passed inplace to
|
||||
# quantization config init and may be used to select the
|
||||
# quant_method for relevant layers during initialization.
|
||||
self.fuse_qkv_a_proj = (
|
||||
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
||||
)
|
||||
if self.fuse_qkv_a_proj:
|
||||
self.packed_modules_mapping["fused_qkv_a_proj"] = [
|
||||
"q_a_proj",
|
||||
"kv_a_proj_with_mqa",
|
||||
]
|
||||
|
||||
self.model = Glm4MoeLiteModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = (
|
||||
self.config.num_hidden_layers - self.config.first_k_dense_replace
|
||||
)
|
||||
self.set_moe_parameters()
|
||||
|
||||
def set_moe_parameters(self):
|
||||
self.expert_weights = []
|
||||
|
||||
self.num_expert_groups = getattr(self.config, "n_group", 1)
|
||||
|
||||
self.moe_layers = []
|
||||
self.moe_mlp_layers = []
|
||||
example_moe = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Glm4MoeLiteDecoderLayer)
|
||||
if isinstance(layer.mlp, Glm4MoeLite):
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = layer.mlp
|
||||
self.moe_mlp_layers.append(layer.mlp)
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return SharedFusedMoE.make_expert_params_mapping(
|
||||
self,
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_redundant_experts=0,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rocm_aiter_moe_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
@@ -327,12 +439,13 @@ class Glm4MoeLiteModel(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
mla_params_mapping = [
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
mha_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
stacked_params_mapping.extend(mla_params_mapping)
|
||||
if self.use_mha:
|
||||
stacked_params_mapping.extend(mha_params_mapping)
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
@@ -510,128 +623,71 @@ class Glm4MoeLiteModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"]
|
||||
# add your opt here..
|
||||
def inject_layer(layer, quant_method, is_mla):
|
||||
q_lora_rank = getattr(layer, "q_lora_rank", None)
|
||||
if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]:
|
||||
if q_lora_rank is not None:
|
||||
layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_a_proj, "weight_scale"):
|
||||
layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_proj, "weight_scale"):
|
||||
layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
else:
|
||||
return
|
||||
del layer.kv_a_proj_with_mqa.weight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
elif quant_method == "GGUFLinearMethod":
|
||||
pass
|
||||
elif quant_method == "AWQMarlinLinearMethod":
|
||||
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
|
||||
assert dtype == torch.int32
|
||||
if layer.q_lora_rank is not None:
|
||||
layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1)
|
||||
layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1)
|
||||
layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
else:
|
||||
return
|
||||
|
||||
del layer.kv_a_proj_with_mqa.qweight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
else:
|
||||
pass
|
||||
|
||||
for _, layer in self.model.named_modules():
|
||||
if layer.__class__.__name__ in ["Glm4MoeLiteAttention","Glm4MoeLiteMLAAttention"]:
|
||||
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
|
||||
quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__
|
||||
else:
|
||||
quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__
|
||||
if quant_method not in opt_support_quant_method:
|
||||
break
|
||||
|
||||
inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "Glm4MoeLiteMLAAttention")
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Glm4MoeLiteForCausalLM(
|
||||
nn.Module, SupportsPP, SupportsLoRA, Glm4LiteMixtureOfExperts
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
|
||||
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
|
||||
self.use_mha = config.model_type == "deepseek" or all(
|
||||
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
|
||||
)
|
||||
|
||||
if self.use_mha:
|
||||
self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
|
||||
|
||||
# `packed_modules_mapping` needs to be modified before
|
||||
# initializing DeepseekV2Model, as it is passed inplace to
|
||||
# quantization config init and may be used to select the
|
||||
# quant_method for relevant layers during initialization.
|
||||
self.fuse_qkv_a_proj = (
|
||||
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
||||
)
|
||||
if self.fuse_qkv_a_proj:
|
||||
self.packed_modules_mapping["fused_qkv_a_proj"] = [
|
||||
"q_a_proj",
|
||||
"kv_a_proj_with_mqa",
|
||||
]
|
||||
|
||||
self.model = Glm4MoeLiteModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = (
|
||||
self.config.num_hidden_layers - self.config.first_k_dense_replace
|
||||
)
|
||||
self.set_moe_parameters()
|
||||
|
||||
def set_moe_parameters(self):
|
||||
self.expert_weights = []
|
||||
|
||||
self.num_expert_groups = getattr(self.config, "n_group", 1)
|
||||
|
||||
self.moe_layers = []
|
||||
self.moe_mlp_layers = []
|
||||
example_moe = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Glm4MoeLiteDecoderLayer)
|
||||
if isinstance(layer.mlp, Glm4MoeLite):
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = layer.mlp
|
||||
self.moe_mlp_layers.append(layer.mlp)
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return SharedFusedMoE.make_expert_params_mapping(
|
||||
self,
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_redundant_experts=0,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(
|
||||
config: "Glm4MoeLiteConfig", weight_name: str
|
||||
) -> int | None:
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import GptOssConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
@@ -23,7 +23,11 @@ from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
|
||||
@@ -42,6 +46,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
from vllm.model_executor.model_loader import padding_weight_loader
|
||||
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
@@ -107,7 +112,6 @@ class OAIAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.num_attention_heads * self.head_dim,
|
||||
output_size=self.hidden_size,
|
||||
@@ -165,7 +169,14 @@ class MLPBlock(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
|
||||
self.router = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=True,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.router",
|
||||
return_bias=False,
|
||||
)
|
||||
assert config.intermediate_size % self.world_size == 0
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
@@ -969,8 +980,18 @@ class GptOssModel(nn.Module):
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
stacked_params_mapping: list[tuple[str, ...]],
|
||||
) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
def handle_weight(name, weight, param_name, permute_dims=None, slice_dims=None, contiguous=True):
|
||||
"""Helper function to handle weight loading with optional slicing and permutation."""
|
||||
param = params_dict[param_name]
|
||||
if slice_dims:
|
||||
weight = weight[slice_dims]
|
||||
if permute_dims:
|
||||
weight = weight.permute(*permute_dims)
|
||||
if contiguous:
|
||||
weight = weight.contiguous()
|
||||
padding_weight_loader(param, weight)
|
||||
loaded_params.add(param_name)
|
||||
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
|
||||
@@ -986,91 +1007,71 @@ class GptOssModel(nn.Module):
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
pack_factor = 2 if envs.VLLM_W8A8_MOE_USE_W4A8 else 1
|
||||
w4a8_flag = envs.VLLM_W8A8_MOE_USE_W4A8
|
||||
gemm_format = envs.VLLM_W8A8_FORMAT
|
||||
|
||||
for name, weight in weights:
|
||||
# Skip layers on other devices.
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
if ".w13_weight" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
# Extract gate and up projection parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
|
||||
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
if ".experts.w13_weight" in name and "scale" not in name and "bias" not in name:
|
||||
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end))
|
||||
permute_dims = None if gemm_format == "NN" else (0, 2, 1)
|
||||
handle_weight(name, weight, name, permute_dims=permute_dims, slice_dims=slice_dims)
|
||||
elif ".experts.w2_weight" in name and "scale" not in name and "bias" not in name:
|
||||
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(tp_rank_start // pack_factor, tp_rank_end // pack_factor), slice(None))
|
||||
permute_dims = None if gemm_format == "NN" else (0, 2, 1)
|
||||
handle_weight(name, weight, name, permute_dims=permute_dims, slice_dims=slice_dims)
|
||||
elif ".experts.gate_up_proj_scale" in name:
|
||||
new_name = name.replace("gate_up_proj_scale", "w13_weight_scale")
|
||||
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end))
|
||||
permute_dims = None if w4a8_flag else (0, 2, 1)
|
||||
handle_weight(name, weight, new_name, permute_dims=permute_dims, slice_dims=slice_dims, contiguous=w4a8_flag)
|
||||
elif ".experts.down_proj_scale" in name:
|
||||
new_name = name.replace("down_proj_scale", "w2_weight_scale")
|
||||
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else None
|
||||
permute_dims = None if w4a8_flag else (0, 2, 1)
|
||||
handle_weight(name, weight, new_name, permute_dims=permute_dims, slice_dims=slice_dims, contiguous=w4a8_flag)
|
||||
elif ".experts.w13_bias" in name:
|
||||
slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end))
|
||||
handle_weight(name, weight, name, slice_dims=slice_dims, contiguous=False)
|
||||
elif ".experts.w2_bias" in name:
|
||||
param = params_dict[name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_weight" in name:
|
||||
# Handle MLP down projection weights
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w13_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
|
||||
|
||||
param = params_dict[name]
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
param = params_dict[name]
|
||||
param.copy_(weight)
|
||||
elif tp_rank != 0:
|
||||
weight.zero_()
|
||||
param.data.copy_(weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
name = name.replace("self_attn", "attn")
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, weight)
|
||||
else:
|
||||
weight_loader(param, weight, shard_id)
|
||||
break
|
||||
elif ("q_proj" in name or "k_proj" in name or "v_proj" in name):
|
||||
shard_id = ("q" if "q_proj" in name else "k" if "k_proj" in name else "v")
|
||||
name = name.replace("self_attn", "attn")
|
||||
param_name = name.replace(f"{shard_id}_proj", "qkv_proj")
|
||||
param = params_dict[param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||
loaded_params.add(param_name)
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(name)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
|
||||
@@ -636,7 +636,13 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
|
||||
spatial_merge_size = vision_config.spatial_merge_size
|
||||
|
||||
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
|
||||
size = mm_kwargs.get("size", image_processor.size)
|
||||
size = image_processor.size
|
||||
if override_size := mm_kwargs.get("size"):
|
||||
size = size | override_size
|
||||
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
|
||||
size = size | {"shortest_edge": override_min_pixels}
|
||||
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
|
||||
size = size | {"longest_edge": override_max_pixels}
|
||||
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
|
||||
@@ -49,7 +49,6 @@ from .utils import (
|
||||
)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
EOT = "<|endofturn|>"
|
||||
IMAGE_TOKEN: str = "<|dummy3|>"
|
||||
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"
|
||||
|
||||
|
||||
@@ -551,7 +551,7 @@ def process_vision_for_patches(
|
||||
`(num_images, height, width, channels)` for a batch. Channels are
|
||||
expected to be RGB.
|
||||
patch_size (`int`):
|
||||
Edge length of square patches; implictly controls resize grid granularity.
|
||||
Edge length of square patches; implicitly controls resize grid granularity.
|
||||
max_num_patches (`int`):
|
||||
Maximum number of patches allowed after resizing.
|
||||
min_num_patches (`int`, *optional*):
|
||||
|
||||
@@ -1021,7 +1021,13 @@ class KeyeProcessingInfo(BaseProcessingInfo):
|
||||
temporal_patch_size = 1
|
||||
|
||||
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
|
||||
size = mm_kwargs.get("size", image_processor.size)
|
||||
size = image_processor.size
|
||||
if override_size := mm_kwargs.get("size"):
|
||||
size = size | override_size
|
||||
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
|
||||
size = size | {"min_pixels": override_min_pixels}
|
||||
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
|
||||
size = size | {"max_pixels": override_max_pixels}
|
||||
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
|
||||
@@ -654,4 +654,4 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
@@ -230,4 +230,4 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
|
||||
}
|
||||
|
||||
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
|
||||
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
|
||||
@@ -88,7 +88,7 @@ class MiniMaxM2MoE(nn.Module):
|
||||
self.use_routing_bias = getattr(config, "use_routing_bias", False)
|
||||
if self.use_routing_bias:
|
||||
self.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.num_local_experts, dtype=torch.float32)
|
||||
torch.empty(config.num_local_experts, dtype=torch.get_default_dtype())
|
||||
)
|
||||
self.e_score_correction_bias.weight_loader = (
|
||||
MiniMaxM2MoE.ebias_weight_loader
|
||||
@@ -107,13 +107,14 @@ class MiniMaxM2MoE(nn.Module):
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=False,
|
||||
# params_dtype=torch.float32,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
@@ -121,7 +122,6 @@ class MiniMaxM2MoE(nn.Module):
|
||||
@staticmethod
|
||||
def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
|
||||
assert param.size() == loaded_weight.size()
|
||||
# param.data.copy_(loaded_weight.to(torch.float32))
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -129,10 +129,9 @@ class MiniMaxM2MoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
# router_logits, _ = self.gate(hidden_states.to(torch.float32))
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
router_logits, _ = self.gate(hidden_states.to(torch.float32))
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
hidden_states=hidden_states, router_logits=router_logits.to(hidden_states.dtype)
|
||||
)
|
||||
final_hidden_states = final_hidden_states
|
||||
if self.tp_size > 1:
|
||||
|
||||
@@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import (
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
from vllm.model_executor.models.parakeet import ParakeetExtractor, ProjectedParakeet
|
||||
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
|
||||
from vllm.model_executor.models.utils import (
|
||||
init_vllm_registered_model,
|
||||
@@ -55,12 +56,14 @@ from vllm.multimodal.evs import (
|
||||
compute_retention_mask,
|
||||
)
|
||||
from vllm.multimodal.inputs import (
|
||||
AudioItem,
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
VideoItem,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
AudioProcessorItems,
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
@@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
|
||||
# Alternative: Set a specific higher limit
|
||||
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
|
||||
|
||||
|
||||
class NanoNemotronVLAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Number of audio clips
|
||||
- t: Audio feature length
|
||||
- f: Feature size (mel bins)
|
||||
"""
|
||||
|
||||
type: Literal["audio_features"] = "audio_features"
|
||||
input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")]
|
||||
feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")]
|
||||
audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")]
|
||||
|
||||
|
||||
MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes
|
||||
|
||||
IMG_START = "<img>"
|
||||
IMG_END = "</img>"
|
||||
IMG_CONTEXT = "<image>"
|
||||
AUDIO_START = "<so_start>"
|
||||
AUDIO_END = "<so_end>"
|
||||
AUDIO_CONTEXT = "<so_embedding>"
|
||||
|
||||
# Profiling
|
||||
# MAX_FRAMES = 16
|
||||
@@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
self.video_token = video_token
|
||||
self.video_pruning_rate = video_pruning_rate
|
||||
|
||||
self.audio_extractor: ParakeetExtractor | None = None
|
||||
raw_sound_config = getattr(config, "sound_config", None)
|
||||
if raw_sound_config is not None:
|
||||
self.audio_extractor = ParakeetExtractor(raw_sound_config)
|
||||
|
||||
# Pre-tokenize special tokens for video processing
|
||||
# to avoid repeated tokenization
|
||||
self._img_start_token_ids = tokenizer.encode(
|
||||
@@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
text = [t.replace("<video>", video_repl_text, 1) for t in text]
|
||||
return text, video_inputs
|
||||
|
||||
def _preprocess_audio(
|
||||
self,
|
||||
text: list[str],
|
||||
audios: list[npt.NDArray],
|
||||
):
|
||||
if len(audios) == 0:
|
||||
return text, {}
|
||||
assert self.audio_extractor is not None
|
||||
|
||||
extractor = self.audio_extractor
|
||||
|
||||
parts = [x for x in re.split(f"({re.escape(AUDIO_CONTEXT)})", text[0]) if x]
|
||||
token_count = parts.count(AUDIO_CONTEXT)
|
||||
if token_count != len(audios):
|
||||
raise ValueError(
|
||||
"Number of audio tokens in text does not match the number "
|
||||
f"of audios (tokens={token_count}, audios={len(audios)})."
|
||||
)
|
||||
audio_index = 0
|
||||
for idx, part in enumerate(parts):
|
||||
if part == AUDIO_CONTEXT:
|
||||
audio_repl = self.get_audio_repl(audios[audio_index])
|
||||
parts[idx] = audio_repl.full
|
||||
audio_index += 1
|
||||
text = ["".join(parts)]
|
||||
audio_inputs = extractor(
|
||||
audios,
|
||||
sampling_rate=extractor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_audio_features = audio_inputs.input_features
|
||||
feature_attention_mask = audio_inputs.attention_mask
|
||||
audio_feature_lengths = feature_attention_mask.sum(dim=1)
|
||||
audio_inputs = {
|
||||
"input_audio_features": input_audio_features,
|
||||
"feature_attention_mask": feature_attention_mask,
|
||||
"audio_feature_lengths": audio_feature_lengths,
|
||||
}
|
||||
|
||||
return text, audio_inputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str] | None = None,
|
||||
images: Image.Image | list[Image.Image] | None = None,
|
||||
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
|
||||
audios: AudioItem | list[AudioItem] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
max_num_tiles: int | None = None,
|
||||
) -> BatchFeature:
|
||||
@@ -964,8 +1034,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
if max_num_tiles is None:
|
||||
max_num_tiles = self.max_num_tiles
|
||||
|
||||
text, images, videos = [
|
||||
self._make_batch_input(x) for x in (text, images, videos)
|
||||
text, images, videos, audios = [
|
||||
self._make_batch_input(x) for x in (text, images, videos, audios)
|
||||
]
|
||||
|
||||
text, image_inputs = self._preprocess_image(
|
||||
@@ -980,17 +1050,22 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
max_num_tiles=1,
|
||||
)
|
||||
|
||||
text, audio_inputs = self._preprocess_audio(
|
||||
text=text,
|
||||
audios=audios,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer(text, add_special_tokens=False)
|
||||
|
||||
combined_inputs = {**text_inputs, **video_inputs, **audio_inputs}
|
||||
|
||||
if self.dynamic_tiler is None:
|
||||
batch = BatchFeature(
|
||||
{**text_inputs, **video_inputs, **image_inputs},
|
||||
{**combined_inputs, **image_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
else:
|
||||
batch = BatchFeature(
|
||||
{**text_inputs, **video_inputs}, tensor_type=return_tensors
|
||||
)
|
||||
batch = BatchFeature(combined_inputs, tensor_type=return_tensors)
|
||||
# allow images to be exempt from the BatchFeature validation:
|
||||
# We will .stack() them in _parse_and_validate_image_input
|
||||
batch.update(image_inputs)
|
||||
@@ -1006,6 +1081,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
def get_audio_repl(
|
||||
self,
|
||||
audio: npt.NDArray,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
assert self.audio_extractor is not None
|
||||
num_tokens = self.audio_extractor.audio_token_count(len(audio))
|
||||
repl_full = f"{AUDIO_START}{AUDIO_CONTEXT * num_tokens}{AUDIO_END}"
|
||||
return PromptUpdateDetails.select_text(repl_full, AUDIO_CONTEXT)
|
||||
|
||||
@classmethod
|
||||
def get_video_repl(
|
||||
cls,
|
||||
@@ -1147,15 +1231,28 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
def supports_video(self):
|
||||
return self.get_hf_processor().supports_video
|
||||
|
||||
@property
|
||||
def audio_extractor(self) -> ParakeetExtractor | None:
|
||||
return self.get_hf_processor().audio_extractor
|
||||
|
||||
def get_data_parser(self):
|
||||
target_sr = None
|
||||
target_channels = None
|
||||
if extractor := self.audio_extractor:
|
||||
target_sr = extractor.sampling_rate
|
||||
target_channels = 1
|
||||
|
||||
return MultiModalDataParser(
|
||||
video_needs_metadata=True,
|
||||
target_sr=target_sr,
|
||||
target_channels=target_channels,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self):
|
||||
video_limit = {"video": None} if self.supports_video else {}
|
||||
return {**super().get_supported_mm_limits(), **video_limit}
|
||||
audio_limit = {"audio": None} if self.audio_extractor is not None else {}
|
||||
return {**super().get_supported_mm_limits(), **video_limit, **audio_limit}
|
||||
|
||||
def get_video_token(self) -> str | None:
|
||||
return IMG_CONTEXT
|
||||
@@ -1304,7 +1401,16 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
else:
|
||||
video_fields = {}
|
||||
|
||||
return image_fields | video_fields
|
||||
if self.info.audio_extractor is not None:
|
||||
audio_fields = dict(
|
||||
input_audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
else:
|
||||
audio_fields = {}
|
||||
|
||||
return image_fields | video_fields | audio_fields
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -1373,6 +1479,20 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
),
|
||||
]
|
||||
|
||||
def get_audio_replacement(item_idx: int):
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
return hf_processor.get_audio_repl(audios.get(item_idx))
|
||||
|
||||
if self.info.audio_extractor is not None:
|
||||
prompt_repl = [
|
||||
*prompt_repl,
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=AUDIO_CONTEXT,
|
||||
replacement=get_audio_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
return prompt_repl
|
||||
|
||||
|
||||
@@ -1422,8 +1542,13 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
|
||||
return (
|
||||
super().get_dummy_text(mm_counts)
|
||||
+ "<video>" * num_videos
|
||||
+ AUDIO_CONTEXT * num_audios
|
||||
)
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
@@ -1482,7 +1607,25 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
}
|
||||
else:
|
||||
dummy_video = {}
|
||||
return {**dummy_image, **dummy_video}
|
||||
|
||||
if extractor := self.info.audio_extractor:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
tokens_per_audio = max(1, seq_len // max(num_audios, 1))
|
||||
max_audio_num_samples = MAX_AUDIO_LEN_S * extractor.sampling_rate
|
||||
calculated_max_audio_num_samples = extractor.audio_length(tokens_per_audio)
|
||||
audio_len = min(max_audio_num_samples, calculated_max_audio_num_samples)
|
||||
dummy_audio = {
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
)
|
||||
}
|
||||
else:
|
||||
dummy_audio = {}
|
||||
|
||||
return {**dummy_image, **dummy_video, **dummy_audio}
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
@@ -1499,12 +1642,15 @@ class NemotronH_Nano_VL_V2(
|
||||
return "<image>"
|
||||
if modality.startswith("video"):
|
||||
return "<video>"
|
||||
if modality.startswith("audio"):
|
||||
return AUDIO_CONTEXT
|
||||
return None
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
model_config = vllm_config.model_config
|
||||
config = model_config.hf_config
|
||||
multimodal_config = model_config.multimodal_config
|
||||
image_size = config.force_image_size
|
||||
patch_size = config.patch_size
|
||||
self.patch_size = patch_size
|
||||
@@ -1523,10 +1669,12 @@ class NemotronH_Nano_VL_V2(
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
llm_dtype = self.language_model.config.dtype
|
||||
assert isinstance(llm_dtype, torch.dtype)
|
||||
self.llm_dtype = llm_dtype
|
||||
with self._mark_tower_model(vllm_config, {"image", "video", "audio"}):
|
||||
self.vision_model = self.get_vit_model_from_radio_config(config).to(
|
||||
self.language_model.config.dtype
|
||||
llm_dtype
|
||||
)
|
||||
|
||||
# Construct the vision projection.
|
||||
@@ -1547,14 +1695,26 @@ class NemotronH_Nano_VL_V2(
|
||||
ReLUSquaredActivation(),
|
||||
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
|
||||
)
|
||||
self.mlp1 = mlp1.to(self.language_model.config.dtype)
|
||||
self.mlp1 = mlp1.to(llm_dtype)
|
||||
self.sound_encoder: ProjectedParakeet | None = None
|
||||
if getattr(config, "sound_config", None) is not None:
|
||||
logger.info_once(
|
||||
"Found sound config, initializing sound encoder for Nemotron AVLM",
|
||||
scope="global",
|
||||
)
|
||||
self.sound_encoder = ProjectedParakeet(
|
||||
config.sound_config,
|
||||
dtype=llm_dtype,
|
||||
llm_hidden_size=llm_hidden_size,
|
||||
max_model_len=model_config.max_model_len,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
# Pre-tokenize special tokens for video processing
|
||||
# to avoid repeated tokenization
|
||||
tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
self._img_start_token_ids = tokenizer.encode(
|
||||
IMG_START, add_special_tokens=False
|
||||
)
|
||||
@@ -1566,7 +1726,10 @@ class NemotronH_Nano_VL_V2(
|
||||
config
|
||||
)
|
||||
if self.dynamic_resolution:
|
||||
logger.info("Dynamic resolution is enabled for NanoNemotronVLProcessor")
|
||||
logger.info_once(
|
||||
"Dynamic resolution is enabled for NanoNemotronVLProcessor",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
@@ -1780,6 +1943,51 @@ class NemotronH_Nano_VL_V2(
|
||||
|
||||
return final_video_embeddings
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: NanoNemotronVLAudioFeatureInputs
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
assert self.sound_encoder is not None
|
||||
input_audio_features = audio_input.input_audio_features
|
||||
feature_attention_mask = audio_input.feature_attention_mask
|
||||
target_device = next(self.sound_encoder.parameters()).device
|
||||
|
||||
# When cross-request batching combines audio clips with different
|
||||
# time dimensions, _reduce_data returns a list instead of a stacked
|
||||
# tensor. Pad to the max time dim and stack; the attention mask
|
||||
# already marks valid positions so zero-padding is safe.
|
||||
if isinstance(input_audio_features, list):
|
||||
feature_sizes = [f.shape[-2] for f in input_audio_features]
|
||||
max_t = max(feature_sizes)
|
||||
padded_feats = [
|
||||
torch.nn.functional.pad(feat, (0, 0, 0, max_t - feat_size))
|
||||
for feat, feat_size in zip(
|
||||
input_audio_features, feature_sizes, strict=True
|
||||
)
|
||||
]
|
||||
padded_masks = [
|
||||
torch.nn.functional.pad(mask, (0, max_t - mask.shape[-1]))
|
||||
for mask in feature_attention_mask
|
||||
]
|
||||
input_audio_features = torch.stack(padded_feats)
|
||||
feature_attention_mask = torch.stack(padded_masks)
|
||||
|
||||
input_audio_features = input_audio_features.to(
|
||||
dtype=self.llm_dtype, device=target_device
|
||||
)
|
||||
feature_attention_mask = feature_attention_mask.to(device=target_device)
|
||||
sound_embeds = self.sound_encoder(input_audio_features, feature_attention_mask)
|
||||
|
||||
valid_input_lens = feature_attention_mask.sum(dim=1)
|
||||
valid_output_lens = self.sound_encoder.encoder._get_subsampling_output_length(
|
||||
valid_input_lens
|
||||
)
|
||||
truncated_embeds = []
|
||||
for i in range(sound_embeds.shape[0]):
|
||||
valid_len = valid_output_lens[i].item()
|
||||
truncated_embeds.append(sound_embeds[i, :valid_len])
|
||||
|
||||
return tuple(truncated_embeds)
|
||||
|
||||
def _create_final_video_embeddings(
|
||||
self,
|
||||
video_embeddings: torch.Tensor,
|
||||
@@ -1887,6 +2095,18 @@ class NemotronH_Nano_VL_V2(
|
||||
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
|
||||
if (
|
||||
input_key
|
||||
in (
|
||||
"input_audio_features",
|
||||
"feature_attention_mask",
|
||||
"audio_feature_lengths",
|
||||
)
|
||||
and "audios" not in modalities
|
||||
):
|
||||
modalities["audios"] = NanoNemotronVLAudioFeatureInputs(
|
||||
**kwargs, validate=False
|
||||
)
|
||||
|
||||
return modalities
|
||||
|
||||
@@ -1917,6 +2137,10 @@ class NemotronH_Nano_VL_V2(
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_input(video_input)
|
||||
multimodal_embeddings += tuple(video_embeddings)
|
||||
if modality == "audios":
|
||||
audio_input = modalities["audios"]
|
||||
audio_embeddings = self._process_audio_input(audio_input)
|
||||
multimodal_embeddings += tuple(audio_embeddings)
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
@@ -1947,8 +2171,8 @@ class NemotronH_Nano_VL_V2(
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="mlp1",
|
||||
tower_model="vision_model",
|
||||
connector=["mlp1", "sound_encoder.projection"],
|
||||
tower_model=["vision_model", "sound_encoder.encoder"],
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
@@ -1969,9 +2193,13 @@ class NemotronH_Nano_VL_V2(
|
||||
def is_vision_weights(name: str) -> bool:
|
||||
return name.startswith("vision_model.radio_model.")
|
||||
|
||||
def is_sound_weights(name: str) -> bool:
|
||||
return name.startswith("sound")
|
||||
|
||||
# Separate weights by component
|
||||
llm_weights = []
|
||||
vision_weights = []
|
||||
sound_weights = []
|
||||
|
||||
for name, w in weights:
|
||||
if is_llm(name):
|
||||
@@ -1987,107 +2215,15 @@ class NemotronH_Nano_VL_V2(
|
||||
# Convert: vision_model.radio_model.* → radio_model.*
|
||||
hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix
|
||||
vision_weights.append((hf_key, w))
|
||||
elif is_sound_weights(name):
|
||||
assert self.sound_encoder is not None
|
||||
sound_weights.append((name, w))
|
||||
|
||||
self.language_model.load_weights(llm_weights)
|
||||
self.vision_model.load_weights(vision_weights)
|
||||
|
||||
def print_architecture(self, detailed: bool = True, save_to_file: str = None):
|
||||
"""
|
||||
Print model architecture with parameter names, shapes, and sizes.
|
||||
|
||||
Args:
|
||||
detailed: If True, show detailed parameter breakdown
|
||||
save_to_file: If provided, save output to this file path
|
||||
"""
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
# Capture output if saving to file
|
||||
original_stdout = sys.stdout
|
||||
if save_to_file:
|
||||
sys.stdout = StringIO()
|
||||
|
||||
try:
|
||||
print("=" * 100)
|
||||
print("NemotronH_Nano_VL_V2 Model Architecture")
|
||||
print("=" * 100)
|
||||
|
||||
total_params = 0
|
||||
param_groups = {
|
||||
"language_model": [],
|
||||
"vision_model": [],
|
||||
"mlp1": [],
|
||||
"other": [],
|
||||
}
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
param_size = param.numel()
|
||||
total_params += param_size
|
||||
|
||||
# Group parameters by main component
|
||||
if name.startswith("language_model"):
|
||||
param_groups["language_model"].append(
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
elif name.startswith("vision_model"):
|
||||
param_groups["vision_model"].append(
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
elif name.startswith("mlp1"):
|
||||
param_groups["mlp1"].append(
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
else:
|
||||
param_groups["other"].append(
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
|
||||
if detailed:
|
||||
print(
|
||||
f"{name:<70} | Shape: {str(param.shape):<25} | "
|
||||
f"Size: {param_size:>12,} | Dtype: {param.dtype}"
|
||||
)
|
||||
|
||||
print("=" * 100)
|
||||
print("Summary by Component:")
|
||||
print("-" * 60)
|
||||
|
||||
for component, params in param_groups.items():
|
||||
if params: # Only show components that have parameters
|
||||
component_total = sum(size for _, _, size, _ in params)
|
||||
percentage = (
|
||||
(component_total / total_params) * 100
|
||||
if total_params > 0
|
||||
else 0
|
||||
)
|
||||
print(
|
||||
f"{component:<20} | Parameters: {len(params):>4} | "
|
||||
f"Total Size: {component_total:>15,} | "
|
||||
f"{percentage:>6.2f}%"
|
||||
)
|
||||
|
||||
print("-" * 60)
|
||||
print(f"{'Total Parameters':<20} | {total_params:>15,}")
|
||||
|
||||
# Estimate memory usage (assuming bfloat16 = 2 bytes per parameter)
|
||||
memory_mb = total_params * 2 / (1024**2)
|
||||
memory_gb = memory_mb / 1024
|
||||
print(f"{'Est. Memory (MB)':<20} | {memory_mb:>15.2f}")
|
||||
print(f"{'Est. Memory (GB)':<20} | {memory_gb:>15.2f}")
|
||||
print("=" * 100)
|
||||
|
||||
# Save to file if requested
|
||||
if save_to_file:
|
||||
output = sys.stdout.getvalue()
|
||||
sys.stdout = original_stdout
|
||||
with open(save_to_file, "w") as f:
|
||||
f.write(output)
|
||||
print(f"Architecture saved to: {save_to_file}")
|
||||
print(output) # Also print to console
|
||||
|
||||
finally:
|
||||
if save_to_file and sys.stdout != original_stdout:
|
||||
sys.stdout = original_stdout
|
||||
if self.sound_encoder is not None:
|
||||
assert len(sound_weights) > 0
|
||||
self.sound_encoder.load_weights(sound_weights)
|
||||
|
||||
def get_vit_model_from_radio_config(self, hf_config):
|
||||
hf_config_vision = hf_config.vision_config
|
||||
|
||||
@@ -34,7 +34,7 @@ from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
GateLinear,
|
||||
SharedFusedMoE,
|
||||
activation_without_mul,
|
||||
)
|
||||
@@ -148,13 +148,11 @@ class NemotronHMoE(nn.Module):
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
router_logits_dtype = torch.float32
|
||||
self.gate = ReplicatedLinear(
|
||||
self.gate = GateLinear(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
params_dtype=router_logits_dtype,
|
||||
quant_config=None,
|
||||
out_dtype=torch.float32,
|
||||
force_fp32_compute=True,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
@@ -232,7 +230,6 @@ class NemotronHMoE(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
router_logits_dtype=router_logits_dtype,
|
||||
routed_input_transform=self.fc1_latent_proj,
|
||||
)
|
||||
|
||||
@@ -244,7 +241,7 @@ class NemotronHMoE(nn.Module):
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
# SharedFusedMoE handles:
|
||||
# - shared experts (with original hidden_states)
|
||||
@@ -298,6 +295,11 @@ class NemotronHMLPDecoderLayer(nn.Module):
|
||||
|
||||
hybrid_override_pattern = config.hybrid_override_pattern
|
||||
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
|
||||
# Get per-layer config for heterogeneous models if exist
|
||||
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
|
||||
layer_config = get_layer_config(layer_idx) if get_layer_config else config
|
||||
config = layer_config
|
||||
|
||||
if isinstance(config.intermediate_size, list):
|
||||
if len(config.intermediate_size) == 1:
|
||||
intermediate_size = config.intermediate_size[0]
|
||||
@@ -670,7 +672,7 @@ class NemotronHModel(nn.Module):
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
if self.has_moe:
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
|
||||
# - FusedMoe.w1 (aka gate_proj) should be up_proj since that's
|
||||
# what the activation is applied to
|
||||
# - FusedMoe.w3 (aka up_proj) should be ignored since we're
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
|
||||
@@ -18,6 +19,8 @@ from transformers import AutoModel, PretrainedConfig
|
||||
from transformers.image_processing_utils_fast import BaseImageProcessorFast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.models.internvl import (
|
||||
@@ -30,24 +33,29 @@ from vllm.model_executor.models.internvl import (
|
||||
InternVLProcessor,
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.siglip import SiglipVisionModel
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.processing import PromptUpdateDetails
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.processor import cached_image_processor_from_config
|
||||
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsCrossEncoding,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
IMG_START = "<img>"
|
||||
IMG_END = "</img>"
|
||||
IMG_CONTEXT = "<image>"
|
||||
from .interfaces_base import VllmModelForPooling
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
def build_transform(input_size: int):
|
||||
@@ -183,10 +191,12 @@ def image_to_pixel_values_nemotron_vl(
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
use_thumbnail: bool,
|
||||
transform: T.Compose | None = None,
|
||||
) -> torch.Tensor:
|
||||
target_ratios = get_nemotron_vl_target_ratios(min_num, max_num)
|
||||
|
||||
transform = build_transform(input_size=input_size)
|
||||
if transform is None:
|
||||
transform = build_transform(input_size=input_size)
|
||||
|
||||
images = dynamic_preprocess_nemotron_vl(
|
||||
image,
|
||||
@@ -200,11 +210,15 @@ def image_to_pixel_values_nemotron_vl(
|
||||
|
||||
|
||||
class NemotronVLProcessor(InternVLProcessor):
|
||||
IMG_START = "<img>"
|
||||
IMG_END = "</img>"
|
||||
IMG_CONTEXT = "<image>"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
image_processor: BaseImageProcessorFast,
|
||||
image_processor: BaseImageProcessorFast | None = None,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
@@ -236,11 +250,18 @@ class NemotronVLProcessor(InternVLProcessor):
|
||||
self.min_dynamic_patch = min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail: bool = self.image_processor.use_thumbnail
|
||||
|
||||
if image_processor is not None:
|
||||
self.use_thumbnail = image_processor.use_thumbnail
|
||||
else:
|
||||
self.use_thumbnail = getattr(config, "use_thumbnail", True)
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
return self.tokenizer.get_vocab()[self.IMG_CONTEXT]
|
||||
|
||||
def _get_transform(self) -> T.Compose:
|
||||
return build_transform(input_size=self.image_size)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
@@ -283,10 +304,26 @@ class NemotronVLProcessor(InternVLProcessor):
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
transform=self._get_transform(),
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
def _replace_image_tokens(
|
||||
self,
|
||||
text: list[str],
|
||||
pixel_values_lst: list[torch.Tensor],
|
||||
) -> list[str]:
|
||||
"""Replace <image> placeholders with image tokens."""
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
# Use temporary placeholder to avoid replacing tokens we just inserted
|
||||
NVL_IMAGE_CONTEXT = image_repl.full.replace("<image>", "<NVL_IMG_CONTEXT>")
|
||||
text = [t.replace("<image>", NVL_IMAGE_CONTEXT, 1) for t in text]
|
||||
return [t.replace("<NVL_IMG_CONTEXT>", self.IMG_CONTEXT) for t in text]
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
text: list[str],
|
||||
@@ -311,15 +348,7 @@ class NemotronVLProcessor(InternVLProcessor):
|
||||
),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
NVL_IMAGE_CONTEXT = image_repl.full.replace(
|
||||
"<image>", "<NVL_IMG_CONTEXT>"
|
||||
)
|
||||
text = [t.replace("<image>", NVL_IMAGE_CONTEXT, 1) for t in text]
|
||||
text = [t.replace("<NVL_IMG_CONTEXT>", IMG_CONTEXT) for t in text]
|
||||
text = self._replace_image_tokens(text, pixel_values_lst)
|
||||
return text, image_inputs
|
||||
|
||||
def get_image_repl(
|
||||
@@ -327,10 +356,10 @@ class NemotronVLProcessor(InternVLProcessor):
|
||||
feature_size: int,
|
||||
num_patches: int | None,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
repl_features = IMG_CONTEXT * feature_size
|
||||
repl_full = IMG_START + repl_features + IMG_END
|
||||
repl_features = self.IMG_CONTEXT * feature_size
|
||||
repl_full = self.IMG_START + repl_features + self.IMG_END
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
return PromptUpdateDetails.select_text(repl_full, self.IMG_CONTEXT)
|
||||
|
||||
|
||||
class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
@@ -396,7 +425,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
hf_config=config.get_text_config(),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
@@ -413,7 +442,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
# the awq models from OpenGVLab missing `modules_to_not_convert`
|
||||
# patch the quant_config to add `modules_to_not_convert` back
|
||||
if isinstance(quant_config, AWQConfig):
|
||||
text_config = config.text_config
|
||||
text_config = config.get_text_config()
|
||||
llm_quant_config = getattr(text_config, "quantization_config", None)
|
||||
if (not quant_config.modules_to_not_convert) and (
|
||||
llm_quant_config is not None
|
||||
@@ -429,10 +458,17 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
):
|
||||
return AutoModel.from_config(config.vision_config, trust_remote_code=True)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
|
||||
vit_hidden_size = config.vit_hidden_size
|
||||
vision_projection_hidden_size = config.projector_hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
def _init_mlp1(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
vit_hidden_size: int | None = None,
|
||||
vision_projection_hidden_size: int | None = None,
|
||||
) -> nn.Module:
|
||||
if vit_hidden_size is None:
|
||||
vit_hidden_size = config.vit_hidden_size
|
||||
if vision_projection_hidden_size is None:
|
||||
vision_projection_hidden_size = config.projector_hidden_size
|
||||
llm_hidden_size = config.get_text_config().hidden_size
|
||||
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(
|
||||
@@ -465,10 +501,18 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""Call vision model and return embeddings.
|
||||
|
||||
Override this method in subclasses to handle different vision model
|
||||
interfaces (e.g., SigLIP vs C-RADIO).
|
||||
"""
|
||||
vit_embeds = self.vision_model(x=pixel_values).features
|
||||
return vit_embeds.to(dtype=torch.bfloat16)
|
||||
|
||||
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
# https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/modeling.py#L177
|
||||
vit_embeds = self.vision_model(x=pixel_values).features
|
||||
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
||||
vit_embeds = self._call_vision_model(pixel_values)
|
||||
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
@@ -523,15 +567,16 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
|
||||
|
||||
num_patches = image_input["num_patches"]
|
||||
hidden_size = self.config.get_text_config().hidden_size
|
||||
|
||||
# Only one image in the current batch
|
||||
if len(num_patches) == 1:
|
||||
return (image_embeds.view(-1, self.config.text_config.hidden_size),)
|
||||
return (image_embeds.view(-1, hidden_size),)
|
||||
|
||||
# NOTE: Image embeddings are split into separate tensors for each image
|
||||
# by the size of each embedding.
|
||||
feature_size = image_embeds.shape[1]
|
||||
image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
|
||||
image_embeds = image_embeds.view(-1, hidden_size)
|
||||
image_feature_sizes = [
|
||||
num_patches * feature_size for num_patches in num_patches
|
||||
]
|
||||
@@ -643,3 +688,255 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
|
||||
connector="mlp1",
|
||||
tower_model="vision_model",
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
# LlamaNemotronVL Embedding Model (nvidia/llama-nemotron-embed-vl-1b-v2)
|
||||
# Extends LlamaNemotronVLChatModel for embedding/pooling tasks:
|
||||
# - SigLIP vision encoder (instead of C-RADIO)
|
||||
# - Bidirectional (non-causal) LLaMA language model
|
||||
# - Pooler output instead of generative logits
|
||||
# --------------------------------------------------------
|
||||
|
||||
# SigLIP normalization constants
|
||||
SIGLIP_MEAN = (0.5, 0.5, 0.5)
|
||||
SIGLIP_STD = (0.5, 0.5, 0.5)
|
||||
|
||||
|
||||
def build_siglip_transform(input_size: int):
|
||||
"""Build transform for SigLIP vision encoder with normalization.
|
||||
|
||||
Extends the base transform from nemotron_vl with SigLIP-specific normalization.
|
||||
"""
|
||||
base_transform = build_transform(input_size=input_size)
|
||||
return T.Compose(
|
||||
[
|
||||
base_transform,
|
||||
T.Normalize(mean=SIGLIP_MEAN, std=SIGLIP_STD),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class LlamaNemotronVLEmbedProcessor(NemotronVLProcessor):
|
||||
"""
|
||||
Processor for LlamaNemotronVL embedding model.
|
||||
|
||||
Inherits from NemotronVLProcessor and specializes it for embedding tasks:
|
||||
- Uses SigLIP transform with normalization instead of base transform
|
||||
- Uses different image context token (<IMG_CONTEXT> vs <image>)
|
||||
"""
|
||||
|
||||
IMG_CONTEXT = "<IMG_CONTEXT>"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
processor_config: dict,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
) -> None:
|
||||
if min_dynamic_patch is None:
|
||||
min_dynamic_patch = processor_config.get(
|
||||
"min_input_tiles",
|
||||
getattr(config, "min_dynamic_patch", 1),
|
||||
)
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = processor_config.get(
|
||||
"max_input_tiles",
|
||||
getattr(config, "max_dynamic_patch", 1),
|
||||
)
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = processor_config.get(
|
||||
"dynamic_image_size",
|
||||
getattr(config, "dynamic_image_size", True),
|
||||
)
|
||||
super().__init__(
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
image_processor=None,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
def _get_transform(self) -> T.Compose:
|
||||
"""Override to add SigLIP normalization."""
|
||||
return build_siglip_transform(input_size=self.image_size)
|
||||
|
||||
def _replace_image_tokens(
|
||||
self,
|
||||
text: list[str],
|
||||
pixel_values_lst: list[torch.Tensor],
|
||||
) -> list[str]:
|
||||
"""Override with simpler token replacement for embedding model.
|
||||
|
||||
No temporary placeholder needed because IMG_CONTEXT is <IMG_CONTEXT>,
|
||||
not <image>, so there's no collision risk.
|
||||
"""
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
text = [t.replace("<image>", image_repl.full, 1) for t in text]
|
||||
return text
|
||||
|
||||
|
||||
class LlamaNemotronVLEmbedProcessingInfo(NemotronVLProcessingInfo):
|
||||
"""Processing info for LlamaNemotronVL embedding model."""
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> LlamaNemotronVLEmbedProcessor:
|
||||
"""Override to create embedding-specific processor without image_processor."""
|
||||
model_config = self.ctx.model_config
|
||||
processor_config = {}
|
||||
if model_config.model is not None:
|
||||
processor_config = (
|
||||
get_hf_file_to_dict(
|
||||
"processor_config.json",
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
return self.ctx.init_processor(
|
||||
LlamaNemotronVLEmbedProcessor,
|
||||
config=self.get_hf_config(),
|
||||
tokenizer=self.get_tokenizer(),
|
||||
processor_config=processor_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
BaseInternVLMultiModalProcessor[LlamaNemotronVLEmbedProcessingInfo],
|
||||
info=LlamaNemotronVLEmbedProcessingInfo,
|
||||
dummy_inputs=BaseInternVLDummyInputsBuilder[LlamaNemotronVLEmbedProcessingInfo],
|
||||
)
|
||||
class LlamaNemotronVLForEmbedding(LlamaNemotronVLChatModel, VllmModelForPooling):
|
||||
"""
|
||||
LlamaNemotronVL model for embeddings.
|
||||
|
||||
Inherits from LlamaNemotronVLChatModel and specializes it for embedding tasks:
|
||||
- Uses SigLIP vision encoder instead of C-RADIO
|
||||
- Uses bidirectional LLaMA (via llm_config) instead of causal LLaMA
|
||||
- Adds pooler for embedding output instead of generating logits
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
# Weight mapping from checkpoint format to vLLM format
|
||||
# Different from parent class due to different vision model structure
|
||||
weight_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# Language model mapping
|
||||
"language_model.layers.": "language_model.model.layers.",
|
||||
"language_model.embed_tokens.": "language_model.model.embed_tokens.",
|
||||
"language_model.norm.": "language_model.model.norm.",
|
||||
# Vision model mapping (SiglipVisionModel has nested vision_model)
|
||||
"vision_model.encoder.": "vision_model.vision_model.encoder.",
|
||||
"vision_model.embeddings.": "vision_model.vision_model.embeddings.",
|
||||
"vision_model.post_layernorm.": "vision_model.vision_model.post_layernorm.",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# Override: get img_context_token_id from config (parent sets None)
|
||||
self.img_context_token_id = getattr(config, "img_context_token_id", None)
|
||||
|
||||
# Initialize pooler for embedding output
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler = DispatchPooler.for_embedding(pooler_config)
|
||||
|
||||
def _init_vision_model(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config,
|
||||
*,
|
||||
prefix: str,
|
||||
) -> nn.Module:
|
||||
"""Override to use SigLIP instead of C-RADIO."""
|
||||
return SiglipVisionModel(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_head=False,
|
||||
)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
|
||||
"""Override to use different MLP structure for embedding model."""
|
||||
return super()._init_mlp1(
|
||||
config,
|
||||
vit_hidden_size=config.vision_config.hidden_size,
|
||||
vision_projection_hidden_size=config.get_text_config().hidden_size,
|
||||
)
|
||||
|
||||
def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""Override to handle SigLIP interface."""
|
||||
return self.vision_model(pixel_values)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Override to use different weight mapping for SigLIP."""
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.weight_mapper)
|
||||
|
||||
|
||||
class LlamaNemotronVLForSequenceClassification(
|
||||
LlamaNemotronVLForEmbedding, SupportsCrossEncoding
|
||||
):
|
||||
"""LlamaNemotronVL model variant for sequence classification / reranking."""
|
||||
|
||||
# Reranker checkpoint places base model weights under `model.*`,
|
||||
# while `score.*` remains at the top level.
|
||||
weight_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) | (
|
||||
LlamaNemotronVLForEmbedding.weight_mapper
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||
model_config = vllm_config.model_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.score = ReplicatedLinear(
|
||||
model_config.get_hidden_size(),
|
||||
text_config.num_labels,
|
||||
bias=False,
|
||||
params_dtype=model_config.head_dtype,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=maybe_prefix(prefix, "score"),
|
||||
)
|
||||
|
||||
pooler_config = model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loaded_weights = super().load_weights(weights)
|
||||
|
||||
# reranker checkpoint omits the inner LM seq-cls head
|
||||
# (`language_model.score.*`). It is unused by this outer model, but
|
||||
# the default loader expects all parameters to be initialized.
|
||||
for name, param in self.named_parameters():
|
||||
if not name.startswith("language_model.score.") or name in loaded_weights:
|
||||
continue
|
||||
|
||||
if name.endswith(".weight"):
|
||||
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
|
||||
elif name.endswith(".bias"):
|
||||
torch.nn.init.zeros_(param)
|
||||
else:
|
||||
torch.nn.init.normal_(param, mean=0.0, std=0.02)
|
||||
|
||||
loaded_weights.add(name)
|
||||
|
||||
return loaded_weights
|
||||
|
||||
@@ -43,12 +43,9 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
|
||||
IMAGE_TOKEN = "<image>"
|
||||
IMAGE_PLACEHOLDER_ID = 151669
|
||||
VIDEO_TOKEN = "<video>"
|
||||
VIDEO_PLACEHOLDER_ID = 151670
|
||||
INDICATOR_IDS = [151672, 151673, 151674, 151675]
|
||||
IMAGE_PAD_TOKEN_ID = 151655
|
||||
THINK_END_TOKEN_ID = 151668
|
||||
|
||||
|
||||
class Ovis2_5ImagePatchInputs(TensorSchema):
|
||||
|
||||
@@ -155,15 +155,30 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
|
||||
patch_size = vision_config.patch_size
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
|
||||
if self.ctx.model_config.trust_remote_code:
|
||||
# Defined in HF Hub repo
|
||||
min_pixels_key = "min_pixels"
|
||||
max_pixels_key = "max_pixels"
|
||||
else:
|
||||
# Defined in Transformers library (requires v5.0 or above)
|
||||
min_pixels_key = "shortest_edge"
|
||||
max_pixels_key = "longest_edge"
|
||||
|
||||
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
|
||||
size = mm_kwargs.get("size", image_processor.size)
|
||||
size = image_processor.size
|
||||
if override_size := mm_kwargs.get("size"):
|
||||
size = size | override_size
|
||||
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
|
||||
size = size | {min_pixels_key: override_min_pixels}
|
||||
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
|
||||
size = size | {max_pixels_key: override_max_pixels}
|
||||
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=image_height,
|
||||
width=image_width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=size["min_pixels"],
|
||||
max_pixels=size["max_pixels"],
|
||||
min_pixels=size[min_pixels_key],
|
||||
max_pixels=size[max_pixels_key],
|
||||
)
|
||||
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
|
||||
|
||||
|
||||
145
vllm/model_executor/models/parakeet.py
Normal file
145
vllm/model_executor/models/parakeet.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Modules below used for the audio encoder component in: models/nano_nemotron_vl.py
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import ParakeetEncoder as HFParakeetEncoder
|
||||
from transformers import ParakeetFeatureExtractor, PretrainedConfig
|
||||
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig
|
||||
|
||||
|
||||
class ParakeetProjection(nn.Module):
|
||||
def __init__(self, config: ParakeetConfig) -> None:
|
||||
super().__init__()
|
||||
sound_hidden_size = config.hidden_size
|
||||
proj_hidden_size = config.projection_hidden_size
|
||||
llm_hidden_size = config.llm_hidden_size
|
||||
bias = config.projection_bias
|
||||
|
||||
self.norm = nn.LayerNorm(sound_hidden_size, eps=config.projection_eps)
|
||||
self.linear1 = nn.Linear(sound_hidden_size, proj_hidden_size, bias=bias)
|
||||
self.activation = ReLUSquaredActivation()
|
||||
self.linear2 = nn.Linear(proj_hidden_size, llm_hidden_size, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.linear1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.linear2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ProjectedParakeet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
llm_hidden_size: int,
|
||||
max_model_len: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = ParakeetConfig.from_hf_config(
|
||||
config, llm_hidden_size=llm_hidden_size, max_model_len=max_model_len
|
||||
)
|
||||
self.encoder = HFParakeetEncoder(self.config)
|
||||
self.encoder = self.encoder.to(dtype)
|
||||
self.projection = ParakeetProjection(self.config)
|
||||
self.projection = self.projection.to(dtype)
|
||||
|
||||
def forward(
|
||||
self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
outputs = self.encoder(
|
||||
input_features=input_features, attention_mask=attention_mask
|
||||
)
|
||||
outputs = outputs.last_hidden_state
|
||||
outputs = outputs.to(dtype=torch.bfloat16)
|
||||
outputs = self.projection(outputs)
|
||||
return outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loaded_params: set[str] = set()
|
||||
params_dict = dict(self.named_parameters())
|
||||
buffers_dict = dict(self.named_buffers())
|
||||
|
||||
if isinstance(weights, dict):
|
||||
weights_list = list(weights.items())
|
||||
else:
|
||||
weights_list = list(weights)
|
||||
|
||||
for name, weight in weights_list:
|
||||
if name.startswith("sound_encoder.encoder.feature_extractor."):
|
||||
# Feature extractor buffers are handled outside the encoder.
|
||||
continue
|
||||
if name.startswith("sound_encoder."):
|
||||
target_name = name[len("sound_encoder.") :]
|
||||
elif name.startswith("sound_projection."):
|
||||
target_name = f"projection.{name[len('sound_projection.') :]}"
|
||||
else:
|
||||
continue
|
||||
|
||||
target = params_dict.get(target_name)
|
||||
if target is None:
|
||||
target = buffers_dict.get(target_name)
|
||||
if target is None:
|
||||
raise ValueError(f"Unknown weight: {name}")
|
||||
weight_loader = getattr(target, "weight_loader", default_weight_loader)
|
||||
with torch.no_grad():
|
||||
weight_loader(target, weight)
|
||||
loaded_params.add(target_name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ParakeetExtractor(ParakeetFeatureExtractor):
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
self.config = ExtractorConfig.from_hf_config(config)
|
||||
super().__init__(**asdict(self.config))
|
||||
self._clip_target_samples = int(
|
||||
round(self.config.clip_duration_s * self.sampling_rate)
|
||||
)
|
||||
self._tail_min_samples = int(
|
||||
round(self.config.clip_min_duration_s * self.sampling_rate)
|
||||
)
|
||||
|
||||
def _normalize_audio_length(self, audio_len: int) -> int:
|
||||
# Match mcore's compute_params() logic for clip/minduration handling.
|
||||
target_len = max(audio_len, self._tail_min_samples)
|
||||
tail_remainder = target_len % self._clip_target_samples
|
||||
if 0 < tail_remainder < self._tail_min_samples:
|
||||
padding = self._tail_min_samples - tail_remainder
|
||||
target_len += padding
|
||||
assert isinstance(target_len, int)
|
||||
return target_len
|
||||
|
||||
def audio_token_count(self, audio_len: int) -> int:
|
||||
audio_len = self._normalize_audio_length(audio_len)
|
||||
num_frames = audio_len // self.hop_length
|
||||
n_tokens = HFParakeetEncoder._get_subsampling_output_length(
|
||||
self, torch.tensor([num_frames], dtype=torch.float)
|
||||
)
|
||||
return max(1, n_tokens.item())
|
||||
|
||||
def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs):
|
||||
padded = []
|
||||
for p in raw_speech:
|
||||
assert p.ndim == 1
|
||||
audio_len = int(p.shape[0])
|
||||
target_len = self._normalize_audio_length(audio_len)
|
||||
p = np.pad(p, (0, target_len - audio_len))
|
||||
padded.append(p)
|
||||
return super().__call__(padded, *args, **kwargs)
|
||||
|
||||
def audio_length(self, audio_tokens: int) -> int:
|
||||
return int(audio_tokens * self.config.subsampling_factor * self.hop_length)
|
||||
@@ -221,7 +221,8 @@ def sparsemixer(scores, jitter_eps=0.01):
|
||||
|
||||
multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
|
||||
selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
|
||||
|
||||
multiplier = multiplier.to(torch.float32)
|
||||
selected_experts = selected_experts.to(torch.int32)
|
||||
return (
|
||||
multiplier,
|
||||
selected_experts,
|
||||
|
||||
@@ -83,15 +83,16 @@ from .vision import (
|
||||
resolve_visual_encoder_outputs,
|
||||
)
|
||||
|
||||
import ixformer.inference.functions as ixf
|
||||
try:
|
||||
# Note: vLLM does not install xformers by default.
|
||||
from xformers import ops as xops
|
||||
|
||||
if current_platform.is_cuda() and current_platform.has_device_capability(100):
|
||||
if current_platform.is_cuda():
|
||||
# Xformers FA is not compatible with B200
|
||||
USE_XFORMERS_OPS = False
|
||||
else:
|
||||
USE_XFORMERS_OPS = True
|
||||
else:
|
||||
USE_XFORMERS_OPS = False
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
@@ -698,23 +699,21 @@ class Attention(nn.Module):
|
||||
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
||||
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||
|
||||
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
||||
|
||||
if USE_XFORMERS_OPS:
|
||||
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
v = v.reshape(batch * patches, self.n_heads, self.head_dim)
|
||||
|
||||
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
||||
q = q.view(batch * patches, self.n_heads, self.head_dim)
|
||||
k = k.view(batch * patches, self.n_heads, self.head_dim)
|
||||
out = ixf.ixinfer_flash_attn_unpad(q,k,v, mask.q_seqinfo.seqstart.to(q.device), mask.k_seqinfo.seqstart.to(q.device), mask.q_seqinfo.max_seqlen, mask.k_seqinfo.max_seqlen)
|
||||
# out = memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
else:
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
assert False, "xformers failed !"
|
||||
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
||||
return self.wo(out)
|
||||
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: VisionEncoderArgs):
|
||||
super().__init__()
|
||||
|
||||
@@ -292,7 +292,7 @@ class QWenBaseModel(nn.Module):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.wte(input_ids)
|
||||
|
||||
|
||||
@@ -122,8 +122,17 @@ def check_interleaved_audio_video(
|
||||
"""
|
||||
Check if video and audio positions are interleaved in the multimodal region.
|
||||
|
||||
Returns:
|
||||
True if video and audio tokens are interleaved, False otherwise.
|
||||
Returns True only for the use_audio_in_video=True case, where video and
|
||||
audio tokens alternate within a single contiguous region with no gaps.
|
||||
|
||||
A simple range-overlap check produces false positives when multiple
|
||||
non-interleaved requests are batched together: audio tokens from request N
|
||||
fall between video tokens from request N and request N+1, making the
|
||||
global ranges overlap even though each individual request is non-interleaved.
|
||||
|
||||
To distinguish true interleaving from this batching artefact we require
|
||||
that every position in the combined [first_VA, last_VA] range is occupied
|
||||
by either a video or an audio token (no text/image gaps).
|
||||
"""
|
||||
if num_video == 0 or num_audio == 0:
|
||||
return False
|
||||
@@ -131,10 +140,22 @@ def check_interleaved_audio_video(
|
||||
video_pos = is_video.nonzero(as_tuple=True)[0]
|
||||
audio_pos = is_audio.nonzero(as_tuple=True)[0]
|
||||
|
||||
return (
|
||||
# Quick range-overlap pre-check (necessary but not sufficient).
|
||||
if not (
|
||||
video_pos[0].item() < audio_pos[-1].item()
|
||||
and audio_pos[0].item() < video_pos[-1].item()
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
# Density check: for true use_audio_in_video interleaving every position
|
||||
# in the combined span is a video or audio token. Batched non-interleaved
|
||||
# requests have text/image tokens between the per-request V and A blocks.
|
||||
# combined_start/end encompass all V/A tokens, so num_video + num_audio
|
||||
# equals the number of V/A tokens in range; compare directly to span size.
|
||||
combined_start = min(video_pos[0].item(), audio_pos[0].item())
|
||||
combined_end = max(video_pos[-1].item(), audio_pos[-1].item())
|
||||
total_in_range = combined_end - combined_start + 1
|
||||
return (num_video + num_audio) == total_in_range
|
||||
|
||||
|
||||
def merge_interleaved_embeddings(
|
||||
@@ -332,6 +353,39 @@ class Qwen2_5OmniThinkerProcessingInfo(
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None, "image": None, "video": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int] | None:
|
||||
mm_counts = mm_counts or {}
|
||||
requested_modalities = {m for m, c in mm_counts.items() if c > 0}
|
||||
mm_max_tokens: dict[str, int] = {}
|
||||
|
||||
if requested_modalities & {"image", "video"}:
|
||||
vl_tokens = Qwen2_5_VLProcessingInfo.get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
mm_max_tokens.update(
|
||||
{
|
||||
m: vl_tokens[m]
|
||||
for m in ["image", "video"]
|
||||
if m in requested_modalities
|
||||
}
|
||||
)
|
||||
|
||||
if "audio" in requested_modalities:
|
||||
audio_tokens = Qwen2AudioProcessingInfo.get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
mm_max_tokens["audio"] = audio_tokens["audio"]
|
||||
|
||||
return mm_max_tokens
|
||||
|
||||
|
||||
class Qwen2_5OmniThinkerDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]
|
||||
@@ -1376,23 +1430,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
# Check for audio-in-video: interleaved video and audio tokens
|
||||
# in the multimodal region.
|
||||
# in the multimodal region. Only use the interleaved path when
|
||||
# needed; otherwise fall back to the default parent implementation.
|
||||
video_token_id = self.config.video_token_index
|
||||
audio_token_id = self.config.audio_token_index
|
||||
|
||||
@@ -1403,6 +1446,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
num_audio = is_audio.sum().item()
|
||||
|
||||
if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio):
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
return merge_interleaved_embeddings(
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
@@ -1413,9 +1462,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
num_audio,
|
||||
)
|
||||
|
||||
# Default: standard merge (no interleaving)
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds, multimodal_embeddings, is_multimodal
|
||||
# Default: standard merge (no interleaving), same as parent class
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
|
||||
- second_per_grid_ts: The video time interval (in seconds) for each
|
||||
grid along the temporal dimension in the 3D position IDs. Returned
|
||||
when `videos` is not `None`.
|
||||
- timestamps: List of timestamp values (in seconds) for each frame
|
||||
after merging. Length equals the temporal dimension after merging.
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values_videos"]
|
||||
@@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
|
||||
TensorShape("nv"),
|
||||
]
|
||||
|
||||
timestamps: list[list[float]] | None = None
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
@@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
- second_per_grid_ts: The video time interval (in seconds) for each
|
||||
grid along the temporal dimension in the 3D position IDs. Returned
|
||||
when `videos` is not `None`.
|
||||
- timestamps: List of timestamp values (in seconds) for each frame
|
||||
after merging. Length equals the temporal dimension after merging.
|
||||
"""
|
||||
|
||||
type: Literal["video_embeds"]
|
||||
@@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
torch.Tensor | None,
|
||||
TensorShape("nv"),
|
||||
] = None
|
||||
timestamps: list[list[float]] | None = None
|
||||
|
||||
|
||||
Qwen2_5_VLVideoInputs: TypeAlias = (
|
||||
@@ -289,10 +296,11 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.act_fn = act_fn
|
||||
self.hidden_features = hidden_features
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x = self.act_fn(gate_up, self.hidden_features)
|
||||
x_down, _ = self.down_proj(x)
|
||||
return x_down
|
||||
|
||||
@@ -357,6 +365,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
@@ -398,6 +407,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
)
|
||||
|
||||
context_layer = einops.rearrange(
|
||||
@@ -463,6 +473,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=None,
|
||||
)
|
||||
x_fused_norm, residual = self.norm2(x, residual=x_attn)
|
||||
x = residual + self.mlp(x_fused_norm)
|
||||
|
||||
@@ -179,6 +179,26 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
mm_counts = mm_counts or {}
|
||||
if mm_counts.get("audio", 0) <= 0:
|
||||
return {}
|
||||
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
chunk_length = min(feature_extractor.chunk_length, 30)
|
||||
audio_len = int(chunk_length * feature_extractor.sampling_rate)
|
||||
hop_length = feature_extractor.hop_length
|
||||
max_mel_seq_len = audio_len // hop_length
|
||||
|
||||
input_lengths = torch.tensor([max_mel_seq_len], dtype=torch.long)
|
||||
_, output_lengths = _get_feat_extract_output_lengths(input_lengths)
|
||||
|
||||
return {"audio": int(output_lengths.item())}
|
||||
|
||||
|
||||
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
|
||||
@@ -755,6 +755,7 @@ def _create_qwen2vl_field_factory(
|
||||
"video", video_embed_grid_sizes
|
||||
),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
|
||||
timestamps=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
|
||||
)
|
||||
|
||||
return _qwen2vl_field_config
|
||||
@@ -843,7 +844,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
|
||||
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
|
||||
size = mm_kwargs.get("size", image_processor.size)
|
||||
size = image_processor.size
|
||||
if override_size := mm_kwargs.get("size"):
|
||||
size = size | override_size
|
||||
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
|
||||
size = size | {"shortest_edge": override_min_pixels}
|
||||
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
|
||||
size = size | {"longest_edge": override_max_pixels}
|
||||
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
@@ -930,7 +937,14 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
mm_kwargs = self.ctx.get_merged_mm_kwargs({})
|
||||
size = mm_kwargs.get("size", image_processor.size)
|
||||
size = image_processor.size
|
||||
if override_size := mm_kwargs.get("size"):
|
||||
size = size | override_size
|
||||
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
|
||||
size = size | {"shortest_edge": override_min_pixels}
|
||||
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
|
||||
size = size | {"longest_edge": override_max_pixels}
|
||||
|
||||
max_pixels = size["longest_edge"]
|
||||
|
||||
unit = patch_size * merge_size
|
||||
|
||||
@@ -51,7 +51,7 @@ from vllm.v1.attention.backend import AttentionType
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix, reparse_quant_config
|
||||
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -142,7 +142,6 @@ class Qwen3Attention(nn.Module):
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.qk_norm = RMSNormQK(self.head_dim, self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -150,17 +149,19 @@ class Qwen3Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
# # Add qk-norm
|
||||
# Add qk-norm
|
||||
# q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
# q_by_head = self.q_norm(q_by_head)
|
||||
# q_by_head = self.q_norm.forward_native(q_by_head) # TODO(gyf) check why
|
||||
# q = q_by_head.view(q.shape)
|
||||
# k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
# k_by_head = self.k_norm(k_by_head)
|
||||
# k_by_head = self.k_norm.forward_native(k_by_head)
|
||||
# k = k_by_head.view(k.shape)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1]//self.head_dim, self.head_dim)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = k.view(*k.shape[:-1],
|
||||
k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
|
||||
out_q, out_k = self.qk_norm(
|
||||
q_by_head,
|
||||
k_by_head,
|
||||
@@ -170,7 +171,6 @@ class Qwen3Attention(nn.Module):
|
||||
|
||||
q = out_q.view(q.shape)
|
||||
k = out_k.view(k.shape)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
@@ -201,8 +201,6 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
quant_config = reparse_quant_config(prefix, quant_config)
|
||||
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -236,23 +234,25 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
ALL_DECODER_LAYER_TYPES = {
|
||||
"attention": Qwen3DecoderLayer,
|
||||
}
|
||||
|
||||
@@ -274,7 +274,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.dtype,
|
||||
),
|
||||
)
|
||||
self.ffn_layer_scale = torch.nn.Parameter(
|
||||
@@ -282,7 +281,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -444,11 +442,7 @@ class Qwen3_5Model(Qwen3NextModel):
|
||||
# qwen3.5 no need to transpose
|
||||
# loaded_weight = loaded_weight.transpose(-1, -2)
|
||||
if "experts.gate_up_proj" in name:
|
||||
if loaded_weight.shape[-2] != 1:
|
||||
chunk_dim = -2
|
||||
else:
|
||||
chunk_dim = -1
|
||||
loaded_weight = loaded_weight.chunk(2, dim=chunk_dim)
|
||||
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||
success_w1 = self.load_fused_expert_weights(
|
||||
name_mapped,
|
||||
params_dict,
|
||||
@@ -544,6 +538,7 @@ class Qwen3_5ForCausalLMBase(
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
raise NotImplementedError(
|
||||
@@ -633,6 +628,9 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder,
|
||||
)
|
||||
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
|
||||
# Qwen3.5 does not support multimodal pruning (EVS).
|
||||
supports_multimodal_pruning = False
|
||||
|
||||
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
|
||||
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
|
||||
"in_proj_ba": ["in_proj_b", "in_proj_a"],
|
||||
@@ -648,10 +646,8 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
multimodal_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
# Qwen3.5 does not support multimodal pruning (EVS).
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
@@ -698,6 +694,12 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def recompute_mrope_positions(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Qwen3.5 does not support multimodal pruning (EVS). "
|
||||
"recompute_mrope_positions should never be called."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -856,10 +858,8 @@ class Qwen3_5MoeForConditionalGeneration(
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
multimodal_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
# Qwen3.5 does not support multimodal pruning (EVS).
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
|
||||
@@ -339,7 +339,7 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["up_proj", "down_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
@@ -77,7 +77,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
|
||||
class Qwen3MoeMLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -170,7 +170,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
@@ -338,13 +338,14 @@ class Qwen3MoeAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
out_q, out_k = ixf_ops.rms_norm_qk(q_by_head, k_by_head, self.q_norm.weight.data, self.k_norm.weight.data, self.q_norm.variance_epsilon)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q = out_q.view(q.shape)
|
||||
k = out_k.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
@@ -379,6 +380,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsW8A8Int8
|
||||
if hasattr(self.self_attn.qkv_proj, "scheme") and isinstance(self.self_attn.qkv_proj.scheme, CompressedTensorsW8A8Int8):
|
||||
self.fused_norm_quant = True
|
||||
else:
|
||||
self.fused_norm_quant = False
|
||||
|
||||
# `mlp_only_layers` in the config.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
mlp_only_layers = (
|
||||
@@ -409,12 +416,23 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
if self.fused_norm_quant:
|
||||
origin_input = hidden_states
|
||||
hidden_states_i8, residual, scale = ixf_ops.residual_rms_norm_dynamic_int8(
|
||||
hidden_states, self.input_layernorm.weight.data, residual,
|
||||
eps=self.input_layernorm.variance_epsilon,
|
||||
)
|
||||
hidden_states = (hidden_states_i8, scale, hidden_states.dtype)
|
||||
if residual is None:
|
||||
residual = origin_input
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@@ -10,6 +10,7 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from vllm import envs
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
@@ -34,7 +35,8 @@ from vllm.model_executor.layers.fla.ops import (
|
||||
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
@@ -114,7 +116,7 @@ def fi_chunk_gated_delta_rule(
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
from flashinfer.gdn_prefill import (
|
||||
@@ -153,21 +155,13 @@ def fi_chunk_gated_delta_rule(
|
||||
class ChunkGatedDeltaRule(CustomOp):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
# logger.info_once(
|
||||
# "Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
|
||||
# )
|
||||
# self._forward_method = self.forward_cuda
|
||||
# else:
|
||||
# logger.info_once(
|
||||
# "Using FlashAttn GDN prefill kernel on CUDA compute capability 90"
|
||||
# )
|
||||
# self._forward_method = self.forward_native
|
||||
|
||||
logger.info_once(
|
||||
"Using FlashAttn GDN prefill kernel on CUDA compute capability 90"
|
||||
)
|
||||
self._forward_method = self.forward_native
|
||||
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
logger.info_once(
|
||||
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90"
|
||||
)
|
||||
self._forward_method = self.forward_cuda
|
||||
else:
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -178,10 +172,10 @@ class ChunkGatedDeltaRule(CustomOp):
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fi_chunk_gated_delta_rule(
|
||||
return self.forward_native(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -202,7 +196,7 @@ class ChunkGatedDeltaRule(CustomOp):
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fla_chunk_gated_delta_rule(
|
||||
@@ -420,6 +414,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
prefix=f"{prefix}.in_proj_qkvz",
|
||||
)
|
||||
# ba_proj doesn't support blockwise fp8 quantization.
|
||||
# # in_proj_ba is defined as MergedColumnParallelLinear for
|
||||
# compatibility with Qwen3_5.
|
||||
self.in_proj_ba = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[self.num_v_heads] * 2,
|
||||
@@ -469,7 +465,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
device=current_platform.current_device(),
|
||||
dtype=config.dtype,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
@@ -482,6 +477,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
)
|
||||
|
||||
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
|
||||
self.enable_packed_recurrent_decode = (
|
||||
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
@@ -631,6 +629,106 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
|
||||
output[:num_tokens], _ = self.out_proj(core_attn_out)
|
||||
|
||||
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
|
||||
"""Warm up GDN prefill kernels during V1 profiling.
|
||||
|
||||
During V1 profile runs, ``_forward_core`` returns early because
|
||||
``attn_metadata`` is ``None``, so the autotuned kernels used by
|
||||
``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
|
||||
``chunk_scaled_dot_kkt``) are never invoked. After profiling,
|
||||
vLLM allocates KV cache using most of the remaining GPU memory.
|
||||
When the first real inference triggers the autotuner it OOMs
|
||||
because there is not enough memory left for benchmarking.
|
||||
|
||||
This method runs minimal forward passes through
|
||||
``chunk_gated_delta_rule`` with small dummy tensors to force
|
||||
autotuning while GPU memory is still plentiful. The autotuner
|
||||
results are cached globally, so only the first layer incurs
|
||||
actual benchmarking cost.
|
||||
|
||||
Most kernels use a fixed ``BT = chunk_size`` (64), but
|
||||
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
|
||||
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT``
|
||||
is part of its autotune key, we run warmup passes with T = 16,
|
||||
32, and 64 to cover all possible ``BT`` values.
|
||||
|
||||
The decode path uses ``fused_sigmoid_gating_delta_rule_update``
|
||||
which has fixed kernel parameters (no autotuning), so only the
|
||||
prefill (chunked) path needs warming up.
|
||||
"""
|
||||
if hasattr(self, "_prefill_kernels_warmed_up"):
|
||||
return
|
||||
self._prefill_kernels_warmed_up = True
|
||||
|
||||
device = mixed_qkv.device
|
||||
dtype = mixed_qkv.dtype
|
||||
num_k_heads = self.num_k_heads // self.tp_size
|
||||
num_v_heads = self.num_v_heads // self.tp_size
|
||||
_, state_dtype = self.get_state_dtype()
|
||||
|
||||
# Run warmup for each possible BT value of chunk_fwd_kernel_o:
|
||||
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
|
||||
# Other kernels always use BT=chunk_size(64), so their autotune
|
||||
# cache is populated on the first pass and reused thereafter.
|
||||
for T in (16, 32, 64):
|
||||
q = torch.randn(
|
||||
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
|
||||
)
|
||||
# NOTE: g and beta must have the same dtypes as during
|
||||
# inference, so we construct them with the same function
|
||||
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
|
||||
# inputs required by that function.
|
||||
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
|
||||
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
|
||||
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
|
||||
state = torch.zeros(
|
||||
1,
|
||||
num_v_heads,
|
||||
self.head_v_dim,
|
||||
self.head_k_dim,
|
||||
device=device,
|
||||
dtype=state_dtype,
|
||||
)
|
||||
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
|
||||
|
||||
try:
|
||||
self.chunk_gated_delta_rule(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"GDN prefill kernel warmup (T=%d) failed for "
|
||||
"layer %s. First inference may OOM due to "
|
||||
"autotuner.",
|
||||
T,
|
||||
self.prefix,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"GDN prefill kernel warmup (T=%d) completed for layer %s",
|
||||
T,
|
||||
self.prefix,
|
||||
)
|
||||
finally:
|
||||
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _forward_core(
|
||||
self,
|
||||
mixed_qkv: torch.Tensor,
|
||||
@@ -638,19 +736,34 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Core attention computation (called by custom op).
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
# V1 profile run — warm up prefill kernels so that
|
||||
# autotuning completes before KV cache allocation.
|
||||
self._warmup_prefill_kernels(mixed_qkv)
|
||||
return
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
|
||||
if (
|
||||
self.enable_packed_recurrent_decode
|
||||
and attn_metadata.spec_sequence_masks is None
|
||||
and attn_metadata.num_prefills == 0
|
||||
and attn_metadata.num_decodes > 0
|
||||
):
|
||||
return self._forward_core_decode_non_spec(
|
||||
mixed_qkv=mixed_qkv,
|
||||
b=b,
|
||||
a=a,
|
||||
core_attn_out=core_attn_out,
|
||||
attn_metadata=attn_metadata,
|
||||
virtual_engine=forward_context.virtual_engine,
|
||||
)
|
||||
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
@@ -738,41 +851,40 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
mixed_qkv_non_spec
|
||||
)
|
||||
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
g_spec = g
|
||||
beta_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = g.index_select(1, spec_token_indx)
|
||||
beta_spec = beta.index_select(1, spec_token_indx)
|
||||
if attn_metadata.num_prefills > 0:
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
if spec_sequence_masks is not None:
|
||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||
else:
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
|
||||
# 2. Recurrent attention
|
||||
|
||||
# 2.1: Process the multi-query part
|
||||
if spec_sequence_masks is not None:
|
||||
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
|
||||
q=query_spec,
|
||||
k=key_spec,
|
||||
v=value_spec,
|
||||
g=g_spec,
|
||||
beta=beta_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
core_attn_out_spec, last_recurrent_state = (
|
||||
fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=self.A_log,
|
||||
a=a,
|
||||
b=b,
|
||||
dt_bias=self.dt_bias,
|
||||
q=query_spec,
|
||||
k=key_spec,
|
||||
v=value_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[
|
||||
: attn_metadata.num_spec_decodes + 1
|
||||
],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
core_attn_out_spec, last_recurrent_state = None, None
|
||||
@@ -801,12 +913,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
core_attn_out_non_spec, last_recurrent_state = (
|
||||
fused_recurrent_gated_delta_rule(
|
||||
fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=self.A_log,
|
||||
a=a,
|
||||
b=b,
|
||||
dt_bias=self.dt_bias,
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc[
|
||||
@@ -834,6 +948,55 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
else:
|
||||
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
|
||||
|
||||
def _forward_core_decode_non_spec(
|
||||
self,
|
||||
mixed_qkv: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
attn_metadata: GDNAttentionMetadata,
|
||||
virtual_engine: int,
|
||||
):
|
||||
"""
|
||||
Core attention computation with a packed non-spec decode fast path.
|
||||
"""
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
mixed_qkv = mixed_qkv[:num_actual_tokens]
|
||||
b = b[:num_actual_tokens]
|
||||
a = a[:num_actual_tokens]
|
||||
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
validate_data=False,
|
||||
)
|
||||
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
|
||||
fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv=mixed_qkv_non_spec,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=self.A_log,
|
||||
dt_bias=self.dt_bias,
|
||||
scale=self.head_k_dim**-0.5,
|
||||
initial_state=ssm_state,
|
||||
out=out_buf,
|
||||
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
def __init__(
|
||||
@@ -961,7 +1124,7 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@@ -1024,7 +1187,6 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.dtype,
|
||||
),
|
||||
)
|
||||
self.ffn_layer_scale = torch.nn.Parameter(
|
||||
@@ -1032,7 +1194,6 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1299,6 +1460,8 @@ class QwenNextMixtureOfExperts(MixtureOfExperts):
|
||||
self.moe_layers = []
|
||||
example_moe = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
|
||||
layer.mlp, Qwen3NextSparseMoeBlock
|
||||
):
|
||||
@@ -1334,6 +1497,8 @@ class Qwen3NextForCausalLM(
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"in_proj_qkvz": ["in_proj_qkvz"],
|
||||
"in_proj_ba": ["in_proj_ba"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
@@ -648,6 +648,7 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor | None, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor | None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x),
|
||||
@@ -655,6 +656,7 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
@@ -975,6 +977,20 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
|
||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
|
||||
# Recompute cu_seqlens in numpy from grid_thw to avoid GPU->CPU sync
|
||||
grid_thw_np = grid_thw.cpu().numpy()
|
||||
cu_seqlens_np = np.repeat(
|
||||
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
|
||||
).cumsum(axis=0, dtype=np.int32)
|
||||
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||
self.attn_backend, cu_seqlens_np
|
||||
)
|
||||
if sequence_lengths is not None:
|
||||
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
hidden_states_list = []
|
||||
deepstack_visual_indexes = self.deepstack_visual_indexes
|
||||
|
||||
@@ -985,6 +1001,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
)
|
||||
if (
|
||||
deepstack_visual_indexes is not None
|
||||
@@ -1146,6 +1163,39 @@ class Qwen3OmniMoeThinkerProcessingInfo(
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None, "image": None, "video": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int] | None:
|
||||
mm_counts = mm_counts or {}
|
||||
requested_modalities = {m for m, c in mm_counts.items() if c > 0}
|
||||
mm_max_tokens: dict[str, int] = {}
|
||||
|
||||
if requested_modalities & {"image", "video"}:
|
||||
vl_tokens = Qwen2_5_VLProcessingInfo.get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
mm_max_tokens.update(
|
||||
{
|
||||
m: vl_tokens[m]
|
||||
for m in ["image", "video"]
|
||||
if m in requested_modalities
|
||||
}
|
||||
)
|
||||
|
||||
if "audio" in requested_modalities:
|
||||
audio_tokens = Qwen2AudioProcessingInfo.get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
mm_max_tokens["audio"] = audio_tokens["audio"]
|
||||
|
||||
return mm_max_tokens
|
||||
|
||||
|
||||
Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder
|
||||
|
||||
@@ -1904,15 +1954,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
num_audio,
|
||||
)
|
||||
|
||||
# Default: standard merge (no interleaving)
|
||||
inputs_embeds = _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
# Default: standard merge (no interleaving), same as parent class.
|
||||
# multimodal_embeddings may have been updated above (deepstack
|
||||
# main-scale). Use super() to stay consistent with the parent
|
||||
# implementation and avoid issues seen in Qwen2.5-Omni (#34506).
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,6 +24,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights."""
|
||||
|
||||
from platform import architecture
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
@@ -45,6 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers.registry import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import MixtureOfExperts
|
||||
from .qwen3_moe import (
|
||||
@@ -415,6 +417,7 @@ class Qwen3VLMoeForConditionalGeneration(
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
@@ -451,14 +454,14 @@ class Qwen3VLMoeForConditionalGeneration(
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config, architectures=["Qwen3MoeForCausalLM"]),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
if not get_pp_group().is_first_rank and hasattr(
|
||||
config.vision_config, "deepstack_visual_indexes"
|
||||
):
|
||||
assert self.language_model.start_layer >= len(
|
||||
assert self.language_model.model.start_layer >= len(
|
||||
config.vision_config.deepstack_visual_indexes
|
||||
), (
|
||||
"start_layer should be greater than or equal to "
|
||||
|
||||
@@ -6,11 +6,9 @@
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
"""Inference-only Qwen-VL model compatible with HuggingFace weights."""
|
||||
|
||||
import copy
|
||||
import math
|
||||
import unicodedata
|
||||
from collections.abc import Callable, Collection, Mapping, Sequence, Set
|
||||
from functools import lru_cache, partial
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
import regex as re
|
||||
@@ -436,60 +434,6 @@ class QwenVLModel(QWenModel):
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_tokenizer_without_image_pad(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> PreTrainedTokenizer:
|
||||
"""
|
||||
The logic of adding image pad tokens should only be applied in
|
||||
[`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor],
|
||||
so they are patched out here.
|
||||
|
||||
The definition of the wrapped tokenizer can be found here:
|
||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
|
||||
"""
|
||||
new_tokenizer = copy.deepcopy(tokenizer)
|
||||
|
||||
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
|
||||
def tokenize(
|
||||
self,
|
||||
text: str,
|
||||
allowed_special: Set[str] | str = "all",
|
||||
disallowed_special: Collection[str] | str = (),
|
||||
**kwargs,
|
||||
) -> list[bytes | str]:
|
||||
text = unicodedata.normalize("NFC", text)
|
||||
|
||||
return [
|
||||
self.decoder[t]
|
||||
for t in self.tokenizer.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
]
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: int | list[int],
|
||||
skip_special_tokens: bool = False,
|
||||
errors: str | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
|
||||
return self.tokenizer.decode(
|
||||
token_ids,
|
||||
errors=errors or self.errors,
|
||||
)
|
||||
|
||||
TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad"
|
||||
|
||||
new_tokenizer.__class__ = TokenizerWithoutImagePad
|
||||
return new_tokenizer
|
||||
|
||||
|
||||
class QwenVLProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
@@ -574,12 +518,6 @@ class QwenVLProcessor:
|
||||
|
||||
|
||||
class QwenVLProcessingInfo(BaseProcessingInfo):
|
||||
def get_tokenizer(self) -> PreTrainedTokenizer:
|
||||
tokenizer = self.ctx.get_tokenizer()
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
return _get_tokenizer_without_image_pad(tokenizer)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
|
||||
return self.ctx.init_processor(
|
||||
QwenVLProcessor,
|
||||
@@ -730,6 +668,8 @@ class QwenVLForConditionalGeneration(
|
||||
"w1",
|
||||
],
|
||||
}
|
||||
|
||||
embed_input_ids = SupportsMultiModal.embed_input_ids
|
||||
|
||||
embed_input_ids = SupportsMultiModal.embed_input_ids
|
||||
|
||||
|
||||
@@ -75,12 +75,14 @@ _TEXT_GENERATION_MODELS = {
|
||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||
"ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
|
||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||
"AXK1ForCausalLM": ("AXK1", "AXK1ForCausalLM"),
|
||||
# baichuan-7b, upper case 'C' in the class name
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
|
||||
# baichuan-13b, lower case 'c' in the class name
|
||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
|
||||
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
|
||||
"BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
|
||||
"BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
|
||||
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
|
||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
@@ -259,6 +261,10 @@ _EMBEDDING_MODELS = {
|
||||
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
|
||||
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
|
||||
"SiglipModel": ("siglip", "SiglipEmbeddingModel"),
|
||||
"LlamaNemotronVLModel": (
|
||||
"nemotron_vl",
|
||||
"LlamaNemotronVLForEmbedding",
|
||||
),
|
||||
# Technically Terratorch models work on images, both in
|
||||
# input and output. I am adding it here because it piggy-backs on embedding
|
||||
# models for the time being.
|
||||
@@ -278,6 +284,10 @@ _CROSS_ENCODER_MODELS = {
|
||||
"llama",
|
||||
"LlamaBidirectionalForSequenceClassification",
|
||||
),
|
||||
"LlamaNemotronVLForSequenceClassification": (
|
||||
"nemotron_vl",
|
||||
"LlamaNemotronVLForSequenceClassification",
|
||||
),
|
||||
"ModernBertForSequenceClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForSequenceClassification",
|
||||
@@ -331,6 +341,10 @@ _MULTIMODAL_MODELS = {
|
||||
"ernie45_vl",
|
||||
"Ernie4_5_VLMoeForConditionalGeneration",
|
||||
),
|
||||
"FireRedASR2ForConditionalGeneration": (
|
||||
"fireredasr2",
|
||||
"FireRedASR2ForConditionalGeneration",
|
||||
),
|
||||
"FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"), # noqa: E501
|
||||
"FunAudioChatForConditionalGeneration": (
|
||||
"funaudiochat",
|
||||
@@ -506,6 +520,7 @@ _MULTIMODAL_MODELS = {
|
||||
}
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
"ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
|
||||
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
|
||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||
|
||||
@@ -346,7 +346,7 @@ class Step3TextModel(nn.Module):
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states"], config.hidden_size
|
||||
["hidden_states","residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -69,7 +69,7 @@ class FP32ReplicatedLinear(ReplicatedLinear):
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
assert self.params_dtype == torch.float32
|
||||
return super().forward(x.to(torch.float32))
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class Step3p5MLP(nn.Module):
|
||||
@@ -148,6 +148,7 @@ class Step3p5Attention(nn.Module):
|
||||
yarn_only_types: list = None,
|
||||
swa_num_attention_heads: int | None = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
total_layer_num: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -250,7 +251,12 @@ class Step3p5Attention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
extra_cache_para={
|
||||
"total_num_kv_heads":self.total_num_kv_heads,
|
||||
"total_layer_num":total_layer_num,
|
||||
"layer_id":self.layer_idx,
|
||||
},
|
||||
)
|
||||
|
||||
self.max_position_embeddings = max_position
|
||||
assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5
|
||||
@@ -390,6 +396,7 @@ class FusedMoEBlock(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
router_logits_dtype=torch.float32,
|
||||
fused_shared_output=True,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -413,8 +420,8 @@ class FusedMoEBlock(nn.Module):
|
||||
assert shared_output is None
|
||||
|
||||
if self.share_expert is not None:
|
||||
assert shared_output is not None
|
||||
final_hidden_states += shared_output
|
||||
if shared_output is not None:
|
||||
final_hidden_states += shared_output
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
@@ -484,6 +491,7 @@ class Step3p5DecoderLayer(nn.Module):
|
||||
if partial_rotary_factors
|
||||
else 1.0,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
total_layer_num=vllm_config.model_config.hf_config.num_hidden_layers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -535,25 +543,28 @@ class Step3p5DecoderLayer(nn.Module):
|
||||
return self.tp_group.all_reduce(in1 + in2)
|
||||
|
||||
def forward(
|
||||
self, positions: torch.Tensor, hidden_states: torch.Tensor
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
if residual is None:
|
||||
residual = hidden_states.clone()
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
if self.use_moe:
|
||||
ffn_output = self.moe(hidden_states)
|
||||
else:
|
||||
ffn_output = self.mlp(hidden_states)
|
||||
hidden_states = ffn_output + residual
|
||||
return hidden_states
|
||||
return ffn_output, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@@ -592,7 +603,7 @@ class Step3p5Model(nn.Module):
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states"], config.hidden_size
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@@ -610,20 +621,26 @@ class Step3p5Model(nn.Module):
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
}
|
||||
)
|
||||
|
||||
hidden_states += residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
@@ -891,7 +908,6 @@ class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.model.norm(hidden_states)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
@@ -934,7 +950,26 @@ class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
all_params = set(dict(self.named_parameters()).keys())
|
||||
missing = all_params - loaded
|
||||
kv_cache_params = {p for p in missing
|
||||
if p.endswith((".k_scale", ".k_zero_point",
|
||||
".q_scale", ".q_zero_point",
|
||||
".v_scale", ".v_zero_point"))}
|
||||
real_missing = missing - kv_cache_params
|
||||
if real_missing:
|
||||
logger.warning(
|
||||
"[DIAG-WEIGHTS] Missing %d params: %s",
|
||||
len(real_missing), sorted(real_missing),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"[DIAG-WEIGHTS] All params loaded. (%d kv-cache params skipped)",
|
||||
len(kv_cache_params),
|
||||
)
|
||||
return loaded
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(
|
||||
|
||||
@@ -6,9 +6,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name
|
||||
from .utils import maybe_prefix
|
||||
|
||||
@@ -40,9 +43,11 @@ class SharedHead(nn.Module):
|
||||
return self.norm(hidden_states)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Step3p5AMultiTokenPredictorLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
@@ -51,8 +56,15 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
|
||||
quant_config = vllm_config.quant_config
|
||||
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.eh_proj = ReplicatedLinear(
|
||||
config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.eh_proj",
|
||||
)
|
||||
self.lm_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = Step3p5DecoderLayer(
|
||||
vllm_config,
|
||||
prefix=f"{prefix}.mtp_block",
|
||||
@@ -64,9 +76,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
embed_tokens: VocabParallelEmbedding | None = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
if inputs_embeds is None:
|
||||
assert embed_tokens is not None
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
@@ -74,7 +89,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
|
||||
)
|
||||
|
||||
hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states)
|
||||
hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, residual=None)
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -92,8 +108,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
|
||||
self.layers = torch.nn.ModuleDict(
|
||||
{
|
||||
str(idx): Step3p5AMultiTokenPredictorLayer(
|
||||
vllm_config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
)
|
||||
for idx in range(
|
||||
self.mtp_start_layer_idx,
|
||||
@@ -112,14 +128,13 @@ class Step3p5AMultiTokenPredictor(nn.Module):
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = spec_step_idx % self.num_mtp_layers
|
||||
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
||||
input_ids,
|
||||
positions,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
self.embed_tokens,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
@@ -131,7 +146,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
|
||||
current_step_idx = spec_step_idx % self.num_mtp_layers
|
||||
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
|
||||
logits = self.logits_processor(
|
||||
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
|
||||
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
|
||||
)
|
||||
return logits
|
||||
|
||||
@@ -139,7 +154,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
|
||||
class Step3p5MTP(nn.Module):
|
||||
class Step3p5MTP(nn.Module, SupportsPP):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
@@ -257,12 +272,15 @@ class Step3p5MTP(nn.Module):
|
||||
name = name.replace(".transformer.", ".")
|
||||
if "shared_head" in name:
|
||||
name = name.replace("shared_head.output", "shared_head.head")
|
||||
name = name.replace("shared_head", "lm_head")
|
||||
if "embed_tokens" in name:
|
||||
assert (
|
||||
hasattr(self.config, "num_nextn_predict_layers")
|
||||
and self.config.num_nextn_predict_layers > 0
|
||||
)
|
||||
name = "model.embed_tokens.weight"
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
@@ -280,6 +298,31 @@ class Step3p5MTP(nn.Module):
|
||||
and getattr(param, "requires_grad", False) is False
|
||||
}
|
||||
params_need_to_load -= optional_params
|
||||
# Some attention quantization metadata may be absent in checkpoints
|
||||
# depending on export path / runtime backend.
|
||||
missing_params = params_need_to_load - loaded_params
|
||||
skip_missing_suffixes = (
|
||||
".attn.k_scale",
|
||||
".attn.q_scale",
|
||||
".attn.v_scale",
|
||||
".attn.q_zero_point",
|
||||
".attn.k_zero_point",
|
||||
".attn.v_zero_point",
|
||||
)
|
||||
skipped_missing_params = sorted(
|
||||
name for name in missing_params if name.endswith(skip_missing_suffixes)
|
||||
)
|
||||
if skipped_missing_params:
|
||||
preview = tuple(skipped_missing_params[:10])
|
||||
logger.warning_once(
|
||||
"Step3p5MTP load_weights: skip %d missing optional attn quant "
|
||||
"params with suffixes %s (showing first %d): %s",
|
||||
len(skipped_missing_params),
|
||||
skip_missing_suffixes,
|
||||
len(preview),
|
||||
preview,
|
||||
)
|
||||
params_need_to_load -= set(skipped_missing_params)
|
||||
if params_need_to_load != loaded_params:
|
||||
missing_params = list(params_need_to_load - loaded_params)
|
||||
param_name_example = missing_params[0]
|
||||
|
||||
@@ -1,500 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import SwinConfig
|
||||
from transformers.models.swin.modeling_swin import SwinEmbeddings, SwinPatchMerging
|
||||
from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer
|
||||
from transformers.pytorch_utils import meshgrid
|
||||
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class SwinSelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if dim % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({dim}) is not a multiple of the number of "
|
||||
f"attention heads ({num_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = num_heads
|
||||
self.attention_head_size = int(dim / num_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.window_size = (
|
||||
window_size
|
||||
if isinstance(window_size, Iterable)
|
||||
else (window_size, window_size)
|
||||
)
|
||||
self.scale = self.attention_head_size**-0.5
|
||||
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(
|
||||
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads
|
||||
)
|
||||
)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1)
|
||||
|
||||
self.relative_position_index = nn.Parameter(
|
||||
relative_position_index, requires_grad=False
|
||||
)
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=dim,
|
||||
head_size=self.attention_head_size,
|
||||
total_num_heads=self.num_attention_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def _get_rel_pos_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
]
|
||||
relative_position_bias = relative_position_bias.view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1,
|
||||
)
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
||||
return relative_position_bias.unsqueeze(0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
|
||||
qkv_output, _ = self.qkv(hidden_states)
|
||||
query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1)
|
||||
|
||||
key_layer = self.transpose_for_scores(key_layer)
|
||||
value_layer = self.transpose_for_scores(value_layer)
|
||||
query_layer = self.transpose_for_scores(query_layer)
|
||||
|
||||
attention_scores = self._get_rel_pos_bias()
|
||||
if attention_mask is not None:
|
||||
mask_shape = attention_mask.shape[0]
|
||||
attention_mask_expanded = attention_mask.view(
|
||||
1, mask_shape, 1, dim, dim
|
||||
).expand(
|
||||
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
|
||||
)
|
||||
attention_scores = attention_scores + attention_mask_expanded.unsqueeze(
|
||||
1
|
||||
).unsqueeze(0)
|
||||
attention_scores = attention_scores.view(
|
||||
-1, self.num_attention_heads, dim, dim
|
||||
)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_scores,
|
||||
dropout_p=0.0,
|
||||
)
|
||||
attention_probs = None
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (
|
||||
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SwinSelfOutput(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense = RowParallelLinear(
|
||||
input_size=dim,
|
||||
output_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self = SwinSelfAttention(
|
||||
config,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self",
|
||||
)
|
||||
self.output = SwinSelfOutput(
|
||||
config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
|
||||
)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
return outputs
|
||||
|
||||
|
||||
class SwinIntermediate(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense = ColumnParallelLinear(
|
||||
dim,
|
||||
int(config.mlp_ratio * dim),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
self.intermediate_act_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinOutput(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense = RowParallelLinear(
|
||||
int(config.mlp_ratio * dim),
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinLayer(HFSwinLayer):
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
input_resolution: int,
|
||||
num_heads: int,
|
||||
drop_path_rate: float = 0.0,
|
||||
shift_size: int = 0,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
drop_path_rate=drop_path_rate,
|
||||
shift_size=shift_size,
|
||||
)
|
||||
|
||||
self.attention = SwinAttention(
|
||||
config,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=self.window_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
self.intermediate = SwinIntermediate(
|
||||
config, dim, quant_config=quant_config, prefix=f"{prefix}.intermediate"
|
||||
)
|
||||
self.output = SwinOutput(
|
||||
config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
|
||||
)
|
||||
|
||||
|
||||
class SwinStage(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
input_resolution: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
drop_path: list[float],
|
||||
downsample: SwinPatchMerging | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = dim
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
SwinLayer(
|
||||
config=config,
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
drop_path_rate=drop_path[layer_idx],
|
||||
shift_size=0 if (layer_idx % 2 == 0) else config.window_size // 2,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(
|
||||
input_resolution, dim=dim, norm_layer=nn.LayerNorm
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.pointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_dimensions: tuple[int, int],
|
||||
output_attentions: bool | None = False,
|
||||
always_partition: bool | None = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
height, width = input_dimensions
|
||||
for i, layer_module in enumerate(self.blocks):
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_before_downsampling = hidden_states
|
||||
if self.downsample is not None:
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
||||
hidden_states = self.downsample(
|
||||
hidden_states_before_downsampling, input_dimensions
|
||||
)
|
||||
else:
|
||||
output_dimensions = (height, width, height, width)
|
||||
|
||||
stage_outputs = (
|
||||
hidden_states,
|
||||
hidden_states_before_downsampling,
|
||||
output_dimensions,
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
stage_outputs += layer_outputs[1:]
|
||||
return stage_outputs
|
||||
|
||||
|
||||
class SwinEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
grid_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = len(config.depths)
|
||||
self.config = config
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(
|
||||
0, config.drop_path_rate, sum(config.depths), device="cpu"
|
||||
)
|
||||
]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
SwinStage(
|
||||
config=config,
|
||||
dim=int(config.embed_dim * 2**layer_idx),
|
||||
input_resolution=(
|
||||
grid_size[0] // (2**layer_idx),
|
||||
grid_size[1] // (2**layer_idx),
|
||||
),
|
||||
depth=config.depths[layer_idx],
|
||||
num_heads=config.num_heads[layer_idx],
|
||||
drop_path=dpr[
|
||||
sum(config.depths[:layer_idx]) : sum(
|
||||
config.depths[: layer_idx + 1]
|
||||
)
|
||||
],
|
||||
downsample=SwinPatchMerging
|
||||
if (layer_idx < self.num_layers - 1)
|
||||
else None,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_dimensions: tuple[int, int],
|
||||
output_attentions: bool | None = False,
|
||||
always_partition: bool | None = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
output_dimensions = layer_outputs[2]
|
||||
|
||||
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinModel(nn.Module):
|
||||
config_class: SwinConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_layers = len(config.depths)
|
||||
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
|
||||
|
||||
self.embeddings = SwinEmbeddings(config)
|
||||
self.encoder = SwinEncoder(
|
||||
config,
|
||||
self.embeddings.patch_grid,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
) -> tuple[torch.Tensor]:
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
input_dimensions,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
("qkv", "query", "q"),
|
||||
("qkv", "key", "k"),
|
||||
("qkv", "value", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
@@ -27,6 +27,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.models.llama import (
|
||||
LlamaDecoderLayer,
|
||||
@@ -69,7 +70,7 @@ class TeleFLMForCausalLM(LlamaForCausalLM):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
# mup
|
||||
self.use_mup = self.config.use_mup
|
||||
if self.use_mup:
|
||||
if self.use_mup and get_pp_group().is_last_rank:
|
||||
self.mup_scale_factor = self.config.mup_scale_factor
|
||||
self.output_mult = self.config.output_mult / self.mup_scale_factor
|
||||
logit_scale = self.output_mult
|
||||
|
||||
@@ -300,14 +300,26 @@ class Base(
|
||||
for child_name, child_module in module.named_children():
|
||||
new_module = child_module
|
||||
qual_name = maybe_prefix(prefix, child_name)
|
||||
# Populate Eagle3 attrs
|
||||
if (
|
||||
isinstance(module, nn.ModuleList)
|
||||
and len(module) == self.text_config.num_hidden_layers
|
||||
):
|
||||
# Populate Eagle3 attrs
|
||||
self._target_class = type(child_module)
|
||||
layer_name = qual_name.removeprefix("model.")
|
||||
self._layer_names[int(child_name)] = layer_name
|
||||
# MTP weights should not be loaded into the base model
|
||||
num_hidden_layers = self.text_config.num_hidden_layers
|
||||
names = (
|
||||
"n_predict", # Override from SpeculativeConfig
|
||||
"num_nextn_predict_layers", # Most models
|
||||
"mtp_num_hidden_layers", # Qwen 3.5
|
||||
)
|
||||
n_predict = getattr_iter(self.text_config, names, 0)
|
||||
for i in range(num_hidden_layers, num_hidden_layers + n_predict):
|
||||
mtp_prefix = f"{prefix}.{i}."
|
||||
if mtp_prefix not in self.ignore_unexpected_prefixes:
|
||||
self.ignore_unexpected_prefixes.append(mtp_prefix)
|
||||
# Replace modules as needed
|
||||
if isinstance(child_module, nn.Linear):
|
||||
generator = (p for p in tp_plan if re.match(p, qual_name))
|
||||
|
||||
@@ -218,7 +218,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
if "mm_token_type_ids" in processed_data
|
||||
else "token_type_ids"
|
||||
)
|
||||
mm_token_type_ids = processed_data.pop(token_type_key)
|
||||
mm_token_type_ids = processed_data.get(token_type_key)
|
||||
|
||||
# We can infer vLLM style placeholder from token type ids, if we split
|
||||
# it for each input `mm_data`.
|
||||
@@ -353,6 +353,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
|
||||
num_image_patches = kwargs.pop("num_image_patches")
|
||||
kwargs.pop("token_type_ids", None) # used only in `forward`
|
||||
kwargs.pop("mm_token_type_ids", None) # used only in `model.get_rope_index`
|
||||
|
||||
if pixel_values is not None:
|
||||
# ROCm: Force math SDP backend for vision encoder to avoid accuracy issues
|
||||
@@ -443,6 +444,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
{
|
||||
"image_grid_thw",
|
||||
"video_grid_thw",
|
||||
"mm_token_type_ids",
|
||||
"second_per_grid_ts",
|
||||
"audio_feature_lengths",
|
||||
"use_audio_in_video",
|
||||
@@ -451,7 +453,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
if any(
|
||||
v
|
||||
for k, v in kwargs.items()
|
||||
if k not in {"image_grid_thw", "video_grid_thw"}
|
||||
if k not in {"image_grid_thw", "mm_token_type_ids"}
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Transformers modeling backend only supports images."
|
||||
@@ -459,6 +461,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
|
||||
image_grid_thw = kwargs.get("image_grid_thw", [])
|
||||
video_grid_thw = kwargs.get("video_grid_thw", [])
|
||||
mm_token_type_ids = kwargs.get("mm_token_type_ids")
|
||||
|
||||
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
|
||||
image_grid_thw
|
||||
@@ -467,10 +470,29 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
video_grid_thw
|
||||
)
|
||||
|
||||
# In v4 `get_rope_index` doesn't have wildcard `kwargs`, and
|
||||
# can't accept arbitrary args, even if its value is `None`
|
||||
kwargs = {}
|
||||
if mm_token_type_ids:
|
||||
if not hasattr(self, "_get_rope_index_accepts_mm_token_type_ids"):
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.model.get_rope_index)
|
||||
params = sig.parameters
|
||||
self._get_rope_index_accepts_mm_token_type_ids = (
|
||||
"mm_token_type_ids" in params
|
||||
or any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
|
||||
)
|
||||
)
|
||||
if self._get_rope_index_accepts_mm_token_type_ids:
|
||||
kwargs["mm_token_type_ids"] = torch.cat(mm_token_type_ids)
|
||||
|
||||
mrope_positions, mrope_position_delta = self.model.get_rope_index(
|
||||
input_ids=torch.tensor(input_tokens).unsqueeze(0),
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
mrope_positions = mrope_positions[:, 0]
|
||||
|
||||
@@ -311,8 +311,9 @@ class AutoWeightsLoader:
|
||||
|
||||
continue
|
||||
|
||||
named_parameters = module.named_parameters(recurse=True)
|
||||
desc_param_keys = {
|
||||
base_prefix + k for k, _ in module.named_parameters(recurse=True)
|
||||
maybe_prefix(base_prefix, k) for k, _ in named_parameters
|
||||
}
|
||||
msg = (
|
||||
f"There is no module or parameter named {prefix!r} "
|
||||
@@ -874,16 +875,3 @@ def get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
||||
if feature_layer_index < 0:
|
||||
return num_hidden_layers + feature_layer_index + 1
|
||||
return feature_layer_index
|
||||
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
||||
|
||||
def reparse_quant_config(prefix: str, quant_config):
|
||||
ignore = getattr(quant_config, "ignore", None)
|
||||
if not ignore:
|
||||
return quant_config
|
||||
|
||||
if should_ignore_layer(prefix, ignore):
|
||||
return None
|
||||
|
||||
return quant_config
|
||||
Reference in New Issue
Block a user