"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union, List import itertools import os import torch from torch import nn from transformers import PretrainedConfig from torch_vacc.vacc.custom_ops_cpu import ( w8a8_block_fp8_linear as w8a8_block_fp8_linear_cpu, ) from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) # from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.awq import AWQLinearMethod from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock, Qwen3MoeMLP from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method from ..ops.mrope_op import get_sin_cos_mrope from ..ops.qwen3_fused_moe import vacc_fused_prefill_moe_fp8, vacc_fused_decode_moe_fp8, recompute_moe_layer_blocksize from .vars import * from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size logger = init_logger(__name__) # uniform the params names from different quantize method def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str): if isinstance(quant_method, UnquantizedLinearMethod): fused_params[name + '_weight'] = layer.weight fused_params[name + '_weight_scale'] = None fused_params[name + '_bias'] = None fused_params[name + '_qzeros'] = None if isinstance(quant_method, Fp8LinearMethod): fused_params[name + '_weight'] = layer.weight fused_params[name + '_weight_scale'] = layer.weight_scale_inv fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros elif isinstance(quant_method, GPTQLinearMethod): fused_params[name + '_weight'] = layer.qweight fused_params[name + '_weight_scale'] = layer.scales fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros elif isinstance(quant_method, AWQLinearMethod): fused_params[name + '_weight'] = layer.qweight fused_params[name + '_weight_scale'] = layer.scales fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros else: raise ValueError(f"Unsupported quant_method: {quant_method}") def apply_w8a8_block_fp8_linear_v2( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_scale = None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] block_size = [ weight.shape[-2] // weight_scale.shape[-2], weight.shape[-1] // weight_scale.shape[-1], ] if input.device.type == "vacc": output = torch.vacc.w8a8_block_fp8_linear( input_2d, weight, input_scale, weight_scale, block_size ) else: output = w8a8_block_fp8_linear_cpu( input_2d, weight, input_scale, weight_scale, block_size ) if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) def vacc_fused_attn_qwen3_naive( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], hidden_states_norm_weight: torch.Tensor, qkv_proj_weight: torch.Tensor, qkv_proj_weight_scale: torch.Tensor, qkv_proj_bias: Optional[torch.Tensor], qkv_proj_qzeros: Optional[torch.Tensor], q_layernorm_weight: torch.Tensor, k_layernorm_weight: torch.Tensor, sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, block_group_size: int, o_proj_weight: torch.Tensor, o_proj_weight_scale: torch.Tensor, o_proj_bias: Optional[torch.Tensor], o_proj_qzeros: Optional[torch.Tensor], seq_lens: List[int], sm_scale: float, num_attention_heads: int, num_key_value_heads: int, flash_attention: bool, is_decode: bool, reduce_result: bool, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], block_size: int = 16 ): if residual is not None: hidden_states = hidden_states + residual residual_out = hidden_states hidden_states = torch.vacc.rms_norm( hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0) # NOTE: for qwen3 and qwen2.5, head_dim is always 128 head_dim = 128 # qkv gen qkv = apply_w8a8_block_fp8_linear_v2( input=hidden_states, weight=qkv_proj_weight, weight_scale=qkv_proj_weight_scale) num_q_heads = num_attention_heads // world_size num_kv_heads = num_key_value_heads // world_size q_size = head_dim * num_q_heads kv_size = head_dim * num_kv_heads q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) # q_by_head = self.q_norm.forward_native(q_by_head) q_norm = torch.vacc.rms_norm(q_by_head, q_layernorm_weight, 1e-6) # q = q_by_head.view(q.shap k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) # k_by_head = k_norm.forward_native(k_by_head) k_norm = torch.vacc.rms_norm(k_by_head, k_layernorm_weight, 1e-6) # k = k_by_head.view(k.shap v = v.view(-1, num_kv_heads, head_dim) # q, k = self.rotary_emb(positions, q, k) start = 0 attn_outs = [] if is_decode: # convert block_tables to 8K group index block_per_group = block_group_size // block_size block_tables = (block_tables // block_per_group).to(torch.int32) # logger.warning(f"decode block table: {block_tables}") num_blocks = kv_cache.shape[1] key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim) value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim) # bs loop for i in range(len(seq_lens)): if not is_decode: # prefill end = start + seq_lens[i] else: # decode end = start + 1 cos = cos_cache[i].unsqueeze(-2) sin = sin_cache[i].unsqueeze(-2) q, k = torch.vacc.RotaryPosEmbedding( q_norm[start : end, ...], k_norm[start : end, ...], cos, sin, 0, "neox") # cache concat torch.vacc.reshape_and_cache_attention(k, key_cache_split, slot_mapping[start : end, ...]) torch.vacc.reshape_and_cache_attention(v[start : end, ...], value_cache_split, slot_mapping[start : end, ...]) # attn_output = self.attn(q, k, v) if not is_decode: # prefill attn_out = torch.vacc.scaled_dot_product_attention( query=q, key=k, value=v[start : end, ...], attn_mask = None, dropout_p = 0.0, is_causal = True, #causal_attn and not self.need_mask, is_train = False, recompute = False, flash_attention = False, sm_scale=sm_scale) else: # decode key_cache = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim) value_cache = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim) k_slices = key_cache[block_tables[i], ...] k_cached = torch.cat( [k_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))], dim=0, ) k_cached = k_cached.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]] v_slices = value_cache[block_tables[i], ...] v_cached = torch.cat( [v_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))], dim=0, ) v_cached = v_cached.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]] attn_out = torch.vacc.scaled_dot_product_attention( query=q, key=k_cached, value=v_cached, attn_mask=None, dropout_p=0, is_causal=False, is_train=False, recompute=False, flash_attention=False,#flash_attention, sm_scale=sm_scale) attn_outs.append(attn_out) # update start start = end attn_out = torch.cat(attn_outs, dim=0) # output, _ = self.o_proj(attn_output) o_proj = apply_w8a8_block_fp8_linear_v2( input = attn_out.reshape(hidden_states.shape[0], -1), weight = o_proj_weight, weight_scale = o_proj_weight_scale, ) if reduce_result: o_proj = tensor_model_parallel_all_reduce(o_proj) if residual is not None: return o_proj, residual_out return o_proj def Qwen3MoeSparseMoeBlock__init__( self, vllm_config: VllmConfig, prefix: str = "", ): super(Qwen3MoeSparseMoeBlock, self).__init__() config = vllm_config.model_config.hf_text_config parallel_config = vllm_config.parallel_config quant_config = vllm_config.quant_config self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group self.ep_rank = self.ep_group.rank() self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = (self.ep_rank * self.n_local_physical_experts) self.physical_expert_end = (self.physical_expert_start + self.n_local_physical_experts) self.experts = FusedMoE(num_experts=self.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=True, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel) self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate") #patch here to transpose w2/w2_scale's data arrange , only for block quant if hasattr(self.experts.quant_method, 'quant_config') and hasattr(self.experts.quant_method.quant_config, 'weight_block_size'): self.experts.w2_weight.data = self.experts.w2_weight.data.transpose(-1,-2).contiguous().transpose(-1,-2) self.experts.w2_weight_scale_inv.data = self.experts.w2_weight_scale_inv.data.transpose(-1,-2).contiguous().transpose(-1,-2) if hasattr(self.experts, 'w2_weight_scale_inv_prefill'): self.experts.w2_weight_scale_inv_prefill.data = self.experts.w2_weight_scale_inv_prefill.data.transpose(-1,-2).contiguous().transpose(-1,-2) def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"], attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]], positions: Union[torch.Tensor, list], is_decode: bool): if isinstance(rotary_emb, MRotaryEmbedding): # get mrope sin/cos cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions) if len(attn_metadata.seq_lens) > 1: if is_decode: cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens)) sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens)) else: cos_cache = torch.split(cos_cache, attn_metadata.seq_lens) sin_cache = torch.split(sin_cache, attn_metadata.seq_lens) else: cos_cache = [cos_cache] sin_cache = [sin_cache] else: if is_decode: positions = [i - 1 for i in attn_metadata.seq_lens] cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions] sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions] else: cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens] sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens] return cos_cache, sin_cache class Qwen3MoeDecoderLayer(nn.Module): def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], cos_cache: list[torch.Tensor], sin_cache: list[torch.Tensor] ) -> torch.Tensor: # Self Attention forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata # NOTE: input_layernorm is fused in vacc_fused_attn_qwen3 if USE_FUSED_QWEN_ATTENTION: if not hasattr(self.self_attn, "fused_params"): self.self_attn.fused_params = {} self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj') set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj') hidden_states, residual = self.self_attn( positions=positions, hidden_states=hidden_states, residual=residual, cos_cache=cos_cache, sin_cache=sin_cache) else: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, cos_cache=cos_cache, sin_cache=sin_cache ) # # Fully Connected # hidden_states, residual = self.post_attention_layernorm( # hidden_states, residual) # hidden_states = self.mlp(hidden_states) # return hidden_states, residual # TODO for noquant or not block_quant if not hasattr(self.mlp.experts.quant_method, 'quant_config') or \ not hasattr(self.mlp.experts.quant_method.quant_config, 'weight_block_size'): if not isinstance(self.mlp.experts.quant_method, MoeWNA16Method): logger.warning('TODO for noquant or other quant') hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual if isinstance(attn_metadata, dict): # is_prefill = get_forward_context().attn_metadata['test'].prefill_metadata attn_metadata_0 = get_forward_context().attn_metadata.items().__iter__().__next__()[1] is_prefill = attn_metadata_0.prefill_metadata else: is_prefill = get_forward_context().attn_metadata.prefill_metadata quant_method = self.mlp.experts.quant_method if isinstance(self.mlp, Qwen3MoeSparseMoeBlock) \ else self.mlp.down_proj.quant_method if is_prefill is not None: if isinstance(quant_method, MoeWNA16Method): try: from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_prefill_moe_gptq_int4 return vacc_fused_prefill_moe_gptq_int4(hidden_states, residual, self.post_attention_layernorm, self.mlp.gate, self.mlp.experts) except Exception as e: print(f'vacc_fused_prefill_moe_gptq_int4 fail: {e}') else: recompute_moe_layer_blocksize(self.mlp.experts) try: return vacc_fused_prefill_moe_fp8(hidden_states, residual, self.post_attention_layernorm, self.mlp.gate, self.mlp.experts) except Exception as e: print(f'vacc_fused_prefill_moe_fp8 fail: {e}') else: if isinstance(quant_method, MoeWNA16Method): try: from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_decode_moe_gptq_int4 return vacc_fused_decode_moe_gptq_int4(hidden_states, residual, self.post_attention_layernorm, self.mlp.gate, self.mlp.experts) except Exception as e: print(f'vacc_fused_decode_moe_gptq_int4 fail: {e}') else: try: return vacc_fused_decode_moe_fp8(hidden_states, residual, self.post_attention_layernorm, self.mlp.gate, self.mlp.experts) except Exception as e: print(f'vacc_fused_decode_moe_fp8 fail: {e}') hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class Qwen3MoeAttention(nn.Module): def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None, # new added params cos_cache: list[torch.Tensor] = None, sin_cache: list[torch.Tensor] = None, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata_all = forward_context.attn_metadata kv_cache = self.attn.kv_cache[forward_context.virtual_engine] # reshape kvcache num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size) kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim) if isinstance(attn_metadata_all, dict): attn_metadata = attn_metadata_all.items().__iter__().__next__()[1] is_decode = attn_metadata.prefill_metadata is None else: is_decode = attn_metadata_all.prefill_metadata is None attn_metadata = attn_metadata_all reduce_result = is_decode # total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size # # only support 4M now # if total_bytes < 4194304: # reduce_result = True if USE_FUSED_QWEN_ATTENTION: if cos_cache is None or sin_cache is None: cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode) if residual is None: res_out = hidden_states #from torch_vacc.vacc import fuse_atten_qwen3 attn_outs = None if not is_decode: from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler if memory_recycler is not None: attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER total_num_kv_heads = self.total_num_kv_heads if self.total_num_kv_heads < get_tp_group().world_size: assert get_tp_group().world_size % self.total_num_kv_heads == 0 total_num_kv_heads = get_tp_group().world_size attn_outs = torch.vacc.fuse_atten_qwen3( # attn_outs = vacc_fused_attn_qwen3_naive( hidden_states=hidden_states, residual=residual, hidden_states_norm_weight=self.fused_params['input_layernorm_weight'], qkv_proj_weight=self.fused_params['qkv_proj_weight'], qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'], qkv_proj_bias=self.fused_params['qkv_proj_bias'], qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'], q_layernorm_weight=self.fused_params['q_norm_weight'], k_layernorm_weight=self.fused_params['k_norm_weight'], sin_cache=sin_cache, cos_cache=cos_cache, slot_mapping=attn_metadata.slot_mapping, kv_cache=kv_cache, block_tables=attn_metadata.block_tables, block_group_size=env_blk_grp_size, o_proj_weight=self.fused_params['o_proj_weight'], o_proj_weight_scale=self.fused_params['o_proj_weight_scale'], o_proj_bias=self.fused_params['o_proj_bias'], o_proj_qzeros=self.fused_params['o_proj_qzeros'], seq_lens=attn_metadata.seq_lens, sm_scale=self.scaling, num_attention_heads=self.total_num_heads, num_key_value_heads=total_num_kv_heads, flash_attention=is_decode, # decode use flash_atten by default is_decode=is_decode, reduce_result=reduce_result, world_size=get_tp_group().world_size, rank=get_tp_group().rank_in_group, group_id=get_tp_group().group_id, dev_info=get_tp_group().rank_device_infos, output_opt=attn_outs, res_opt=residual) # debug_qwen3_moe_attention_prefill(hidden_states=hidden_states, # residual=residual, # attn_outs=attn_outs, # fused_params=self.fused_params, # attn_metadata=attn_metadata, # is_decode=is_decode, # sin_cache=sin_cache, # cos_cache=cos_cache, # kv_cache=kv_cache, # env_blk_grp_size=env_blk_grp_size, # scaling=self.scaling, # total_num_heads=self.total_num_heads, # total_num_kv_heads=self.total_num_kv_heads, # world_size=get_tp_group().world_size, # rank=get_tp_group().rank_in_group, # group_id=get_tp_group().group_id, # dev_info=get_tp_group().rank_device_infos) if residual is None: attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs else: res_out = attn_outs[1] attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0] return attn_out, res_out else: # orig code qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm.forward_native(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm.forward_native(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class Qwen3MoeModel(nn.Module): def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, deepstack_input_embeds: Optional[IntermediateTensors] = None, ) -> Union[torch.Tensor, IntermediateTensors]: forward_context: ForwardContext = get_forward_context() attn_metadata_all = forward_context.attn_metadata if not hasattr(self, "weight_capture"): from vllm_vacc.vllm.model_executor.models.weight_capture.qwen3_moe_weight_capture import Qwen3Moe_WeightCapture self.weight_capture = Qwen3Moe_WeightCapture(self.layers, self.start_layer, self.end_layer) self.layer_nums = self.end_layer - self.start_layer if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] # fused layer decoder only support fp8 quant model now use_default_layer = self.weight_capture.support_fused_weights and USE_DECODER_LAYER_FUSE_MODE # print('Qwen3MoeModel attn_metadata', attn_metadata) if isinstance(attn_metadata_all, dict): # is_decode = attn_metadata_all['test'].prefill_metadata is None # attn_metadata = attn_metadata_all['test'] attn_metadata = attn_metadata_all.items().__iter__().__next__()[1] is_decode = attn_metadata.prefill_metadata is None else: is_decode = attn_metadata_all.prefill_metadata is None attn_metadata = attn_metadata_all if(use_default_layer and is_decode): from torch_vacc.vacc.custom_ops import qwen3_fuse_attention_moe_decode layer0 = self.layers[self.start_layer] cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode=True) for i in range(0, self.layer_nums): layer = self.layers[i + self.start_layer] kv_cache = layer.self_attn.attn.kv_cache[forward_context.virtual_engine] num_kv_heads = max(1, layer.self_attn.total_num_kv_heads // get_tp_group().world_size) kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, layer.self_attn.head_dim) total_num_kv_heads = layer.self_attn.total_num_kv_heads if layer.self_attn.total_num_kv_heads < get_tp_group().world_size: assert get_tp_group().world_size % layer.self_attn.total_num_kv_heads == 0 total_num_kv_heads = get_tp_group().world_size hidden_states, residual = qwen3_fuse_attention_moe_decode(hidden_states, residual, hidden_states_norm_weight=self.weight_capture.layer_mapper.attn_args._0_input_layernorm_weight[i], qkv_proj_weight=self.weight_capture.layer_mapper.attn_args._1_qkv_proj_weight[i], qkv_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._2_qkv_proj_weight_scale[i], qkv_proj_bias=self.weight_capture.layer_mapper.attn_args._3_qkv_proj_bias[i], qkv_proj_qzeros=self.weight_capture.layer_mapper.attn_args._4_qkv_proj_qzeros[i], q_layernorm_weight=self.weight_capture.layer_mapper.attn_args._5_q_norm_weight[i], k_layernorm_weight=self.weight_capture.layer_mapper.attn_args._6_k_norm_weight[i], sin_cache=sin_cache, cos_cache=cos_cache, slot_mapping=attn_metadata.slot_mapping, kv_cache=kv_cache, block_tables=attn_metadata.block_tables, block_group_size=env_blk_grp_size, o_proj_weight=self.weight_capture.layer_mapper.attn_args._13_o_proj_weight[i], o_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._14_o_proj_weight_scale[i], o_proj_bias=self.weight_capture.layer_mapper.attn_args._15_o_proj_bias[i], o_proj_qzeros=self.weight_capture.layer_mapper.attn_args._16_o_proj_qzeros[i], seq_lens_num=attn_metadata.seq_lens, sm_scale=layer.self_attn.scaling, num_attention_heads=layer.self_attn.total_num_heads, num_key_value_heads=total_num_kv_heads, flash_attentiton=True, is_decode=True, reduce_result=True, # moe rms_weight=self.weight_capture.layer_mapper.moe_args._0_rms_norm_weight[i], moe_weight_13=self.weight_capture.layer_mapper.moe_args._1_w13_weight[i], moe_weight_2=self.weight_capture.layer_mapper.moe_args._2_w2_weight[i], moe_weight_13_dequat=self.weight_capture.layer_mapper.moe_args._3_w13_weight_scale_inv[i], moe_weight_2_dequant=self.weight_capture.layer_mapper.moe_args._4_w2_weight_scale_inv[i], gate_weight=self.weight_capture.layer_mapper.moe_args._5_gate_weight[i], block_size_13=self.weight_capture.layer_mapper.moe_args._6_w13_block_size, block_size_2=self.weight_capture.layer_mapper.moe_args._7_w2_block_size, # dist world_size=self.weight_capture.layer_mapper.dist_args._0_world_size, rank=self.weight_capture.layer_mapper.dist_args._1_rank, group_id=self.weight_capture.layer_mapper.dist_args._2_group_id, dev_info=self.weight_capture.layer_mapper.dist_args._3_dev_info) else: layer0 = self.layers[self.start_layer] cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode) for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual, cos_cache, sin_cache ) if deepstack_input_embeds is not None and i in range(0, len(deepstack_input_embeds)): if isinstance(deepstack_input_embeds, IntermediateTensors): hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{i}"] elif isinstance(deepstack_input_embeds, torch.Tensor): hidden_states = hidden_states + deepstack_input_embeds[i] else: raise ValueError(f'unsupported type: {type(deepstack_input_embeds)}') if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) if residual is not None: hidden_states, _ = self.norm(hidden_states, residual) else: hidden_states = self.norm(hidden_states, residual) return hidden_states class Qwen3MoeForCausalLM(nn.Module): def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, deepstack_input_embeds = None, ) -> Union[torch.Tensor, IntermediateTensors]: attn_metadata = get_forward_context().attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata.items().__iter__().__next__()[1] if attn_metadata.prefill_metadata is not None: from .memory.memory_recycling import alloc_memory_recycler from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager if hasattr(attn_metadata, 'num_prefill_tokens'): tokens = attn_metadata.num_prefill_tokens else: tokens = attn_metadata.prefill_metadata.num_prefill_tokens vllm_model_mode = "qwen3_moe" config_infos = vllm_vacc_config_manager().get_model_infos() if config_infos != "default": vllm_model_mode = config_infos if get_tp_group().rank_in_group == 0: memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}' logger.info(memory_infos) if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size, dtype=self.lm_head.weight.dtype): logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens) hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds, deepstack_input_embeds) return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: from .memory.memory_recycling import init_huge_memory_allocator from .vars import LLM_MAX_PREFILL_SEQ_LEN from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager # default is deepseek, config can set to ['deepseek_mtp',] model_name = "qwen3_moe" config_infos = vllm_vacc_config_manager().get_model_infos() if config_infos != "default": model_name = config_infos if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name): logger.warning("init huge memory allocator fail. prefill memory recycling will disable") from vllm.model_executor.models.utils import AutoWeightsLoader loader = AutoWeightsLoader(self) return loader.load_weights(weights)