# Inference-only Qwen3.6-27B (Qwen3_5 architecture) for Iluvatar BI-V100. # Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency). # Text-only (no VL, no MTP). from typing import Iterable, List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) from vllm.logger import init_logger from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA logger = init_logger(__name__) # --------------------------------------------------------------------------- # Pure-PyTorch DeltaNet kernels (fallbacks from transformers 5.2.0) # --------------------------------------------------------------------------- def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) def _torch_causal_conv1d_update( hidden_states: torch.Tensor, # (batch, channels, seq=1) conv_state: torch.Tensor, # (batch, channels, state_len) modified in-place weight: torch.Tensor, # (channels, kernel_size) bias: Optional[torch.Tensor] = None, activation: Optional[str] = None, ) -> torch.Tensor: _, channels, seq_len = hidden_states.shape state_len = conv_state.shape[-1] cat = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) conv_state.copy_(cat[:, :, -state_len:]) out = F.conv1d(cat, weight.unsqueeze(1), bias, padding=0, groups=channels) out = out[:, :, -seq_len:] if activation is not None: out = F.silu(out) return out.to(hidden_states.dtype) def _torch_chunk_gated_delta_rule( query: torch.Tensor, # (batch, seq, num_heads, head_k_dim) key: torch.Tensor, value: torch.Tensor, # (batch, seq, num_heads, head_v_dim) g: torch.Tensor, # (batch, seq, num_heads) beta: torch.Tensor, # (batch, seq, num_heads) chunk_size: int = 64, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, use_qk_l2norm_in_kernel: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: initial_dtype = query.dtype if use_qk_l2norm_in_kernel: query = _l2norm(query) key = _l2norm(key) # Transpose to (batch, num_heads, seq, dim) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] batch, num_heads, seq_len, k_dim = key.shape v_dim = value.shape[-1] pad = (chunk_size - seq_len % chunk_size) % chunk_size query = F.pad(query, (0, 0, 0, pad)) key = F.pad(key, (0, 0, 0, pad)) value = F.pad(value, (0, 0, 0, pad)) beta = F.pad(beta, (0, pad)) g = F.pad(g, (0, pad)) total_len = seq_len + pad scale = 1.0 / (query.shape[-1] ** 0.5) query = query * scale v_beta = value * beta.unsqueeze(-1) k_beta = key * beta.unsqueeze(-1) query, key, value, k_beta, v_beta = [ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) ] g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) mask_upper = torch.triu( torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) g = g.cumsum(dim=-1) decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask_upper, 0) for i in range(1, chunk_size): row = attn[..., i, :i].clone() sub = attn[..., :i, :i].clone() attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) last_state = ( torch.zeros(batch, num_heads, k_dim, v_dim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value) ) core_out = torch.zeros_like(value) mask_upper2 = torch.triu( torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) for i in range(total_len // chunk_size): q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask_upper2, 0) v_prime = k_cumdecay[:, :, i] @ last_state v_new = v_i - v_prime attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state core_out[:, :, i] = attn_inter + attn_i @ v_new last_state = ( last_state * g[:, :, i, -1, None, None].exp() + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]) .transpose(-1, -2) @ v_new ) if not output_final_state: last_state = None core_out = core_out.reshape(batch, num_heads, -1, v_dim)[:, :, :seq_len] core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype) return core_out, last_state def _torch_recurrent_gated_delta_rule( query: torch.Tensor, # (batch, 1, num_heads, head_k_dim) key: torch.Tensor, value: torch.Tensor, g: torch.Tensor, # (batch, 1, num_heads) beta: torch.Tensor, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, use_qk_l2norm_in_kernel: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: initial_dtype = query.dtype if use_qk_l2norm_in_kernel: query = _l2norm(query) key = _l2norm(key) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] batch, num_heads, seq_len, k_dim = key.shape v_dim = value.shape[-1] scale = 1.0 / (query.shape[-1] ** 0.5) query = query * scale core_out = torch.zeros(batch, num_heads, seq_len, v_dim, dtype=value.dtype, device=value.device) last_state = ( torch.zeros(batch, num_heads, k_dim, v_dim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value) ) for t in range(seq_len): q_t = query[:, :, t] k_t = key[:, :, t] v_t = value[:, :, t] g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) beta_t = beta[:, :, t].unsqueeze(-1) last_state = last_state * g_t kv_mem = (last_state * k_t.unsqueeze(-1)).sum(dim=-2) delta = (v_t - kv_mem) * beta_t last_state = last_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) core_out[:, :, t] = (last_state * q_t.unsqueeze(-1)).sum(dim=-2) if not output_final_state: last_state = None core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype) return core_out, last_state # --------------------------------------------------------------------------- # Gated RMSNorm (for DeltaNet output normalisation) # --------------------------------------------------------------------------- class Qwen3_5RMSNormGated(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hs = hidden_states.to(torch.float32) variance = hs.pow(2).mean(-1, keepdim=True) hs = hs * torch.rsqrt(variance + self.variance_epsilon) hs = self.weight * hs.to(input_dtype) return (hs * F.silu(gate.to(torch.float32))).to(input_dtype) # --------------------------------------------------------------------------- # Gated DeltaNet (linear_attention layers) # --------------------------------------------------------------------------- class GatedDeltaNet(nn.Module): def __init__( self, text_cfg, layer_idx: int, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.layer_idx = layer_idx self.hidden_size = text_cfg.hidden_size self.num_v_heads = text_cfg.linear_num_value_heads # 48 self.num_k_heads = text_cfg.linear_num_key_heads # 16 self.head_k_dim = text_cfg.linear_key_head_dim # 128 self.head_v_dim = text_cfg.linear_value_head_dim # 128 self.key_dim = self.num_k_heads * self.head_k_dim # 2048 self.value_dim = self.num_v_heads * self.head_v_dim # 6144 self.conv_dim = self.key_dim * 2 + self.value_dim # 10240 self.conv_kernel_size = text_cfg.linear_conv_kernel_dim # 4 self.head_expand_ratio = self.num_v_heads // self.num_k_heads # 3 tp_size = get_tensor_model_parallel_world_size() # Sharded projections — MergedColumnParallelLinear shards each of q/k/v # independently so each TP rank gets [q_shard, k_shard, v_shard]. # Plain ColumnParallelLinear would shard contiguously, giving rank 0 # [q_all, k_partial] — completely wrong Q/K/V after the split below. self.in_proj_qkv = MergedColumnParallelLinear( self.hidden_size, [self.key_dim, self.key_dim, self.value_dim], bias=False, quant_config=quant_config) self.in_proj_z = ColumnParallelLinear( self.hidden_size, self.value_dim, bias=False, quant_config=quant_config) self.in_proj_b = ColumnParallelLinear( self.hidden_size, self.num_v_heads, bias=False, quant_config=quant_config) self.in_proj_a = ColumnParallelLinear( self.hidden_size, self.num_v_heads, bias=False, quant_config=quant_config) self.out_proj = RowParallelLinear( self.value_dim, self.hidden_size, bias=False, quant_config=quant_config) # Depthwise conv weight — sharded along channel dim (dim 0) local_conv_dim = self.conv_dim // tp_size self.conv1d_weight = nn.Parameter( torch.empty(local_conv_dim, 1, self.conv_kernel_size)) set_weight_attrs(self.conv1d_weight, { "weight_loader": self._conv1d_weight_loader}) # Per-head scalar parameters — sharded along dim 0 local_num_v = self.num_v_heads // tp_size self.A_log = nn.Parameter(torch.zeros(local_num_v)) self.dt_bias = nn.Parameter(torch.zeros(local_num_v)) set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) # Gated RMSNorm on head_v_dim — replicated (head_v_dim=128 is small) self.norm = Qwen3_5RMSNormGated(self.head_v_dim, eps=text_cfg.rms_norm_eps) def _conv1d_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # loaded_weight: (conv_dim=10240, 1, kernel) ordered as [q, k, v] channels # Must gather channels in the same non-contiguous pattern that # MergedColumnParallelLinear uses for in_proj_qkv, so that each rank's # conv1d_weight[i] applies to the correct in_proj_qkv output channel. tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() key_local = self.key_dim // tp_size # 512 with TP=4 val_local = self.value_dim // tp_size # 1536 with TP=4 q_s = loaded_weight[tp_rank * key_local : (tp_rank + 1) * key_local] k_s = loaded_weight[self.key_dim + tp_rank * key_local : self.key_dim + (tp_rank + 1) * key_local] v_s = loaded_weight[2 * self.key_dim + tp_rank * val_local : 2 * self.key_dim + (tp_rank + 1) * val_local] param.data.copy_(torch.cat([q_s, k_s, v_s], dim=0)) def forward( self, hidden_states: torch.Tensor, # (total_tokens, hidden_size) attn_metadata: AttentionMetadata, conv_state: torch.Tensor, # (batch, local_conv_dim, kernel-1) in-place temporal_state: torch.Tensor, # (batch, local_v_heads, k_dim, v_dim) in-place ) -> torch.Tensor: tp_size = get_tensor_model_parallel_world_size() local_key_dim = self.key_dim // tp_size local_val_dim = self.value_dim // tp_size local_num_v = self.num_v_heads // tp_size local_num_k = self.num_k_heads // tp_size local_conv_dim = self.conv_dim // tp_size is_prefill = attn_metadata.num_prefill_tokens > 0 # Compute all projections for every token at once (batched, efficient) mixed_qkv_all, _ = self.in_proj_qkv(hidden_states) # (total, local_conv_dim) z_all, _ = self.in_proj_z(hidden_states) # (total, local_val_dim) b_all, _ = self.in_proj_b(hidden_states) # (total, local_num_v) a_all, _ = self.in_proj_a(hidden_states) # (total, local_num_v) if is_prefill: seq_starts = attn_metadata.query_start_loc.tolist() outputs = [] state_len = self.conv_kernel_size - 1 weight_2d = self.conv1d_weight.squeeze(1) # (local_conv_dim, kernel) for si in range(len(seq_starts) - 1): s, e = int(seq_starts[si]), int(seq_starts[si + 1]) seq_len = e - s # Shape: (1, local_conv_dim, seq_len) mixed_qkv = (mixed_qkv_all[s:e] .transpose(0, 1).unsqueeze(0) .to(weight_2d.dtype)) # Save conv state (last state_len positions) if seq_len >= state_len: conv_state[si].copy_(mixed_qkv[0, :, -state_len:]) else: conv_state[si, :, state_len - seq_len:].copy_( mixed_qkv[0]) conv_state[si, :, :state_len - seq_len] = 0 # Causal conv (left-pad with zeros, then convolve) padded = F.pad(mixed_qkv, (state_len, 0)) mixed_qkv_conv = F.conv1d( padded, self.conv1d_weight, bias=None, padding=0, groups=local_conv_dim) mixed_qkv_conv = F.silu(mixed_qkv_conv) # (1, seq_len, local_conv_dim) mixed_qkv_conv = mixed_qkv_conv.squeeze(0).transpose(0, 1).unsqueeze(0) q, k, v = torch.split( mixed_qkv_conv, [local_key_dim, local_key_dim, local_val_dim], dim=-1) q = q.reshape(1, seq_len, local_num_k, self.head_k_dim) k = k.reshape(1, seq_len, local_num_k, self.head_k_dim) v = v.reshape(1, seq_len, local_num_v, self.head_v_dim) beta = b_all[s:e].sigmoid().unsqueeze(0) # (1, seq_len, local_num_v) g = (-self.A_log.float().exp() * F.softplus(a_all[s:e].float() + self.dt_bias) ).unsqueeze(0) # (1, seq_len, local_num_v) # Expand k/q to match num_v_heads q = q.repeat_interleave(self.head_expand_ratio, dim=2) k = k.repeat_interleave(self.head_expand_ratio, dim=2) # Sub-sequence chunking: call _torch_chunk_gated_delta_rule # on _DNN_CHUNK tokens at a time to cap peak memory. # Full 18K: tensors [1,6,282,64,64]=220 MB each → ~990 MB/call. # With _DNN_CHUNK=4096: [1,6,64,64,64]=6 MB each → ~137 MB/call. # State is chained via initial_state / output_final_state. _DNN_CHUNK = 4096 cur_state = temporal_state[si:si + 1].clone() core_out_parts = [] for sc_start in range(0, seq_len, _DNN_CHUNK): sc_end = min(sc_start + _DNN_CHUNK, seq_len) c_out, cur_state = _torch_chunk_gated_delta_rule( q[:, sc_start:sc_end], k[:, sc_start:sc_end], v[:, sc_start:sc_end], g[:, sc_start:sc_end], beta[:, sc_start:sc_end], initial_state=cur_state, output_final_state=True, use_qk_l2norm_in_kernel=True, ) core_out_parts.append(c_out) if cur_state is not None: temporal_state[si].copy_(cur_state[0]) # [1, seq_len, num_v_heads, head_v_dim] core_out = torch.cat(core_out_parts, dim=1) # Gate + norm + output proj z = z_all[s:e].reshape(seq_len, local_num_v, self.head_v_dim) core_out = core_out.reshape(seq_len, local_num_v, self.head_v_dim) normed = self.norm( core_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)) normed = normed.reshape(seq_len, -1) out, _ = self.out_proj(normed) outputs.append(out) result = torch.cat(outputs, dim=0) if torch.isnan(result).any(): logger.warning("NaN in prefill GatedDeltaNet layer %d (frac=%.4f), replacing with zeros", self.layer_idx, torch.isnan(result).float().mean().item()) result = torch.nan_to_num(result, nan=0.0) return result else: # Decode: one token per sequence num_seqs = hidden_states.shape[0] weight_2d = self.conv1d_weight.squeeze(1) # (num_seqs, local_conv_dim, 1) mixed_qkv = (mixed_qkv_all .to(weight_2d.dtype) .unsqueeze(-1)) mixed_qkv_conv = _torch_causal_conv1d_update( mixed_qkv, conv_state, weight_2d, bias=None, activation='silu') # (num_seqs, local_conv_dim, 1) → (num_seqs, 1, local_conv_dim) mixed_qkv_conv = mixed_qkv_conv.squeeze(-1).unsqueeze(1) q, k, v = torch.split( mixed_qkv_conv, [local_key_dim, local_key_dim, local_val_dim], dim=-1) q = q.reshape(num_seqs, 1, local_num_k, self.head_k_dim) k = k.reshape(num_seqs, 1, local_num_k, self.head_k_dim) v = v.reshape(num_seqs, 1, local_num_v, self.head_v_dim) beta = b_all.sigmoid().unsqueeze(1) # (num_seqs, 1, local_num_v) g = (-self.A_log.float().exp() * F.softplus(a_all.float() + self.dt_bias) ).unsqueeze(1) # (num_seqs, 1, local_num_v) q = q.repeat_interleave(self.head_expand_ratio, dim=2) k = k.repeat_interleave(self.head_expand_ratio, dim=2) core_out, last_state = _torch_recurrent_gated_delta_rule( q, k, v, g, beta, initial_state=temporal_state, output_final_state=True, use_qk_l2norm_in_kernel=True, ) if last_state is not None: temporal_state.copy_(last_state) z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim) core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim) normed = self.norm( core_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)) normed = normed.reshape(num_seqs, -1) out, _ = self.out_proj(normed) if torch.isnan(out).any(): logger.warning("NaN in decode GatedDeltaNet layer %d (frac=%.4f), replacing with zeros", self.layer_idx, torch.isnan(out).float().mean().item()) out = torch.nan_to_num(out, nan=0.0) return out # --------------------------------------------------------------------------- # Full Attention (with gated q — unique to Qwen3.5) # --------------------------------------------------------------------------- class Qwen3_5FullAttention(nn.Module): def __init__( self, text_cfg, layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.layer_idx = layer_idx self.hidden_size = text_cfg.hidden_size # 5120 self.num_heads = text_cfg.num_attention_heads # 24 self.num_kv_heads = text_cfg.num_key_value_heads # 4 self.head_dim = text_cfg.head_dim # 256 self.rms_norm_eps = text_cfg.rms_norm_eps tp_size = get_tensor_model_parallel_world_size() self.local_num_heads = self.num_heads // tp_size self.local_num_kv_heads = max(1, self.num_kv_heads // tp_size) self.local_q_dim = self.local_num_heads * self.head_dim self.local_kv_dim = self.local_num_kv_heads * self.head_dim self.scaling = self.head_dim ** -0.5 # q_proj includes gate: output = num_heads * head_dim * 2 self.q_proj = ColumnParallelLinear( self.hidden_size, self.num_heads * self.head_dim * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.q_proj") self.k_proj = ColumnParallelLinear( self.hidden_size, self.num_kv_heads * self.head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.k_proj") self.v_proj = ColumnParallelLinear( self.hidden_size, self.num_kv_heads * self.head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.v_proj") self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj") self.q_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps) self.k_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps) # Partial RoPE: rotary_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64 rope_params = getattr(text_cfg, "rope_parameters", {}) or {} rope_theta = rope_params.get("rope_theta", 10_000_000) partial_factor = rope_params.get("partial_rotary_factor", 0.25) rotary_dim = int(self.head_dim * partial_factor) self.rotary_emb = get_rope( self.head_dim, rotary_dim=rotary_dim, max_position=text_cfg.max_position_embeddings, base=rope_theta, ) self.attn = Attention( self.local_num_heads, self.head_dim, self.scaling, num_kv_heads=self.local_num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: total_tokens = hidden_states.shape[0] # q_proj output includes gate (dim doubled) qg, _ = self.q_proj(hidden_states) # (total, local_num_heads * head_dim * 2) qg = qg.view(total_tokens, self.local_num_heads, self.head_dim * 2) q = qg[:, :, :self.head_dim].reshape(total_tokens, -1) gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1) k, _ = self.k_proj(hidden_states) # (total, local_kv_dim) v, _ = self.v_proj(hidden_states) # Per-head RMSNorm q = self.q_norm.forward_cuda( q.view(total_tokens, self.local_num_heads, self.head_dim) .contiguous()).view(total_tokens, -1) k = self.k_norm.forward_cuda( k.view(total_tokens, self.local_num_kv_heads, self.head_dim) .contiguous()).view(total_tokens, -1) q, k = self.rotary_emb(positions, q, k) attn_out = self.attn(q, k, v, kv_cache, attn_metadata) # Multiply by sigmoid gate before output projection attn_out = attn_out * torch.sigmoid(gate.float()).to(attn_out.dtype) output, _ = self.o_proj(attn_out) return output # --------------------------------------------------------------------------- # MLP (SwiGLU, same as Qwen2/Qwen3) # --------------------------------------------------------------------------- class Qwen3_5MLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}") self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x # --------------------------------------------------------------------------- # Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention) # --------------------------------------------------------------------------- class Qwen3_5DecoderLayer(nn.Module): def __init__( self, text_cfg, layer_idx: int, layer_type: str, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.layer_type = layer_type self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size, eps=text_cfg.rms_norm_eps) self.post_attention_layernorm = GemmaRMSNorm(text_cfg.hidden_size, eps=text_cfg.rms_norm_eps) if layer_type == "linear_attention": self.linear_attn = GatedDeltaNet(text_cfg, layer_idx, quant_config=quant_config) else: self.self_attn = Qwen3_5FullAttention( text_cfg, layer_idx, cache_config=cache_config, quant_config=quant_config, prefix=f"layers.{layer_idx}.self_attn", ) self.mlp = Qwen3_5MLP( hidden_size=text_cfg.hidden_size, intermediate_size=text_cfg.intermediate_size, hidden_act=text_cfg.hidden_act, quant_config=quant_config, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], # Only for linear_attention layers: conv_state: Optional[torch.Tensor] = None, temporal_state: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) if self.layer_type == "linear_attention": hidden_states = self.linear_attn( hidden_states, attn_metadata, conv_state, temporal_state) else: hidden_states = self.self_attn( positions, hidden_states, kv_cache, attn_metadata) hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual # --------------------------------------------------------------------------- # Full transformer model # --------------------------------------------------------------------------- class Qwen3_5Model(nn.Module): def __init__( self, text_cfg, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.text_cfg = text_cfg self.embed_tokens = VocabParallelEmbedding( text_cfg.vocab_size, text_cfg.hidden_size) self.layers = nn.ModuleList([ Qwen3_5DecoderLayer( text_cfg, i, text_cfg.layer_types[i], cache_config=cache_config, quant_config=quant_config) for i in range(text_cfg.num_hidden_layers) ]) self.norm = GemmaRMSNorm(text_cfg.hidden_size, eps=text_cfg.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, conv_states: torch.Tensor, # (num_linear_layers, batch, ...) temporal_states: torch.Tensor, # (num_linear_layers, batch, ...) ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None attn_idx = 0 linear_idx = 0 for layer in self.layers: if layer.layer_type == "linear_attention": hidden_states, residual = layer( positions, hidden_states, kv_cache=None, attn_metadata=attn_metadata, residual=residual, conv_state=conv_states[linear_idx], temporal_state=temporal_states[linear_idx], ) linear_idx += 1 else: kv_cache = kv_caches[attn_idx] hidden_states, residual = layer( positions, hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, residual=residual, ) attn_idx += 1 hidden_states, _ = self.norm(hidden_states, residual) return hidden_states # --------------------------------------------------------------------------- # Top-level CausalLM wrapper with MambaCacheManager # --------------------------------------------------------------------------- class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA): has_inner_state = True supports_lora = True packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } supported_lora_modules = [ "gate_up_proj", "down_proj", "o_proj", ] embedding_modules = {} embedding_padding_modules = [] def __init__( self, config, # Qwen3_5Config (top-level) cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, scheduler_config: Optional[SchedulerConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.scheduler_config = scheduler_config # The text config holds all architecture parameters text_cfg = config.text_config self.text_cfg = text_cfg # Pre-compute counts self.num_linear_layers = sum( 1 for lt in text_cfg.layer_types if lt == "linear_attention") self.num_attn_layers = sum( 1 for lt in text_cfg.layer_types if lt == "full_attention") # DeltaNet state dimensions (per layer, per sequence, TP-sharded) tp_size = get_tensor_model_parallel_world_size() self.conv_dim = (text_cfg.linear_num_key_heads * text_cfg.linear_key_head_dim * 2 + text_cfg.linear_num_value_heads * text_cfg.linear_value_head_dim) self.num_v_heads = text_cfg.linear_num_value_heads self.head_k_dim = text_cfg.linear_key_head_dim self.head_v_dim = text_cfg.linear_value_head_dim self.conv_kernel_size = text_cfg.linear_conv_kernel_dim self.model = Qwen3_5Model( text_cfg, cache_config=cache_config, quant_config=quant_config, ) self.lm_head = ParallelLMHead( text_cfg.vocab_size, text_cfg.hidden_size, quant_config=quant_config, ) self.logits_processor = LogitsProcessor(text_cfg.vocab_size) self.sampler = Sampler() # Lazy initialised in first forward call self.mamba_cache: Optional[MambaCacheManager] = None def _get_mamba_cache_shape(self): tp_size = get_tensor_model_parallel_world_size() # Each sequence's state is stored in float32 conv_state_shape = (self.conv_dim // tp_size, self.conv_kernel_size - 1) temporal_state_shape = ( self.num_v_heads // tp_size, self.head_k_dim, self.head_v_dim) return conv_state_shape, temporal_state_shape def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs, ) -> torch.Tensor: if self.mamba_cache is None: if self.scheduler_config is not None: max_batch_size = _get_graph_batch_size( self.scheduler_config.max_num_seqs) else: max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + 2 self.mamba_cache = MambaCacheManager( torch.float32, self.num_linear_layers, max_batch_size, *self._get_mamba_cache_shape(), ) mamba_tensors = self.mamba_cache.current_run_tensors( input_ids, attn_metadata, **kwargs) # conv_states: (num_linear_layers, batch, local_conv_dim, kernel-1) # temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim) conv_states, temporal_states = mamba_tensors hidden_states = self.model( input_ids, positions, kv_caches, attn_metadata, conv_states, temporal_states) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.logits_processor(self.lm_head, hidden_states, sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: return self.sampler(logits, sampling_metadata) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, weight_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: # Skip vision and MTP branches if (name.startswith("model.visual") or name.startswith("mtp.") or name.startswith("model.mtp")): continue # Remap checkpoint prefix → module path # Checkpoint: "model.language_model.{rest}" → our module: "model.{rest}" # Checkpoint: "lm_head.weight" → our module: "lm_head.weight" if name.startswith("model.language_model."): name = "model." + name[len("model.language_model."):] # lm_head is already at top level — no change needed # Skip positional embedding caches if "rotary_emb.inv_freq" in name: continue # Remap conv1d.weight → conv1d_weight # The conv has depth (1) dim in the checkpoint that we handle separately if ".linear_attn.conv1d.weight" in name: name = name.replace(".linear_attn.conv1d.weight", ".linear_attn.conv1d_weight") # Stacked param loading (gate_up_proj) for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: break if name not in params_dict: break param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if name.endswith(".bias") and name not in params_dict: continue if name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)