Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/qwen3_5.py

921 lines
38 KiB
Python
Raw Normal View History

# Inference-only Qwen3.6-27B (Qwen3_5 architecture) for Iluvatar BI-V100.
# Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency).
# Text-only (no VL, no MTP).
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheManager
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# Pure-PyTorch DeltaNet kernels (fallbacks from transformers 5.2.0)
# ---------------------------------------------------------------------------
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
def _torch_causal_conv1d_update(
hidden_states: torch.Tensor, # (batch, channels, seq=1)
conv_state: torch.Tensor, # (batch, channels, state_len) modified in-place
weight: torch.Tensor, # (channels, kernel_size)
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
) -> torch.Tensor:
_, channels, seq_len = hidden_states.shape
state_len = conv_state.shape[-1]
cat = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
conv_state.copy_(cat[:, :, -state_len:])
out = F.conv1d(cat, weight.unsqueeze(1), bias, padding=0, groups=channels)
out = out[:, :, -seq_len:]
if activation is not None:
out = F.silu(out)
return out.to(hidden_states.dtype)
def _torch_chunk_gated_delta_rule(
query: torch.Tensor, # (batch, seq, num_heads, head_k_dim)
key: torch.Tensor,
value: torch.Tensor, # (batch, seq, num_heads, head_v_dim)
g: torch.Tensor, # (batch, seq, num_heads)
beta: torch.Tensor, # (batch, seq, num_heads)
chunk_size: int = 64,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = _l2norm(query)
key = _l2norm(key)
# Transpose to (batch, num_heads, seq, dim)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch, num_heads, seq_len, k_dim = key.shape
v_dim = value.shape[-1]
pad = (chunk_size - seq_len % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad))
key = F.pad(key, (0, 0, 0, pad))
value = F.pad(value, (0, 0, 0, pad))
beta = F.pad(beta, (0, pad))
g = F.pad(g, (0, pad))
total_len = seq_len + pad
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask_upper = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
diagonal=0)
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask_upper, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_state = (
torch.zeros(batch, num_heads, k_dim, v_dim, dtype=value.dtype, device=value.device)
if initial_state is None
else initial_state.to(value)
)
core_out = torch.zeros_like(value)
mask_upper2 = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
diagonal=1)
for i in range(total_len // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask_upper2, 0)
v_prime = k_cumdecay[:, :, i] @ last_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state
core_out[:, :, i] = attn_inter + attn_i @ v_new
last_state = (
last_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None])
.transpose(-1, -2) @ v_new
)
if not output_final_state:
last_state = None
core_out = core_out.reshape(batch, num_heads, -1, v_dim)[:, :, :seq_len]
core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_out, last_state
def _torch_recurrent_gated_delta_rule(
query: torch.Tensor, # (batch, 1, num_heads, head_k_dim)
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor, # (batch, 1, num_heads)
beta: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = _l2norm(query)
key = _l2norm(key)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch, num_heads, seq_len, k_dim = key.shape
v_dim = value.shape[-1]
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale
core_out = torch.zeros(batch, num_heads, seq_len, v_dim,
dtype=value.dtype, device=value.device)
last_state = (
torch.zeros(batch, num_heads, k_dim, v_dim,
dtype=value.dtype, device=value.device)
if initial_state is None
else initial_state.to(value)
)
for t in range(seq_len):
q_t = query[:, :, t]
k_t = key[:, :, t]
v_t = value[:, :, t]
g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, t].unsqueeze(-1)
last_state = last_state * g_t
kv_mem = (last_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_state = last_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
core_out[:, :, t] = (last_state * q_t.unsqueeze(-1)).sum(dim=-2)
if not output_final_state:
last_state = None
core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_out, last_state
# ---------------------------------------------------------------------------
# Gated RMSNorm (for DeltaNet output normalisation)
# ---------------------------------------------------------------------------
class Qwen3_5RMSNormGated(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor,
gate: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hs = hidden_states.to(torch.float32)
variance = hs.pow(2).mean(-1, keepdim=True)
hs = hs * torch.rsqrt(variance + self.variance_epsilon)
hs = self.weight * hs.to(input_dtype)
return (hs * F.silu(gate.to(torch.float32))).to(input_dtype)
# ---------------------------------------------------------------------------
# Gated DeltaNet (linear_attention layers)
# ---------------------------------------------------------------------------
class GatedDeltaNet(nn.Module):
def __init__(
self,
text_cfg,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = text_cfg.hidden_size
self.num_v_heads = text_cfg.linear_num_value_heads # 48
self.num_k_heads = text_cfg.linear_num_key_heads # 16
self.head_k_dim = text_cfg.linear_key_head_dim # 128
self.head_v_dim = text_cfg.linear_value_head_dim # 128
self.key_dim = self.num_k_heads * self.head_k_dim # 2048
self.value_dim = self.num_v_heads * self.head_v_dim # 6144
self.conv_dim = self.key_dim * 2 + self.value_dim # 10240
self.conv_kernel_size = text_cfg.linear_conv_kernel_dim # 4
self.head_expand_ratio = self.num_v_heads // self.num_k_heads # 3
tp_size = get_tensor_model_parallel_world_size()
# Sharded projections — MergedColumnParallelLinear shards each of q/k/v
# independently so each TP rank gets [q_shard, k_shard, v_shard].
# Plain ColumnParallelLinear would shard contiguously, giving rank 0
# [q_all, k_partial] — completely wrong Q/K/V after the split below.
self.in_proj_qkv = MergedColumnParallelLinear(
self.hidden_size, [self.key_dim, self.key_dim, self.value_dim],
bias=False, quant_config=quant_config)
self.in_proj_z = ColumnParallelLinear(
self.hidden_size, self.value_dim,
bias=False, quant_config=quant_config)
self.in_proj_b = ColumnParallelLinear(
self.hidden_size, self.num_v_heads,
bias=False, quant_config=quant_config)
self.in_proj_a = ColumnParallelLinear(
self.hidden_size, self.num_v_heads,
bias=False, quant_config=quant_config)
self.out_proj = RowParallelLinear(
self.value_dim, self.hidden_size,
bias=False, quant_config=quant_config)
# Depthwise conv weight — sharded along channel dim (dim 0)
local_conv_dim = self.conv_dim // tp_size
self.conv1d_weight = nn.Parameter(
torch.empty(local_conv_dim, 1, self.conv_kernel_size))
set_weight_attrs(self.conv1d_weight, {
"weight_loader": self._conv1d_weight_loader})
# Per-head scalar parameters — sharded along dim 0
local_num_v = self.num_v_heads // tp_size
self.A_log = nn.Parameter(torch.zeros(local_num_v))
self.dt_bias = nn.Parameter(torch.zeros(local_num_v))
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
# Gated RMSNorm on head_v_dim — replicated (head_v_dim=128 is small)
self.norm = Qwen3_5RMSNormGated(self.head_v_dim,
eps=text_cfg.rms_norm_eps)
def _conv1d_weight_loader(self, param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
# loaded_weight: (conv_dim=10240, 1, kernel) ordered as [q, k, v] channels
# Must gather channels in the same non-contiguous pattern that
# MergedColumnParallelLinear uses for in_proj_qkv, so that each rank's
# conv1d_weight[i] applies to the correct in_proj_qkv output channel.
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
key_local = self.key_dim // tp_size # 512 with TP=4
val_local = self.value_dim // tp_size # 1536 with TP=4
q_s = loaded_weight[tp_rank * key_local : (tp_rank + 1) * key_local]
k_s = loaded_weight[self.key_dim + tp_rank * key_local :
self.key_dim + (tp_rank + 1) * key_local]
v_s = loaded_weight[2 * self.key_dim + tp_rank * val_local :
2 * self.key_dim + (tp_rank + 1) * val_local]
param.data.copy_(torch.cat([q_s, k_s, v_s], dim=0))
def forward(
self,
hidden_states: torch.Tensor, # (total_tokens, hidden_size)
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, # (batch, local_conv_dim, kernel-1) in-place
temporal_state: torch.Tensor, # (batch, local_v_heads, k_dim, v_dim) in-place
) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
local_key_dim = self.key_dim // tp_size
local_val_dim = self.value_dim // tp_size
local_num_v = self.num_v_heads // tp_size
local_num_k = self.num_k_heads // tp_size
local_conv_dim = self.conv_dim // tp_size
is_prefill = attn_metadata.num_prefill_tokens > 0
# Compute all projections for every token at once (batched, efficient)
mixed_qkv_all, _ = self.in_proj_qkv(hidden_states) # (total, local_conv_dim)
z_all, _ = self.in_proj_z(hidden_states) # (total, local_val_dim)
b_all, _ = self.in_proj_b(hidden_states) # (total, local_num_v)
a_all, _ = self.in_proj_a(hidden_states) # (total, local_num_v)
if is_prefill:
seq_starts = attn_metadata.query_start_loc.tolist()
outputs = []
state_len = self.conv_kernel_size - 1
weight_2d = self.conv1d_weight.squeeze(1) # (local_conv_dim, kernel)
for si in range(len(seq_starts) - 1):
s, e = int(seq_starts[si]), int(seq_starts[si + 1])
seq_len = e - s
# Shape: (1, local_conv_dim, seq_len)
mixed_qkv = (mixed_qkv_all[s:e]
.transpose(0, 1).unsqueeze(0)
.to(weight_2d.dtype))
# Save conv state (last state_len positions)
if seq_len >= state_len:
conv_state[si].copy_(mixed_qkv[0, :, -state_len:])
else:
conv_state[si, :, state_len - seq_len:].copy_(
mixed_qkv[0])
conv_state[si, :, :state_len - seq_len] = 0
# Causal conv (left-pad with zeros, then convolve)
padded = F.pad(mixed_qkv, (state_len, 0))
mixed_qkv_conv = F.conv1d(
padded, self.conv1d_weight,
bias=None, padding=0, groups=local_conv_dim)
mixed_qkv_conv = F.silu(mixed_qkv_conv)
# (1, seq_len, local_conv_dim)
mixed_qkv_conv = mixed_qkv_conv.squeeze(0).transpose(0, 1).unsqueeze(0)
q, k, v = torch.split(
mixed_qkv_conv,
[local_key_dim, local_key_dim, local_val_dim], dim=-1)
q = q.reshape(1, seq_len, local_num_k, self.head_k_dim)
k = k.reshape(1, seq_len, local_num_k, self.head_k_dim)
v = v.reshape(1, seq_len, local_num_v, self.head_v_dim)
beta = b_all[s:e].sigmoid().unsqueeze(0) # (1, seq_len, local_num_v)
g = (-self.A_log.float().exp()
* F.softplus(a_all[s:e].float() + self.dt_bias)
).unsqueeze(0) # (1, seq_len, local_num_v)
# Expand k/q to match num_v_heads
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
# Sub-sequence chunking: call _torch_chunk_gated_delta_rule
# on _DNN_CHUNK tokens at a time to cap peak memory.
# Full 18K: tensors [1,6,282,64,64]=220 MB each → ~990 MB/call.
# With _DNN_CHUNK=4096: [1,6,64,64,64]=6 MB each → ~137 MB/call.
# State is chained via initial_state / output_final_state.
_DNN_CHUNK = 4096
cur_state = temporal_state[si:si + 1].clone()
core_out_parts = []
for sc_start in range(0, seq_len, _DNN_CHUNK):
sc_end = min(sc_start + _DNN_CHUNK, seq_len)
c_out, cur_state = _torch_chunk_gated_delta_rule(
q[:, sc_start:sc_end],
k[:, sc_start:sc_end],
v[:, sc_start:sc_end],
g[:, sc_start:sc_end],
beta[:, sc_start:sc_end],
initial_state=cur_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
)
core_out_parts.append(c_out)
if cur_state is not None:
temporal_state[si].copy_(cur_state[0])
# [1, seq_len, num_v_heads, head_v_dim]
core_out = torch.cat(core_out_parts, dim=1)
# Gate + norm + output proj
z = z_all[s:e].reshape(seq_len, local_num_v, self.head_v_dim)
core_out = core_out.reshape(seq_len, local_num_v, self.head_v_dim)
normed = self.norm(
core_out.reshape(-1, self.head_v_dim),
z.reshape(-1, self.head_v_dim))
normed = normed.reshape(seq_len, -1)
out, _ = self.out_proj(normed)
outputs.append(out)
result = torch.cat(outputs, dim=0)
if torch.isnan(result).any():
logger.warning("NaN in prefill GatedDeltaNet layer %d (frac=%.4f), replacing with zeros",
self.layer_idx, torch.isnan(result).float().mean().item())
result = torch.nan_to_num(result, nan=0.0)
return result
else:
# Decode: one token per sequence
num_seqs = hidden_states.shape[0]
weight_2d = self.conv1d_weight.squeeze(1)
# (num_seqs, local_conv_dim, 1)
mixed_qkv = (mixed_qkv_all
.to(weight_2d.dtype)
.unsqueeze(-1))
mixed_qkv_conv = _torch_causal_conv1d_update(
mixed_qkv, conv_state, weight_2d,
bias=None, activation='silu')
# (num_seqs, local_conv_dim, 1) → (num_seqs, 1, local_conv_dim)
mixed_qkv_conv = mixed_qkv_conv.squeeze(-1).unsqueeze(1)
q, k, v = torch.split(
mixed_qkv_conv,
[local_key_dim, local_key_dim, local_val_dim], dim=-1)
q = q.reshape(num_seqs, 1, local_num_k, self.head_k_dim)
k = k.reshape(num_seqs, 1, local_num_k, self.head_k_dim)
v = v.reshape(num_seqs, 1, local_num_v, self.head_v_dim)
beta = b_all.sigmoid().unsqueeze(1) # (num_seqs, 1, local_num_v)
g = (-self.A_log.float().exp()
* F.softplus(a_all.float() + self.dt_bias)
).unsqueeze(1) # (num_seqs, 1, local_num_v)
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
core_out, last_state = _torch_recurrent_gated_delta_rule(
q, k, v, g, beta,
initial_state=temporal_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
)
if last_state is not None:
temporal_state.copy_(last_state)
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim)
normed = self.norm(
core_out.reshape(-1, self.head_v_dim),
z.reshape(-1, self.head_v_dim))
normed = normed.reshape(num_seqs, -1)
out, _ = self.out_proj(normed)
if torch.isnan(out).any():
logger.warning("NaN in decode GatedDeltaNet layer %d (frac=%.4f), replacing with zeros",
self.layer_idx, torch.isnan(out).float().mean().item())
out = torch.nan_to_num(out, nan=0.0)
return out
# ---------------------------------------------------------------------------
# Full Attention (with gated q — unique to Qwen3.5)
# ---------------------------------------------------------------------------
class Qwen3_5FullAttention(nn.Module):
def __init__(
self,
text_cfg,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = text_cfg.hidden_size # 5120
self.num_heads = text_cfg.num_attention_heads # 24
self.num_kv_heads = text_cfg.num_key_value_heads # 4
self.head_dim = text_cfg.head_dim # 256
self.rms_norm_eps = text_cfg.rms_norm_eps
tp_size = get_tensor_model_parallel_world_size()
self.local_num_heads = self.num_heads // tp_size
self.local_num_kv_heads = max(1, self.num_kv_heads // tp_size)
self.local_q_dim = self.local_num_heads * self.head_dim
self.local_kv_dim = self.local_num_kv_heads * self.head_dim
self.scaling = self.head_dim ** -0.5
# q_proj includes gate: output = num_heads * head_dim * 2
self.q_proj = ColumnParallelLinear(
self.hidden_size, self.num_heads * self.head_dim * 2,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.k_proj = ColumnParallelLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.k_proj")
self.v_proj = ColumnParallelLinear(
self.hidden_size, self.num_kv_heads * self.head_dim,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.v_proj")
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim, self.hidden_size,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.q_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps)
# Partial RoPE: rotary_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64
rope_params = getattr(text_cfg, "rope_parameters", {}) or {}
rope_theta = rope_params.get("rope_theta", 10_000_000)
partial_factor = rope_params.get("partial_rotary_factor", 0.25)
rotary_dim = int(self.head_dim * partial_factor)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=rotary_dim,
max_position=text_cfg.max_position_embeddings,
base=rope_theta,
)
self.attn = Attention(
self.local_num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.local_num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
total_tokens = hidden_states.shape[0]
# q_proj output includes gate (dim doubled)
qg, _ = self.q_proj(hidden_states) # (total, local_num_heads * head_dim * 2)
qg = qg.view(total_tokens, self.local_num_heads, self.head_dim * 2)
q = qg[:, :, :self.head_dim].reshape(total_tokens, -1)
gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1)
k, _ = self.k_proj(hidden_states) # (total, local_kv_dim)
v, _ = self.v_proj(hidden_states)
# Per-head RMSNorm
q = self.q_norm.forward_cuda(
q.view(total_tokens, self.local_num_heads, self.head_dim)
.contiguous()).view(total_tokens, -1)
k = self.k_norm.forward_cuda(
k.view(total_tokens, self.local_num_kv_heads, self.head_dim)
.contiguous()).view(total_tokens, -1)
q, k = self.rotary_emb(positions, q, k)
attn_out = self.attn(q, k, v, kv_cache, attn_metadata)
# Multiply by sigmoid gate before output projection
attn_out = attn_out * torch.sigmoid(gate.float()).to(attn_out.dtype)
output, _ = self.o_proj(attn_out)
return output
# ---------------------------------------------------------------------------
# MLP (SwiGLU, same as Qwen2/Qwen3)
# ---------------------------------------------------------------------------
class Qwen3_5MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False, quant_config=quant_config)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size,
bias=False, quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}")
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
# ---------------------------------------------------------------------------
# Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention)
# ---------------------------------------------------------------------------
class Qwen3_5DecoderLayer(nn.Module):
def __init__(
self,
text_cfg,
layer_idx: int,
layer_type: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.layer_type = layer_type
self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
eps=text_cfg.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
eps=text_cfg.rms_norm_eps)
if layer_type == "linear_attention":
self.linear_attn = GatedDeltaNet(text_cfg, layer_idx,
quant_config=quant_config)
else:
self.self_attn = Qwen3_5FullAttention(
text_cfg, layer_idx,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"layers.{layer_idx}.self_attn",
)
self.mlp = Qwen3_5MLP(
hidden_size=text_cfg.hidden_size,
intermediate_size=text_cfg.intermediate_size,
hidden_act=text_cfg.hidden_act,
quant_config=quant_config,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
# Only for linear_attention layers:
conv_state: Optional[torch.Tensor] = None,
temporal_state: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states, attn_metadata, conv_state, temporal_state)
else:
hidden_states = self.self_attn(
positions, hidden_states, kv_cache, attn_metadata)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
# ---------------------------------------------------------------------------
# Full transformer model
# ---------------------------------------------------------------------------
class Qwen3_5Model(nn.Module):
def __init__(
self,
text_cfg,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.text_cfg = text_cfg
self.embed_tokens = VocabParallelEmbedding(
text_cfg.vocab_size, text_cfg.hidden_size)
self.layers = nn.ModuleList([
Qwen3_5DecoderLayer(
text_cfg, i, text_cfg.layer_types[i],
cache_config=cache_config, quant_config=quant_config)
for i in range(text_cfg.num_hidden_layers)
])
self.norm = GemmaRMSNorm(text_cfg.hidden_size, eps=text_cfg.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
conv_states: torch.Tensor, # (num_linear_layers, batch, ...)
temporal_states: torch.Tensor, # (num_linear_layers, batch, ...)
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
attn_idx = 0
linear_idx = 0
for layer in self.layers:
if layer.layer_type == "linear_attention":
hidden_states, residual = layer(
positions, hidden_states,
kv_cache=None,
attn_metadata=attn_metadata,
residual=residual,
conv_state=conv_states[linear_idx],
temporal_state=temporal_states[linear_idx],
)
linear_idx += 1
else:
kv_cache = kv_caches[attn_idx]
hidden_states, residual = layer(
positions, hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
)
attn_idx += 1
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
# ---------------------------------------------------------------------------
# Top-level CausalLM wrapper with MambaCacheManager
# ---------------------------------------------------------------------------
class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
has_inner_state = True
supports_lora = True
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
supported_lora_modules = [
"gate_up_proj",
"down_proj",
"o_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config, # Qwen3_5Config (top-level)
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
# The text config holds all architecture parameters
text_cfg = config.text_config
self.text_cfg = text_cfg
# Pre-compute counts
self.num_linear_layers = sum(
1 for lt in text_cfg.layer_types if lt == "linear_attention")
self.num_attn_layers = sum(
1 for lt in text_cfg.layer_types if lt == "full_attention")
# DeltaNet state dimensions (per layer, per sequence, TP-sharded)
tp_size = get_tensor_model_parallel_world_size()
self.conv_dim = (text_cfg.linear_num_key_heads * text_cfg.linear_key_head_dim * 2
+ text_cfg.linear_num_value_heads * text_cfg.linear_value_head_dim)
self.num_v_heads = text_cfg.linear_num_value_heads
self.head_k_dim = text_cfg.linear_key_head_dim
self.head_v_dim = text_cfg.linear_value_head_dim
self.conv_kernel_size = text_cfg.linear_conv_kernel_dim
self.model = Qwen3_5Model(
text_cfg,
cache_config=cache_config,
quant_config=quant_config,
)
self.lm_head = ParallelLMHead(
text_cfg.vocab_size, text_cfg.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(text_cfg.vocab_size)
self.sampler = Sampler()
# Lazy initialised in first forward call
self.mamba_cache: Optional[MambaCacheManager] = None
def _get_mamba_cache_shape(self):
tp_size = get_tensor_model_parallel_world_size()
# Each sequence's state is stored in float32
conv_state_shape = (self.conv_dim // tp_size, self.conv_kernel_size - 1)
temporal_state_shape = (
self.num_v_heads // tp_size, self.head_k_dim, self.head_v_dim)
return conv_state_shape, temporal_state_shape
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs,
) -> torch.Tensor:
if self.mamba_cache is None:
if self.scheduler_config is not None:
max_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
else:
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + 2
self.mamba_cache = MambaCacheManager(
torch.float32,
self.num_linear_layers,
max_batch_size,
*self._get_mamba_cache_shape(),
)
mamba_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
# conv_states: (num_linear_layers, batch, local_conv_dim, kernel-1)
# temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim)
conv_states, temporal_states = mamba_tensors
hidden_states = self.model(
input_ids, positions, kv_caches, attn_metadata,
conv_states, temporal_states)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.sampler(logits, sampling_metadata)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, weight_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# Skip vision and MTP branches
if (name.startswith("model.visual")
or name.startswith("mtp.")
or name.startswith("model.mtp")):
continue
# Remap checkpoint prefix → module path
# Checkpoint: "model.language_model.{rest}" → our module: "model.{rest}"
# Checkpoint: "lm_head.weight" → our module: "lm_head.weight"
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
# lm_head is already at top level — no change needed
# Skip positional embedding caches
if "rotary_emb.inv_freq" in name:
continue
# Remap conv1d.weight → conv1d_weight
# The conv has depth (1) dim in the checkpoint that we handle separately
if ".linear_attn.conv1d.weight" in name:
name = name.replace(".linear_attn.conv1d.weight",
".linear_attn.conv1d_weight")
# Stacked param loading (gate_up_proj)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
break
if name not in params_dict:
break
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)