This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from .vars import *
class BertLayer(nn.Module):
def forward(self, hidden_states: torch.Tensor):
if USE_FUSED_BERT_ATTENTION:
tp_group = get_tp_group()
world_size = tp_group.world_size
rank = tp_group.rank_in_group
total_bytes = hidden_states.numel() * hidden_states.element_size() * world_size
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
else:
attn_metadata = attn_metadata_all
# (matmul + bias_add) with TP + all_reduce结构为了避免重复加bias只对rank 0下发bias做bias add
# bert layer里BertSelfOutput和BertOutput模块存在这种结构对应的bias参数是下面的self_bias和output_bias
if total_bytes < 4194304 or world_size == 1:
# 1. TP场景all_reduce输入小于4MB时会在以下融合算子里调用dsp all_reduce大于等于4MB时由于限制需要在外面调用vccl all_reduce
# 2. 没有TP的场景也会调用下面的融合算子
output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=self.attention.self.qkv_proj.weight,
qkv_bias=self.attention.self.qkv_proj.bias,
self_weight=self.attention.output.dense.weight,
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
self_norm_weight=self.attention.output.LayerNorm.weight,
self_norm_bias=self.attention.output.LayerNorm.bias,
intermediate_weight=self.intermediate.dense.weight,
intermediate_bias=self.intermediate.dense.bias,
output_weight=self.output.dense.weight,
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
output_norm_weight=self.output.LayerNorm.weight,
output_norm_bias=self.output.LayerNorm.bias,
dense_out=torch.Tensor(),
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.FullStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=True if world_size > 1 else False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
else:
attn_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=self.attention.self.qkv_proj.weight,
qkv_bias=self.attention.self.qkv_proj.bias,
self_weight=self.attention.output.dense.weight,
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
self_norm_weight=torch.Tensor(),
self_norm_bias=torch.Tensor(),
intermediate_weight=torch.Tensor(),
intermediate_bias=torch.Tensor(),
output_weight=torch.Tensor(),
output_bias=torch.Tensor(),
output_norm_weight=torch.Tensor(),
output_norm_bias=torch.Tensor(),
dense_out=torch.Tensor(),
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.AttnOutStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
if world_size > 1:
attn_out_stage_output = tensor_model_parallel_all_reduce(attn_out_stage_output)
if USE_FUSED_MLP_VISION:
attn_output = self.attention.output.LayerNorm(attn_out_stage_output + hidden_states)
inter_out_stage_output = torch.vacc.fuse_mlp_vision(src=attn_output,
weights_13=self.intermediate.dense.weight,
weights_2=self.output.dense.weight,
weights_13_bias=self.intermediate.dense.bias,
weights_2_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
act_type=0 # gelu
)
else:
inter_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=torch.Tensor(),
qkv_bias=torch.Tensor(),
self_weight=torch.Tensor(),
self_bias=torch.Tensor(),
self_norm_weight=self.attention.output.LayerNorm.weight,
self_norm_bias=self.attention.output.LayerNorm.bias,
intermediate_weight=self.intermediate.dense.weight,
intermediate_bias=self.intermediate.dense.bias,
output_weight=self.output.dense.weight,
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
output_norm_weight=torch.Tensor(),
output_norm_bias=torch.Tensor(),
dense_out=attn_out_stage_output,
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.InterOutStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
if world_size > 1:
inter_out_stage_output = tensor_model_parallel_all_reduce(inter_out_stage_output)
if USE_FUSED_MLP_VISION:
output = self.output.LayerNorm(inter_out_stage_output + attn_output)
else:
output = self.output.LayerNorm(inter_out_stage_output + attn_out_stage_output)
else:
attn_output = self.attention(hidden_states)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output

View File

@@ -0,0 +1,292 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Set, Tuple, Optional
import torch
import torch.nn as nn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.forward_context import ForwardContext, get_forward_context
from .vars import *
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictorLayer as DeepSeekMultiTokenPredictorLayerOrig
from vllm.distributed import get_tp_group
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
def DeepSeekMultiTokenPredictorLayer__init__(self, vllm_config: VllmConfig, prefix: str) -> None:
super(DeepSeekMultiTokenPredictorLayerOrig, self).__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if USE_PARALLEL_MTP_EH_PROJ:
self.eh_proj = RowParallelLinear(config.hidden_size * 2,
config.hidden_size,
bias=False,
return_bias=False)
else:
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
from vllm.model_executor.models.deepseek_mtp import SharedHead
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekMTPWegitCapture
self.weight_capture = DeepseekMTPWegitCapture(self.mtp_block)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
if inputs_embeds.shape[0] > 256:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
deepseek_mtp_layer_input_buffer = None
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
deepseek_mtp_layer_input_buffer = memory_recycler.DEEPSEEK_MTP_LAYER_INPUT
from torch_vacc.vacc.custom_ops import fuse_mtp_stage0
hidden_states = fuse_mtp_stage0(
inputs_embeds,
previous_hidden_states,
positions,
self.enorm.weight,
self.hnorm.weight,
self.enorm.variance_epsilon,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output=deepseek_mtp_layer_input_buffer,
)
# if USE_PARALLEL_MTP_EH_PROJ:
# tp_size = get_tensor_model_parallel_world_size()
# rank_id = get_tensor_model_parallel_rank()
# last_dim = hidden_states.shape[-1]
# if tp_size > 1:
# hiddens_tp = last_dim//tp_size
# hidden_states = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp]
hidden_states = self.eh_proj(hidden_states)
else:
hidden_states = torch.vacc.fuse_mtp_allreduce(
inputs_embeds,
previous_hidden_states,
positions,
self.enorm.weight,
self.hnorm.weight,
self.eh_proj.weight,
self.enorm.variance_epsilon,
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
else:
from torch_vacc.vacc.custom_ops import fuse_mla_moe_v2_allreduce_decode
layer = self.mtp_block
layer_id = 0
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
# 对于MTP Layer来说 residual为None且需要返回residual
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
hidden_states = hidden_states,
residual = None,
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[layer_id],
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[layer_id],
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[layer_id],
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[layer_id],
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[layer_id],
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[layer_id],
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[layer_id],
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[layer_id],
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[layer_id],
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[layer_id],
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[layer_id],
sin_cache = sin_cache,
cos_cache = cos_cache,
slot_mapping = attn_metadata.slot_mapping,
kv_cache = kv_cache,
block_tables = attn_metadata.decode_metadata.block_tables,
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[layer_id],
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[layer_id],
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[layer_id],
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[layer_id],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[layer_id],
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[layer_id],
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[layer_id],
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[layer_id],
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[layer_id],
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[layer_id],
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[layer_id],
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[layer_id],
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[layer_id],
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[layer_id],
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[layer_id],
# moe params
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
#hidden_states = residual + hidden_states
hidden_states = residual.add_(hidden_states)
return hidden_states
class DeepSeekMTP(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
from vllm.model_executor.models.utils import PPMissingLayer
for layer_id in self.model.layers:
layer = self.model.layers[layer_id]
if isinstance(layer_id, PPMissingLayer):
continue
layer.mtp_block.self_attn.merge_qkv_weights()
return loaded_params
class SharedHead(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
try:
from torch_vacc.vacc.custom_ops import rms_norm
return rms_norm(hidden_states, self.norm.weight, output=hidden_states)
except Exception as e:
print(f"fuse rms_norm run fail, now use unfused ops: {e}")
return self.norm(hidden_states)

View File

@@ -0,0 +1,658 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tp_group,
get_tensor_model_parallel_world_size,get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.model_executor.models.deepseek_v2 import yarn_get_mscale, DeepseekV2MLAAttention, Indexer
from vllm.logger import init_logger
logger = init_logger(__name__)
from .vars import *
from ..ops.deepseek_fused_mlp_moe import (vacc_fused_decode_moe_fp8,
vacc_fused_prefill_moe_fp8,
vacc_fused_mlp_fp8)
from .fused_forward import *
import os
test_layer_en = os.getenv("test_layer_en", "0")
# class DeepseekV2MLAAttention(nn.Module):
# def __init__(
# self,
# vllm_config: VllmConfig,
# config: Union[DeepseekV2Config, DeepseekV3Config],
# hidden_size: int,
# num_heads: int,
# qk_nope_head_dim: int,
# qk_rope_head_dim: int,
# v_head_dim: int,
# q_lora_rank: Optional[int],
# kv_lora_rank: int,
# rope_theta: float = 10000,
# rope_scaling: Optional[dict[str, Any]] = None,
# max_position_embeddings: int = 8192,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# prefix: str = "",
# topk_indices_buffer: Optional[torch.Tensor] = None,
# ) -> None:
# super(DeepseekV2MLAAttention,self).__init__()
# self.hidden_size = hidden_size
# self.qk_nope_head_dim = qk_nope_head_dim
# self.qk_rope_head_dim = qk_rope_head_dim
# self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
# self.v_head_dim = v_head_dim
# self.q_lora_rank = q_lora_rank
# self.kv_lora_rank = kv_lora_rank
# self.num_heads = num_heads
# tp_size = get_tensor_model_parallel_world_size()
# assert num_heads % tp_size == 0
# self.num_local_heads = num_heads // tp_size
# self.scaling = self.qk_head_dim**-0.5
# self.rope_theta = rope_theta
# self.max_position_embeddings = max_position_embeddings
# if self.q_lora_rank is not None:
# if USE_PARALLEL_Q_KV_GEN:
# self.q_a_proj = RowParallelLinear(self.hidden_size,
# self.q_lora_rank,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_a_proj")
# else:
# self.q_a_proj = ReplicatedLinear(self.hidden_size,
# self.q_lora_rank,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_a_proj")
# self.q_a_layernorm = RMSNorm(self.q_lora_rank,
# eps=config.rms_norm_eps)
# self.q_b_proj = ColumnParallelLinear(q_lora_rank,
# self.num_heads *
# self.qk_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_b_proj")
# else:
# self.q_proj = ColumnParallelLinear(self.hidden_size,
# self.num_heads *
# self.qk_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.q_proj")
# if USE_PARALLEL_Q_KV_GEN:
# self.kv_a_proj_with_mqa = RowParallelLinear(
# self.hidden_size,
# self.kv_lora_rank + self.qk_rope_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.kv_a_proj_with_mqa")
# else:
# self.kv_a_proj_with_mqa = ReplicatedLinear(
# self.hidden_size,
# self.kv_lora_rank + self.qk_rope_head_dim,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.kv_a_proj_with_mqa")
# self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
# eps=config.rms_norm_eps)
# self.kv_b_proj = ColumnParallelLinear(
# self.kv_lora_rank,
# self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.kv_b_proj")
# self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
# self.hidden_size,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.o_proj")
# rope_scaling["rope_type"] = 'deepseek_yarn'
# self.rotary_emb = get_rope(qk_rope_head_dim,
# rotary_dim=qk_rope_head_dim,
# max_position=max_position_embeddings,
# base=rope_theta,
# rope_scaling=rope_scaling,
# is_neox_style=False)
# if rope_scaling:
# mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
# scaling_factor = rope_scaling["factor"]
# mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
# self.scaling = self.scaling * mscale * mscale
# self.is_v32 = hasattr(config, "index_topk")
# if self.is_v32:
# self.indexer = Indexer(vllm_config, config, hidden_size,
# q_lora_rank, quant_config, cache_config,
# topk_indices_buffer, f"{prefix}.indexer")
# else:
# self.indexer = None
# self.mla_attn = Attention(
# num_heads=self.num_local_heads,
# head_size=self.kv_lora_rank,
# scale=self.scaling,
# num_kv_heads=1,
# cache_config=cache_config,
# quant_config=quant_config,
# prefix=f"{prefix}.attn",
# use_mla=True,
# # MLA Args
# q_lora_rank=self.q_lora_rank,
# kv_lora_rank=self.kv_lora_rank,
# qk_nope_head_dim=self.qk_nope_head_dim,
# qk_rope_head_dim=self.qk_rope_head_dim,
# qk_head_dim=self.qk_head_dim,
# v_head_dim=self.v_head_dim,
# rotary_emb=self.rotary_emb,
# q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
# kv_b_proj=self.kv_b_proj,
# o_proj=self.o_proj,
# )
# self.prefix = prefix
# self.debug_layer_idx = int(self.prefix.split(".")[-2])
# def forward(
# self,
# positions: torch.Tensor,
# hidden_states: torch.Tensor,
# kv_cache: torch.Tensor,
# attn_metadata: AttentionMetadata,
# ) -> torch.Tensor:
# tp_size = get_tensor_model_parallel_world_size()
# rank_id = get_tensor_model_parallel_rank()
# last_dim = hidden_states.shape[-1]
# if USE_PARALLEL_Q_KV_GEN: #tp qa and kva
# hidden_states_split = hidden_states
# if tp_size > 1:
# hiddens_tp = last_dim//tp_size
# hidden_states_split = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp].contiguous()
# if self.q_lora_rank is not None:
# ckq = self.q_a_proj(hidden_states_split)[0]
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
# else:
# hidden_states_or_q_c = hidden_states
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states_split)[0].split(
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
# attn_metadata)
# if self.q_lora_rank is not None:
# ckq = self.q_a_proj(hidden_states)[0]
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
# else:
# hidden_states_or_q_c = hidden_states
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
# attn_metadata)
class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, residual = None, rms_norm = None):
# moe layer support prefill&decode vacc ops
if residual is not None:
try:
reduce_result = self.tp_size > 1
# decode moe, first seq
if self.is_decode:
hidden_states, residual = vacc_fused_decode_moe_fp8(self, self.shared_experts,
hidden_states, residual,
rms_norm, self.gate, self.experts,
self.routed_scaling_factor,
reduce_result)
return hidden_states, residual
# prefill moe, first expert
else:
hidden_states, residual = vacc_fused_prefill_moe_fp8(self, self.shared_experts,
hidden_states, residual,
rms_norm, self.gate, self.experts,
self.routed_scaling_factor,
reduce_result)
return hidden_states, residual
except Exception as e:
logger.warning("vacc fused moe run fail, now use unfused ops %s", e)
hidden_states, residual = rms_norm(hidden_states, residual)
self.experts.is_decode = self.is_decode
# 1. fuse_prefill_pre_moe
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
try:
shared_output = vacc_fused_mlp_fp8(self.shared_experts, hidden_states, moe_share=True)
except Exception as e:
logger.warning("fused mlp is Error, now use Default:%s", e)
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
# 2. fused_moe
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
# 3. add_reduce
# now fuse share_mlp add to experts
# if shared_output is not None:
# # out = input + other * alpha
# final_hidden_states = shared_output.add_(final_hidden_states, alpha=self.routed_scaling_factor)
# else:
# final_hidden_states = final_hidden_states * self.routed_scaling_factor
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if residual is not None:
return final_hidden_states.view(num_tokens, hidden_dim), residual
return final_hidden_states.view(num_tokens, hidden_dim)
class DeepseekV2MLP(nn.Module):
def forward(self, x, residual = None, rms_norm = None):
# use all fused ops
if residual is not None:
reduce_result = self.down_proj.reduce_results and self.down_proj.tp_size > 1
hidden_states, residual = vacc_fused_mlp_fp8(self,
x, residual,
rms_norm,
reduce_result)
return hidden_states, residual
# use default fuse ops
try:
output_parallel = vacc_fused_mlp_fp8(self, x, residual, rms_norm)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
x = tensor_model_parallel_all_reduce(output_parallel)
else:
x = output_parallel
except Exception as e:
logger.warning("fuse_mlp run fail, now use default: %s", e)
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
first_k_dense_replace = self.config.first_k_dense_replace if hasattr(self.config, "first_k_dense_replace") else 3
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekWeightCapture
self.weight_capture = DeepseekWeightCapture(self.layers, self.start_layer, self.end_layer)
self.cached_weights_state = True
self.cached_batch = 1
self.layer_nums = self.end_layer - self.start_layer
self.is_pipeline_first = get_pp_group().is_first_rank
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
else:
# update global seq lens, use for serve infos
# update_seqence_length(attn_metadata.decode_metadata.seq_lens)
if FUSE_ALL_DECODER_LAYERS:
self.weight_capture.update_attn_args(attn_metadata.decode_metadata.seq_lens,
attn_metadata.slot_mapping,
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(self.start_layer, first_k_dense_replace)],
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(first_k_dense_replace, self.end_layer)],
attn_metadata.decode_metadata.block_tables)
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 0)
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 1)
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 2)
if hidden_states.shape[0] != self.cached_batch:
# batch切换重新执行缓存
self.cached_weights_state = True
self.cached_batch = hidden_states.shape[0]
if self.cached_weights_state:
self.cached_weights_state = False
hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, self.weight_capture)
else:
hidden_states, residual = forward_mla_moe_layers_without_weights(hidden_states, residual, self.weight_capture)
else:
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode,fuse_mla_moe_v2_allreduce_decode
for i in range(0, self.layer_nums):
layer_id = i + self.start_layer
layer = self.layers[layer_id]
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
if layer_id < first_k_dense_replace:
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = self.weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[i],
q_a_proj_weight = self.weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[i],
q_a_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[i],
q_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[i],
w_q = self.weight_capture.layer_mlp.attn_args._3_W_Q[i],
w_q_scale = self.weight_capture.layer_mlp.attn_args._4_W_Q_scales[i],
w_uk = self.weight_capture.layer_mlp.attn_args._5_W_UK[i],
w_uk_scale = self.weight_capture.layer_mlp.attn_args._6_W_UK_scales[i],
w_qr = self.weight_capture.layer_mlp.attn_args._7_W_QR[i],
w_qr_scale = self.weight_capture.layer_mlp.attn_args._8_W_QR_scales[i],
kv_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[i],
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
block_tables = attn_metadata.decode_metadata.block_tables,#self.weight_capture.layer_mlp.attn_args._14_block_tables[i],
block_group_size = self.weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_mlp.attn_args._16_W_UV[i],
w_uv_scale = self.weight_capture.layer_mlp.attn_args._17_W_UV_scales[i],
o_proj_weight = self.weight_capture.layer_mlp.attn_args._18_o_proj_weight[i],
o_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[i],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_mlp.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_mlp.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# mlp weight
rms_weight = self.weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[i],
mlp_weight_13 = self.weight_capture.layer_mlp.mlp_args._1_mlp_w13[i],
mlp_weight_2 = self.weight_capture.layer_mlp.mlp_args._2_mlp_w2[i],
mlp_weight_scale_13 = self.weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[i],
mlp_weight_scale_2 = self.weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[i],
# mlp params
mlp_block_size_w13 = self.weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_mlp.dist_args._0_world_size,
rank = self.weight_capture.layer_mlp.dist_args._1_rank,
group_id = self.weight_capture.layer_mlp.dist_args._2_group_id,
dev_info = self.weight_capture.layer_mlp.dist_args._3_dev_info)
else:
wid = i - first_k_dense_replace if self.is_pipeline_first else i
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[wid],
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[wid],
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[wid],
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[wid],
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[wid],
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[wid],
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[wid],
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[wid],
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[wid],
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[wid],
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[wid],
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
block_tables = attn_metadata.decode_metadata.block_tables,
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[wid],
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[wid],
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[wid],
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[wid],
# mla params
seq_lens = attn_metadata.decode_metadata.seq_lens,
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[wid],
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[wid],
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[wid],
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[wid],
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[wid],
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[wid],
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[wid],
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[wid],
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[wid],
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[wid],
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[wid],
# moe params
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
rank = self.weight_capture.layer_moe.dist_args._1_rank,
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
from .memory.memory_recycling import init_huge_memory_allocator
from .vars import LLM_MAX_PREFILL_SEQ_LEN
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
# default is deepseek, config can set to ['deepseek_mtp',]
model_name = "deepseek"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
if config_infos in ['mtp']:
model_name = "deepseek_mtp"
else:
model_name = config_infos
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if test_layer_en == "1":
test_layer = 5
if name not in ['model.embed_tokens.weight', 'model.norm.weight', 'lm_head.weight']:
if int(name.split(".")[2]) > test_layer:
continue
# TODO(simon): support nextn predict layers
if hasattr(self.config, "num_nextn_predict_layers"
) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
layer.self_attn.merge_qkv_weights()
return loaded_params
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if attn_metadata.prefill_metadata is not None:
from .memory.memory_recycling import alloc_memory_recycler
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
if hasattr(attn_metadata, 'num_prefill_tokens'):
tokens = attn_metadata.num_prefill_tokens
else:
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
vllm_model_mode = "deepseek"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
if config_infos in ['mtp']:
vllm_model_mode = "deepseek_mtp"
else:
vllm_model_mode = config_infos
if get_tp_group().rank_in_group == 0:
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
logger.info(memory_infos)
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size):
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,216 @@
import torch
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode_layers,fuse_mla_moe_v2_allreduce_decode_layers,fuse_mla_moe_v2_allreduce_decode_layers_v2
from typing import Optional
from .weight_capture.deepseek_weight_capture import DeepseekWeightCapture
import time
from .vars import *
# 单层 mla + mlp
def forward_mla_mlp_single_layer(hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture,
layer_id: int):
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[layer_id],
q_a_proj_weight = weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[layer_id],
q_a_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[layer_id],
q_a_layernorm_weight = weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[layer_id],
w_q = weight_capture.layer_mlp.attn_args._3_W_Q[layer_id],
w_q_scale = weight_capture.layer_mlp.attn_args._4_W_Q_scales[layer_id],
w_uk = weight_capture.layer_mlp.attn_args._5_W_UK[layer_id],
w_uk_scale = weight_capture.layer_mlp.attn_args._6_W_UK_scales[layer_id],
w_qr = weight_capture.layer_mlp.attn_args._7_W_QR[layer_id],
w_qr_scale = weight_capture.layer_mlp.attn_args._8_W_QR_scales[layer_id],
kv_a_layernorm_weight = weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[layer_id],
sin_cache = weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = weight_capture.layer_mlp.attn_args._12_slot_mapping,
kv_cache = weight_capture.layer_mlp.attn_args._13_kv_cache[layer_id],
block_tables = weight_capture.layer_mlp.attn_args._14_block_tables,
block_group_size = weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
w_uv = weight_capture.layer_mlp.attn_args._16_W_UV[layer_id],
w_uv_scale = weight_capture.layer_mlp.attn_args._17_W_UV_scales[layer_id],
o_proj_weight = weight_capture.layer_mlp.attn_args._18_o_proj_weight[layer_id],
o_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[layer_id],
# mla params
seq_lens = weight_capture.layer_mlp.attn_args._20_seq_lens,
sm_scale = weight_capture.layer_mlp.attn_args._21_sm_scale,
head_num = weight_capture.layer_mlp.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# mlp weight
rms_weight = weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[layer_id],
mlp_weight_13 = weight_capture.layer_mlp.mlp_args._1_mlp_w13[layer_id],
mlp_weight_2 = weight_capture.layer_mlp.mlp_args._2_mlp_w2[layer_id],
mlp_weight_scale_13 = weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[layer_id],
mlp_weight_scale_2 = weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[layer_id],
# mlp params
mlp_block_size_w13 = weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
mlp_block_size_w2 = weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
# vccl info
world_size = weight_capture.layer_mlp.dist_args._0_world_size,
rank = weight_capture.layer_mlp.dist_args._1_rank,
group_id = weight_capture.layer_mlp.dist_args._2_group_id,
dev_info = weight_capture.layer_mlp.dist_args._3_dev_info)
return hidden_states, residual
# 多层 mla + mlp
def forward_mla_mlp_layers(hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture):
if residual == None:
residual = torch.zeros_like(hidden_states)
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode_layers(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight,
q_a_proj_weight = weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights,
q_a_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv,
q_a_layernorm_weight = weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight,
w_q = weight_capture.layer_mlp.attn_args._3_W_Q,
w_q_scale = weight_capture.layer_mlp.attn_args._4_W_Q_scales,
w_uk = weight_capture.layer_mlp.attn_args._5_W_UK,
w_uk_scale = weight_capture.layer_mlp.attn_args._6_W_UK_scales,
w_qr = weight_capture.layer_mlp.attn_args._7_W_QR,
w_qr_scale = weight_capture.layer_mlp.attn_args._8_W_QR_scales,
kv_a_layernorm_weight = weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight,
sin_cache = weight_capture.layer_mlp.attn_args._10_sin_cache,
cos_cache = weight_capture.layer_mlp.attn_args._11_cos_cache,
slot_mapping = weight_capture.layer_mlp.attn_args._12_slot_mapping,
kv_cache = weight_capture.layer_mlp.attn_args._13_kv_cache,
block_tables = weight_capture.layer_mlp.attn_args._14_block_tables,
block_group_size = weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
w_uv = weight_capture.layer_mlp.attn_args._16_W_UV,
w_uv_scale = weight_capture.layer_mlp.attn_args._17_W_UV_scales,
o_proj_weight = weight_capture.layer_mlp.attn_args._18_o_proj_weight,
o_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv,
# mla params
seq_lens = weight_capture.layer_mlp.attn_args._20_seq_lens,
sm_scale = weight_capture.layer_mlp.attn_args._21_sm_scale,
head_num = weight_capture.layer_mlp.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# mlp weight
rms_weight = weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight,
mlp_weight_13 = weight_capture.layer_mlp.mlp_args._1_mlp_w13,
mlp_weight_2 = weight_capture.layer_mlp.mlp_args._2_mlp_w2,
mlp_weight_scale_13 = weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale,
mlp_weight_scale_2 = weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale,
# mlp params
mlp_block_size_w13 = weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
mlp_block_size_w2 = weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
# vccl info
world_size = weight_capture.layer_mlp.dist_args._0_world_size,
rank = weight_capture.layer_mlp.dist_args._1_rank,
group_id = weight_capture.layer_mlp.dist_args._2_group_id,
dev_info = weight_capture.layer_mlp.dist_args._3_dev_info)
return hidden_states, residual
# 多层 mla + moe, 未缓存weights
def forward_mla_moe_layers_with_weights(hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture):
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode_layers(
hidden_states = hidden_states,
residual = residual,
hidden_states_norm_weight = weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight,
q_a_proj_weight = weight_capture.layer_moe.attn_args._0_merge_q_kv_weights,
q_a_proj_weight_scale_inv = weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv,
q_a_layernorm_weight = weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight,
w_q = weight_capture.layer_moe.attn_args._3_W_Q,
w_q_scale = weight_capture.layer_moe.attn_args._4_W_Q_scales,
w_uk = weight_capture.layer_moe.attn_args._5_W_UK,
w_uk_scale = weight_capture.layer_moe.attn_args._6_W_UK_scales,
w_qr = weight_capture.layer_moe.attn_args._7_W_QR,
w_qr_scale = weight_capture.layer_moe.attn_args._8_W_QR_scales,
kv_a_layernorm_weight = weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight,
sin_cache = weight_capture.layer_moe.attn_args._10_sin_cache,
cos_cache = weight_capture.layer_moe.attn_args._11_cos_cache,
slot_mapping = weight_capture.layer_moe.attn_args._12_slot_mapping,
kv_cache = weight_capture.layer_moe.attn_args._13_kv_cache,
block_tables = weight_capture.layer_moe.attn_args._14_block_tables,
block_group_size = weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
w_uv = weight_capture.layer_moe.attn_args._16_W_UV,
w_uv_scale = weight_capture.layer_moe.attn_args._17_W_UV_scales,
o_proj_weight = weight_capture.layer_moe.attn_args._18_o_proj_weight,
o_proj_weight_scale_inv = weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv,
# mla params
seq_lens = weight_capture.layer_moe.attn_args._20_seq_lens,
sm_scale = weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = weight_capture.layer_moe.attn_args._22_head_num,
# flash attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
rms_weight = weight_capture.layer_moe.moe_args._0_moe_rms_weight,
mlp_weight_13 = weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13,
mlp_weight_2 = weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2,
mlp_weight_scale_13 = weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale,
mlp_weight_scale_2 = weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale,
moe_weight_13 = weight_capture.layer_moe.moe_args._5_moe_w13,
moe_weight_2 = weight_capture.layer_moe.moe_args._6_moe_w2,
moe_weight_scale_13 = weight_capture.layer_moe.moe_args._7_moe_w13_scale,
moe_weight_scale_2 = weight_capture.layer_moe.moe_args._8_moe_w2_scale,
mm_weight = weight_capture.layer_moe.moe_args._9_gate_weight,
moe_bias = weight_capture.layer_moe.moe_args._10_moe_bias,
# moe params
mlp_block_size_w13 = weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = weight_capture.layer_moe.dist_args._0_world_size,
rank = weight_capture.layer_moe.dist_args._1_rank,
group_id = weight_capture.layer_moe.dist_args._2_group_id,
dev_info = weight_capture.layer_moe.dist_args._3_dev_info)
return hidden_states, residual
# 多层 mla + moe 缓存weights必须要在未缓存weights算子执行之后才可以调用
def forward_mla_moe_layers_without_weights(hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture):
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode_layers_v2(
hidden_states = hidden_states,
residual = residual,
sin_cache = weight_capture.layer_moe.attn_args._10_sin_cache,
cos_cache = weight_capture.layer_moe.attn_args._11_cos_cache,
slot_mapping = weight_capture.layer_moe.attn_args._12_slot_mapping,
kv_cache = weight_capture.layer_moe.attn_args._13_kv_cache,
block_tables = weight_capture.layer_moe.attn_args._14_block_tables,
block_group_size = weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
# mla params
seq_lens = weight_capture.layer_moe.attn_args._20_seq_lens,
sm_scale = weight_capture.layer_moe.attn_args._21_sm_scale,
head_num = weight_capture.layer_moe.attn_args._22_head_num,
# flash_attention
flash_attention = (USE_FLASH_ATTENTION==1),
# moe weight
# moe params
mlp_block_size_w13 = weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
mlp_block_size_w2 = weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
moe_block_size_w13 = weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
moe_block_size_w2 = weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
# vccl info
world_size = weight_capture.layer_moe.dist_args._0_world_size,
rank = weight_capture.layer_moe.dist_args._1_rank,
group_id = weight_capture.layer_moe.dist_args._2_group_id,
dev_info = weight_capture.layer_moe.dist_args._3_dev_info)
return hidden_states, residual
def forward_deepseekv3(model: torch.nn.Module,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
weight_capture : DeepseekWeightCapture):
hidden_states, residual = forward_mla_mlp_layers(hidden_states, residual, weight_capture)
hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, weight_capture)
return hidden_states, residual

View File

@@ -0,0 +1,239 @@
import warnings
from transformers.utils import (
CONFIG_NAME,
IMAGE_PROCESSOR_NAME,
cached_file,
is_timm_config_dict,
is_timm_local_checkpoint,
is_torchvision_available,
is_vision_available,
logging,
)
from transformers.models.auto.configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
from transformers.image_processing_utils import ImageProcessingMixin
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.image_processing_auto import (
AutoImageProcessor,
logger,
FORCE_FAST_IMAGE_PROCESSOR,
IMAGE_PROCESSOR_MAPPING_NAMES,
IMAGE_PROCESSOR_MAPPING,
get_image_processor_class_from_name,
resolve_trust_remote_code,
_warning_fast_image_processor_available,
get_class_from_dynamic_module
)
def check_vacc_support_module(module_class):
if module_class.__name__ == "Qwen2VLImageProcessorFast":
from .qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
return Qwen2VLImageProcessorFastWithVacc
return module_class
"""AutoImageProcessor class."""
class AutoImageProcessorWithVacc(AutoImageProcessor):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if kwargs.get("token") is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
kwargs["token"] = use_auth_token
config = kwargs.pop("config", None)
# TODO: @yoni, change in v4.48 (use_fast set to True by default)
use_fast = kwargs.pop("use_fast", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
# Resolve the image processor config filename
if "image_processor_filename" in kwargs:
image_processor_filename = kwargs.pop("image_processor_filename")
elif is_timm_local_checkpoint(pretrained_model_name_or_path):
image_processor_filename = CONFIG_NAME
else:
image_processor_filename = IMAGE_PROCESSOR_NAME
# Load the image processor config
try:
# Main path for all transformers models and local TimmWrapper checkpoints
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
)
except Exception as initial_exception:
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
# load `config.json` and if it fails with some error, we raise the initial exception.
try:
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
)
except Exception:
raise initial_exception
# In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
# because only timm models have image processing in `config.json`.
if not is_timm_config_dict(config_dict):
raise initial_exception
image_processor_type = config_dict.get("image_processor_type", None)
# 跳转vacc预处理算子相关替换
# if image_processor_type == "Qwen2VLImageProcessorFast":
# from .qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
# return Qwen2VLImageProcessorFastWithVacc.from_dict(config_dict, **kwargs)
image_processor_auto_map = None
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
# and if so, infer the image processor class from there.
if image_processor_type is None and image_processor_auto_map is None:
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
if feature_extractor_class is not None:
image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
# If we don't find the image processor class in the image processor config, let's try the model config.
if image_processor_type is None and image_processor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
# It could be in `config.image_processor_type``
image_processor_type = getattr(config, "image_processor_type", None)
if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
image_processor_class = None
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
if image_processor_type is not None:
# if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
if use_fast is None:
use_fast = image_processor_type.endswith("Fast")
if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available():
use_fast = True
logger.warning_once(
f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. "
"This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. "
"Note that this behavior will be extended to all models in a future release."
)
if not use_fast:
logger.warning_once(
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
)
if use_fast and not image_processor_type.endswith("Fast"):
image_processor_type += "Fast"
if use_fast and not is_torchvision_available():
# check if there is a slow image processor class to fallback to
image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4])
if image_processor_class is None:
raise ValueError(
f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again."
)
logger.warning_once(
"Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
)
use_fast = False
if use_fast:
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
if image_processor_type in image_processors:
break
else:
image_processor_type = image_processor_type[:-4]
use_fast = False
logger.warning_once(
"`use_fast` is set to `True` but the image processor class does not have a fast version. "
" Falling back to the slow version."
)
image_processor_class = get_image_processor_class_from_name(image_processor_type)
else:
image_processor_type_slow = image_processor_type.removesuffix("Fast")
image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
if image_processor_class is None and image_processor_type.endswith("Fast"):
raise ValueError(
f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor."
)
has_remote_code = image_processor_auto_map is not None
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
if has_remote_code:
if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
# In some configs, only the slow image processor class is stored
image_processor_auto_map = (image_processor_auto_map, None)
if use_fast and image_processor_auto_map[1] is not None:
class_ref = image_processor_auto_map[1]
else:
class_ref = image_processor_auto_map[0]
if "--" in class_ref:
upstream_repo = class_ref.split("--")[0]
else:
upstream_repo = None
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
)
if has_remote_code and trust_remote_code:
if not use_fast and image_processor_auto_map[1] is not None:
_warning_fast_image_processor_available(image_processor_auto_map[1])
image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
image_processor_class.register_for_auto_class()
# check preprocess module supported by vacc
image_processor_class = check_vacc_support_module(image_processor_class)
return image_processor_class.from_dict(config_dict, **kwargs)
elif image_processor_class is not None:
# check preprocess module supported by vacc
image_processor_class = check_vacc_support_module(image_processor_class)
return image_processor_class.from_dict(config_dict, **kwargs)
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
elif type(config) in IMAGE_PROCESSOR_MAPPING:
image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
image_processor_class_py, image_processor_class_fast = image_processor_tuple
if not use_fast and image_processor_class_fast is not None:
_warning_fast_image_processor_available(image_processor_class_fast)
if image_processor_class_fast and (use_fast or image_processor_class_py is None):
# check preprocess module supported by vacc
image_processor_class_fast = check_vacc_support_module(image_processor_class_fast)
return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
if image_processor_class_py is not None:
# check preprocess module supported by vacc
image_processor_class_py = check_vacc_support_module(image_processor_class_py)
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
raise ValueError(
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
)
raise ValueError(
f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}"
)

View File

@@ -0,0 +1,402 @@
# coding=utf-8
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Qwen2-VL."""
from typing import Optional, Union
import torch
from torchvision.transforms.v2 import functional as F
from transformers.image_processing_utils import BatchFeature
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
TensorType,
auto_docstring,
logging,
)
from transformers.video_utils import VideoInput, make_batched_videos
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import (
Qwen2VLFastImageProcessorKwargs,
logger
)
# reseize_normalize_repeat_transpose_reshape
def fuse_qwen2_vl_preprocess_img_cpu(
image: "torch.Tensor",
do_resize: bool,
min_pixels: int,
max_pixels: int,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
resized_height: int,
resized_width: int,
interpolation: Optional["F.InterpolationMode"],
patch_size: int,
temporal_patch_size: int,
merge_size: int,
image_mean_0: float,
image_mean_1: float,
image_mean_2: float,
image_std_0: float,
image_std_1: float,
image_std_2: float,
batch_size: int = 1,
grid_t: int = 1,
channel: int = 3,
) -> "torch.Tensor":
def resize(
image: "torch.Tensor",
size_h: int,
size_w: int,
interpolation: Optional["F.InterpolationMode"] = None,
antialias: bool = True,
) -> "torch.Tensor":
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
if size_h and size_w:
new_size = (size_h, size_w)
else:
raise ValueError(
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size_h} and {size_w}."
)
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
def fuse_mean_std_and_rescale_factor(
image_mean_0: float,
image_mean_1: float,
image_mean_2: float,
image_std_0: float,
image_std_1: float,
image_std_2: float,
do_normalize: Optional[bool] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
device: Optional["torch.device"] = None,
) -> tuple:
if do_rescale and do_normalize:
image_mean = torch.tensor([image_mean_0, image_mean_1, image_mean_2], device=device)
image_std = torch.tensor([image_std_0, image_std_1, image_std_2], device=device)
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
return image_mean, image_std
def rescale_and_normalize(
images: "torch.Tensor",
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean_0: float,
image_mean_1: float,
image_mean_2: float,
image_std_0: float,
image_std_1: float,
image_std_2: float
) -> "torch.Tensor":
image_mean, image_std = fuse_mean_std_and_rescale_factor(
image_mean_0, image_mean_1, image_mean_2,
image_std_0, image_std_1, image_std_2,
do_normalize,
do_rescale,
rescale_factor,
device=images.device,
)
if do_normalize:
images = F.normalize(images.to(dtype=torch.float32), image_mean, image_std)
return images
if image.dim() == 3:
image = image.unsqueeze(0)
stacked_images = resize(
image=image,
size_h=resized_height,
size_w=resized_width,
interpolation=interpolation, # BICUBIC插值
)
patches = rescale_and_normalize(
stacked_images,
do_rescale,
rescale_factor,
do_normalize,
image_mean_0, image_mean_1, image_mean_2,
image_std_0, image_std_1, image_std_2
)
if patches.ndim == 4:
patches = patches.unsqueeze(1)
if patches.shape[1] % temporal_patch_size != 0:
repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
patches = torch.cat([patches, repeats], dim=1)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
patches = patches.view(
batch_size, # 1 - 批次大小
grid_t, # 1 - 时间网格数
temporal_patch_size, # 2 - 时间块大小
channel, # 3 - 通道数(RGB)
grid_h // merge_size, # 12 // 2 = 6 - 高度方向合并后的网格数
merge_size, # 2 - 高度方向合并大小
patch_size, # 14 - 高度方向块大小
grid_w // merge_size, # 38 // 2 = 19 - 宽度方向合并后的网格数
merge_size, # 2 - 宽度方向合并大小
patch_size, # 14 - 宽度方向块大小
)
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
flatten_patches = patches.reshape(
# batch_size, # 1
grid_t * grid_h * grid_w, # 1 * 12 * 38 = 456 (总网格数)
channel * temporal_patch_size * patch_size * patch_size, # 3 * 2 * 14 * 14 = 1176 (每个网格的特征维度)
)
return flatten_patches
class Qwen2VLImageProcessorFastWithVacc(BaseImageProcessorFast):
do_resize = True
resample = PILImageResampling.BICUBIC
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
do_rescale = True
do_normalize = True
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
do_convert_rgb = True
patch_size = 14
temporal_patch_size = 2
merge_size = 2
min_pixels = None
max_pixels = None
valid_kwargs = Qwen2VLFastImageProcessorKwargs
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
def __init__(self, **kwargs: Unpack[Qwen2VLFastImageProcessorKwargs]):
size = kwargs.pop("size", None)
min_pixels = kwargs.pop("min_pixels", None)
max_pixels = kwargs.pop("max_pixels", None)
# backward compatibility: override size with min_pixels and max_pixels if they are provided
size = self.size if size is None else size
if min_pixels is not None:
size["shortest_edge"] = min_pixels
size.pop("min_pixels", None)
if max_pixels is not None:
size["longest_edge"] = max_pixels
size.pop("max_pixels", None)
if "shortest_edge" not in size or "longest_edge" not in size:
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
def _further_process_kwargs(
self,
size: Optional[SizeDict] = None,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
**kwargs,
) -> dict:
"""
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if min_pixels is not None and max_pixels is not None:
size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
elif size is not None:
if "shortest_edge" not in size or "longest_edge" not in size:
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
min_pixels = size["shortest_edge"]
max_pixels = size["longest_edge"]
else:
size = {**self.size}
return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
def preprocess(
self,
images: ImageInput,
videos: Optional[VideoInput] = None,
**kwargs: Unpack[Qwen2VLFastImageProcessorKwargs],
) -> BatchFeature:
return super().preprocess(images, videos, **kwargs)
def _preprocess_image_like_inputs(
self,
images: ImageInput,
videos: VideoInput,
do_convert_rgb: bool,
input_data_format: ChannelDimension,
device: Optional[Union[str, "torch.device"]] = None,
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
) -> BatchFeature:
"""
Preprocess image-like inputs.
To be overridden by subclasses when image-like inputs other than images should be processed.
It can be used for segmentation maps, depth maps, etc.
"""
# Prepare input images
batch_feature = BatchFeature()
if images is not None:
images = self._prepare_image_like_inputs(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
batch_feature = self._preprocess(images, **kwargs)
if videos is not None:
logger.warning(
"`Qwen2VLImageProcessorFast` works only with image inputs and doesn't process videos anymore. "
"This is a deprecated behavior and will be removed in v5.0. "
"Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
)
# Can't change _prepare_images_structure to work with videos because it also needs to work with images.
videos = make_batched_videos(videos)
videos = [
torch.stack(self._prepare_image_like_inputs(video, do_convert_rgb, input_data_format, device))
for video in videos
]
video_outputs = self._preprocess(videos, **kwargs)
batch_feature.update(
{"pixel_values_videos": video_outputs.pixel_values, "video_grid_thw": video_outputs.image_grid_thw}
)
return batch_feature
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
patch_size: int,
temporal_patch_size: int,
merge_size: int,
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
):
min_pixels=size["shortest_edge"],
max_pixels=size["longest_edge"],
processed_images = []
processed_grids = []
for img in images:
height, width = img.shape[-2:]
if do_resize:
resized_height, resized_width = smart_resize(
height,
width,
factor=patch_size * merge_size,
min_pixels=min_pixels[0],
max_pixels=max_pixels[0],
)
# reseize_normalize_repeat = fuse_qwen2_vl_preprocess_img_cpu(
# img,
# do_resize,
# min_pixels,
# max_pixels,
# do_rescale,
# rescale_factor,
# do_normalize,
# resized_height,
# resized_width,
# interpolation,
# patch_size,
# temporal_patch_size,
# merge_size,
# image_mean[0], image_mean[1], image_mean[2],
# image_std[0], image_std[1], image_std[2],
# 1,1,3
# )
import torch_vacc
if img.device.type != "vacc":
img = img.to("vacc")
reseize_normalize_repeat = torch_vacc.vacc.custom_qwen3_ops.qwen2vl_img_preprocess(img,
do_resize,
min_pixels[0],
max_pixels[0],
do_rescale,
rescale_factor,
do_normalize,
resized_height,
resized_width,
0x1003, # interpolation,
patch_size,
temporal_patch_size,
merge_size,
image_mean[0], image_mean[1], image_mean[2],
image_std[0], image_std[1], image_std[2] )
processed_images.append(reseize_normalize_repeat)
grid_t = 1
grid_h = resized_height // patch_size
grid_w = resized_width // patch_size
grid_thw_ = torch.tensor([[grid_t, grid_h, grid_w]])
processed_grids.append(grid_thw_)
pixel_values = torch.cat(processed_images, dim=0)
image_grid_thw = torch.cat(processed_grids, dim=0)
return BatchFeature(
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
)
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Note: Do not remove this method! It is used by vLLM to infer the number of patches and placeholders
without an image input.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of image patches per image.
"""
min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
patch_size = images_kwargs.get("patch_size", self.patch_size)
merge_size = images_kwargs.get("merge_size", self.merge_size)
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * grid_w
__all__ = ["Qwen2VLImageProcessorFastWithVacc"]

View File

@@ -0,0 +1,91 @@
from transformers.models.qwen3_vl import Qwen3VLProcessor
# from transformers.models.auto.image_processing_auto import AutoImageProcessor
from transformers.models.qwen2_vl import Qwen2VLProcessor
class Qwen3VLProcessorWithVacc(Qwen3VLProcessor):
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
@classmethod
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Identify and instantiate the subcomponents of Processor classes, like image processors and
tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
will be unable to find the relevant subcomponent class and will raise an error.
"""
args = []
for attribute_name in cls.attributes:
class_name = getattr(cls, f"{attribute_name}_class")
if isinstance(class_name, tuple):
classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
if attribute_name == "image_processor":
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
use_fast = kwargs.get("use_fast")
if use_fast is None:
print(
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
)
else:
use_fast = kwargs.get("use_fast", True)
if use_fast and classes[1] is not None:
attribute_class = classes[1]
else:
attribute_class = classes[0]
else:
attribute_class = cls.get_possibly_dynamic_module(class_name)
if attribute_class.__name__ == "AutoImageProcessor":
from .auto_image_preprocessor import AutoImageProcessorWithVacc
attribute_class = AutoImageProcessorWithVacc
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args
class Qwen2VLProcessorWithVacc(Qwen2VLProcessor):
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
@classmethod
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Identify and instantiate the subcomponents of Processor classes, like image processors and
tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
will be unable to find the relevant subcomponent class and will raise an error.
"""
args = []
for attribute_name in cls.attributes:
class_name = getattr(cls, f"{attribute_name}_class")
if isinstance(class_name, tuple):
classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
if attribute_name == "image_processor":
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
use_fast = kwargs.get("use_fast")
if use_fast is None:
print(
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
)
else:
use_fast = kwargs.get("use_fast", True)
if use_fast and classes[1] is not None:
attribute_class = classes[1]
else:
attribute_class = classes[0]
else:
attribute_class = cls.get_possibly_dynamic_module(class_name)
if attribute_class.__name__ == "AutoImageProcessor":
from .auto_image_preprocessor import AutoImageProcessorWithVacc
attribute_class = AutoImageProcessorWithVacc
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args

View File

@@ -0,0 +1,235 @@
import torch
import torch.nn as nn
from typing import List
class VaccHugeMemoryAllocator(nn.Module):
# self._active_bytes means the real tensor buffers used bytes.
# you can use this value to slice the src buffer
# you can not free the self._src_buffer_array, because the src buffer is the max buffer
# self._block_bytes means the part max buffer size
def __init__(self, blocks, dtype = torch.bfloat16, use_contiguous = False):
self._total_blocks = blocks
self._dtype = dtype
self._enable = False
self._src_buffer_array = None
self._block_bytes = 0
self._max_tokens = 0
self._hiddens = 0
self._active_bytes = 0
self._use_contiguous_buffer = use_contiguous
# max tokens for dynamic buffer size
# dynamic buffer size is bigger than normal buffer usually
self._dynamic_max_tokens = 0
self._dynamic_block_bytes = 0
# malloc the max buffer, and not free
def init_buffers(self, max_tokens, hiddens):
self._max_tokens = max_tokens
self._hiddens = hiddens
try:
import torch_vacc
self._block_bytes = self._max_tokens * self._hiddens \
* self.get_dtype_bytes(self._dtype)
self._all_bytes = self._block_bytes * self._total_blocks
if self._use_contiguous_buffer:
# 一次性申请[N*3,]大小的BytesBuffer
self._src_buffer = torch.zeros(self._all_bytes,
dtype = torch.uint8,
device = "vacc")
tmp_buffer_array = self._src_buffer.view(self._total_blocks, -1)
self._src_buffer_array = [tmp_buffer_array[i]
for i in range(self._total_blocks)]
else:
# 一次性申请3块[N,]大小的BytesBuffer
self._src_buffer_array = [torch.zeros(self.block_bytes,
dtype = torch.uint8,
device = "vacc")
for i in range(self._total_blocks)]
self._enable = True
except Exception as e:
print(f"vacc huge buffer alloc fail: {e}")
# 为有dynamic buffer需求的网络设计 dynamic buffer 可能会要比普通的max_tokens buffers大一些
def init_buffers_with_dynamic(self, max_tokens, dynamic_tokens, hiddens, dynamic_buffers_mask: List):
self._max_tokens = max_tokens
self._dynamic_max_tokens = dynamic_tokens
self._hiddens = hiddens
dynamic_buffers_count = sum(dynamic_buffers_mask)
normal_buffers_count = self._total_blocks - dynamic_buffers_count
self._block_bytes = self._max_tokens * self._hiddens \
* self.get_dtype_bytes(self._dtype)
self._dynamic_block_bytes = self._dynamic_max_tokens * self._hiddens \
* self.get_dtype_bytes(self._dtype)
self._all_bytes = self._block_bytes * normal_buffers_count + \
self._dynamic_block_bytes * dynamic_buffers_count
# print("创建重复利用buffer: dynamic的数量->", dynamic_buffers_count,
# " 正常的数量->", normal_buffers_count,
# " dynamic buffer大小->", self._dynamic_block_bytes,
# " 正常大小->", self._block_bytes)
try:
assert self._use_contiguous_buffer is False, "malloc dynamic recycle memory buffers only support separation buffer now"
self._src_buffer_array = []
for i in range(self._total_blocks):
if dynamic_buffers_mask[i]:
self._src_buffer_array.append(torch.zeros(self._dynamic_block_bytes,
dtype = torch.uint8,
device = "vacc"))
else:
# 一次性申请3块[N,]大小的BytesBuffer
self._src_buffer_array.append(torch.zeros(self._block_bytes,
dtype = torch.uint8,
device = "vacc"))
self._enable = True
except Exception as e:
print("vacc huge buffer alloc fail.", e)
# slice buffers from the src buffer (48K * blocks)
# use for target tensor
# you should analyse such tensor position by model, such as deepseek have 4 buffers
# you need alloc the buffer by real input tokens when new request is in
# notice: you should warn the dtype, because the buffer created by uint8
def alloc_memory_buffers(self, tokens, dtype=torch.bfloat16):
if tokens > self._max_tokens:
print("alloc memory buffer fail, tokens is large than max_tokens.", self._max_tokens)
return None
self._active_bytes = tokens * self._hiddens * self.get_dtype_bytes(dtype)
return [sub_array[:self._active_bytes]
for sub_array in self._src_buffer_array]
@property
def memory_buffers(self):
return self._src_buffer_array
# allock 1_2 buffers
# @params tokens 待缓存的prefill tokens buffer大小
# @params part 总共划分的区域
# @params return_buffer_list 需要返回的区域列表,如果为空的话,返回所有
# @params dtype 数据类型
# 创建1/N的buffers
# 返回[第N部分]
def alloc_1_div_N_buffers(self, part = 2,
return_buffer_list = [], ):
if not hasattr(self, "_src_buffer"):
print("1 div N alloctor need a contiguous buffer")
return None
assert isinstance(return_buffer_list, list), "return_buffer_list need List object"
# 数据以int8的方式划分为part
tmp_buffer_array = self._src_buffer.view(part, -1)
# 如果未指定return_buffer_list 返回所有的part buffer
if len(return_buffer_list) == 0:
return [tmp_buffer_array[i] for i in range(part)]
return [tmp_buffer_array[i] for i in return_buffer_list]
def get_dtype_bytes(self, dtype):
if isinstance(dtype, torch.dtype):
if dtype in [torch.float16, torch.bfloat16, torch.half]:
return 2
elif dtype in [torch.float32, torch.float, torch.int32]:
return 4
elif dtype in [torch.float64, torch.double, torch.int64]:
return 8
elif dtype in [torch.int8, torch.uint8, torch.bool]:
return 1
else:
return 1
elif dtype == int:
return 8
elif dtype == float:
return 8
elif dtype == bool:
return 1
return 0
@property
def enable(self):
return self._enable
@property
def blocks(self):
return self._total_blocks
@property
def max_tokens(self):
return self._max_tokens
@property
def hiddens(self):
return self._hiddens
@property
def active_bytes(self):
return self._active_bytes
@property
def block_bytes(self):
return self._block_bytes
@property
def dynamic_block_bytes(self):
return self._dynamic_block_bytes
class LLMMemoryRecycler:
def __init__(self):
self.count = 3
self.embedding_output = None
self.moe_shared_mlp_output = None
self.mla_oproj_output = None
#self.moe_expert_output = None
def clear(self):
self.embedding_output = None
self.moe_shared_mlp_output = None
self.mla_oproj_output = None
#self.moe_expert_output = None
@property
def EMBEDDING_OUT_BUFFER(self):
return self.embedding_output
@property
def MOE_SHARED_MLP_OUT_BUFFER(self):
return self.moe_shared_mlp_output
@property
def MLA_OPROJ_OUT_BUFFER(self):
return self.mla_oproj_output
def alloc_memory_recycler_llm(tokens,
alloctor:VaccHugeMemoryAllocator,
recycler:LLMMemoryRecycler,
dtype:torch.dtype = torch.bfloat16):
if not alloctor.enable:
print("memory alloctor is not Init.")
return False
recycler.clear()
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
if out_buffers is None:
print("llm memory recycler buffers alloc fail. now disable it")
return False
if len(out_buffers) != recycler.count:
print("memory recycler buffers not equal llm.")
return False
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
return True

View File

@@ -0,0 +1,116 @@
import torch
from .allocator import VaccHugeMemoryAllocator,LLMMemoryRecycler
'''
DeepSeek support 48K input tokens, there are 3 buffers can recycle:
1. parallel_embedding output buffer
2. mla_oproject output buffer
3. moe_shared_mlp output buffer
# 4. moe_expert output buffer
each buffer size is 48K * 7168 * 2 bytes
'''
class DeepseekV3MemoryRecycler(LLMMemoryRecycler):
def __init__(self):
super().__init__()
#self.moe_expert_output = None
# @property
# def MOE_EXPERT_OUT_BUFFER(self):
# return self.moe_expert_output
class DeepseekMTPMemoryRecycler(LLMMemoryRecycler):
def __init__(self):
super().__init__()
self.dynamic_output = None
self.deepseek_mtp_layer_input = None
@property
def DYNAMIC_OUTPUT_BUFFER(self):
return self.dynamic_output
@property
def DEEPSEEK_MTP_LAYER_INPUT(self):
return self.deepseek_mtp_layer_input
def alloc_memory_recycler_deepseek_v3(tokens,
alloctor:VaccHugeMemoryAllocator,
recycler:DeepseekV3MemoryRecycler,
dtype:torch.dtype = torch.bfloat16):
if not alloctor.enable:
print("memory alloctor is not Init.")
return False
recycler.clear()
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
if out_buffers is None:
print("deepseek_v3 memory recycler buffers alloc fail. now disable it")
return False
if len(out_buffers) != recycler.count:
print("memory recycler buffers not equal deepseek_v3.")
return False
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
return True
def alloc_memory_recycler_deepseek_mtp(tokens,
alloctor:VaccHugeMemoryAllocator,
world_size:int,
recycler:DeepseekMTPMemoryRecycler,
dtype:torch.dtype = torch.bfloat16):
if not alloctor.enable:
print("memory alloctor is not Init.")
return False
recycler.clear()
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
if out_buffers is None:
print("deepseek_mtp memory recycler buffers alloc fail. now disable it")
return False
if len(out_buffers) != recycler.count:
print("memory recycler buffers not equal deepseek_mtp.")
return False
#MTP的内存布局为
# 1. deepseek 主模型和 草稿模型中的decoder_layer
# a. embedding_output
# b. mla_oproj_output
# c. moe_shared_mlp_output
# 因为: moe会作为previous_hidden_states被重新组织一次组织好的buffer会置于buffer0[a.]中
# 如果embedding_output放在第一位也可以相关地址需要整体往后偏移
# 为了理解方便把embedding置于最后不会参与buffer的重划分复用
# 2. deepseek_mtp 草稿模型(未启用该策略, )
# a. dynamic_output 占用1/2
# b. mtp_input 占用1/6,并且位于dynamic_output buffer后面
# dynamic_buffer = alloctor.alloc_1_div_N_buffers(2, [0,])
# mtp_input_buffer = alloctor.alloc_1_div_N_buffers(6, [3,])
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
#1/2用broadcast的自由分配mtp涉及到previous_hidden_states的广播
memory_buffers = alloctor.memory_buffers
recycler.dynamic_output = memory_buffers[1] #公共用mla_oproject
#1/6用于mtp decoder layer的输入 该算子特殊处理仅用了1/tp的输出buffer即可
# recycler.deepseek_mtp_layer_input = mtp_input_buffer[0]
mtp_layer_input_dims = alloctor.hiddens * 2 // world_size
mtp_layer_input_numels = tokens * mtp_layer_input_dims
from vllm import envs
if envs.VLLM_USE_V1:
#v1 无需重新缓存previous_hidden_states因此moe还在被占用状态因此需要先用attention的o-buffer去暂存mtp 预处理的空间
recycler.deepseek_mtp_layer_input = memory_buffers[1].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
else:
# 公共用moe output
recycler.deepseek_mtp_layer_input = memory_buffers[2].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
return True

View File

@@ -0,0 +1,132 @@
import os
import torch
from .allocator import VaccHugeMemoryAllocator
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler
VLLM_MODEL_MODE = os.environ.get("VLLM_MODEL_MODE", "deepseek")
global huge_memory_alloctor
global memory_recycler
huge_memory_alloctor = None
memory_recycler = None
# you should call this function when new request is in
def alloc_memory_recycler(tokens, dtype=torch.bfloat16, **argv):
global huge_memory_alloctor
global memory_recycler
vllm_model = argv.get('vllm_model')
if vllm_model is None:
print("model infos is empty, now using VLLM_MODEL_MODE")
vllm_model = VLLM_MODEL_MODE
# TODO: use default memory-recycle schedule
# if vllm_model in ['xxx']:
# vllm_model = "llm_default"
memory_recycler = None
if vllm_model == "deepseek":
from .deepseek_v3_memory_recycler import DeepseekV3MemoryRecycler, alloc_memory_recycler_deepseek_v3
memory_recycler = DeepseekV3MemoryRecycler()
state = alloc_memory_recycler_deepseek_v3(tokens, huge_memory_alloctor, memory_recycler, dtype)
if not state:
del memory_recycler
memory_recycler = None
return state
if vllm_model == "deepseek_mtp":
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler, alloc_memory_recycler_deepseek_mtp
memory_recycler = DeepseekMTPMemoryRecycler()
if argv.get('world_size') is None:
print("mtp should have TP world size, memory recycler allock fail")
return False
state = alloc_memory_recycler_deepseek_mtp(tokens, huge_memory_alloctor, argv['world_size'], memory_recycler, dtype)
if not state:
del memory_recycler
memory_recycler = None
return state
if vllm_model == "qwen3_moe":
from .qwen3_moe_memory_recycler import QWen3MoeMemoryRecycler, alloc_memory_recycler_qwen3_moe
memory_recycler = QWen3MoeMemoryRecycler()
state = alloc_memory_recycler_qwen3_moe(tokens, huge_memory_alloctor, memory_recycler, dtype)
if not state:
del memory_recycler
memory_recycler = None
return state
if vllm_model == "llm_default":
from .allocator import LLMMemoryRecycler, alloc_memory_recycler_llm
memory_recycler = LLMMemoryRecycler()
state = alloc_memory_recycler_llm(tokens, huge_memory_alloctor, memory_recycler, dtype)
if not state:
del memory_recycler
memory_recycler = None
return False
# LLM pipeline parallel 方案下, 对于非stage0的PART
# 在接收来自BEFORE PART的hiddens residual的时候
# 符合内存复用规则
# 该过程与llm forward过程相独立, 需要单独维护
# hiddens 对应 llm forward的时候moe mlp buffer
# residual 对应 llm forward的时候embedding buffer
def alloc_pipeline_parallel_recycler_buffer(size:torch.Size, dtype:torch.dtype, key:str):
global huge_memory_alloctor
if huge_memory_alloctor is None:
return None
intermize_tensor_dict = {
"hidden_states": 2,
"attention": 1,
"residual": 0
}
if not key in intermize_tensor_dict:
return None
src_tensors = huge_memory_alloctor.memory_buffers[intermize_tensor_dict[key]]
all_bytes = size.numel() * huge_memory_alloctor.get_dtype_bytes(dtype)
return src_tensors[:all_bytes].view(dtype).view(size)
# you should call this function when new server and every workers is start
def init_huge_memory_allocator(max_tokens, hidden_size, vllm_model = None):
global huge_memory_alloctor
if vllm_model is None:
print("model infos is empty, now using VLLM_MODEL_MODE")
vllm_model = VLLM_MODEL_MODE
if huge_memory_alloctor is not None:
del huge_memory_alloctor
torch.vacc.empty_cache()
huge_memory_alloctor = None
if vllm_model == "deepseek":
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
return True
# deepseek_mtp set use_congituous = True
# deepseek_mtp buffer recycler:
# buffer[0]: normal_buffer -> embedding_output
# buffer[1]: dynamic_buffer -> mla_oproj_output, dynamic_output
# buffer[2]: normal_buffer -> moe_shared_mlp_output, deepseek_mtp_layer_input
if vllm_model == "deepseek_mtp":
# deepseek dynamic tokens last block use for mtp-weights
# dynamic_buffer_max_tokens = max_tokens + 128, we will let mtp only support 48K now
dynamic_buffer_max_tokens = max_tokens
# dynamic bufffer use more 128tokens for broadcast
# positions, input tokens
deepseek_mtp_max_tokens = max_tokens
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
# huge_memory_alloctor.init_buffers(max_tokens, 7168)
huge_memory_alloctor.init_buffers_with_dynamic(deepseek_mtp_max_tokens, dynamic_buffer_max_tokens, hidden_size, [False, True, False])
return True
if vllm_model == "qwen3_moe":
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
return True
return False

View File

@@ -0,0 +1,38 @@
import torch
from .allocator import VaccHugeMemoryAllocator, LLMMemoryRecycler
'''
QWen3-Moe support 56K input tokens, there are 3 buffers can recycle:
1. parallel_embedding output buffer
2. mla_oproject output buffer
3. moe_shared_mlp output buffer
'''
class QWen3MoeMemoryRecycler(LLMMemoryRecycler):
def __init__(self):
super().__init__()
def alloc_memory_recycler_qwen3_moe(tokens,
alloctor:VaccHugeMemoryAllocator,
recycler:QWen3MoeMemoryRecycler,
dtype:torch.dtype = torch.bfloat16):
if not alloctor.enable:
print("memory alloctor is not Init.")
return False
recycler.clear()
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
if out_buffers is None:
print("qwen3_moe memory recycler buffers alloc fail. now disable it")
return False
if len(out_buffers) != recycler.count:
print("memory recycler buffers not equal qwen3_moe.")
return False
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
return True

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,33 @@
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vacc_tools.trace_logger import get_trace_api
trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
get_trace_api("deepseek")
)
logger = init_logger(__name__)
class Qwen2_5_VisionAttention(nn.Module):
# @trace_time('Qwen2_5_VisionAttention_vacc_split_qkv')
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q1, k1, v1 = qkv.chunk(3, dim=-1)
q1, k1, v1 = (x.view(*new_shape) for x in (q1, k1, v1))
return q1, k1, v1
class Qwen2_5_VisionPatchEmbed(nn.Module):
# convert conv3d to matmul
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.matmul(x, self.proj.weight.view(self.hidden_size, -1).T)

View File

@@ -0,0 +1,285 @@
"""Inference-only Qwen2VL model compatible with HuggingFace weights."""
import torch
import torch.nn as nn
from typing import Optional
from vllm.attention.layer import check_upstream_fa_availability
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention as Qwen2VisionAttentionOrg
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend
from vllm.logger import init_logger
from .hf_processor.qwenvl_processor import Qwen2VLProcessorWithVacc
from .hf_processor.qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm_vacc.vllm.model_executor.models.vars import USE_FUSED_QWEN_ATTENTION
logger = init_logger(__name__)
class Qwen2VisionPatchEmbed(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self.proj, 'bias') and self.proj.bias is not None:
return torch.nn.functional.linear(x, self.proj.weight.view(self.hidden_size, -1), self.proj.bias)
return torch.matmul(x, self.proj.weight.view(self.embed_dim, -1).T)
class Qwen2VLProcessingInfo():
def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessorWithVacc:
return self.ctx.get_hf_processor(
Qwen2VLProcessorWithVacc,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFastWithVacc:
return self.get_hf_processor(**kwargs).image_processor
import torch.nn.functional as F
class Qwen2VisionTransformer(nn.Module):
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
if USE_FUSED_QWEN_ATTENTION:
try:
from torch_vacc.vacc.custom_qwen3_ops import rot_pos_emb_qwenvl
sin_cache, cos_cache = rot_pos_emb_qwenvl(grid_thw, self.embed_dim, self.num_heads, self.spatial_merge_size, self.dtype, self.device)
except Exception as e:
logger.error(f"rot_pos_emb fused ops run fail, e:{e}")
rotary_pos_emb = None
else:
rotary_pos_emb = self.rot_pos_emb(grid_thw)
sin_cache, cos_cache = None, None
# tmp_rotary_pos_emb = self.transformer_rot_pos_emb(grid_thw)
# qwen3_rotary_pos_emb = self.qwen3_rot_pos_emb(grid_thw)
# compute cu_seqlens
grid_thw_ = torch.tensor(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
grid_thw_[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
x = x.unsqueeze(1)
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
if USE_FUSED_QWEN_ATTENTION:
cu_seqlens = cu_seqlens.tolist()
max_seqlen, seqlens = None, None
else:
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
sin_cache=sin_cache,
cos_cache=cos_cache,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
# adapter
x = self.merger(x)
return x
class Qwen2VisionAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional["QuantizationConfig"] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super(Qwen2VisionAttentionOrg, self).__init__()
# Per attention head and per partition values.
self.tp_size = (1 if use_data_parallel else
parallel_state.get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
# self.qkv = ColumnParallelLinear(input_size=embed_dim,
# output_size=3 * projection_size,
# quant_config=quant_config,
# prefix=f"{prefix}.qkv",
# disable_tp=use_data_parallel)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now.")
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q1, k1, v1 = qkv.chunk(3, dim=-1)
q1, k1, v1 = (x.view(*new_shape) for x in (q1, k1, v1))
return q1, k1, v1
class Qwen2VisionBlock(nn.Module):
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
sin_cache: torch.Tensor,
cos_cache: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
if USE_FUSED_QWEN_ATTENTION:
total_bytes = x.numel() * x.element_size() * get_tp_group().world_size
reduce_result = get_tp_group().world_size > 1 and total_bytes < 4194304
# hidden_states = self.norm1(x)
attn_outs = torch.vacc.fuse_atten_vit(
hidden_states=x.view(-1, x.shape[-1]),
hidden_states_norm_weight = self.norm1.weight,
hidden_states_norm_bias = self.norm1.bias,
# hidden_states_norm_weight = torch.Tensor(),
# hidden_states_norm_bias = torch.Tensor(),
qkv_proj_weight=self.attn.qkv.weight,
qkv_proj_bias=self.attn.qkv.bias,
sin_cache=sin_cache,
cos_cache=cos_cache,
o_proj_weight=self.attn.proj.weight,
o_proj_bias=self.attn.proj.bias if self.attn.proj.tp_rank == 0 else torch.Tensor(),
seq_lens=cu_seqlens,
sm_scale=-1,
num_attention_heads=self.attn.num_attention_heads_per_partition * get_tp_group().world_size,
flash_attention=True,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos
)
attn_out = attn_outs[0] if reduce_result else tensor_model_parallel_all_reduce(attn_outs[0])
attn_out = attn_out.view(x.shape)
x = x + attn_out
else:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
return x
class Qwen2VisionMLP():
def forward(self, x: torch.Tensor):
try:
from torch_vacc.vacc import fuse_mlp_vision
hiddens_shape = x.shape
tp_rank_id = get_tp_group().rank_in_group
fc2_bias = None if tp_rank_id > 0 else self.fc2.bias
hidden_states = fuse_mlp_vision(x.view(-1, hiddens_shape[-1]),
self.fc1.weight, # nk
self.fc2.weight, # nk
self.fc1.bias,
fc2_bias,
2) # 0 is gelu, 1 is relu, 2 is quick_gelu
vacc_res = tensor_model_parallel_all_reduce(hidden_states).view(hiddens_shape)
return vacc_res
except Exception as e:
logger.error(f"mlp fused ops run fail, e:{e}")
x_parallel, _ = self.fc1(x)
x_parallel = self.act(x_parallel)
x, _ = self.fc2(x_parallel)
return x
class Qwen2VisionPatchMerger():
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size)
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
try:
from torch_vacc.vacc import patch_merger_vision
tp_rank_id = get_tp_group().rank_in_group
fc2_bias = None if tp_rank_id > 0 else mlp_fc2.bias
hidden_states = patch_merger_vision(x,
mlp_fc1.weight,
mlp_fc2.weight,
mlp_fc1.bias,
fc2_bias,
0) #0 is gelu, 1 is silu
vacc_res = tensor_model_parallel_all_reduce(hidden_states)
return vacc_res
except Exception as e:
logger.error(f"merge patch fused vision mlp run fail, cased by:{e}")
x_parallel, _ = mlp_fc1(x)
x_parallel = mlp_act(x_parallel)
out, _ = mlp_fc2(x_parallel)
return out

View File

@@ -0,0 +1,194 @@
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union, Any, Dict
import torch
from torch import nn
from vllm.logger import init_logger
from .vars import *
from vllm.model_executor.layers.linear import UnquantizedLinearMethod as UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
# uniform the params names from different quantize method
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
if isinstance(quant_method, UnquantizedLinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = torch.Tensor()
fused_params[name + '_bias'] = None
fused_params[name + '_qzeros'] = None
elif isinstance(quant_method, Fp8LinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, GPTQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, AWQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
else:
raise ValueError(f"Unsupported quant_method: {quant_method}")
class Qwen3Attention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None # new added params
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
# reshape kvcache
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
is_decode = attn_metadata.prefill_metadata is None
else:
is_decode = attn_metadata_all.prefill_metadata is None
attn_metadata = attn_metadata_all
reduce_result = is_decode
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
# # only support 4M now
# if total_bytes < 4194304:
# reduce_result = True
if USE_FUSED_QWEN_ATTENTION:
if is_decode:
positions = [i - 1 for i in attn_metadata.seq_lens]
cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions]
sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions]
else:
cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
if residual is None:
res_out = hidden_states
#from torch_vacc.vacc import fuse_atten_qwen3
attn_outs = None
if not is_decode:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
if memory_recycler is not None:
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
total_num_kv_heads = self.total_num_kv_heads
if self.total_num_kv_heads < get_tp_group().world_size:
assert get_tp_group().world_size % self.total_num_kv_heads == 0
total_num_kv_heads = get_tp_group().world_size
attn_outs = torch.vacc.fuse_atten_qwen3(
# attn_outs = vacc_fused_attn_qwen3_naive(
hidden_states=hidden_states,
residual=residual,
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
q_layernorm_weight=self.fused_params['q_norm_weight'],
k_layernorm_weight=self.fused_params['k_norm_weight'],
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables, # tensor
block_group_size=env_blk_grp_size,
o_proj_weight=self.fused_params['o_proj_weight'],
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
o_proj_bias=self.fused_params['o_proj_bias'],
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
seq_lens=attn_metadata.seq_lens,
sm_scale=self.scaling,
num_attention_heads=self.total_num_heads,
num_key_value_heads=total_num_kv_heads,
flash_attention=is_decode, # decode use flash_atten by default
is_decode=is_decode,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output_opt=attn_outs,
res_opt=residual)
if residual is None:
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
else:
res_out = attn_outs[1]
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
return attn_out, res_out
else:
# orig code
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen3DecoderLayer(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
if USE_FUSED_QWEN_ATTENTION:
if not hasattr(self.self_attn, "fused_params"):
self.self_attn.fused_params = {}
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
hidden_states, residual = self.self_attn(
positions=positions,
hidden_states=hidden_states,
residual=residual)
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual

View File

@@ -0,0 +1,790 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union, List
import itertools
import os
import torch
from torch import nn
from transformers import PretrainedConfig
from torch_vacc.vacc.custom_ops_cpu import (
w8a8_block_fp8_linear as w8a8_block_fp8_linear_cpu,
)
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock, Qwen3MoeMLP
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
from ..ops.mrope_op import get_sin_cos_mrope
from ..ops.qwen3_fused_moe import vacc_fused_prefill_moe_fp8, vacc_fused_decode_moe_fp8, recompute_moe_layer_blocksize
from .vars import *
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
# uniform the params names from different quantize method
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
if isinstance(quant_method, UnquantizedLinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = None
fused_params[name + '_bias'] = None
fused_params[name + '_qzeros'] = None
if isinstance(quant_method, Fp8LinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, GPTQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, AWQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
else:
raise ValueError(f"Unsupported quant_method: {quant_method}")
def apply_w8a8_block_fp8_linear_v2(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_scale = None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
block_size = [
weight.shape[-2] // weight_scale.shape[-2],
weight.shape[-1] // weight_scale.shape[-1],
]
if input.device.type == "vacc":
output = torch.vacc.w8a8_block_fp8_linear(
input_2d, weight, input_scale, weight_scale, block_size
)
else:
output = w8a8_block_fp8_linear_cpu(
input_2d, weight, input_scale, weight_scale, block_size
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def vacc_fused_attn_qwen3_naive(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
hidden_states_norm_weight: torch.Tensor,
qkv_proj_weight: torch.Tensor,
qkv_proj_weight_scale: torch.Tensor,
qkv_proj_bias: Optional[torch.Tensor],
qkv_proj_qzeros: Optional[torch.Tensor],
q_layernorm_weight: torch.Tensor,
k_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
o_proj_weight: torch.Tensor,
o_proj_weight_scale: torch.Tensor,
o_proj_bias: Optional[torch.Tensor],
o_proj_qzeros: Optional[torch.Tensor],
seq_lens: List[int],
sm_scale: float,
num_attention_heads: int,
num_key_value_heads: int,
flash_attention: bool,
is_decode: bool,
reduce_result: bool,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
block_size: int = 16
):
if residual is not None:
hidden_states = hidden_states + residual
residual_out = hidden_states
hidden_states = torch.vacc.rms_norm(
hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0)
# NOTE: for qwen3 and qwen2.5, head_dim is always 128
head_dim = 128
# qkv gen
qkv = apply_w8a8_block_fp8_linear_v2(
input=hidden_states,
weight=qkv_proj_weight,
weight_scale=qkv_proj_weight_scale)
num_q_heads = num_attention_heads // world_size
num_kv_heads = num_key_value_heads // world_size
q_size = head_dim * num_q_heads
kv_size = head_dim * num_kv_heads
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
# q_by_head = self.q_norm.forward_native(q_by_head)
q_norm = torch.vacc.rms_norm(q_by_head, q_layernorm_weight, 1e-6)
# q = q_by_head.view(q.shap
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
# k_by_head = k_norm.forward_native(k_by_head)
k_norm = torch.vacc.rms_norm(k_by_head, k_layernorm_weight, 1e-6)
# k = k_by_head.view(k.shap
v = v.view(-1, num_kv_heads, head_dim)
# q, k = self.rotary_emb(positions, q, k)
start = 0
attn_outs = []
if is_decode:
# convert block_tables to 8K group index
block_per_group = block_group_size // block_size
block_tables = (block_tables // block_per_group).to(torch.int32)
# logger.warning(f"decode block table: {block_tables}")
num_blocks = kv_cache.shape[1]
key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim)
value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim)
# bs loop
for i in range(len(seq_lens)):
if not is_decode:
# prefill
end = start + seq_lens[i]
else:
# decode
end = start + 1
cos = cos_cache[i].unsqueeze(-2)
sin = sin_cache[i].unsqueeze(-2)
q, k = torch.vacc.RotaryPosEmbedding(
q_norm[start : end, ...], k_norm[start : end, ...], cos, sin, 0, "neox")
# cache concat
torch.vacc.reshape_and_cache_attention(k, key_cache_split, slot_mapping[start : end, ...])
torch.vacc.reshape_and_cache_attention(v[start : end, ...], value_cache_split, slot_mapping[start : end, ...])
# attn_output = self.attn(q, k, v)
if not is_decode:
# prefill
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v[start : end, ...],
attn_mask = None,
dropout_p = 0.0,
is_causal = True, #causal_attn and not self.need_mask,
is_train = False,
recompute = False,
flash_attention = False,
sm_scale=sm_scale)
else:
# decode
key_cache = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
value_cache = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
k_slices = key_cache[block_tables[i], ...]
k_cached = torch.cat(
[k_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
dim=0,
)
k_cached = k_cached.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
v_slices = value_cache[block_tables[i], ...]
v_cached = torch.cat(
[v_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
dim=0,
)
v_cached = v_cached.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k_cached,
value=v_cached,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,#flash_attention,
sm_scale=sm_scale)
attn_outs.append(attn_out)
# update start
start = end
attn_out = torch.cat(attn_outs, dim=0)
# output, _ = self.o_proj(attn_output)
o_proj = apply_w8a8_block_fp8_linear_v2(
input = attn_out.reshape(hidden_states.shape[0], -1),
weight = o_proj_weight,
weight_scale = o_proj_weight_scale,
)
if reduce_result:
o_proj = tensor_model_parallel_all_reduce(o_proj)
if residual is not None:
return o_proj, residual_out
return o_proj
def Qwen3MoeSparseMoeBlock__init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super(Qwen3MoeSparseMoeBlock, self).__init__()
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=True,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate")
#patch here to transpose w2/w2_scale's data arrange , only for block quant
if hasattr(self.experts.quant_method, 'quant_config') and hasattr(self.experts.quant_method.quant_config, 'weight_block_size'):
self.experts.w2_weight.data = self.experts.w2_weight.data.transpose(-1,-2).contiguous().transpose(-1,-2)
self.experts.w2_weight_scale_inv.data = self.experts.w2_weight_scale_inv.data.transpose(-1,-2).contiguous().transpose(-1,-2)
if hasattr(self.experts, 'w2_weight_scale_inv_prefill'):
self.experts.w2_weight_scale_inv_prefill.data = self.experts.w2_weight_scale_inv_prefill.data.transpose(-1,-2).contiguous().transpose(-1,-2)
def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"],
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]],
positions: Union[torch.Tensor, list],
is_decode: bool):
if isinstance(rotary_emb, MRotaryEmbedding):
# get mrope sin/cos
cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions)
if len(attn_metadata.seq_lens) > 1:
if is_decode:
cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens))
sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens))
else:
cos_cache = torch.split(cos_cache, attn_metadata.seq_lens)
sin_cache = torch.split(sin_cache, attn_metadata.seq_lens)
else:
cos_cache = [cos_cache]
sin_cache = [sin_cache]
else:
if is_decode:
positions = [i - 1 for i in attn_metadata.seq_lens]
cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions]
sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions]
else:
cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
return cos_cache, sin_cache
class Qwen3MoeDecoderLayer(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
cos_cache: list[torch.Tensor],
sin_cache: list[torch.Tensor]
) -> torch.Tensor:
# Self Attention
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
if USE_FUSED_QWEN_ATTENTION:
if not hasattr(self.self_attn, "fused_params"):
self.self_attn.fused_params = {}
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
hidden_states, residual = self.self_attn(
positions=positions,
hidden_states=hidden_states,
residual=residual,
cos_cache=cos_cache,
sin_cache=sin_cache)
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
cos_cache=cos_cache,
sin_cache=sin_cache
)
# # Fully Connected
# hidden_states, residual = self.post_attention_layernorm(
# hidden_states, residual)
# hidden_states = self.mlp(hidden_states)
# return hidden_states, residual
# TODO for noquant or not block_quant
if not hasattr(self.mlp.experts.quant_method, 'quant_config') or \
not hasattr(self.mlp.experts.quant_method.quant_config, 'weight_block_size'):
if not isinstance(self.mlp.experts.quant_method, MoeWNA16Method):
logger.warning('TODO for noquant or other quant')
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
if isinstance(attn_metadata, dict):
# is_prefill = get_forward_context().attn_metadata['test'].prefill_metadata
attn_metadata_0 = get_forward_context().attn_metadata.items().__iter__().__next__()[1]
is_prefill = attn_metadata_0.prefill_metadata
else:
is_prefill = get_forward_context().attn_metadata.prefill_metadata
quant_method = self.mlp.experts.quant_method if isinstance(self.mlp, Qwen3MoeSparseMoeBlock) \
else self.mlp.down_proj.quant_method
if is_prefill is not None:
if isinstance(quant_method, MoeWNA16Method):
try:
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_prefill_moe_gptq_int4
return vacc_fused_prefill_moe_gptq_int4(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_prefill_moe_gptq_int4 fail: {e}')
else:
recompute_moe_layer_blocksize(self.mlp.experts)
try:
return vacc_fused_prefill_moe_fp8(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_prefill_moe_fp8 fail: {e}')
else:
if isinstance(quant_method, MoeWNA16Method):
try:
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_decode_moe_gptq_int4
return vacc_fused_decode_moe_gptq_int4(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_decode_moe_gptq_int4 fail: {e}')
else:
try:
return vacc_fused_decode_moe_fp8(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_decode_moe_fp8 fail: {e}')
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen3MoeAttention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None, # new added params
cos_cache: list[torch.Tensor] = None,
sin_cache: list[torch.Tensor] = None,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
# reshape kvcache
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
is_decode = attn_metadata.prefill_metadata is None
else:
is_decode = attn_metadata_all.prefill_metadata is None
attn_metadata = attn_metadata_all
reduce_result = is_decode
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
# # only support 4M now
# if total_bytes < 4194304:
# reduce_result = True
if USE_FUSED_QWEN_ATTENTION:
if cos_cache is None or sin_cache is None:
cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode)
if residual is None:
res_out = hidden_states
#from torch_vacc.vacc import fuse_atten_qwen3
attn_outs = None
if not is_decode:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
if memory_recycler is not None:
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
total_num_kv_heads = self.total_num_kv_heads
if self.total_num_kv_heads < get_tp_group().world_size:
assert get_tp_group().world_size % self.total_num_kv_heads == 0
total_num_kv_heads = get_tp_group().world_size
attn_outs = torch.vacc.fuse_atten_qwen3(
# attn_outs = vacc_fused_attn_qwen3_naive(
hidden_states=hidden_states,
residual=residual,
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
q_layernorm_weight=self.fused_params['q_norm_weight'],
k_layernorm_weight=self.fused_params['k_norm_weight'],
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables,
block_group_size=env_blk_grp_size,
o_proj_weight=self.fused_params['o_proj_weight'],
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
o_proj_bias=self.fused_params['o_proj_bias'],
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
seq_lens=attn_metadata.seq_lens,
sm_scale=self.scaling,
num_attention_heads=self.total_num_heads,
num_key_value_heads=total_num_kv_heads,
flash_attention=is_decode, # decode use flash_atten by default
is_decode=is_decode,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output_opt=attn_outs,
res_opt=residual)
# debug_qwen3_moe_attention_prefill(hidden_states=hidden_states,
# residual=residual,
# attn_outs=attn_outs,
# fused_params=self.fused_params,
# attn_metadata=attn_metadata,
# is_decode=is_decode,
# sin_cache=sin_cache,
# cos_cache=cos_cache,
# kv_cache=kv_cache,
# env_blk_grp_size=env_blk_grp_size,
# scaling=self.scaling,
# total_num_heads=self.total_num_heads,
# total_num_kv_heads=self.total_num_kv_heads,
# world_size=get_tp_group().world_size,
# rank=get_tp_group().rank_in_group,
# group_id=get_tp_group().group_id,
# dev_info=get_tp_group().rank_device_infos)
if residual is None:
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
else:
res_out = attn_outs[1]
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
return attn_out, res_out
else:
# orig code
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm.forward_native(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm.forward_native(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen3MoeModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.qwen3_moe_weight_capture import Qwen3Moe_WeightCapture
self.weight_capture = Qwen3Moe_WeightCapture(self.layers, self.start_layer, self.end_layer)
self.layer_nums = self.end_layer - self.start_layer
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# fused layer decoder only support fp8 quant model now
use_default_layer = self.weight_capture.support_fused_weights and USE_DECODER_LAYER_FUSE_MODE
# print('Qwen3MoeModel attn_metadata', attn_metadata)
if isinstance(attn_metadata_all, dict):
# is_decode = attn_metadata_all['test'].prefill_metadata is None
# attn_metadata = attn_metadata_all['test']
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
is_decode = attn_metadata.prefill_metadata is None
else:
is_decode = attn_metadata_all.prefill_metadata is None
attn_metadata = attn_metadata_all
if(use_default_layer and is_decode):
from torch_vacc.vacc.custom_ops import qwen3_fuse_attention_moe_decode
layer0 = self.layers[self.start_layer]
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode=True)
for i in range(0, self.layer_nums):
layer = self.layers[i + self.start_layer]
kv_cache = layer.self_attn.attn.kv_cache[forward_context.virtual_engine]
num_kv_heads = max(1, layer.self_attn.total_num_kv_heads // get_tp_group().world_size)
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, layer.self_attn.head_dim)
total_num_kv_heads = layer.self_attn.total_num_kv_heads
if layer.self_attn.total_num_kv_heads < get_tp_group().world_size:
assert get_tp_group().world_size % layer.self_attn.total_num_kv_heads == 0
total_num_kv_heads = get_tp_group().world_size
hidden_states, residual = qwen3_fuse_attention_moe_decode(hidden_states, residual,
hidden_states_norm_weight=self.weight_capture.layer_mapper.attn_args._0_input_layernorm_weight[i],
qkv_proj_weight=self.weight_capture.layer_mapper.attn_args._1_qkv_proj_weight[i],
qkv_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._2_qkv_proj_weight_scale[i],
qkv_proj_bias=self.weight_capture.layer_mapper.attn_args._3_qkv_proj_bias[i],
qkv_proj_qzeros=self.weight_capture.layer_mapper.attn_args._4_qkv_proj_qzeros[i],
q_layernorm_weight=self.weight_capture.layer_mapper.attn_args._5_q_norm_weight[i],
k_layernorm_weight=self.weight_capture.layer_mapper.attn_args._6_k_norm_weight[i],
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables,
block_group_size=env_blk_grp_size,
o_proj_weight=self.weight_capture.layer_mapper.attn_args._13_o_proj_weight[i],
o_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._14_o_proj_weight_scale[i],
o_proj_bias=self.weight_capture.layer_mapper.attn_args._15_o_proj_bias[i],
o_proj_qzeros=self.weight_capture.layer_mapper.attn_args._16_o_proj_qzeros[i],
seq_lens_num=attn_metadata.seq_lens,
sm_scale=layer.self_attn.scaling,
num_attention_heads=layer.self_attn.total_num_heads,
num_key_value_heads=total_num_kv_heads,
flash_attentiton=True,
is_decode=True,
reduce_result=True,
# moe
rms_weight=self.weight_capture.layer_mapper.moe_args._0_rms_norm_weight[i],
moe_weight_13=self.weight_capture.layer_mapper.moe_args._1_w13_weight[i],
moe_weight_2=self.weight_capture.layer_mapper.moe_args._2_w2_weight[i],
moe_weight_13_dequat=self.weight_capture.layer_mapper.moe_args._3_w13_weight_scale_inv[i],
moe_weight_2_dequant=self.weight_capture.layer_mapper.moe_args._4_w2_weight_scale_inv[i],
gate_weight=self.weight_capture.layer_mapper.moe_args._5_gate_weight[i],
block_size_13=self.weight_capture.layer_mapper.moe_args._6_w13_block_size,
block_size_2=self.weight_capture.layer_mapper.moe_args._7_w2_block_size,
# dist
world_size=self.weight_capture.layer_mapper.dist_args._0_world_size,
rank=self.weight_capture.layer_mapper.dist_args._1_rank,
group_id=self.weight_capture.layer_mapper.dist_args._2_group_id,
dev_info=self.weight_capture.layer_mapper.dist_args._3_dev_info)
else:
layer0 = self.layers[self.start_layer]
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual, cos_cache, sin_cache )
if deepstack_input_embeds is not None and i in range(0, len(deepstack_input_embeds)):
if isinstance(deepstack_input_embeds, IntermediateTensors):
hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{i}"]
elif isinstance(deepstack_input_embeds, torch.Tensor):
hidden_states = hidden_states + deepstack_input_embeds[i]
else:
raise ValueError(f'unsupported type: {type(deepstack_input_embeds)}')
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states, residual)
return hidden_states
class Qwen3MoeForCausalLM(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds = None,
) -> Union[torch.Tensor, IntermediateTensors]:
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if attn_metadata.prefill_metadata is not None:
from .memory.memory_recycling import alloc_memory_recycler
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
if hasattr(attn_metadata, 'num_prefill_tokens'):
tokens = attn_metadata.num_prefill_tokens
else:
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
vllm_model_mode = "qwen3_moe"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
vllm_model_mode = config_infos
if get_tp_group().rank_in_group == 0:
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
logger.info(memory_infos)
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size, dtype=self.lm_head.weight.dtype):
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, deepstack_input_embeds)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
from .memory.memory_recycling import init_huge_memory_allocator
from .vars import LLM_MAX_PREFILL_SEQ_LEN
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
# default is deepseek, config can set to ['deepseek_mtp',]
model_name = "qwen3_moe"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
model_name = config_infos
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
from vllm.model_executor.models.utils import AutoWeightsLoader
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@@ -0,0 +1,362 @@
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings)
from vllm.logger import init_logger
from .hf_processor.qwenvl_processor import Qwen3VLProcessorWithVacc
from .hf_processor.qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
from .vars import USE_FUSED_QWEN_ATTENTION
# from vacc_tools.trace_logger import get_trace_api
# trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
# get_trace_api("Qwen3vl")
# )
logger = init_logger(__name__)
class Qwen3_VisionPatchEmbed(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self.proj, 'bias') and self.proj.bias is not None:
return torch.nn.functional.linear(x, self.proj.weight.view(self.hidden_size, -1), self.proj.bias)
return torch.matmul(x, self.proj.weight.view(self.hidden_size, -1).T)
class Qwen3_VisionBlock(nn.Module):
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | list[torch.Tensor],
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
if USE_FUSED_QWEN_ATTENTION:
assert isinstance(rotary_pos_emb, list), "qwen3vl vit-attention need rotary_pos_emb is list[torch.Tensor]"
total_bytes = x.numel() * x.element_size() * get_tp_group().world_size
reduce_result = get_tp_group().world_size > 1 and total_bytes < 4194304
# hidden_states = self.norm1(x)
attn_outs = torch.vacc.fuse_atten_vit(
hidden_states=x.view(-1, x.shape[-1]),
hidden_states_norm_weight = self.norm1.weight,
hidden_states_norm_bias = self.norm1.bias,
# hidden_states_norm_weight = torch.Tensor(),
# hidden_states_norm_bias = torch.Tensor(),
qkv_proj_weight=self.attn.qkv.weight,
qkv_proj_bias=self.attn.qkv.bias,
sin_cache=rotary_pos_emb[0],
cos_cache=rotary_pos_emb[1],
o_proj_weight=self.attn.proj.weight,
o_proj_bias=self.attn.proj.bias if self.attn.proj.tp_rank == 0 else torch.Tensor(),
seq_lens=cu_seqlens,
sm_scale=-1,
num_attention_heads=self.attn.num_attention_heads_per_partition * get_tp_group().world_size,
flash_attention=True,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos
)
attn_out = attn_outs[0] if reduce_result else tensor_model_parallel_all_reduce(attn_outs[0])
attn_out = attn_out.view(x.shape)
x = x + attn_out
else:
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens)
x = x + self.mlp(self.norm2(x))
return x
class Qwen3_VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw):
if USE_FUSED_QWEN_ATTENTION:
try:
from torch_vacc.vacc.custom_qwen3_ops import rot_pos_emb_qwenvl
return rot_pos_emb_qwenvl(grid_thw, self.hidden_size, self.num_heads, self.spatial_merge_size, self.dtype, self.device)
except Exception as e:
logger.error(f"rot_pos_emb fused ops run fail, e:{e}")
pos_ids = []
# Support both Tensor and list inputs for DP path
if isinstance(grid_thw, list):
grid_list = grid_thw
max_grid_size = max(max(h, w) for _, h, w in grid_list)
else:
grid_list = grid_thw.tolist()
max_grid_size = int(grid_thw[:, 1:].max().item())
for t, h, w in grid_list:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def fast_pos_embed_interpolate(self,
grid_thw: list[list[int]]) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
m_size = self.spatial_merge_size
hidden_dim = self.pos_embed.embedding_dim
try:
from torch_vacc.vacc.custom_qwen3_ops import fast_pos_embed_interpolate_qwenvl
return fast_pos_embed_interpolate_qwenvl(self.pos_embed.weight, grid_thw, num_grid_per_side, m_size, hidden_dim)
except Exception as e:
logger.error(f"fast_pos_embed_interpolate fused ops run fail, e:{e}")
outputs = []
for t, h, w in grid_thw:
h_idxs = torch.linspace(0,
num_grid_per_side - 1,
h,
dtype=torch.float32,
device=self.device)
w_idxs = torch.linspace(0,
num_grid_per_side - 1,
w,
dtype=torch.float32,
device=self.device)
h_floor = h_idxs.to(torch.long)
w_floor = w_idxs.to(torch.long)
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
dh = h_idxs - h_floor
dw = w_idxs - w_floor
# Create meshgrid view for all h, w vars
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor,
w_floor,
indexing='ij')
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil,
w_ceil,
indexing='ij')
h_floor_grid_idx = h_floor_grid * num_grid_per_side
h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11 = dh_grid * dw_grid
w10 = dh_grid - w11
w01 = dw_grid - w11
w00 = 1 - dh_grid - dw_grid + w11
idx00 = h_floor_grid_idx + w_floor_grid
idx01 = h_floor_grid_idx + w_ceil_grid
idx10 = h_ceil_grid_idx + w_floor_grid
idx11 = h_ceil_grid_idx + w_ceil_grid
indices = torch.stack([idx00, idx01, idx10, idx11],
dim=0).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11],
dim=0).reshape(4, -1, 1)
weights = weights.to(dtype=self.dtype, device=self.device)
embeds = self.pos_embed(indices)
weighted_embeds = embeds * weights
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
combined = p0 + p1 + p2 + p3
combined = combined.view(h * w, hidden_dim)
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
repeated = repeated.view(t, h // m_size, m_size, w // m_size,
m_size, hidden_dim)
repeated = repeated.permute(0, 1, 3, 2, 4,
5).reshape(-1, hidden_dim)
outputs.append(repeated)
return torch.cat(outputs, dim=0)
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
grid_thw_tensor = torch.tensor(grid_thw,
dtype=torch.int32)
cu_seqlens = torch.repeat_interleave(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
grid_thw_tensor[:, 0]).cumsum(
dim=0,
dtype=grid_thw_tensor.dtype
if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
hidden_states = hidden_states.unsqueeze(1)
if isinstance(rotary_pos_emb, torch.Tensor):
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
if USE_FUSED_QWEN_ATTENTION:
max_seqlen, seqlens = None, None
cu_seqlens = cu_seqlens.tolist()
else:
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(
layer_num)
deepstack_feature = self.deepstack_merger_list[
deepstack_merger_idx](hidden_states)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
hidden_states = torch.cat(
[hidden_states] + deepstack_feature_lists,
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states
class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
deepstack_input_embeds = None
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if self.use_deepstack:
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
input_ids, inputs_embeds, multimodal_embeddings)
self._set_deepstack_input_embeds(deepstack_input_embeds)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
# commit here to remove deepstack_input_embeds copy
# if self.use_deepstack:
# if deepstack_input_embeds is None:
# deepstack_input_embeds = torch.zeros_like(
# inputs_embeds).unsqueeze(0).repeat(
# self.deepstack_num_level, 1, 1).contiguous()
# self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
return #patch here to optimize deepstack_input_embeds
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
class Qwen3VLProcessingInfo():
def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessorWithVacc:
processor = self.ctx.get_hf_processor(
Qwen3VLProcessorWithVacc,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
return processor
def get_image_processor(self,
**kwargs: object) -> Qwen2VLImageProcessorFastWithVacc:
return self.get_hf_processor(**kwargs).image_processor
# def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
# return self.get_hf_processor(**kwargs).video_processor
class Qwen3_VisionPatchMerger():
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
x = self.norm(x.view(-1, self.hidden_size))
else:
x = self.norm(x).view(-1, self.hidden_size)
try:
from torch_vacc.vacc import patch_merger_vision
tp_rank_id = get_tp_group().rank_in_group
fc2_bias = None if tp_rank_id > 0 else self.linear_fc2.bias
hidden_states = patch_merger_vision(x,
self.linear_fc1.weight, self.linear_fc2.weight,
self.linear_fc1.bias, fc2_bias,
0) #0 is gelu, 1 is silu
return tensor_model_parallel_all_reduce(hidden_states)
except Exception as e:
logger.error(f"merge patch fused vision mlp run fail, cased by:{e}")
x_parallel, _ = self.linear_fc1(x)
x_parallel = self.act_fn(x_parallel)
out, _ = self.linear_fc2(x_parallel)
return out
class Qwen3_VisionMLP():
def forward(self, x: torch.Tensor):
try:
from torch_vacc.vacc import fuse_mlp_vision
hiddens_shape = x.shape
tp_rank_id = get_tp_group().rank_in_group
fc2_bias = None if tp_rank_id > 0 else self.linear_fc2.bias
hidden_states = fuse_mlp_vision(x.view(-1, hiddens_shape[-1]),
self.linear_fc1.weight, self.linear_fc2.weight,
self.linear_fc1.bias, fc2_bias,
0) #0 is gelu, 1 is silu
return tensor_model_parallel_all_reduce(hidden_states).view(hiddens_shape)
except Exception as e:
logger.error(f"qwen3vl fused vision mlp run fail, cased by:{e}")
return self.linear_fc2(self.act_fn(self.linear_fc1(x)))

View File

@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from vllm.model_executor.models.bert import _decode_token_type_ids
class RobertaEmbedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids)
# position_embeddings = self.position_embeddings(position_ids)
# token_type_embeddings = self.token_type_embeddings(token_type_ids)
# embeddings = inputs_embeds + token_type_embeddings + position_embeddings
# embeddings = self.LayerNorm(embeddings)
embeddings = torch.vacc.fuse_bge_embedding_stage1(inputs_embeds, position_ids, self.position_embeddings.weight, token_type_ids, self.token_type_embeddings.weight, self.LayerNorm.weight, self.LayerNorm.bias, self.LayerNorm.eps)
return embeddings

View File

@@ -0,0 +1,58 @@
import os
# support Q,KV Gen with TP
USE_PARALLEL_Q_KV_GEN = True
# support Merge Q,KV Gen, Q,QR weights Merge
USE_MERGE_Q_KV_GEN_AND_Q_QR = True
# Support FP8 Weights for WQ,QR
W_Q_W_QR_WUV_WUK_USE_FP8 = True
# fused prefill
USE_FUSED_PREFILL = True
# fused prefill stage1
USE_FUSED_PREFILL_STAGE1 = True
# All Request Seq Lens
DO_SEQ_LENS = 0
def update_seqence_length(seq_num):
global DO_SEQ_LENS
DO_SEQ_LENS = seq_num
USE_DS3_SAMPLER = int(os.getenv("USE_DS3_SAMPLER", 1))
USE_DS3_SAMPLER_OP = int(os.getenv("USE_DS3_SAMPLER_OP", 1))
# cut prefill seq len
CUT_PREFILL_SEQ_LEN = int(os.getenv("CUT_PREFILL_SEQ_LEN", -1))
# llm max prefill seq len
LLM_MAX_PREFILL_SEQ_LEN = int(os.getenv("LLM_MAX_PREFILL_SEQ_LEN", 56 * 1024))
# All Fused Decode, default is cpu loop
USE_DECODER_LAYER_FUSE_MODE = int(os.getenv("USE_DECODER_LAYER_FUSE_MODE", 1))
# Fused all layers, use cmcu loop
FUSE_ALL_DECODER_LAYERS = int(os.getenv("FUSE_ALL_DECODER_LAYERS", 1))
# where to use flash attention (default: 1)
USE_FLASH_ATTENTION = int(os.getenv("USE_FLASH_ATTENTION", 1))
# transpose gptq weight KN => NK
TRANSPOSE_GPTQ_WEIGHT = True
# qwen fused attention
USE_FUSED_QWEN_ATTENTION = int(os.getenv("USE_FUSED_QWEN_ATTENTION", 1))
# support MTP eh_proj with TP
USE_PARALLEL_MTP_EH_PROJ = int(os.getenv("USE_PARALLEL_MTP_EH_PROJ", 1))
# kv_cache group size
BLOCK_GROUP_SIZE = int(os.getenv("BLOCK_GROUP_SIZE", 8192))
# bert fused attention
USE_FUSED_BERT_ATTENTION = int(os.getenv("USE_FUSED_BERT_ATTENTION", 1))
# fused mlp vision
USE_FUSED_MLP_VISION = int(os.getenv("USE_FUSED_MLP_VISION", 1))

View File

@@ -0,0 +1,491 @@
import torch
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
DeepseekV2MLAAttention,
DeepseekV2MLP,
DeepseekV2MoE)
from ..vars import *
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
OUTPUT_ARGS_LOGS = False
class DistributedArgs():
def __init__(self):
self._0_world_size = 32
self._1_rank = -1
self._2_group_id = 0
self._3_dev_info = []
def logs(self):
print("dist self._0_world_size = " , self._0_world_size)
print("dist self._1_rank = " , self._1_rank)
print("dist self._2_group_id = " , self._2_group_id)
print("dist self._3_dev_info = " , self._3_dev_info)
class AttenArgs():
def __init__(self):
self._a_hidden_states_norm_weight = []
self._0_merge_q_kv_weights = [] # 融合Q,KV Weights
self._1_merge_q_kv_scale_inv = [] # 融合Q,KV Scales
self._2_q_a_layernorm_weight = []
self._3_W_Q = []
self._4_W_Q_scales = []
self._5_W_UK = []
self._6_W_UK_scales = []
self._7_W_QR = []
self._8_W_QR_scales = []
self._9_kv_a_layernorm_weight = []
self._10_sin_cache = None
self._11_cos_cache = None
self._12_slot_mapping = None
self._13_kv_cache = None
self._14_block_tables = None
self._15_env_blk_grp_size = env_blk_grp_size
self._16_W_UV = []
self._17_W_UV_scales = []
self._18_o_proj_weight =[]
self._19_o_proj_weight_scale_inv = []
# mla params
self._20_seq_lens = []
self._21_sm_scale = 0.0
self._22_head_num = 128
def logs(self):
print("mla _20_seq_lens block size is:", self._20_seq_lens)
print("mla _21_sm_scale block size is:", self._21_sm_scale)
print("mla _22_head_num block size is:", self._22_head_num)
class MlpArgs():
def __init__(self):
#mlp params
self._0_mlp_rms_weight = []
self._1_mlp_w13 = []
self._2_mlp_w2 = []
self._3_mlp_w13_scale = []
self._4_mlp_w2_scale = []
self._5_mlp_w13_block_size = []
self._6_mlp_w2_block_size = []
def logs(self):
print("mlp _5_mlp_w13_block_size block size is:", self._5_mlp_w13_block_size)
print("mlp _6_mlp_w2_block_size block size is:", self._6_mlp_w2_block_size)
class MoeArgs():
def __init__(self):
#moe params
self._0_moe_rms_weight = []
self._1_moe_share_mlp_w13 = []
self._2_moe_share_mlp_w2 = []
self._3_moe_share_mlp_w13_scale = []
self._4_moe_share_mlp_w2_scale = []
self._5_moe_w13 = []
self._6_moe_w2 = []
self._7_moe_w13_scale = []
self._8_moe_w2_scale = []
self._9_gate_weight = []
self._10_moe_bias = []
self._11_moe_mlp_w13_block_size = []
self._12_moe_mlp_w2_block_size = []
self._13_moe_w13_block_size = []
self._14_moe_w2_block_size = []
def logs(self):
print("moe _11_moe_mlp_w13_block_size block size is:", self._11_moe_mlp_w13_block_size)
print("moe _12_moe_mlp_w2_block_size block size is:", self._12_moe_mlp_w2_block_size)
print("moe _13_moe_w13_block_size block size is:", self._13_moe_w13_block_size)
print("moe _14_moe_w2_block_size block size is:", self._14_moe_w2_block_size)
class WeightMapper():
def __init__(self):
self.attn_args = AttenArgs()
self.mlp_args = MlpArgs() # 3 mla+mlp
self.moe_args = MoeArgs() # 58 mla+moe
self.dist_args = DistributedArgs()
# 1. weights载入
# 2. dequant blocks 预计算
# 3. 参数缓存&提取
class DeepseekWeightCapture():
def __init__(self, layer: torch.nn.ModuleList,
start: int,
end: int):
self.layer_mlp = WeightMapper()
self.layer_moe = WeightMapper()
self.sin_cache_all = None
self.cos_cache_all = None
self.mlp_nums = 3
self.moe_nums = end - self.mlp_nums
self.start_idx = start
self.end_idx = end
for i in range(start, end):
if i < self.mlp_nums:
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_mlp.attn_args)
self.capture_deepseek_mlp_weights(layer[i])
else:
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_moe.attn_args)
self.capture_deepseek_moe_weights(layer[i])
if OUTPUT_ARGS_LOGS:
self.layer_mlp.attn_args.logs()
self.layer_mlp.mlp_args.logs()
self.layer_moe.attn_args.logs()
self.layer_moe.moe_args.logs()
from vllm.distributed import get_tp_group
tp_group = get_tp_group()
self.layer_mlp.dist_args._0_world_size = tp_group.world_size
self.layer_mlp.dist_args._1_rank = tp_group.rank_in_group
self.layer_mlp.dist_args._2_group_id = tp_group.group_id
self.layer_mlp.dist_args._3_dev_info = tp_group.rank_device_infos
self.layer_moe.dist_args._0_world_size = tp_group.world_size
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
self.layer_moe.dist_args._2_group_id = tp_group.group_id
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
def capture_deepseek_mlp_weights(self, module: DeepseekV2DecoderLayer):
assert isinstance(module.mlp, DeepseekV2MLP)
mlp = module.mlp
rms_norm = module.post_attention_layernorm
w13_weight = mlp.gate_up_proj.weight
w2_weight = mlp.down_proj.weight
w13_wscale = mlp.gate_up_proj.weight_scale_inv
w13_iscale = mlp.gate_up_proj.input_scale
w2_wscale = mlp.down_proj.weight_scale_inv
w2_iscale = mlp.down_proj.input_scale
w13_block_size0, w13_block_size1 = mlp.gate_up_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = mlp.gate_up_proj.quant_method.scale_n, mlp.gate_up_proj.quant_method.scale_k
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
w13_block_size0 = w13_block_size0 // scale_n
w13_block_size1 = w13_block_size1 // scale_k
w2_block_size0, w2_block_size1 = mlp.down_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = mlp.down_proj.quant_method.scale_n, mlp.down_proj.quant_method.scale_k
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
w2_block_size0 = w2_block_size0 // scale_n
w2_block_size1 = w2_block_size1 // scale_k
self.layer_mlp.mlp_args._0_mlp_rms_weight.append(rms_norm.weight)
self.layer_mlp.mlp_args._1_mlp_w13.append(w13_weight)
self.layer_mlp.mlp_args._2_mlp_w2.append(w2_weight)
self.layer_mlp.mlp_args._3_mlp_w13_scale.append(w13_wscale)
self.layer_mlp.mlp_args._4_mlp_w2_scale.append(w2_wscale)
self.layer_mlp.mlp_args._5_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
self.layer_mlp.mlp_args._6_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
assert isinstance(module.mlp, DeepseekV2MoE)
share_expert_layer = module.mlp.shared_experts
experts_layer = module.mlp.experts
rms_norm = module.post_attention_layernorm
gate = module.mlp.gate
w13_weight = share_expert_layer.gate_up_proj.weight
w2_weight = share_expert_layer.down_proj.weight
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
# w13_iscale = share_expert_layer.gate_up_proj.input_scale
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
# w2_iscale = share_expert_layer.down_proj.input_scale
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
w13_block_size0 = w13_block_size0 // scale_n
w13_block_size1 = w13_block_size1 // scale_k
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
w2_block_size0 = w2_block_size0 // scale_n
w2_block_size1 = w2_block_size1 // scale_k
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
block_size0, block_size1 = (
hidden_dims // hidden_blocks,
inter_dims // inter_blocks,
)
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
weight_mapper: AttenArgs):
if(self.sin_cache_all is None):
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
fused_params = {}
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
for name, param in module.self_attn.q_a_proj.named_parameters():
fused_params['q_a_proj_' + name] = param
for name, param in module.self_attn.q_a_layernorm.named_parameters():
fused_params['q_a_layernorm_' + name] = param
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
fused_params['kv_a_proj_' + name] = param
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
fused_params['kv_a_layernorm_' + name] = param
for name, param in module.self_attn.o_proj.named_parameters():
fused_params['o_proj_' + name] = param
import os
self._15_env_blk_grp_size = env_blk_grp_size
# init sin,cos cache
mla_params = module.self_attn.mla_attn.impl.extract_weights()
fused_params = {**fused_params, **mla_params}
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
weight_mapper._3_W_Q.append(fused_params['W_Q'])
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
weight_mapper._5_W_UK.append(fused_params['W_UK'])
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
weight_mapper._7_W_QR.append(fused_params['W_QR'])
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
#weight_mapper._10_sin_cache.append(None)
#weight_mapper._11_cos_cache.append(None)
#weight_mapper._12_slot_mapping.append(None)
#weight_mapper._13_kv_cache.append(None)
#weight_mapper._14_block_tables.append(None)
# weight_mapper._15_env_blk_grp_size.append(None)
weight_mapper._16_W_UV.append(fused_params['W_UV'])
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
weight_mapper._20_seq_lens = None
weight_mapper._21_sm_scale = module.self_attn.scaling
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
# 可优化,在c++里面只用Tensor即可
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
positions = [i - 1 for i in seq_lens]
cos_cache = [self.cos_cache_all[i] for i in positions]
sin_cache = [self.sin_cache_all[i] for i in positions]
self.layer_mlp.attn_args._10_sin_cache = sin_cache
self.layer_mlp.attn_args._11_cos_cache = cos_cache
self.layer_moe.attn_args._10_sin_cache = sin_cache
self.layer_moe.attn_args._11_cos_cache = cos_cache
self.layer_mlp.attn_args._20_seq_lens = seq_lens
self.layer_moe.attn_args._20_seq_lens = seq_lens
self.layer_mlp.attn_args._13_kv_cache = kv_caches_dense_layer
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
self.layer_mlp.attn_args._12_slot_mapping = slot_mapping
self.layer_mlp.attn_args._14_block_tables = block_tables
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
self.layer_moe.attn_args._14_block_tables = block_tables
# for i in range(self.mlp_nums):
# if i < self.end_idx:
# self.layer_mlp.attn_args._12_slot_mapping[i] = slot_mapping
# self.layer_mlp.attn_args._14_block_tables[i] = block_tables
# for i in range(self.moe_nums):
# if i < self.end_idx:
# self.layer_moe.attn_args._12_slot_mapping[i] = slot_mapping
# self.layer_moe.attn_args._14_block_tables[i] = block_tables
def logs(self):
print("current layer mlp attn: \n")
self.layer_mlp.attn_args.logs()
self.layer_mlp.dist_args.logs()
class DeepseekMTPWegitCapture():
# 相比DeepSeek Weight Capture MTP只有1层DeepseekDecoderLayer, 且是MOE的layer
def __init__(self, layer: torch.nn.Module):
self.layer_moe = WeightMapper()
self.sin_cache_all = None
self.cos_cache_all = None
self.capture_deepseek_mla_attn_weights(layer, self.layer_moe.attn_args)
self.capture_deepseek_moe_weights(layer)
if OUTPUT_ARGS_LOGS:
self.layer_moe.attn_args.logs()
self.layer_moe.moe_args.logs()
from vllm.distributed import get_tp_group
tp_group = get_tp_group()
self.layer_moe.dist_args._0_world_size = tp_group.world_size
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
self.layer_moe.dist_args._2_group_id = tp_group.group_id
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
assert isinstance(module.mlp, DeepseekV2MoE)
share_expert_layer = module.mlp.shared_experts
experts_layer = module.mlp.experts
rms_norm = module.post_attention_layernorm
gate = module.mlp.gate
w13_weight = share_expert_layer.gate_up_proj.weight
w2_weight = share_expert_layer.down_proj.weight
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
w13_block_size0 = w13_block_size0 // scale_n
w13_block_size1 = w13_block_size1 // scale_k
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
w2_block_size0 = w2_block_size0 // scale_n
w2_block_size1 = w2_block_size1 // scale_k
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
block_size0, block_size1 = (
hidden_dims // hidden_blocks,
inter_dims // inter_blocks,
)
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
weight_mapper: AttenArgs):
if(self.sin_cache_all is None):
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
fused_params = {}
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
for name, param in module.self_attn.q_a_proj.named_parameters():
fused_params['q_a_proj_' + name] = param
for name, param in module.self_attn.q_a_layernorm.named_parameters():
fused_params['q_a_layernorm_' + name] = param
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
fused_params['kv_a_proj_' + name] = param
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
fused_params['kv_a_layernorm_' + name] = param
for name, param in module.self_attn.o_proj.named_parameters():
fused_params['o_proj_' + name] = param
import os
self._15_env_blk_grp_size = env_blk_grp_size
# init sin,cos cache
mla_params = module.self_attn.mla_attn.impl.extract_weights()
fused_params = {**fused_params, **mla_params}
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
weight_mapper._3_W_Q.append(fused_params['W_Q'])
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
weight_mapper._5_W_UK.append(fused_params['W_UK'])
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
weight_mapper._7_W_QR.append(fused_params['W_QR'])
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
#weight_mapper._10_sin_cache.append(None)
#weight_mapper._11_cos_cache.append(None)
#weight_mapper._12_slot_mapping.append(None)
#weight_mapper._13_kv_cache.append(None)
#weight_mapper._14_block_tables.append(None)
# weight_mapper._15_env_blk_grp_size.append(None)
weight_mapper._16_W_UV.append(fused_params['W_UV'])
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
weight_mapper._20_seq_lens = None
weight_mapper._21_sm_scale = module.self_attn.scaling
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
# 可优化,在c++里面只用Tensor即可
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
positions = [i - 1 for i in seq_lens]
cos_cache = [self.cos_cache_all[i] for i in positions]
sin_cache = [self.sin_cache_all[i] for i in positions]
self.layer_moe.attn_args._10_sin_cache = sin_cache
self.layer_moe.attn_args._11_cos_cache = cos_cache
self.layer_moe.attn_args._20_seq_lens = seq_lens
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
self.layer_moe.attn_args._14_block_tables = block_tables
def logs(self):
print("current layer mlp attn: \n")
self.layer_mlp.attn_args.logs()
self.layer_mlp.dist_args.logs()

View File

@@ -0,0 +1,154 @@
import torch
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer,
Qwen3MoeMLP)
class Qwen3Moe_DistributedArgs():
def __init__(self):
self._0_world_size = 32
self._1_rank = -1
self._2_group_id = 0
self._3_dev_info = []
def __repr__(self):
dist_infos = f"[dist] world_size = {self._0_world_size} \n" \
+ f"[dist] rank = {self._1_rank} \n" \
+ f"[dist] group_id = {self._2_group_id} \n" \
+ f"[dist] dev_info = {self._3_dev_info}"
return dist_infos
class Qwen3Moe_AttenArgs():
def __init__(self):
self._0_input_layernorm_weight = []
self._1_qkv_proj_weight = [] #
self._2_qkv_proj_weight_scale = []
self._3_qkv_proj_bias = []
self._4_qkv_proj_qzeros = []
self._5_q_norm_weight = []
self._6_k_norm_weight = []
self._7_sin_cache = None
self._8_cos_cache = None
self._9_slot_mapping = None
self._10_kv_cache = None
self._11_block_tables = None
self._12_block_group_size = None
self._13_o_proj_weight = []
self._14_o_proj_weight_scale = []
self._15_o_proj_bias = []
self._16_o_proj_qzeros = []
self._17_seq_lens = None
self._18_sm_scale =None
self._19_num_attention_heads = None
self._20_num_key_value_heads = None
def __repr__(self):
attn_infos = "[qwen attn] 21 args" \
+ f"[qwen attn] weight counts: {len(self._0_input_layernorm_weight)}"
return attn_infos
class Qwen3Moe_MoeArgs():
def __init__(self):
#moe params
self._0_rms_norm_weight = []
self._1_w13_weight = []
self._2_w2_weight = []
self._3_w13_weight_scale_inv = []
self._4_w2_weight_scale_inv = []
self._5_gate_weight = []
self._6_w13_block_size = None
self._7_w2_block_size = None
def __repr__(self):
moe_infos = f"[moe] w13_block_size: {self._6_w13_block_size}" \
+ f"[moe] w2_block_size: {self._7_w2_block_size}" \
+ f"[moe] weight counts: {len(self._1_w13_weight)}"
return moe_infos
class Qwen3Moe_WeightMapper:
def __init__(self):
self.attn_args = Qwen3Moe_AttenArgs()
self.moe_args = Qwen3Moe_MoeArgs()
self.dist_args = Qwen3Moe_DistributedArgs()
class Qwen3Moe_WeightCapture():
def __init__(self, layers: torch.nn.ModuleList,
start: int,
end: int):
self.layer_mapper = Qwen3Moe_WeightMapper()
# qwen3 only support fp8 now
self.support_fused_weights = False
for i in range(start, end):
layer = layers[i]
self.capture_attn_weights(layer)
self.capture_moe_weights(layer)
# 注册 多卡环境信息
from vllm.distributed import get_tp_group
tp_group = get_tp_group()
self.layer_mapper.dist_args._0_world_size = tp_group.world_size
self.layer_mapper.dist_args._1_rank = tp_group.rank_in_group
self.layer_mapper.dist_args._2_group_id = tp_group.group_id
self.layer_mapper.dist_args._3_dev_info = tp_group.rank_device_infos
def capture_attn_weights(self, layer):
from vllm_vacc.vllm.model_executor.models.qwen3_moe import set_fused_params
# 注册融合算子
fused_params = {}
fused_params['input_layernorm_weight'] = layer.input_layernorm.weight
fused_params['q_norm_weight'] = layer.self_attn.q_norm.weight
fused_params['k_norm_weight'] = layer.self_attn.k_norm.weight
set_fused_params(fused_params, layer.self_attn.qkv_proj.quant_method, layer.self_attn.qkv_proj, 'qkv_proj')
set_fused_params(fused_params, layer.self_attn.o_proj.quant_method, layer.self_attn.o_proj, 'o_proj')
self.support_fused_weights = hasattr(layer.mlp.experts.quant_method, 'quant_config') and hasattr(layer.mlp.experts.quant_method.quant_config, 'weight_block_size')
if not hasattr(layer.self_attn, "fused_params"):
layer.self_attn.fused_params = fused_params
self.layer_mapper.attn_args._0_input_layernorm_weight.append(fused_params['input_layernorm_weight'])
self.layer_mapper.attn_args._1_qkv_proj_weight.append(fused_params['qkv_proj_weight'])
self.layer_mapper.attn_args._2_qkv_proj_weight_scale.append(fused_params['qkv_proj_weight_scale'])
self.layer_mapper.attn_args._3_qkv_proj_bias.append(fused_params['qkv_proj_bias'])
self.layer_mapper.attn_args._4_qkv_proj_qzeros.append(fused_params['qkv_proj_qzeros'])
self.layer_mapper.attn_args._5_q_norm_weight.append(fused_params['q_norm_weight'])
self.layer_mapper.attn_args._6_k_norm_weight.append(fused_params['k_norm_weight'])
# self.layer_mapper.attn_args._7_sin_cache
# self.layer_mapper.attn_args._8_cos_cache
# self.layer_mapper.attn_args._9_slot_mapping
# self.layer_mapper.attn_args._10_kv_cache
# self.layer_mapper.attn_args._11_block_tables
# self.layer_mapper.attn_args._12_block_group_size
self.layer_mapper.attn_args._13_o_proj_weight.append(fused_params['o_proj_weight'])
self.layer_mapper.attn_args._14_o_proj_weight_scale.append(fused_params['o_proj_weight_scale'])
self.layer_mapper.attn_args._15_o_proj_bias.append(fused_params['o_proj_bias'])
self.layer_mapper.attn_args._16_o_proj_qzeros.append(fused_params['o_proj_qzeros'])
# self.layer_mapper.attn_args._17_seq_lens
self.layer_mapper.attn_args._18_sm_scale = layer.self_attn.scaling
self.layer_mapper.attn_args._19_num_attention_heads = layer.self_attn.total_num_heads
self.layer_mapper.attn_args._20_num_key_value_heads = layer.self_attn.total_num_kv_heads
def capture_moe_weights(self, layer: Qwen3MoeDecoderLayer):
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
quant_method = layer.mlp.experts.quant_method if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock) \
else layer.mlp.down_proj.quant_method
if not isinstance(quant_method, MoeWNA16Method):
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import recompute_moe_layer_blocksize
recompute_moe_layer_blocksize(layer.mlp.experts)
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_weight)
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_weight)
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_weight_scale_inv)
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_weight_scale_inv)
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size
else:
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_qweight)
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_qweight)
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_scales)
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_scales)
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size