790 lines
40 KiB
Python
790 lines
40 KiB
Python
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union, List
|
|
import itertools
|
|
import os
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
from torch_vacc.vacc.custom_ops_cpu import (
|
|
w8a8_block_fp8_linear as w8a8_block_fp8_linear_cpu,
|
|
)
|
|
|
|
from vllm.attention import Attention
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
|
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
|
|
get_tensor_model_parallel_world_size,
|
|
get_tensor_model_parallel_rank,
|
|
tensor_model_parallel_all_reduce)
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear)
|
|
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
|
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
|
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock, Qwen3MoeMLP
|
|
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
|
|
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
|
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
|
|
|
|
from ..ops.mrope_op import get_sin_cos_mrope
|
|
from ..ops.qwen3_fused_moe import vacc_fused_prefill_moe_fp8, vacc_fused_decode_moe_fp8, recompute_moe_layer_blocksize
|
|
from .vars import *
|
|
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# uniform the params names from different quantize method
|
|
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
|
|
if isinstance(quant_method, UnquantizedLinearMethod):
|
|
fused_params[name + '_weight'] = layer.weight
|
|
fused_params[name + '_weight_scale'] = None
|
|
fused_params[name + '_bias'] = None
|
|
fused_params[name + '_qzeros'] = None
|
|
if isinstance(quant_method, Fp8LinearMethod):
|
|
fused_params[name + '_weight'] = layer.weight
|
|
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
|
|
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
|
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
|
elif isinstance(quant_method, GPTQLinearMethod):
|
|
fused_params[name + '_weight'] = layer.qweight
|
|
fused_params[name + '_weight_scale'] = layer.scales
|
|
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
|
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
|
elif isinstance(quant_method, AWQLinearMethod):
|
|
fused_params[name + '_weight'] = layer.qweight
|
|
fused_params[name + '_weight_scale'] = layer.scales
|
|
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
|
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
|
else:
|
|
raise ValueError(f"Unsupported quant_method: {quant_method}")
|
|
|
|
|
|
|
|
def apply_w8a8_block_fp8_linear_v2(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scale: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
input_scale = None
|
|
# View input as 2D matrix for fp8 methods
|
|
input_2d = input.view(-1, input.shape[-1])
|
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
|
block_size = [
|
|
weight.shape[-2] // weight_scale.shape[-2],
|
|
weight.shape[-1] // weight_scale.shape[-1],
|
|
]
|
|
|
|
if input.device.type == "vacc":
|
|
output = torch.vacc.w8a8_block_fp8_linear(
|
|
input_2d, weight, input_scale, weight_scale, block_size
|
|
)
|
|
else:
|
|
output = w8a8_block_fp8_linear_cpu(
|
|
input_2d, weight, input_scale, weight_scale, block_size
|
|
)
|
|
|
|
if bias is not None:
|
|
output = output + bias
|
|
return output.to(dtype=input.dtype).view(*output_shape)
|
|
|
|
def vacc_fused_attn_qwen3_naive(
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
hidden_states_norm_weight: torch.Tensor,
|
|
qkv_proj_weight: torch.Tensor,
|
|
qkv_proj_weight_scale: torch.Tensor,
|
|
qkv_proj_bias: Optional[torch.Tensor],
|
|
qkv_proj_qzeros: Optional[torch.Tensor],
|
|
q_layernorm_weight: torch.Tensor,
|
|
k_layernorm_weight: torch.Tensor,
|
|
sin_cache: List[torch.Tensor],
|
|
cos_cache: List[torch.Tensor],
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
block_group_size: int,
|
|
o_proj_weight: torch.Tensor,
|
|
o_proj_weight_scale: torch.Tensor,
|
|
o_proj_bias: Optional[torch.Tensor],
|
|
o_proj_qzeros: Optional[torch.Tensor],
|
|
seq_lens: List[int],
|
|
sm_scale: float,
|
|
num_attention_heads: int,
|
|
num_key_value_heads: int,
|
|
flash_attention: bool,
|
|
is_decode: bool,
|
|
reduce_result: bool,
|
|
world_size: int,
|
|
rank: int,
|
|
group_id: int,
|
|
dev_info: List[int] | Tuple[int],
|
|
block_size: int = 16
|
|
):
|
|
if residual is not None:
|
|
hidden_states = hidden_states + residual
|
|
residual_out = hidden_states
|
|
|
|
hidden_states = torch.vacc.rms_norm(
|
|
hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0)
|
|
|
|
# NOTE: for qwen3 and qwen2.5, head_dim is always 128
|
|
head_dim = 128
|
|
|
|
# qkv gen
|
|
qkv = apply_w8a8_block_fp8_linear_v2(
|
|
input=hidden_states,
|
|
weight=qkv_proj_weight,
|
|
weight_scale=qkv_proj_weight_scale)
|
|
|
|
num_q_heads = num_attention_heads // world_size
|
|
num_kv_heads = num_key_value_heads // world_size
|
|
|
|
q_size = head_dim * num_q_heads
|
|
kv_size = head_dim * num_kv_heads
|
|
|
|
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
|
|
|
# Add qk-norm
|
|
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
|
# q_by_head = self.q_norm.forward_native(q_by_head)
|
|
q_norm = torch.vacc.rms_norm(q_by_head, q_layernorm_weight, 1e-6)
|
|
# q = q_by_head.view(q.shap
|
|
|
|
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
|
# k_by_head = k_norm.forward_native(k_by_head)
|
|
k_norm = torch.vacc.rms_norm(k_by_head, k_layernorm_weight, 1e-6)
|
|
# k = k_by_head.view(k.shap
|
|
|
|
v = v.view(-1, num_kv_heads, head_dim)
|
|
|
|
# q, k = self.rotary_emb(positions, q, k)
|
|
start = 0
|
|
attn_outs = []
|
|
|
|
if is_decode:
|
|
# convert block_tables to 8K group index
|
|
block_per_group = block_group_size // block_size
|
|
block_tables = (block_tables // block_per_group).to(torch.int32)
|
|
# logger.warning(f"decode block table: {block_tables}")
|
|
|
|
num_blocks = kv_cache.shape[1]
|
|
key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim)
|
|
value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim)
|
|
|
|
# bs loop
|
|
for i in range(len(seq_lens)):
|
|
if not is_decode:
|
|
# prefill
|
|
end = start + seq_lens[i]
|
|
else:
|
|
# decode
|
|
end = start + 1
|
|
|
|
cos = cos_cache[i].unsqueeze(-2)
|
|
sin = sin_cache[i].unsqueeze(-2)
|
|
|
|
q, k = torch.vacc.RotaryPosEmbedding(
|
|
q_norm[start : end, ...], k_norm[start : end, ...], cos, sin, 0, "neox")
|
|
|
|
# cache concat
|
|
torch.vacc.reshape_and_cache_attention(k, key_cache_split, slot_mapping[start : end, ...])
|
|
torch.vacc.reshape_and_cache_attention(v[start : end, ...], value_cache_split, slot_mapping[start : end, ...])
|
|
|
|
# attn_output = self.attn(q, k, v)
|
|
if not is_decode:
|
|
# prefill
|
|
attn_out = torch.vacc.scaled_dot_product_attention(
|
|
query=q,
|
|
key=k,
|
|
value=v[start : end, ...],
|
|
attn_mask = None,
|
|
dropout_p = 0.0,
|
|
is_causal = True, #causal_attn and not self.need_mask,
|
|
is_train = False,
|
|
recompute = False,
|
|
flash_attention = False,
|
|
sm_scale=sm_scale)
|
|
else:
|
|
# decode
|
|
key_cache = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
|
value_cache = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
|
|
|
k_slices = key_cache[block_tables[i], ...]
|
|
k_cached = torch.cat(
|
|
[k_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
|
|
dim=0,
|
|
)
|
|
k_cached = k_cached.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
|
|
|
|
v_slices = value_cache[block_tables[i], ...]
|
|
v_cached = torch.cat(
|
|
[v_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
|
|
dim=0,
|
|
)
|
|
v_cached = v_cached.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
|
|
attn_out = torch.vacc.scaled_dot_product_attention(
|
|
query=q,
|
|
key=k_cached,
|
|
value=v_cached,
|
|
attn_mask=None,
|
|
dropout_p=0,
|
|
is_causal=False,
|
|
is_train=False,
|
|
recompute=False,
|
|
flash_attention=False,#flash_attention,
|
|
sm_scale=sm_scale)
|
|
|
|
attn_outs.append(attn_out)
|
|
# update start
|
|
start = end
|
|
attn_out = torch.cat(attn_outs, dim=0)
|
|
|
|
# output, _ = self.o_proj(attn_output)
|
|
o_proj = apply_w8a8_block_fp8_linear_v2(
|
|
input = attn_out.reshape(hidden_states.shape[0], -1),
|
|
weight = o_proj_weight,
|
|
weight_scale = o_proj_weight_scale,
|
|
)
|
|
|
|
if reduce_result:
|
|
o_proj = tensor_model_parallel_all_reduce(o_proj)
|
|
|
|
if residual is not None:
|
|
return o_proj, residual_out
|
|
return o_proj
|
|
|
|
def Qwen3MoeSparseMoeBlock__init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
):
|
|
super(Qwen3MoeSparseMoeBlock, self).__init__()
|
|
config = vllm_config.model_config.hf_text_config
|
|
parallel_config = vllm_config.parallel_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
self.ep_group = get_ep_group().device_group
|
|
self.ep_rank = self.ep_group.rank()
|
|
self.ep_size = self.ep_group.size()
|
|
self.n_routed_experts = config.num_experts
|
|
|
|
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
|
|
|
if self.tp_size > config.num_experts:
|
|
raise ValueError(
|
|
f"Tensor parallel size {self.tp_size} is greater than "
|
|
f"the number of experts {config.num_experts}.")
|
|
|
|
# Load balancing settings.
|
|
vllm_config = get_current_vllm_config()
|
|
eplb_config = vllm_config.parallel_config.eplb_config
|
|
self.enable_eplb = parallel_config.enable_eplb
|
|
|
|
self.n_logical_experts = self.n_routed_experts
|
|
self.n_redundant_experts = eplb_config.num_redundant_experts
|
|
self.n_physical_experts = (self.n_logical_experts +
|
|
self.n_redundant_experts)
|
|
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
|
|
|
self.physical_expert_start = (self.ep_rank *
|
|
self.n_local_physical_experts)
|
|
self.physical_expert_end = (self.physical_expert_start +
|
|
self.n_local_physical_experts)
|
|
|
|
self.experts = FusedMoE(num_experts=self.n_routed_experts,
|
|
top_k=config.num_experts_per_tok,
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.moe_intermediate_size,
|
|
reduce_results=True,
|
|
renormalize=config.norm_topk_prob,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.experts",
|
|
enable_eplb=self.enable_eplb,
|
|
num_redundant_experts=self.n_redundant_experts,
|
|
is_sequence_parallel=self.is_sequence_parallel)
|
|
|
|
self.gate = ReplicatedLinear(config.hidden_size,
|
|
config.num_experts,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate")
|
|
|
|
#patch here to transpose w2/w2_scale's data arrange , only for block quant
|
|
if hasattr(self.experts.quant_method, 'quant_config') and hasattr(self.experts.quant_method.quant_config, 'weight_block_size'):
|
|
self.experts.w2_weight.data = self.experts.w2_weight.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
|
self.experts.w2_weight_scale_inv.data = self.experts.w2_weight_scale_inv.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
|
if hasattr(self.experts, 'w2_weight_scale_inv_prefill'):
|
|
self.experts.w2_weight_scale_inv_prefill.data = self.experts.w2_weight_scale_inv_prefill.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
|
|
|
def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"],
|
|
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]],
|
|
positions: Union[torch.Tensor, list],
|
|
is_decode: bool):
|
|
if isinstance(rotary_emb, MRotaryEmbedding):
|
|
# get mrope sin/cos
|
|
cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions)
|
|
if len(attn_metadata.seq_lens) > 1:
|
|
if is_decode:
|
|
cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens))
|
|
sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens))
|
|
else:
|
|
cos_cache = torch.split(cos_cache, attn_metadata.seq_lens)
|
|
sin_cache = torch.split(sin_cache, attn_metadata.seq_lens)
|
|
else:
|
|
cos_cache = [cos_cache]
|
|
sin_cache = [sin_cache]
|
|
else:
|
|
if is_decode:
|
|
positions = [i - 1 for i in attn_metadata.seq_lens]
|
|
cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
|
sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
|
else:
|
|
cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
|
sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
|
return cos_cache, sin_cache
|
|
|
|
class Qwen3MoeDecoderLayer(nn.Module):
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
cos_cache: list[torch.Tensor],
|
|
sin_cache: list[torch.Tensor]
|
|
) -> torch.Tensor:
|
|
# Self Attention
|
|
forward_context: ForwardContext = get_forward_context()
|
|
attn_metadata = forward_context.attn_metadata
|
|
|
|
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
|
|
if USE_FUSED_QWEN_ATTENTION:
|
|
if not hasattr(self.self_attn, "fused_params"):
|
|
self.self_attn.fused_params = {}
|
|
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
|
|
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
|
|
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
|
|
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
|
|
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
|
|
|
|
hidden_states, residual = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
cos_cache=cos_cache,
|
|
sin_cache=sin_cache)
|
|
else:
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
cos_cache=cos_cache,
|
|
sin_cache=sin_cache
|
|
)
|
|
|
|
# # Fully Connected
|
|
# hidden_states, residual = self.post_attention_layernorm(
|
|
# hidden_states, residual)
|
|
# hidden_states = self.mlp(hidden_states)
|
|
# return hidden_states, residual
|
|
|
|
# TODO for noquant or not block_quant
|
|
if not hasattr(self.mlp.experts.quant_method, 'quant_config') or \
|
|
not hasattr(self.mlp.experts.quant_method.quant_config, 'weight_block_size'):
|
|
if not isinstance(self.mlp.experts.quant_method, MoeWNA16Method):
|
|
logger.warning('TODO for noquant or other quant')
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
if isinstance(attn_metadata, dict):
|
|
# is_prefill = get_forward_context().attn_metadata['test'].prefill_metadata
|
|
attn_metadata_0 = get_forward_context().attn_metadata.items().__iter__().__next__()[1]
|
|
is_prefill = attn_metadata_0.prefill_metadata
|
|
|
|
else:
|
|
is_prefill = get_forward_context().attn_metadata.prefill_metadata
|
|
|
|
quant_method = self.mlp.experts.quant_method if isinstance(self.mlp, Qwen3MoeSparseMoeBlock) \
|
|
else self.mlp.down_proj.quant_method
|
|
|
|
if is_prefill is not None:
|
|
if isinstance(quant_method, MoeWNA16Method):
|
|
try:
|
|
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_prefill_moe_gptq_int4
|
|
return vacc_fused_prefill_moe_gptq_int4(hidden_states,
|
|
residual,
|
|
self.post_attention_layernorm,
|
|
self.mlp.gate,
|
|
self.mlp.experts)
|
|
except Exception as e:
|
|
print(f'vacc_fused_prefill_moe_gptq_int4 fail: {e}')
|
|
else:
|
|
recompute_moe_layer_blocksize(self.mlp.experts)
|
|
try:
|
|
return vacc_fused_prefill_moe_fp8(hidden_states,
|
|
residual,
|
|
self.post_attention_layernorm,
|
|
self.mlp.gate,
|
|
self.mlp.experts)
|
|
except Exception as e:
|
|
print(f'vacc_fused_prefill_moe_fp8 fail: {e}')
|
|
else:
|
|
if isinstance(quant_method, MoeWNA16Method):
|
|
try:
|
|
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_decode_moe_gptq_int4
|
|
return vacc_fused_decode_moe_gptq_int4(hidden_states,
|
|
residual,
|
|
self.post_attention_layernorm,
|
|
self.mlp.gate,
|
|
self.mlp.experts)
|
|
except Exception as e:
|
|
print(f'vacc_fused_decode_moe_gptq_int4 fail: {e}')
|
|
else:
|
|
try:
|
|
return vacc_fused_decode_moe_fp8(hidden_states,
|
|
residual,
|
|
self.post_attention_layernorm,
|
|
self.mlp.gate,
|
|
self.mlp.experts)
|
|
except Exception as e:
|
|
print(f'vacc_fused_decode_moe_fp8 fail: {e}')
|
|
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
class Qwen3MoeAttention(nn.Module):
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None, # new added params
|
|
cos_cache: list[torch.Tensor] = None,
|
|
sin_cache: list[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
|
|
forward_context: ForwardContext = get_forward_context()
|
|
attn_metadata_all = forward_context.attn_metadata
|
|
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
|
|
|
# reshape kvcache
|
|
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
|
|
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
|
|
|
|
if isinstance(attn_metadata_all, dict):
|
|
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
|
is_decode = attn_metadata.prefill_metadata is None
|
|
else:
|
|
is_decode = attn_metadata_all.prefill_metadata is None
|
|
attn_metadata = attn_metadata_all
|
|
|
|
|
|
reduce_result = is_decode
|
|
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
|
|
# # only support 4M now
|
|
# if total_bytes < 4194304:
|
|
# reduce_result = True
|
|
|
|
if USE_FUSED_QWEN_ATTENTION:
|
|
if cos_cache is None or sin_cache is None:
|
|
cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode)
|
|
|
|
if residual is None:
|
|
res_out = hidden_states
|
|
#from torch_vacc.vacc import fuse_atten_qwen3
|
|
attn_outs = None
|
|
if not is_decode:
|
|
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
|
if memory_recycler is not None:
|
|
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
|
|
|
total_num_kv_heads = self.total_num_kv_heads
|
|
if self.total_num_kv_heads < get_tp_group().world_size:
|
|
assert get_tp_group().world_size % self.total_num_kv_heads == 0
|
|
total_num_kv_heads = get_tp_group().world_size
|
|
attn_outs = torch.vacc.fuse_atten_qwen3(
|
|
# attn_outs = vacc_fused_attn_qwen3_naive(
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
|
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
|
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
|
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
|
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
|
|
q_layernorm_weight=self.fused_params['q_norm_weight'],
|
|
k_layernorm_weight=self.fused_params['k_norm_weight'],
|
|
sin_cache=sin_cache,
|
|
cos_cache=cos_cache,
|
|
slot_mapping=attn_metadata.slot_mapping,
|
|
kv_cache=kv_cache,
|
|
block_tables=attn_metadata.block_tables,
|
|
block_group_size=env_blk_grp_size,
|
|
o_proj_weight=self.fused_params['o_proj_weight'],
|
|
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
|
o_proj_bias=self.fused_params['o_proj_bias'],
|
|
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
|
|
seq_lens=attn_metadata.seq_lens,
|
|
sm_scale=self.scaling,
|
|
num_attention_heads=self.total_num_heads,
|
|
num_key_value_heads=total_num_kv_heads,
|
|
flash_attention=is_decode, # decode use flash_atten by default
|
|
is_decode=is_decode,
|
|
reduce_result=reduce_result,
|
|
world_size=get_tp_group().world_size,
|
|
rank=get_tp_group().rank_in_group,
|
|
group_id=get_tp_group().group_id,
|
|
dev_info=get_tp_group().rank_device_infos,
|
|
output_opt=attn_outs,
|
|
res_opt=residual)
|
|
# debug_qwen3_moe_attention_prefill(hidden_states=hidden_states,
|
|
# residual=residual,
|
|
# attn_outs=attn_outs,
|
|
# fused_params=self.fused_params,
|
|
# attn_metadata=attn_metadata,
|
|
# is_decode=is_decode,
|
|
# sin_cache=sin_cache,
|
|
# cos_cache=cos_cache,
|
|
# kv_cache=kv_cache,
|
|
# env_blk_grp_size=env_blk_grp_size,
|
|
# scaling=self.scaling,
|
|
# total_num_heads=self.total_num_heads,
|
|
# total_num_kv_heads=self.total_num_kv_heads,
|
|
# world_size=get_tp_group().world_size,
|
|
# rank=get_tp_group().rank_in_group,
|
|
# group_id=get_tp_group().group_id,
|
|
# dev_info=get_tp_group().rank_device_infos)
|
|
|
|
if residual is None:
|
|
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
|
else:
|
|
res_out = attn_outs[1]
|
|
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
|
|
|
return attn_out, res_out
|
|
else:
|
|
# orig code
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
|
|
# Add qk-norm
|
|
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
|
self.head_dim)
|
|
q_by_head = self.q_norm.forward_native(q_by_head)
|
|
|
|
q = q_by_head.view(q.shape)
|
|
|
|
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
|
self.head_dim)
|
|
k_by_head = self.k_norm.forward_native(k_by_head)
|
|
|
|
k = k_by_head.view(k.shape)
|
|
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
|
|
|
|
attn_output = self.attn(q, k, v)
|
|
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
class Qwen3MoeModel(nn.Module):
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
deepstack_input_embeds: Optional[IntermediateTensors] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
|
|
forward_context: ForwardContext = get_forward_context()
|
|
attn_metadata_all = forward_context.attn_metadata
|
|
if not hasattr(self, "weight_capture"):
|
|
from vllm_vacc.vllm.model_executor.models.weight_capture.qwen3_moe_weight_capture import Qwen3Moe_WeightCapture
|
|
self.weight_capture = Qwen3Moe_WeightCapture(self.layers, self.start_layer, self.end_layer)
|
|
self.layer_nums = self.end_layer - self.start_layer
|
|
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
# fused layer decoder only support fp8 quant model now
|
|
use_default_layer = self.weight_capture.support_fused_weights and USE_DECODER_LAYER_FUSE_MODE
|
|
# print('Qwen3MoeModel attn_metadata', attn_metadata)
|
|
if isinstance(attn_metadata_all, dict):
|
|
# is_decode = attn_metadata_all['test'].prefill_metadata is None
|
|
# attn_metadata = attn_metadata_all['test']
|
|
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
|
is_decode = attn_metadata.prefill_metadata is None
|
|
|
|
else:
|
|
is_decode = attn_metadata_all.prefill_metadata is None
|
|
attn_metadata = attn_metadata_all
|
|
|
|
if(use_default_layer and is_decode):
|
|
from torch_vacc.vacc.custom_ops import qwen3_fuse_attention_moe_decode
|
|
|
|
layer0 = self.layers[self.start_layer]
|
|
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode=True)
|
|
|
|
for i in range(0, self.layer_nums):
|
|
layer = self.layers[i + self.start_layer]
|
|
kv_cache = layer.self_attn.attn.kv_cache[forward_context.virtual_engine]
|
|
num_kv_heads = max(1, layer.self_attn.total_num_kv_heads // get_tp_group().world_size)
|
|
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, layer.self_attn.head_dim)
|
|
total_num_kv_heads = layer.self_attn.total_num_kv_heads
|
|
if layer.self_attn.total_num_kv_heads < get_tp_group().world_size:
|
|
assert get_tp_group().world_size % layer.self_attn.total_num_kv_heads == 0
|
|
total_num_kv_heads = get_tp_group().world_size
|
|
|
|
hidden_states, residual = qwen3_fuse_attention_moe_decode(hidden_states, residual,
|
|
hidden_states_norm_weight=self.weight_capture.layer_mapper.attn_args._0_input_layernorm_weight[i],
|
|
qkv_proj_weight=self.weight_capture.layer_mapper.attn_args._1_qkv_proj_weight[i],
|
|
qkv_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._2_qkv_proj_weight_scale[i],
|
|
qkv_proj_bias=self.weight_capture.layer_mapper.attn_args._3_qkv_proj_bias[i],
|
|
qkv_proj_qzeros=self.weight_capture.layer_mapper.attn_args._4_qkv_proj_qzeros[i],
|
|
q_layernorm_weight=self.weight_capture.layer_mapper.attn_args._5_q_norm_weight[i],
|
|
k_layernorm_weight=self.weight_capture.layer_mapper.attn_args._6_k_norm_weight[i],
|
|
sin_cache=sin_cache,
|
|
cos_cache=cos_cache,
|
|
slot_mapping=attn_metadata.slot_mapping,
|
|
kv_cache=kv_cache,
|
|
block_tables=attn_metadata.block_tables,
|
|
block_group_size=env_blk_grp_size,
|
|
o_proj_weight=self.weight_capture.layer_mapper.attn_args._13_o_proj_weight[i],
|
|
o_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._14_o_proj_weight_scale[i],
|
|
o_proj_bias=self.weight_capture.layer_mapper.attn_args._15_o_proj_bias[i],
|
|
o_proj_qzeros=self.weight_capture.layer_mapper.attn_args._16_o_proj_qzeros[i],
|
|
seq_lens_num=attn_metadata.seq_lens,
|
|
sm_scale=layer.self_attn.scaling,
|
|
num_attention_heads=layer.self_attn.total_num_heads,
|
|
num_key_value_heads=total_num_kv_heads,
|
|
flash_attentiton=True,
|
|
is_decode=True,
|
|
reduce_result=True,
|
|
# moe
|
|
rms_weight=self.weight_capture.layer_mapper.moe_args._0_rms_norm_weight[i],
|
|
moe_weight_13=self.weight_capture.layer_mapper.moe_args._1_w13_weight[i],
|
|
moe_weight_2=self.weight_capture.layer_mapper.moe_args._2_w2_weight[i],
|
|
moe_weight_13_dequat=self.weight_capture.layer_mapper.moe_args._3_w13_weight_scale_inv[i],
|
|
moe_weight_2_dequant=self.weight_capture.layer_mapper.moe_args._4_w2_weight_scale_inv[i],
|
|
gate_weight=self.weight_capture.layer_mapper.moe_args._5_gate_weight[i],
|
|
block_size_13=self.weight_capture.layer_mapper.moe_args._6_w13_block_size,
|
|
block_size_2=self.weight_capture.layer_mapper.moe_args._7_w2_block_size,
|
|
# dist
|
|
world_size=self.weight_capture.layer_mapper.dist_args._0_world_size,
|
|
rank=self.weight_capture.layer_mapper.dist_args._1_rank,
|
|
group_id=self.weight_capture.layer_mapper.dist_args._2_group_id,
|
|
dev_info=self.weight_capture.layer_mapper.dist_args._3_dev_info)
|
|
else:
|
|
layer0 = self.layers[self.start_layer]
|
|
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode)
|
|
|
|
for i in range(self.start_layer, self.end_layer):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(positions, hidden_states, residual, cos_cache, sin_cache )
|
|
if deepstack_input_embeds is not None and i in range(0, len(deepstack_input_embeds)):
|
|
if isinstance(deepstack_input_embeds, IntermediateTensors):
|
|
hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{i}"]
|
|
elif isinstance(deepstack_input_embeds, torch.Tensor):
|
|
hidden_states = hidden_states + deepstack_input_embeds[i]
|
|
else:
|
|
raise ValueError(f'unsupported type: {type(deepstack_input_embeds)}')
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states": hidden_states,
|
|
"residual": residual
|
|
})
|
|
|
|
if residual is not None:
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
else:
|
|
hidden_states = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
class Qwen3MoeForCausalLM(nn.Module):
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
deepstack_input_embeds = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
if isinstance(attn_metadata, dict):
|
|
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
|
if attn_metadata.prefill_metadata is not None:
|
|
from .memory.memory_recycling import alloc_memory_recycler
|
|
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
|
if hasattr(attn_metadata, 'num_prefill_tokens'):
|
|
tokens = attn_metadata.num_prefill_tokens
|
|
else:
|
|
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
|
|
|
|
vllm_model_mode = "qwen3_moe"
|
|
config_infos = vllm_vacc_config_manager().get_model_infos()
|
|
if config_infos != "default":
|
|
vllm_model_mode = config_infos
|
|
|
|
if get_tp_group().rank_in_group == 0:
|
|
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
|
|
logger.info(memory_infos)
|
|
|
|
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size, dtype=self.lm_head.weight.dtype):
|
|
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
|
|
|
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds, deepstack_input_embeds)
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
|
|
from .memory.memory_recycling import init_huge_memory_allocator
|
|
from .vars import LLM_MAX_PREFILL_SEQ_LEN
|
|
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
|
|
|
# default is deepseek, config can set to ['deepseek_mtp',]
|
|
model_name = "qwen3_moe"
|
|
config_infos = vllm_vacc_config_manager().get_model_infos()
|
|
if config_infos != "default":
|
|
model_name = config_infos
|
|
|
|
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
|
|
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
|
|
|
|
from vllm.model_executor.models.utils import AutoWeightsLoader
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights) |