init
This commit is contained in:
0
vllm_vacc/vllm/model_executor/models/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/models/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
133
vllm_vacc/vllm/model_executor/models/bert.py
Normal file
133
vllm_vacc/vllm/model_executor/models/bert.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from .vars import *
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
if USE_FUSED_BERT_ATTENTION:
|
||||
tp_group = get_tp_group()
|
||||
world_size = tp_group.world_size
|
||||
rank = tp_group.rank_in_group
|
||||
total_bytes = hidden_states.numel() * hidden_states.element_size() * world_size
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
else:
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
# (matmul + bias_add) with TP + all_reduce结构,为了避免重复加bias,只对rank 0下发bias做bias add,
|
||||
# bert layer里BertSelfOutput和BertOutput模块存在这种结构,对应的bias参数是下面的self_bias和output_bias
|
||||
if total_bytes < 4194304 or world_size == 1:
|
||||
# 1. TP场景all_reduce输入小于4MB时会在以下融合算子里调用dsp all_reduce,大于等于4MB时由于限制需要在外面调用vccl all_reduce
|
||||
# 2. 没有TP的场景也会调用下面的融合算子
|
||||
output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
|
||||
qkv_weight=self.attention.self.qkv_proj.weight,
|
||||
qkv_bias=self.attention.self.qkv_proj.bias,
|
||||
self_weight=self.attention.output.dense.weight,
|
||||
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
self_norm_weight=self.attention.output.LayerNorm.weight,
|
||||
self_norm_bias=self.attention.output.LayerNorm.bias,
|
||||
intermediate_weight=self.intermediate.dense.weight,
|
||||
intermediate_bias=self.intermediate.dense.bias,
|
||||
output_weight=self.output.dense.weight,
|
||||
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
output_norm_weight=self.output.LayerNorm.weight,
|
||||
output_norm_bias=self.output.LayerNorm.bias,
|
||||
dense_out=torch.Tensor(),
|
||||
seqs=attn_metadata.seq_lens,
|
||||
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.FullStage,
|
||||
sm_scale=self.attention.self.scaling,
|
||||
num_q_heads=self.attention.self.num_heads * world_size,
|
||||
num_kv_heads=self.attention.self.num_kv_heads * world_size,
|
||||
flash_attention=False,
|
||||
reduce_result=True if world_size > 1 else False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos)
|
||||
else:
|
||||
attn_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
|
||||
qkv_weight=self.attention.self.qkv_proj.weight,
|
||||
qkv_bias=self.attention.self.qkv_proj.bias,
|
||||
self_weight=self.attention.output.dense.weight,
|
||||
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
self_norm_weight=torch.Tensor(),
|
||||
self_norm_bias=torch.Tensor(),
|
||||
intermediate_weight=torch.Tensor(),
|
||||
intermediate_bias=torch.Tensor(),
|
||||
output_weight=torch.Tensor(),
|
||||
output_bias=torch.Tensor(),
|
||||
output_norm_weight=torch.Tensor(),
|
||||
output_norm_bias=torch.Tensor(),
|
||||
dense_out=torch.Tensor(),
|
||||
seqs=attn_metadata.seq_lens,
|
||||
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.AttnOutStage,
|
||||
sm_scale=self.attention.self.scaling,
|
||||
num_q_heads=self.attention.self.num_heads * world_size,
|
||||
num_kv_heads=self.attention.self.num_kv_heads * world_size,
|
||||
flash_attention=False,
|
||||
reduce_result=False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos)
|
||||
if world_size > 1:
|
||||
attn_out_stage_output = tensor_model_parallel_all_reduce(attn_out_stage_output)
|
||||
if USE_FUSED_MLP_VISION:
|
||||
attn_output = self.attention.output.LayerNorm(attn_out_stage_output + hidden_states)
|
||||
inter_out_stage_output = torch.vacc.fuse_mlp_vision(src=attn_output,
|
||||
weights_13=self.intermediate.dense.weight,
|
||||
weights_2=self.output.dense.weight,
|
||||
weights_13_bias=self.intermediate.dense.bias,
|
||||
weights_2_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
act_type=0 # gelu
|
||||
)
|
||||
else:
|
||||
inter_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
|
||||
qkv_weight=torch.Tensor(),
|
||||
qkv_bias=torch.Tensor(),
|
||||
self_weight=torch.Tensor(),
|
||||
self_bias=torch.Tensor(),
|
||||
self_norm_weight=self.attention.output.LayerNorm.weight,
|
||||
self_norm_bias=self.attention.output.LayerNorm.bias,
|
||||
intermediate_weight=self.intermediate.dense.weight,
|
||||
intermediate_bias=self.intermediate.dense.bias,
|
||||
output_weight=self.output.dense.weight,
|
||||
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
|
||||
output_norm_weight=torch.Tensor(),
|
||||
output_norm_bias=torch.Tensor(),
|
||||
dense_out=attn_out_stage_output,
|
||||
seqs=attn_metadata.seq_lens,
|
||||
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.InterOutStage,
|
||||
sm_scale=self.attention.self.scaling,
|
||||
num_q_heads=self.attention.self.num_heads * world_size,
|
||||
num_kv_heads=self.attention.self.num_kv_heads * world_size,
|
||||
flash_attention=False,
|
||||
reduce_result=False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos)
|
||||
if world_size > 1:
|
||||
inter_out_stage_output = tensor_model_parallel_all_reduce(inter_out_stage_output)
|
||||
if USE_FUSED_MLP_VISION:
|
||||
output = self.output.LayerNorm(inter_out_stage_output + attn_output)
|
||||
else:
|
||||
output = self.output.LayerNorm(inter_out_stage_output + attn_out_stage_output)
|
||||
else:
|
||||
attn_output = self.attention(hidden_states)
|
||||
intermediate_output = self.intermediate(attn_output)
|
||||
output = self.output(intermediate_output, attn_output)
|
||||
return output
|
||||
292
vllm_vacc/vllm/model_executor/models/deepseek_mtp.py
Normal file
292
vllm_vacc/vllm/model_executor/models/deepseek_mtp.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Iterable, Set, Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from .vars import *
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictorLayer as DeepSeekMultiTokenPredictorLayerOrig
|
||||
|
||||
from vllm.distributed import get_tp_group
|
||||
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
||||
|
||||
def DeepSeekMultiTokenPredictorLayer__init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
super(DeepSeekMultiTokenPredictorLayerOrig, self).__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
if USE_PARALLEL_MTP_EH_PROJ:
|
||||
self.eh_proj = RowParallelLinear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
return_bias=False)
|
||||
else:
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
from vllm.model_executor.models.deepseek_mtp import SharedHead
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
if self.is_v32:
|
||||
topk_tokens = config.index_topk
|
||||
topk_indices_buffer = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
topk_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
else:
|
||||
topk_indices_buffer = None
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer)
|
||||
|
||||
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||||
|
||||
if not hasattr(self, "weight_capture"):
|
||||
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekMTPWegitCapture
|
||||
self.weight_capture = DeepseekMTPWegitCapture(self.mtp_block)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
assert inputs_embeds is not None
|
||||
|
||||
if inputs_embeds.shape[0] > 256:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
|
||||
deepseek_mtp_layer_input_buffer = None
|
||||
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
|
||||
deepseek_mtp_layer_input_buffer = memory_recycler.DEEPSEEK_MTP_LAYER_INPUT
|
||||
|
||||
from torch_vacc.vacc.custom_ops import fuse_mtp_stage0
|
||||
hidden_states = fuse_mtp_stage0(
|
||||
inputs_embeds,
|
||||
previous_hidden_states,
|
||||
positions,
|
||||
self.enorm.weight,
|
||||
self.hnorm.weight,
|
||||
self.enorm.variance_epsilon,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos,
|
||||
output=deepseek_mtp_layer_input_buffer,
|
||||
)
|
||||
|
||||
# if USE_PARALLEL_MTP_EH_PROJ:
|
||||
# tp_size = get_tensor_model_parallel_world_size()
|
||||
# rank_id = get_tensor_model_parallel_rank()
|
||||
# last_dim = hidden_states.shape[-1]
|
||||
# if tp_size > 1:
|
||||
# hiddens_tp = last_dim//tp_size
|
||||
# hidden_states = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp]
|
||||
|
||||
hidden_states = self.eh_proj(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.vacc.fuse_mtp_allreduce(
|
||||
inputs_embeds,
|
||||
previous_hidden_states,
|
||||
positions,
|
||||
self.enorm.weight,
|
||||
self.hnorm.weight,
|
||||
self.eh_proj.weight,
|
||||
self.enorm.variance_epsilon,
|
||||
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = self.weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
|
||||
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
else:
|
||||
from torch_vacc.vacc.custom_ops import fuse_mla_moe_v2_allreduce_decode
|
||||
layer = self.mtp_block
|
||||
layer_id = 0
|
||||
|
||||
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
|
||||
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
|
||||
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
|
||||
|
||||
# 对于MTP Layer来说, residual为None,且需要返回residual
|
||||
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
|
||||
hidden_states = hidden_states,
|
||||
residual = None,
|
||||
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[layer_id],
|
||||
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[layer_id],
|
||||
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[layer_id],
|
||||
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[layer_id],
|
||||
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[layer_id],
|
||||
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[layer_id],
|
||||
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[layer_id],
|
||||
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[layer_id],
|
||||
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[layer_id],
|
||||
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[layer_id],
|
||||
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[layer_id],
|
||||
sin_cache = sin_cache,
|
||||
cos_cache = cos_cache,
|
||||
slot_mapping = attn_metadata.slot_mapping,
|
||||
kv_cache = kv_cache,
|
||||
block_tables = attn_metadata.decode_metadata.block_tables,
|
||||
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
|
||||
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[layer_id],
|
||||
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[layer_id],
|
||||
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[layer_id],
|
||||
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[layer_id],
|
||||
# mla params
|
||||
seq_lens = attn_metadata.decode_metadata.seq_lens,
|
||||
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
|
||||
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# moe weight
|
||||
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[layer_id],
|
||||
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[layer_id],
|
||||
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[layer_id],
|
||||
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[layer_id],
|
||||
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[layer_id],
|
||||
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[layer_id],
|
||||
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[layer_id],
|
||||
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[layer_id],
|
||||
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[layer_id],
|
||||
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[layer_id],
|
||||
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[layer_id],
|
||||
# moe params
|
||||
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
|
||||
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
|
||||
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
|
||||
# vccl info
|
||||
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = self.weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
#hidden_states = residual + hidden_states
|
||||
hidden_states = residual.add_(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class DeepSeekMTP(nn.Module):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
from vllm.model_executor.models.utils import PPMissingLayer
|
||||
for layer_id in self.model.layers:
|
||||
layer = self.model.layers[layer_id]
|
||||
if isinstance(layer_id, PPMissingLayer):
|
||||
continue
|
||||
layer.mtp_block.self_attn.merge_qkv_weights()
|
||||
return loaded_params
|
||||
|
||||
class SharedHead(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import rms_norm
|
||||
return rms_norm(hidden_states, self.norm.weight, output=hidden_states)
|
||||
except Exception as e:
|
||||
print(f"fuse rms_norm run fail, now use unfused ops: {e}")
|
||||
|
||||
return self.norm(hidden_states)
|
||||
658
vllm_vacc/vllm/model_executor/models/deepseek_v2.py
Normal file
658
vllm_vacc/vllm/model_executor/models/deepseek_v2.py
Normal file
@@ -0,0 +1,658 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tp_group,
|
||||
get_tensor_model_parallel_world_size,get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsPP
|
||||
from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
from vllm.model_executor.models.deepseek_v2 import yarn_get_mscale, DeepseekV2MLAAttention, Indexer
|
||||
|
||||
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
from .vars import *
|
||||
from ..ops.deepseek_fused_mlp_moe import (vacc_fused_decode_moe_fp8,
|
||||
vacc_fused_prefill_moe_fp8,
|
||||
vacc_fused_mlp_fp8)
|
||||
|
||||
from .fused_forward import *
|
||||
import os
|
||||
test_layer_en = os.getenv("test_layer_en", "0")
|
||||
|
||||
# class DeepseekV2MLAAttention(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# vllm_config: VllmConfig,
|
||||
# config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
# hidden_size: int,
|
||||
# num_heads: int,
|
||||
# qk_nope_head_dim: int,
|
||||
# qk_rope_head_dim: int,
|
||||
# v_head_dim: int,
|
||||
# q_lora_rank: Optional[int],
|
||||
# kv_lora_rank: int,
|
||||
# rope_theta: float = 10000,
|
||||
# rope_scaling: Optional[dict[str, Any]] = None,
|
||||
# max_position_embeddings: int = 8192,
|
||||
# cache_config: Optional[CacheConfig] = None,
|
||||
# quant_config: Optional[QuantizationConfig] = None,
|
||||
# prefix: str = "",
|
||||
# topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
# ) -> None:
|
||||
# super(DeepseekV2MLAAttention,self).__init__()
|
||||
# self.hidden_size = hidden_size
|
||||
# self.qk_nope_head_dim = qk_nope_head_dim
|
||||
# self.qk_rope_head_dim = qk_rope_head_dim
|
||||
# self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
# self.v_head_dim = v_head_dim
|
||||
|
||||
# self.q_lora_rank = q_lora_rank
|
||||
# self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
# self.num_heads = num_heads
|
||||
# tp_size = get_tensor_model_parallel_world_size()
|
||||
# assert num_heads % tp_size == 0
|
||||
# self.num_local_heads = num_heads // tp_size
|
||||
|
||||
# self.scaling = self.qk_head_dim**-0.5
|
||||
# self.rope_theta = rope_theta
|
||||
# self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# if self.q_lora_rank is not None:
|
||||
# if USE_PARALLEL_Q_KV_GEN:
|
||||
# self.q_a_proj = RowParallelLinear(self.hidden_size,
|
||||
# self.q_lora_rank,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.q_a_proj")
|
||||
# else:
|
||||
# self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
# self.q_lora_rank,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.q_a_proj")
|
||||
# self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
# eps=config.rms_norm_eps)
|
||||
# self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||
# self.num_heads *
|
||||
# self.qk_head_dim,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.q_b_proj")
|
||||
# else:
|
||||
# self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
# self.num_heads *
|
||||
# self.qk_head_dim,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.q_proj")
|
||||
# if USE_PARALLEL_Q_KV_GEN:
|
||||
# self.kv_a_proj_with_mqa = RowParallelLinear(
|
||||
# self.hidden_size,
|
||||
# self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
# else:
|
||||
# self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
# self.hidden_size,
|
||||
# self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
# self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
# eps=config.rms_norm_eps)
|
||||
# self.kv_b_proj = ColumnParallelLinear(
|
||||
# self.kv_lora_rank,
|
||||
# self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.kv_b_proj")
|
||||
# self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
# self.hidden_size,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.o_proj")
|
||||
|
||||
# rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
# self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
# rotary_dim=qk_rope_head_dim,
|
||||
# max_position=max_position_embeddings,
|
||||
# base=rope_theta,
|
||||
# rope_scaling=rope_scaling,
|
||||
# is_neox_style=False)
|
||||
# if rope_scaling:
|
||||
# mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
# scaling_factor = rope_scaling["factor"]
|
||||
# mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
# self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# self.is_v32 = hasattr(config, "index_topk")
|
||||
|
||||
# if self.is_v32:
|
||||
# self.indexer = Indexer(vllm_config, config, hidden_size,
|
||||
# q_lora_rank, quant_config, cache_config,
|
||||
# topk_indices_buffer, f"{prefix}.indexer")
|
||||
# else:
|
||||
# self.indexer = None
|
||||
|
||||
# self.mla_attn = Attention(
|
||||
# num_heads=self.num_local_heads,
|
||||
# head_size=self.kv_lora_rank,
|
||||
# scale=self.scaling,
|
||||
# num_kv_heads=1,
|
||||
# cache_config=cache_config,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.attn",
|
||||
# use_mla=True,
|
||||
# # MLA Args
|
||||
# q_lora_rank=self.q_lora_rank,
|
||||
# kv_lora_rank=self.kv_lora_rank,
|
||||
# qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
# qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
# qk_head_dim=self.qk_head_dim,
|
||||
# v_head_dim=self.v_head_dim,
|
||||
# rotary_emb=self.rotary_emb,
|
||||
# q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
# kv_b_proj=self.kv_b_proj,
|
||||
# o_proj=self.o_proj,
|
||||
# )
|
||||
|
||||
# self.prefix = prefix
|
||||
# self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
# def forward(
|
||||
# self,
|
||||
# positions: torch.Tensor,
|
||||
# hidden_states: torch.Tensor,
|
||||
# kv_cache: torch.Tensor,
|
||||
# attn_metadata: AttentionMetadata,
|
||||
# ) -> torch.Tensor:
|
||||
|
||||
# tp_size = get_tensor_model_parallel_world_size()
|
||||
# rank_id = get_tensor_model_parallel_rank()
|
||||
# last_dim = hidden_states.shape[-1]
|
||||
|
||||
# if USE_PARALLEL_Q_KV_GEN: #tp qa and kva
|
||||
# hidden_states_split = hidden_states
|
||||
# if tp_size > 1:
|
||||
# hiddens_tp = last_dim//tp_size
|
||||
# hidden_states_split = hidden_states[...,rank_id*hiddens_tp : (rank_id+1)*hiddens_tp].contiguous()
|
||||
# if self.q_lora_rank is not None:
|
||||
# ckq = self.q_a_proj(hidden_states_split)[0]
|
||||
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
# else:
|
||||
# hidden_states_or_q_c = hidden_states
|
||||
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states_split)[0].split(
|
||||
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
||||
# attn_metadata)
|
||||
|
||||
# if self.q_lora_rank is not None:
|
||||
# ckq = self.q_a_proj(hidden_states)[0]
|
||||
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
# else:
|
||||
# hidden_states_or_q_c = hidden_states
|
||||
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
||||
# attn_metadata)
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual = None, rms_norm = None):
|
||||
# moe layer support prefill&decode vacc ops
|
||||
if residual is not None:
|
||||
try:
|
||||
reduce_result = self.tp_size > 1
|
||||
# decode moe, first seq
|
||||
if self.is_decode:
|
||||
hidden_states, residual = vacc_fused_decode_moe_fp8(self, self.shared_experts,
|
||||
hidden_states, residual,
|
||||
rms_norm, self.gate, self.experts,
|
||||
self.routed_scaling_factor,
|
||||
reduce_result)
|
||||
return hidden_states, residual
|
||||
# prefill moe, first expert
|
||||
else:
|
||||
hidden_states, residual = vacc_fused_prefill_moe_fp8(self, self.shared_experts,
|
||||
hidden_states, residual,
|
||||
rms_norm, self.gate, self.experts,
|
||||
self.routed_scaling_factor,
|
||||
reduce_result)
|
||||
return hidden_states, residual
|
||||
except Exception as e:
|
||||
logger.warning("vacc fused moe run fail, now use unfused ops %s", e)
|
||||
hidden_states, residual = rms_norm(hidden_states, residual)
|
||||
|
||||
self.experts.is_decode = self.is_decode
|
||||
|
||||
# 1. fuse_prefill_pre_moe
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.n_shared_experts is not None:
|
||||
try:
|
||||
shared_output = vacc_fused_mlp_fp8(self.shared_experts, hidden_states, moe_share=True)
|
||||
except Exception as e:
|
||||
logger.warning("fused mlp is Error, now use Default:%s", e)
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
# 2. fused_moe
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
# 3. add_reduce
|
||||
# now fuse share_mlp add to experts
|
||||
# if shared_output is not None:
|
||||
# # out = input + other * alpha
|
||||
# final_hidden_states = shared_output.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
||||
# else:
|
||||
# final_hidden_states = final_hidden_states * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
if residual is not None:
|
||||
return final_hidden_states.view(num_tokens, hidden_dim), residual
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
|
||||
def forward(self, x, residual = None, rms_norm = None):
|
||||
# use all fused ops
|
||||
if residual is not None:
|
||||
reduce_result = self.down_proj.reduce_results and self.down_proj.tp_size > 1
|
||||
hidden_states, residual = vacc_fused_mlp_fp8(self,
|
||||
x, residual,
|
||||
rms_norm,
|
||||
reduce_result)
|
||||
return hidden_states, residual
|
||||
# use default fuse ops
|
||||
try:
|
||||
output_parallel = vacc_fused_mlp_fp8(self, x, residual, rms_norm)
|
||||
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
|
||||
x = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
x = output_parallel
|
||||
except Exception as e:
|
||||
logger.warning("fuse_mlp run fail, now use default: %s", e)
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||||
first_k_dense_replace = self.config.first_k_dense_replace if hasattr(self.config, "first_k_dense_replace") else 3
|
||||
|
||||
if not hasattr(self, "weight_capture"):
|
||||
from vllm_vacc.vllm.model_executor.models.weight_capture.deepseek_weight_capture import DeepseekWeightCapture
|
||||
self.weight_capture = DeepseekWeightCapture(self.layers, self.start_layer, self.end_layer)
|
||||
self.cached_weights_state = True
|
||||
self.cached_batch = 1
|
||||
self.layer_nums = self.end_layer - self.start_layer
|
||||
self.is_pipeline_first = get_pp_group().is_first_rank
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
if(attn_metadata.prefill_metadata is not None or not USE_DECODER_LAYER_FUSE_MODE):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
else:
|
||||
# update global seq lens, use for serve infos
|
||||
# update_seqence_length(attn_metadata.decode_metadata.seq_lens)
|
||||
|
||||
if FUSE_ALL_DECODER_LAYERS:
|
||||
self.weight_capture.update_attn_args(attn_metadata.decode_metadata.seq_lens,
|
||||
attn_metadata.slot_mapping,
|
||||
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(self.start_layer, first_k_dense_replace)],
|
||||
[self.layers[i].self_attn.mla_attn.kv_cache[forward_context.virtual_engine] for i in range(first_k_dense_replace, self.end_layer)],
|
||||
attn_metadata.decode_metadata.block_tables)
|
||||
|
||||
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 0)
|
||||
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 1)
|
||||
hidden_states, residual = forward_mla_mlp_single_layer(hidden_states, residual, self.weight_capture, 2)
|
||||
|
||||
if hidden_states.shape[0] != self.cached_batch:
|
||||
# batch切换,重新执行缓存
|
||||
self.cached_weights_state = True
|
||||
self.cached_batch = hidden_states.shape[0]
|
||||
|
||||
if self.cached_weights_state:
|
||||
self.cached_weights_state = False
|
||||
hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, self.weight_capture)
|
||||
else:
|
||||
hidden_states, residual = forward_mla_moe_layers_without_weights(hidden_states, residual, self.weight_capture)
|
||||
else:
|
||||
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode,fuse_mla_moe_v2_allreduce_decode
|
||||
for i in range(0, self.layer_nums):
|
||||
layer_id = i + self.start_layer
|
||||
layer = self.layers[layer_id]
|
||||
kv_cache = layer.self_attn.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||
positions = [p - 1 for p in attn_metadata.decode_metadata.seq_lens]
|
||||
cos_cache = [layer.self_attn.mla_attn.impl.rotary_emb.cos_cache[p] for p in positions]
|
||||
sin_cache = [layer.self_attn.mla_attn.impl.rotary_emb.sin_cache[p] for p in positions]
|
||||
|
||||
if layer_id < first_k_dense_replace:
|
||||
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = self.weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[i],
|
||||
q_a_proj_weight = self.weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[i],
|
||||
q_a_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[i],
|
||||
q_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[i],
|
||||
w_q = self.weight_capture.layer_mlp.attn_args._3_W_Q[i],
|
||||
w_q_scale = self.weight_capture.layer_mlp.attn_args._4_W_Q_scales[i],
|
||||
w_uk = self.weight_capture.layer_mlp.attn_args._5_W_UK[i],
|
||||
w_uk_scale = self.weight_capture.layer_mlp.attn_args._6_W_UK_scales[i],
|
||||
w_qr = self.weight_capture.layer_mlp.attn_args._7_W_QR[i],
|
||||
w_qr_scale = self.weight_capture.layer_mlp.attn_args._8_W_QR_scales[i],
|
||||
kv_a_layernorm_weight = self.weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[i],
|
||||
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
|
||||
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
|
||||
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
|
||||
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
|
||||
block_tables = attn_metadata.decode_metadata.block_tables,#self.weight_capture.layer_mlp.attn_args._14_block_tables[i],
|
||||
block_group_size = self.weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
|
||||
w_uv = self.weight_capture.layer_mlp.attn_args._16_W_UV[i],
|
||||
w_uv_scale = self.weight_capture.layer_mlp.attn_args._17_W_UV_scales[i],
|
||||
o_proj_weight = self.weight_capture.layer_mlp.attn_args._18_o_proj_weight[i],
|
||||
o_proj_weight_scale_inv = self.weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[i],
|
||||
# mla params
|
||||
seq_lens = attn_metadata.decode_metadata.seq_lens,
|
||||
sm_scale = self.weight_capture.layer_mlp.attn_args._21_sm_scale,
|
||||
head_num = self.weight_capture.layer_mlp.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# mlp weight
|
||||
rms_weight = self.weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[i],
|
||||
mlp_weight_13 = self.weight_capture.layer_mlp.mlp_args._1_mlp_w13[i],
|
||||
mlp_weight_2 = self.weight_capture.layer_mlp.mlp_args._2_mlp_w2[i],
|
||||
mlp_weight_scale_13 = self.weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[i],
|
||||
mlp_weight_scale_2 = self.weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[i],
|
||||
# mlp params
|
||||
mlp_block_size_w13 = self.weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = self.weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
|
||||
# vccl info
|
||||
world_size = self.weight_capture.layer_mlp.dist_args._0_world_size,
|
||||
rank = self.weight_capture.layer_mlp.dist_args._1_rank,
|
||||
group_id = self.weight_capture.layer_mlp.dist_args._2_group_id,
|
||||
dev_info = self.weight_capture.layer_mlp.dist_args._3_dev_info)
|
||||
else:
|
||||
wid = i - first_k_dense_replace if self.is_pipeline_first else i
|
||||
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = self.weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight[wid],
|
||||
q_a_proj_weight = self.weight_capture.layer_moe.attn_args._0_merge_q_kv_weights[wid],
|
||||
q_a_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv[wid],
|
||||
q_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight[wid],
|
||||
w_q = self.weight_capture.layer_moe.attn_args._3_W_Q[wid],
|
||||
w_q_scale = self.weight_capture.layer_moe.attn_args._4_W_Q_scales[wid],
|
||||
w_uk = self.weight_capture.layer_moe.attn_args._5_W_UK[wid],
|
||||
w_uk_scale = self.weight_capture.layer_moe.attn_args._6_W_UK_scales[wid],
|
||||
w_qr = self.weight_capture.layer_moe.attn_args._7_W_QR[wid],
|
||||
w_qr_scale = self.weight_capture.layer_moe.attn_args._8_W_QR_scales[wid],
|
||||
kv_a_layernorm_weight = self.weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight[wid],
|
||||
sin_cache = sin_cache,# self.weight_capture.layer_mlp.attn_args._10_sin_cache,
|
||||
cos_cache = cos_cache,# self.weight_capture.layer_mlp.attn_args._11_cos_cache,
|
||||
slot_mapping = attn_metadata.slot_mapping,#self.weight_capture.layer_mlp.attn_args._12_slot_mapping[i],
|
||||
kv_cache = kv_cache,#self.weight_capture.layer_mlp.attn_args._13_kv_cache[i],
|
||||
block_tables = attn_metadata.decode_metadata.block_tables,
|
||||
block_group_size = self.weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
|
||||
w_uv = self.weight_capture.layer_moe.attn_args._16_W_UV[wid],
|
||||
w_uv_scale = self.weight_capture.layer_moe.attn_args._17_W_UV_scales[wid],
|
||||
o_proj_weight = self.weight_capture.layer_moe.attn_args._18_o_proj_weight[wid],
|
||||
o_proj_weight_scale_inv = self.weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv[wid],
|
||||
# mla params
|
||||
seq_lens = attn_metadata.decode_metadata.seq_lens,
|
||||
sm_scale = self.weight_capture.layer_moe.attn_args._21_sm_scale,
|
||||
head_num = self.weight_capture.layer_moe.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# moe weight
|
||||
rms_weight = self.weight_capture.layer_moe.moe_args._0_moe_rms_weight[wid],
|
||||
mlp_weight_13 = self.weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13[wid],
|
||||
mlp_weight_2 = self.weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2[wid],
|
||||
mlp_weight_scale_13 = self.weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale[wid],
|
||||
mlp_weight_scale_2 = self.weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale[wid],
|
||||
moe_weight_13 = self.weight_capture.layer_moe.moe_args._5_moe_w13[wid],
|
||||
moe_weight_2 = self.weight_capture.layer_moe.moe_args._6_moe_w2[wid],
|
||||
moe_weight_scale_13 = self.weight_capture.layer_moe.moe_args._7_moe_w13_scale[wid],
|
||||
moe_weight_scale_2 = self.weight_capture.layer_moe.moe_args._8_moe_w2_scale[wid],
|
||||
mm_weight = self.weight_capture.layer_moe.moe_args._9_gate_weight[wid],
|
||||
moe_bias = self.weight_capture.layer_moe.moe_args._10_moe_bias[wid],
|
||||
# moe params
|
||||
mlp_block_size_w13 = self.weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = self.weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
|
||||
moe_block_size_w13 = self.weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
|
||||
moe_block_size_w2 = self.weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
|
||||
# vccl info
|
||||
world_size = self.weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = self.weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = self.weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = self.weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
from .memory.memory_recycling import init_huge_memory_allocator
|
||||
from .vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
|
||||
# default is deepseek, config can set to ['deepseek_mtp',]
|
||||
model_name = "deepseek"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
if config_infos in ['mtp']:
|
||||
model_name = "deepseek_mtp"
|
||||
else:
|
||||
model_name = config_infos
|
||||
|
||||
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
|
||||
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if test_layer_en == "1":
|
||||
test_layer = 5
|
||||
if name not in ['model.embed_tokens.weight', 'model.norm.weight', 'lm_head.weight']:
|
||||
if int(name.split(".")[2]) > test_layer:
|
||||
continue
|
||||
# TODO(simon): support nextn predict layers
|
||||
if hasattr(self.config, "num_nextn_predict_layers"
|
||||
) and self.config.num_nextn_predict_layers > 0:
|
||||
assert self.config.num_nextn_predict_layers == 1
|
||||
layer_idx = self.config.num_hidden_layers
|
||||
if name.startswith(f"model.layers.{layer_idx}"):
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
if USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
layer.self_attn.merge_qkv_weights()
|
||||
|
||||
return loaded_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||||
if attn_metadata.prefill_metadata is not None:
|
||||
from .memory.memory_recycling import alloc_memory_recycler
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
if hasattr(attn_metadata, 'num_prefill_tokens'):
|
||||
tokens = attn_metadata.num_prefill_tokens
|
||||
else:
|
||||
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
|
||||
|
||||
vllm_model_mode = "deepseek"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
if config_infos in ['mtp']:
|
||||
vllm_model_mode = "deepseek_mtp"
|
||||
else:
|
||||
vllm_model_mode = config_infos
|
||||
|
||||
if get_tp_group().rank_in_group == 0:
|
||||
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
|
||||
logger.info(memory_infos)
|
||||
|
||||
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size):
|
||||
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
1367
vllm_vacc/vllm/model_executor/models/deepseek_v2_fused.py
Normal file
1367
vllm_vacc/vllm/model_executor/models/deepseek_v2_fused.py
Normal file
File diff suppressed because it is too large
Load Diff
216
vllm_vacc/vllm/model_executor/models/fused_forward.py
Normal file
216
vllm_vacc/vllm/model_executor/models/fused_forward.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import torch
|
||||
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode_layers,fuse_mla_moe_v2_allreduce_decode_layers,fuse_mla_moe_v2_allreduce_decode_layers_v2
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .weight_capture.deepseek_weight_capture import DeepseekWeightCapture
|
||||
import time
|
||||
|
||||
from .vars import *
|
||||
|
||||
# 单层 mla + mlp
|
||||
def forward_mla_mlp_single_layer(hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture,
|
||||
layer_id: int):
|
||||
from torch_vacc.vacc.custom_ops import fuse_mla_mlp_v2_allreduce_decode
|
||||
|
||||
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight[layer_id],
|
||||
q_a_proj_weight = weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights[layer_id],
|
||||
q_a_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv[layer_id],
|
||||
q_a_layernorm_weight = weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight[layer_id],
|
||||
w_q = weight_capture.layer_mlp.attn_args._3_W_Q[layer_id],
|
||||
w_q_scale = weight_capture.layer_mlp.attn_args._4_W_Q_scales[layer_id],
|
||||
w_uk = weight_capture.layer_mlp.attn_args._5_W_UK[layer_id],
|
||||
w_uk_scale = weight_capture.layer_mlp.attn_args._6_W_UK_scales[layer_id],
|
||||
w_qr = weight_capture.layer_mlp.attn_args._7_W_QR[layer_id],
|
||||
w_qr_scale = weight_capture.layer_mlp.attn_args._8_W_QR_scales[layer_id],
|
||||
kv_a_layernorm_weight = weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight[layer_id],
|
||||
sin_cache = weight_capture.layer_mlp.attn_args._10_sin_cache,
|
||||
cos_cache = weight_capture.layer_mlp.attn_args._11_cos_cache,
|
||||
slot_mapping = weight_capture.layer_mlp.attn_args._12_slot_mapping,
|
||||
kv_cache = weight_capture.layer_mlp.attn_args._13_kv_cache[layer_id],
|
||||
block_tables = weight_capture.layer_mlp.attn_args._14_block_tables,
|
||||
block_group_size = weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
|
||||
w_uv = weight_capture.layer_mlp.attn_args._16_W_UV[layer_id],
|
||||
w_uv_scale = weight_capture.layer_mlp.attn_args._17_W_UV_scales[layer_id],
|
||||
o_proj_weight = weight_capture.layer_mlp.attn_args._18_o_proj_weight[layer_id],
|
||||
o_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv[layer_id],
|
||||
# mla params
|
||||
seq_lens = weight_capture.layer_mlp.attn_args._20_seq_lens,
|
||||
sm_scale = weight_capture.layer_mlp.attn_args._21_sm_scale,
|
||||
head_num = weight_capture.layer_mlp.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# mlp weight
|
||||
rms_weight = weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight[layer_id],
|
||||
mlp_weight_13 = weight_capture.layer_mlp.mlp_args._1_mlp_w13[layer_id],
|
||||
mlp_weight_2 = weight_capture.layer_mlp.mlp_args._2_mlp_w2[layer_id],
|
||||
mlp_weight_scale_13 = weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale[layer_id],
|
||||
mlp_weight_scale_2 = weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale[layer_id],
|
||||
# mlp params
|
||||
mlp_block_size_w13 = weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
|
||||
# vccl info
|
||||
world_size = weight_capture.layer_mlp.dist_args._0_world_size,
|
||||
rank = weight_capture.layer_mlp.dist_args._1_rank,
|
||||
group_id = weight_capture.layer_mlp.dist_args._2_group_id,
|
||||
dev_info = weight_capture.layer_mlp.dist_args._3_dev_info)
|
||||
return hidden_states, residual
|
||||
|
||||
# 多层 mla + mlp
|
||||
def forward_mla_mlp_layers(hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture):
|
||||
if residual == None:
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
hidden_states, residual = fuse_mla_mlp_v2_allreduce_decode_layers(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = weight_capture.layer_mlp.attn_args._a_hidden_states_norm_weight,
|
||||
q_a_proj_weight = weight_capture.layer_mlp.attn_args._0_merge_q_kv_weights,
|
||||
q_a_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._1_merge_q_kv_scale_inv,
|
||||
q_a_layernorm_weight = weight_capture.layer_mlp.attn_args._2_q_a_layernorm_weight,
|
||||
w_q = weight_capture.layer_mlp.attn_args._3_W_Q,
|
||||
w_q_scale = weight_capture.layer_mlp.attn_args._4_W_Q_scales,
|
||||
w_uk = weight_capture.layer_mlp.attn_args._5_W_UK,
|
||||
w_uk_scale = weight_capture.layer_mlp.attn_args._6_W_UK_scales,
|
||||
w_qr = weight_capture.layer_mlp.attn_args._7_W_QR,
|
||||
w_qr_scale = weight_capture.layer_mlp.attn_args._8_W_QR_scales,
|
||||
kv_a_layernorm_weight = weight_capture.layer_mlp.attn_args._9_kv_a_layernorm_weight,
|
||||
sin_cache = weight_capture.layer_mlp.attn_args._10_sin_cache,
|
||||
cos_cache = weight_capture.layer_mlp.attn_args._11_cos_cache,
|
||||
slot_mapping = weight_capture.layer_mlp.attn_args._12_slot_mapping,
|
||||
kv_cache = weight_capture.layer_mlp.attn_args._13_kv_cache,
|
||||
block_tables = weight_capture.layer_mlp.attn_args._14_block_tables,
|
||||
block_group_size = weight_capture.layer_mlp.attn_args._15_env_blk_grp_size,
|
||||
w_uv = weight_capture.layer_mlp.attn_args._16_W_UV,
|
||||
w_uv_scale = weight_capture.layer_mlp.attn_args._17_W_UV_scales,
|
||||
o_proj_weight = weight_capture.layer_mlp.attn_args._18_o_proj_weight,
|
||||
o_proj_weight_scale_inv = weight_capture.layer_mlp.attn_args._19_o_proj_weight_scale_inv,
|
||||
# mla params
|
||||
seq_lens = weight_capture.layer_mlp.attn_args._20_seq_lens,
|
||||
sm_scale = weight_capture.layer_mlp.attn_args._21_sm_scale,
|
||||
head_num = weight_capture.layer_mlp.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# mlp weight
|
||||
rms_weight = weight_capture.layer_mlp.mlp_args._0_mlp_rms_weight,
|
||||
mlp_weight_13 = weight_capture.layer_mlp.mlp_args._1_mlp_w13,
|
||||
mlp_weight_2 = weight_capture.layer_mlp.mlp_args._2_mlp_w2,
|
||||
mlp_weight_scale_13 = weight_capture.layer_mlp.mlp_args._3_mlp_w13_scale,
|
||||
mlp_weight_scale_2 = weight_capture.layer_mlp.mlp_args._4_mlp_w2_scale,
|
||||
# mlp params
|
||||
mlp_block_size_w13 = weight_capture.layer_mlp.mlp_args._5_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = weight_capture.layer_mlp.mlp_args._6_mlp_w2_block_size,
|
||||
# vccl info
|
||||
world_size = weight_capture.layer_mlp.dist_args._0_world_size,
|
||||
rank = weight_capture.layer_mlp.dist_args._1_rank,
|
||||
group_id = weight_capture.layer_mlp.dist_args._2_group_id,
|
||||
dev_info = weight_capture.layer_mlp.dist_args._3_dev_info)
|
||||
return hidden_states, residual
|
||||
|
||||
# 多层 mla + moe, 未缓存weights
|
||||
def forward_mla_moe_layers_with_weights(hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture):
|
||||
|
||||
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode_layers(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
hidden_states_norm_weight = weight_capture.layer_moe.attn_args._a_hidden_states_norm_weight,
|
||||
q_a_proj_weight = weight_capture.layer_moe.attn_args._0_merge_q_kv_weights,
|
||||
q_a_proj_weight_scale_inv = weight_capture.layer_moe.attn_args._1_merge_q_kv_scale_inv,
|
||||
q_a_layernorm_weight = weight_capture.layer_moe.attn_args._2_q_a_layernorm_weight,
|
||||
w_q = weight_capture.layer_moe.attn_args._3_W_Q,
|
||||
w_q_scale = weight_capture.layer_moe.attn_args._4_W_Q_scales,
|
||||
w_uk = weight_capture.layer_moe.attn_args._5_W_UK,
|
||||
w_uk_scale = weight_capture.layer_moe.attn_args._6_W_UK_scales,
|
||||
w_qr = weight_capture.layer_moe.attn_args._7_W_QR,
|
||||
w_qr_scale = weight_capture.layer_moe.attn_args._8_W_QR_scales,
|
||||
kv_a_layernorm_weight = weight_capture.layer_moe.attn_args._9_kv_a_layernorm_weight,
|
||||
sin_cache = weight_capture.layer_moe.attn_args._10_sin_cache,
|
||||
cos_cache = weight_capture.layer_moe.attn_args._11_cos_cache,
|
||||
slot_mapping = weight_capture.layer_moe.attn_args._12_slot_mapping,
|
||||
kv_cache = weight_capture.layer_moe.attn_args._13_kv_cache,
|
||||
block_tables = weight_capture.layer_moe.attn_args._14_block_tables,
|
||||
block_group_size = weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
|
||||
w_uv = weight_capture.layer_moe.attn_args._16_W_UV,
|
||||
w_uv_scale = weight_capture.layer_moe.attn_args._17_W_UV_scales,
|
||||
o_proj_weight = weight_capture.layer_moe.attn_args._18_o_proj_weight,
|
||||
o_proj_weight_scale_inv = weight_capture.layer_moe.attn_args._19_o_proj_weight_scale_inv,
|
||||
# mla params
|
||||
seq_lens = weight_capture.layer_moe.attn_args._20_seq_lens,
|
||||
sm_scale = weight_capture.layer_moe.attn_args._21_sm_scale,
|
||||
head_num = weight_capture.layer_moe.attn_args._22_head_num,
|
||||
# flash attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# moe weight
|
||||
rms_weight = weight_capture.layer_moe.moe_args._0_moe_rms_weight,
|
||||
mlp_weight_13 = weight_capture.layer_moe.moe_args._1_moe_share_mlp_w13,
|
||||
mlp_weight_2 = weight_capture.layer_moe.moe_args._2_moe_share_mlp_w2,
|
||||
mlp_weight_scale_13 = weight_capture.layer_moe.moe_args._3_moe_share_mlp_w13_scale,
|
||||
mlp_weight_scale_2 = weight_capture.layer_moe.moe_args._4_moe_share_mlp_w2_scale,
|
||||
moe_weight_13 = weight_capture.layer_moe.moe_args._5_moe_w13,
|
||||
moe_weight_2 = weight_capture.layer_moe.moe_args._6_moe_w2,
|
||||
moe_weight_scale_13 = weight_capture.layer_moe.moe_args._7_moe_w13_scale,
|
||||
moe_weight_scale_2 = weight_capture.layer_moe.moe_args._8_moe_w2_scale,
|
||||
mm_weight = weight_capture.layer_moe.moe_args._9_gate_weight,
|
||||
moe_bias = weight_capture.layer_moe.moe_args._10_moe_bias,
|
||||
# moe params
|
||||
mlp_block_size_w13 = weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
|
||||
moe_block_size_w13 = weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
|
||||
moe_block_size_w2 = weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
|
||||
# vccl info
|
||||
world_size = weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
# 多层 mla + moe, 缓存weights,必须要在未缓存weights算子执行之后才可以调用
|
||||
def forward_mla_moe_layers_without_weights(hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture):
|
||||
hidden_states, residual = fuse_mla_moe_v2_allreduce_decode_layers_v2(
|
||||
hidden_states = hidden_states,
|
||||
residual = residual,
|
||||
sin_cache = weight_capture.layer_moe.attn_args._10_sin_cache,
|
||||
cos_cache = weight_capture.layer_moe.attn_args._11_cos_cache,
|
||||
slot_mapping = weight_capture.layer_moe.attn_args._12_slot_mapping,
|
||||
kv_cache = weight_capture.layer_moe.attn_args._13_kv_cache,
|
||||
block_tables = weight_capture.layer_moe.attn_args._14_block_tables,
|
||||
block_group_size = weight_capture.layer_moe.attn_args._15_env_blk_grp_size,
|
||||
# mla params
|
||||
seq_lens = weight_capture.layer_moe.attn_args._20_seq_lens,
|
||||
sm_scale = weight_capture.layer_moe.attn_args._21_sm_scale,
|
||||
head_num = weight_capture.layer_moe.attn_args._22_head_num,
|
||||
# flash_attention
|
||||
flash_attention = (USE_FLASH_ATTENTION==1),
|
||||
# moe weight
|
||||
# moe params
|
||||
mlp_block_size_w13 = weight_capture.layer_moe.moe_args._11_moe_mlp_w13_block_size,
|
||||
mlp_block_size_w2 = weight_capture.layer_moe.moe_args._12_moe_mlp_w2_block_size,
|
||||
moe_block_size_w13 = weight_capture.layer_moe.moe_args._13_moe_w13_block_size,
|
||||
moe_block_size_w2 = weight_capture.layer_moe.moe_args._14_moe_w2_block_size,
|
||||
# vccl info
|
||||
world_size = weight_capture.layer_moe.dist_args._0_world_size,
|
||||
rank = weight_capture.layer_moe.dist_args._1_rank,
|
||||
group_id = weight_capture.layer_moe.dist_args._2_group_id,
|
||||
dev_info = weight_capture.layer_moe.dist_args._3_dev_info)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_deepseekv3(model: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
weight_capture : DeepseekWeightCapture):
|
||||
|
||||
hidden_states, residual = forward_mla_mlp_layers(hidden_states, residual, weight_capture)
|
||||
hidden_states, residual = forward_mla_moe_layers_with_weights(hidden_states, residual, weight_capture)
|
||||
return hidden_states, residual
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,239 @@
|
||||
import warnings
|
||||
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
cached_file,
|
||||
is_timm_config_dict,
|
||||
is_timm_local_checkpoint,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
from transformers.models.auto.configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
AutoConfig,
|
||||
model_type_to_module_name,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
|
||||
from transformers.image_processing_utils import ImageProcessingMixin
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto.image_processing_auto import (
|
||||
AutoImageProcessor,
|
||||
logger,
|
||||
FORCE_FAST_IMAGE_PROCESSOR,
|
||||
IMAGE_PROCESSOR_MAPPING_NAMES,
|
||||
IMAGE_PROCESSOR_MAPPING,
|
||||
get_image_processor_class_from_name,
|
||||
resolve_trust_remote_code,
|
||||
_warning_fast_image_processor_available,
|
||||
get_class_from_dynamic_module
|
||||
)
|
||||
|
||||
def check_vacc_support_module(module_class):
|
||||
if module_class.__name__ == "Qwen2VLImageProcessorFast":
|
||||
from .qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
|
||||
return Qwen2VLImageProcessorFastWithVacc
|
||||
return module_class
|
||||
|
||||
"""AutoImageProcessor class."""
|
||||
class AutoImageProcessorWithVacc(AutoImageProcessor):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if kwargs.get("token") is not None:
|
||||
raise ValueError(
|
||||
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
||||
)
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
# TODO: @yoni, change in v4.48 (use_fast set to True by default)
|
||||
use_fast = kwargs.pop("use_fast", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
kwargs["_from_auto"] = True
|
||||
|
||||
# Resolve the image processor config filename
|
||||
if "image_processor_filename" in kwargs:
|
||||
image_processor_filename = kwargs.pop("image_processor_filename")
|
||||
elif is_timm_local_checkpoint(pretrained_model_name_or_path):
|
||||
image_processor_filename = CONFIG_NAME
|
||||
else:
|
||||
image_processor_filename = IMAGE_PROCESSOR_NAME
|
||||
|
||||
# Load the image processor config
|
||||
try:
|
||||
# Main path for all transformers models and local TimmWrapper checkpoints
|
||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
||||
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
|
||||
)
|
||||
except Exception as initial_exception:
|
||||
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
|
||||
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
|
||||
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
|
||||
# load `config.json` and if it fails with some error, we raise the initial exception.
|
||||
try:
|
||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
||||
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
|
||||
)
|
||||
except Exception:
|
||||
raise initial_exception
|
||||
|
||||
# In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
|
||||
# because only timm models have image processing in `config.json`.
|
||||
if not is_timm_config_dict(config_dict):
|
||||
raise initial_exception
|
||||
|
||||
image_processor_type = config_dict.get("image_processor_type", None)
|
||||
|
||||
# 跳转vacc预处理算子相关替换
|
||||
# if image_processor_type == "Qwen2VLImageProcessorFast":
|
||||
# from .qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
|
||||
# return Qwen2VLImageProcessorFastWithVacc.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
image_processor_auto_map = None
|
||||
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
||||
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
|
||||
|
||||
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
|
||||
# and if so, infer the image processor class from there.
|
||||
if image_processor_type is None and image_processor_auto_map is None:
|
||||
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
|
||||
if feature_extractor_class is not None:
|
||||
image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
|
||||
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
||||
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
||||
image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
|
||||
|
||||
# If we don't find the image processor class in the image processor config, let's try the model config.
|
||||
if image_processor_type is None and image_processor_auto_map is None:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
# It could be in `config.image_processor_type``
|
||||
image_processor_type = getattr(config, "image_processor_type", None)
|
||||
if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
|
||||
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
|
||||
|
||||
image_processor_class = None
|
||||
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||
if image_processor_type is not None:
|
||||
# if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
|
||||
if use_fast is None:
|
||||
use_fast = image_processor_type.endswith("Fast")
|
||||
if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available():
|
||||
use_fast = True
|
||||
logger.warning_once(
|
||||
f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. "
|
||||
"This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. "
|
||||
"Note that this behavior will be extended to all models in a future release."
|
||||
)
|
||||
if not use_fast:
|
||||
logger.warning_once(
|
||||
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
||||
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
|
||||
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
||||
)
|
||||
if use_fast and not image_processor_type.endswith("Fast"):
|
||||
image_processor_type += "Fast"
|
||||
if use_fast and not is_torchvision_available():
|
||||
# check if there is a slow image processor class to fallback to
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4])
|
||||
if image_processor_class is None:
|
||||
raise ValueError(
|
||||
f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again."
|
||||
)
|
||||
logger.warning_once(
|
||||
"Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
|
||||
)
|
||||
use_fast = False
|
||||
if use_fast:
|
||||
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
|
||||
if image_processor_type in image_processors:
|
||||
break
|
||||
else:
|
||||
image_processor_type = image_processor_type[:-4]
|
||||
use_fast = False
|
||||
logger.warning_once(
|
||||
"`use_fast` is set to `True` but the image processor class does not have a fast version. "
|
||||
" Falling back to the slow version."
|
||||
)
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type)
|
||||
else:
|
||||
image_processor_type_slow = image_processor_type.removesuffix("Fast")
|
||||
image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
|
||||
if image_processor_class is None and image_processor_type.endswith("Fast"):
|
||||
raise ValueError(
|
||||
f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor."
|
||||
)
|
||||
|
||||
has_remote_code = image_processor_auto_map is not None
|
||||
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
|
||||
if has_remote_code:
|
||||
if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
|
||||
# In some configs, only the slow image processor class is stored
|
||||
image_processor_auto_map = (image_processor_auto_map, None)
|
||||
if use_fast and image_processor_auto_map[1] is not None:
|
||||
class_ref = image_processor_auto_map[1]
|
||||
else:
|
||||
class_ref = image_processor_auto_map[0]
|
||||
if "--" in class_ref:
|
||||
upstream_repo = class_ref.split("--")[0]
|
||||
else:
|
||||
upstream_repo = None
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
||||
)
|
||||
|
||||
if has_remote_code and trust_remote_code:
|
||||
if not use_fast and image_processor_auto_map[1] is not None:
|
||||
_warning_fast_image_processor_available(image_processor_auto_map[1])
|
||||
|
||||
image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
||||
_ = kwargs.pop("code_revision", None)
|
||||
image_processor_class.register_for_auto_class()
|
||||
# check preprocess module supported by vacc
|
||||
image_processor_class = check_vacc_support_module(image_processor_class)
|
||||
return image_processor_class.from_dict(config_dict, **kwargs)
|
||||
elif image_processor_class is not None:
|
||||
# check preprocess module supported by vacc
|
||||
image_processor_class = check_vacc_support_module(image_processor_class)
|
||||
return image_processor_class.from_dict(config_dict, **kwargs)
|
||||
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
|
||||
elif type(config) in IMAGE_PROCESSOR_MAPPING:
|
||||
image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
|
||||
|
||||
image_processor_class_py, image_processor_class_fast = image_processor_tuple
|
||||
|
||||
if not use_fast and image_processor_class_fast is not None:
|
||||
_warning_fast_image_processor_available(image_processor_class_fast)
|
||||
|
||||
if image_processor_class_fast and (use_fast or image_processor_class_py is None):
|
||||
# check preprocess module supported by vacc
|
||||
image_processor_class_fast = check_vacc_support_module(image_processor_class_fast)
|
||||
return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
else:
|
||||
if image_processor_class_py is not None:
|
||||
# check preprocess module supported by vacc
|
||||
image_processor_class_py = check_vacc_support_module(image_processor_class_py)
|
||||
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
|
||||
f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
|
||||
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}"
|
||||
)
|
||||
@@ -0,0 +1,402 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for Qwen2-VL."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
|
||||
from transformers.image_processing_utils import BatchFeature
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
logging,
|
||||
)
|
||||
from transformers.video_utils import VideoInput, make_batched_videos
|
||||
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import (
|
||||
Qwen2VLFastImageProcessorKwargs,
|
||||
logger
|
||||
)
|
||||
|
||||
# reseize_normalize_repeat_transpose_reshape
|
||||
def fuse_qwen2_vl_preprocess_img_cpu(
|
||||
image: "torch.Tensor",
|
||||
do_resize: bool,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
resized_height: int,
|
||||
resized_width: int,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
image_mean_0: float,
|
||||
image_mean_1: float,
|
||||
image_mean_2: float,
|
||||
image_std_0: float,
|
||||
image_std_1: float,
|
||||
image_std_2: float,
|
||||
batch_size: int = 1,
|
||||
grid_t: int = 1,
|
||||
channel: int = 3,
|
||||
) -> "torch.Tensor":
|
||||
def resize(
|
||||
image: "torch.Tensor",
|
||||
size_h: int,
|
||||
size_w: int,
|
||||
interpolation: Optional["F.InterpolationMode"] = None,
|
||||
antialias: bool = True,
|
||||
) -> "torch.Tensor":
|
||||
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
||||
if size_h and size_w:
|
||||
new_size = (size_h, size_w)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
|
||||
f" {size_h} and {size_w}."
|
||||
)
|
||||
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
||||
|
||||
def fuse_mean_std_and_rescale_factor(
|
||||
image_mean_0: float,
|
||||
image_mean_1: float,
|
||||
image_mean_2: float,
|
||||
image_std_0: float,
|
||||
image_std_1: float,
|
||||
image_std_2: float,
|
||||
do_normalize: Optional[bool] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
) -> tuple:
|
||||
if do_rescale and do_normalize:
|
||||
image_mean = torch.tensor([image_mean_0, image_mean_1, image_mean_2], device=device)
|
||||
image_std = torch.tensor([image_std_0, image_std_1, image_std_2], device=device)
|
||||
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
|
||||
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
|
||||
return image_mean, image_std
|
||||
|
||||
def rescale_and_normalize(
|
||||
images: "torch.Tensor",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean_0: float,
|
||||
image_mean_1: float,
|
||||
image_mean_2: float,
|
||||
image_std_0: float,
|
||||
image_std_1: float,
|
||||
image_std_2: float
|
||||
) -> "torch.Tensor":
|
||||
image_mean, image_std = fuse_mean_std_and_rescale_factor(
|
||||
image_mean_0, image_mean_1, image_mean_2,
|
||||
image_std_0, image_std_1, image_std_2,
|
||||
do_normalize,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
device=images.device,
|
||||
)
|
||||
if do_normalize:
|
||||
images = F.normalize(images.to(dtype=torch.float32), image_mean, image_std)
|
||||
return images
|
||||
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
stacked_images = resize(
|
||||
image=image,
|
||||
size_h=resized_height,
|
||||
size_w=resized_width,
|
||||
interpolation=interpolation, # BICUBIC插值
|
||||
)
|
||||
patches = rescale_and_normalize(
|
||||
stacked_images,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
image_mean_0, image_mean_1, image_mean_2,
|
||||
image_std_0, image_std_1, image_std_2
|
||||
)
|
||||
if patches.ndim == 4:
|
||||
patches = patches.unsqueeze(1)
|
||||
if patches.shape[1] % temporal_patch_size != 0:
|
||||
repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
|
||||
patches = torch.cat([patches, repeats], dim=1)
|
||||
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
patches = patches.view(
|
||||
batch_size, # 1 - 批次大小
|
||||
grid_t, # 1 - 时间网格数
|
||||
temporal_patch_size, # 2 - 时间块大小
|
||||
channel, # 3 - 通道数(RGB)
|
||||
grid_h // merge_size, # 12 // 2 = 6 - 高度方向合并后的网格数
|
||||
merge_size, # 2 - 高度方向合并大小
|
||||
patch_size, # 14 - 高度方向块大小
|
||||
grid_w // merge_size, # 38 // 2 = 19 - 宽度方向合并后的网格数
|
||||
merge_size, # 2 - 宽度方向合并大小
|
||||
patch_size, # 14 - 宽度方向块大小
|
||||
)
|
||||
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
|
||||
flatten_patches = patches.reshape(
|
||||
# batch_size, # 1
|
||||
grid_t * grid_h * grid_w, # 1 * 12 * 38 = 456 (总网格数)
|
||||
channel * temporal_patch_size * patch_size * patch_size, # 3 * 2 * 14 * 14 = 1176 (每个网格的特征维度)
|
||||
)
|
||||
|
||||
return flatten_patches
|
||||
|
||||
|
||||
class Qwen2VLImageProcessorFastWithVacc(BaseImageProcessorFast):
|
||||
do_resize = True
|
||||
resample = PILImageResampling.BICUBIC
|
||||
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
do_convert_rgb = True
|
||||
patch_size = 14
|
||||
temporal_patch_size = 2
|
||||
merge_size = 2
|
||||
min_pixels = None
|
||||
max_pixels = None
|
||||
valid_kwargs = Qwen2VLFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Qwen2VLFastImageProcessorKwargs]):
|
||||
size = kwargs.pop("size", None)
|
||||
min_pixels = kwargs.pop("min_pixels", None)
|
||||
max_pixels = kwargs.pop("max_pixels", None)
|
||||
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
||||
size = self.size if size is None else size
|
||||
if min_pixels is not None:
|
||||
size["shortest_edge"] = min_pixels
|
||||
size.pop("min_pixels", None)
|
||||
if max_pixels is not None:
|
||||
size["longest_edge"] = max_pixels
|
||||
size.pop("max_pixels", None)
|
||||
if "shortest_edge" not in size or "longest_edge" not in size:
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
|
||||
super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
|
||||
|
||||
def _further_process_kwargs(
|
||||
self,
|
||||
size: Optional[SizeDict] = None,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Update kwargs that need further processing before being validated
|
||||
Can be overridden by subclasses to customize the processing of kwargs.
|
||||
"""
|
||||
if min_pixels is not None and max_pixels is not None:
|
||||
size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
|
||||
elif size is not None:
|
||||
if "shortest_edge" not in size or "longest_edge" not in size:
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
min_pixels = size["shortest_edge"]
|
||||
max_pixels = size["longest_edge"]
|
||||
else:
|
||||
size = {**self.size}
|
||||
|
||||
return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: Optional[VideoInput] = None,
|
||||
**kwargs: Unpack[Qwen2VLFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
return super().preprocess(images, videos, **kwargs)
|
||||
|
||||
def _preprocess_image_like_inputs(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput,
|
||||
do_convert_rgb: bool,
|
||||
input_data_format: ChannelDimension,
|
||||
device: Optional[Union[str, "torch.device"]] = None,
|
||||
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess image-like inputs.
|
||||
To be overridden by subclasses when image-like inputs other than images should be processed.
|
||||
It can be used for segmentation maps, depth maps, etc.
|
||||
"""
|
||||
# Prepare input images
|
||||
batch_feature = BatchFeature()
|
||||
if images is not None:
|
||||
images = self._prepare_image_like_inputs(
|
||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||
)
|
||||
batch_feature = self._preprocess(images, **kwargs)
|
||||
if videos is not None:
|
||||
logger.warning(
|
||||
"`Qwen2VLImageProcessorFast` works only with image inputs and doesn't process videos anymore. "
|
||||
"This is a deprecated behavior and will be removed in v5.0. "
|
||||
"Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
|
||||
)
|
||||
# Can't change _prepare_images_structure to work with videos because it also needs to work with images.
|
||||
videos = make_batched_videos(videos)
|
||||
videos = [
|
||||
torch.stack(self._prepare_image_like_inputs(video, do_convert_rgb, input_data_format, device))
|
||||
for video in videos
|
||||
]
|
||||
video_outputs = self._preprocess(videos, **kwargs)
|
||||
batch_feature.update(
|
||||
{"pixel_values_videos": video_outputs.pixel_values, "video_grid_thw": video_outputs.image_grid_thw}
|
||||
)
|
||||
return batch_feature
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
):
|
||||
min_pixels=size["shortest_edge"],
|
||||
max_pixels=size["longest_edge"],
|
||||
processed_images = []
|
||||
processed_grids = []
|
||||
for img in images:
|
||||
height, width = img.shape[-2:]
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=min_pixels[0],
|
||||
max_pixels=max_pixels[0],
|
||||
)
|
||||
# reseize_normalize_repeat = fuse_qwen2_vl_preprocess_img_cpu(
|
||||
# img,
|
||||
# do_resize,
|
||||
# min_pixels,
|
||||
# max_pixels,
|
||||
# do_rescale,
|
||||
# rescale_factor,
|
||||
# do_normalize,
|
||||
# resized_height,
|
||||
# resized_width,
|
||||
# interpolation,
|
||||
# patch_size,
|
||||
# temporal_patch_size,
|
||||
# merge_size,
|
||||
# image_mean[0], image_mean[1], image_mean[2],
|
||||
# image_std[0], image_std[1], image_std[2],
|
||||
# 1,1,3
|
||||
# )
|
||||
import torch_vacc
|
||||
if img.device.type != "vacc":
|
||||
img = img.to("vacc")
|
||||
reseize_normalize_repeat = torch_vacc.vacc.custom_qwen3_ops.qwen2vl_img_preprocess(img,
|
||||
do_resize,
|
||||
min_pixels[0],
|
||||
max_pixels[0],
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
resized_height,
|
||||
resized_width,
|
||||
0x1003, # interpolation,
|
||||
patch_size,
|
||||
temporal_patch_size,
|
||||
merge_size,
|
||||
image_mean[0], image_mean[1], image_mean[2],
|
||||
image_std[0], image_std[1], image_std[2] )
|
||||
processed_images.append(reseize_normalize_repeat)
|
||||
grid_t = 1
|
||||
grid_h = resized_height // patch_size
|
||||
grid_w = resized_width // patch_size
|
||||
grid_thw_ = torch.tensor([[grid_t, grid_h, grid_w]])
|
||||
processed_grids.append(grid_thw_)
|
||||
|
||||
pixel_values = torch.cat(processed_images, dim=0)
|
||||
image_grid_thw = torch.cat(processed_grids, dim=0)
|
||||
|
||||
return BatchFeature(
|
||||
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
|
||||
)
|
||||
|
||||
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||
"""
|
||||
A utility that returns number of image patches for a given image size.
|
||||
|
||||
Note: Do not remove this method! It is used by vLLM to infer the number of patches and placeholders
|
||||
without an image input.
|
||||
|
||||
Args:
|
||||
height (`int`):
|
||||
Height of the input image.
|
||||
width (`int`):
|
||||
Width of the input image.
|
||||
images_kwargs (`dict`, *optional*)
|
||||
Any kwargs to override defaults of the image processor.
|
||||
Returns:
|
||||
`int`: Number of image patches per image.
|
||||
"""
|
||||
min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
|
||||
max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
|
||||
patch_size = images_kwargs.get("patch_size", self.patch_size)
|
||||
merge_size = images_kwargs.get("merge_size", self.merge_size)
|
||||
|
||||
factor = patch_size * merge_size
|
||||
resized_height, resized_width = smart_resize(
|
||||
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
|
||||
)
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
return grid_h * grid_w
|
||||
|
||||
|
||||
__all__ = ["Qwen2VLImageProcessorFastWithVacc"]
|
||||
@@ -0,0 +1,91 @@
|
||||
from transformers.models.qwen3_vl import Qwen3VLProcessor
|
||||
# from transformers.models.auto.image_processing_auto import AutoImageProcessor
|
||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||
|
||||
class Qwen3VLProcessorWithVacc(Qwen3VLProcessor):
|
||||
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||
|
||||
@classmethod
|
||||
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
"""
|
||||
Identify and instantiate the subcomponents of Processor classes, like image processors and
|
||||
tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
|
||||
subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
|
||||
the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
|
||||
via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
|
||||
will be unable to find the relevant subcomponent class and will raise an error.
|
||||
"""
|
||||
args = []
|
||||
for attribute_name in cls.attributes:
|
||||
class_name = getattr(cls, f"{attribute_name}_class")
|
||||
if isinstance(class_name, tuple):
|
||||
classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
|
||||
if attribute_name == "image_processor":
|
||||
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||
use_fast = kwargs.get("use_fast")
|
||||
if use_fast is None:
|
||||
print(
|
||||
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
||||
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
|
||||
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
||||
)
|
||||
else:
|
||||
use_fast = kwargs.get("use_fast", True)
|
||||
if use_fast and classes[1] is not None:
|
||||
attribute_class = classes[1]
|
||||
else:
|
||||
attribute_class = classes[0]
|
||||
else:
|
||||
attribute_class = cls.get_possibly_dynamic_module(class_name)
|
||||
|
||||
if attribute_class.__name__ == "AutoImageProcessor":
|
||||
from .auto_image_preprocessor import AutoImageProcessorWithVacc
|
||||
attribute_class = AutoImageProcessorWithVacc
|
||||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
||||
|
||||
return args
|
||||
|
||||
class Qwen2VLProcessorWithVacc(Qwen2VLProcessor):
|
||||
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||
|
||||
@classmethod
|
||||
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
"""
|
||||
Identify and instantiate the subcomponents of Processor classes, like image processors and
|
||||
tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
|
||||
subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
|
||||
the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
|
||||
via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
|
||||
will be unable to find the relevant subcomponent class and will raise an error.
|
||||
"""
|
||||
args = []
|
||||
for attribute_name in cls.attributes:
|
||||
class_name = getattr(cls, f"{attribute_name}_class")
|
||||
if isinstance(class_name, tuple):
|
||||
classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
|
||||
if attribute_name == "image_processor":
|
||||
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||
use_fast = kwargs.get("use_fast")
|
||||
if use_fast is None:
|
||||
print(
|
||||
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
||||
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
|
||||
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
||||
)
|
||||
else:
|
||||
use_fast = kwargs.get("use_fast", True)
|
||||
if use_fast and classes[1] is not None:
|
||||
attribute_class = classes[1]
|
||||
else:
|
||||
attribute_class = classes[0]
|
||||
else:
|
||||
attribute_class = cls.get_possibly_dynamic_module(class_name)
|
||||
|
||||
if attribute_class.__name__ == "AutoImageProcessor":
|
||||
from .auto_image_preprocessor import AutoImageProcessorWithVacc
|
||||
attribute_class = AutoImageProcessorWithVacc
|
||||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
||||
|
||||
return args
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
235
vllm_vacc/vllm/model_executor/models/memory/allocator.py
Normal file
235
vllm_vacc/vllm/model_executor/models/memory/allocator.py
Normal file
@@ -0,0 +1,235 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List
|
||||
|
||||
class VaccHugeMemoryAllocator(nn.Module):
|
||||
# self._active_bytes means the real tensor buffers used bytes.
|
||||
# you can use this value to slice the src buffer
|
||||
# you can not free the self._src_buffer_array, because the src buffer is the max buffer
|
||||
# self._block_bytes means the part max buffer size
|
||||
def __init__(self, blocks, dtype = torch.bfloat16, use_contiguous = False):
|
||||
self._total_blocks = blocks
|
||||
self._dtype = dtype
|
||||
self._enable = False
|
||||
self._src_buffer_array = None
|
||||
self._block_bytes = 0
|
||||
self._max_tokens = 0
|
||||
self._hiddens = 0
|
||||
self._active_bytes = 0
|
||||
self._use_contiguous_buffer = use_contiguous
|
||||
# max tokens for dynamic buffer size
|
||||
# dynamic buffer size is bigger than normal buffer usually
|
||||
self._dynamic_max_tokens = 0
|
||||
self._dynamic_block_bytes = 0
|
||||
|
||||
# malloc the max buffer, and not free
|
||||
def init_buffers(self, max_tokens, hiddens):
|
||||
self._max_tokens = max_tokens
|
||||
self._hiddens = hiddens
|
||||
|
||||
try:
|
||||
import torch_vacc
|
||||
self._block_bytes = self._max_tokens * self._hiddens \
|
||||
* self.get_dtype_bytes(self._dtype)
|
||||
self._all_bytes = self._block_bytes * self._total_blocks
|
||||
|
||||
if self._use_contiguous_buffer:
|
||||
# 一次性申请[N*3,]大小的BytesBuffer
|
||||
self._src_buffer = torch.zeros(self._all_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc")
|
||||
tmp_buffer_array = self._src_buffer.view(self._total_blocks, -1)
|
||||
self._src_buffer_array = [tmp_buffer_array[i]
|
||||
for i in range(self._total_blocks)]
|
||||
else:
|
||||
# 一次性申请3块[N,]大小的BytesBuffer
|
||||
self._src_buffer_array = [torch.zeros(self.block_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc")
|
||||
for i in range(self._total_blocks)]
|
||||
self._enable = True
|
||||
except Exception as e:
|
||||
print(f"vacc huge buffer alloc fail: {e}")
|
||||
|
||||
# 为有dynamic buffer需求的网络设计, dynamic buffer 可能会要比普通的max_tokens buffers大一些
|
||||
def init_buffers_with_dynamic(self, max_tokens, dynamic_tokens, hiddens, dynamic_buffers_mask: List):
|
||||
self._max_tokens = max_tokens
|
||||
self._dynamic_max_tokens = dynamic_tokens
|
||||
self._hiddens = hiddens
|
||||
|
||||
dynamic_buffers_count = sum(dynamic_buffers_mask)
|
||||
normal_buffers_count = self._total_blocks - dynamic_buffers_count
|
||||
|
||||
self._block_bytes = self._max_tokens * self._hiddens \
|
||||
* self.get_dtype_bytes(self._dtype)
|
||||
self._dynamic_block_bytes = self._dynamic_max_tokens * self._hiddens \
|
||||
* self.get_dtype_bytes(self._dtype)
|
||||
|
||||
self._all_bytes = self._block_bytes * normal_buffers_count + \
|
||||
self._dynamic_block_bytes * dynamic_buffers_count
|
||||
|
||||
# print("创建重复利用buffer: dynamic的数量->", dynamic_buffers_count,
|
||||
# " 正常的数量->", normal_buffers_count,
|
||||
# " dynamic buffer大小->", self._dynamic_block_bytes,
|
||||
# " 正常大小->", self._block_bytes)
|
||||
|
||||
try:
|
||||
assert self._use_contiguous_buffer is False, "malloc dynamic recycle memory buffers only support separation buffer now"
|
||||
|
||||
self._src_buffer_array = []
|
||||
for i in range(self._total_blocks):
|
||||
if dynamic_buffers_mask[i]:
|
||||
self._src_buffer_array.append(torch.zeros(self._dynamic_block_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc"))
|
||||
else:
|
||||
# 一次性申请3块[N,]大小的BytesBuffer
|
||||
self._src_buffer_array.append(torch.zeros(self._block_bytes,
|
||||
dtype = torch.uint8,
|
||||
device = "vacc"))
|
||||
self._enable = True
|
||||
except Exception as e:
|
||||
print("vacc huge buffer alloc fail.", e)
|
||||
|
||||
# slice buffers from the src buffer (48K * blocks)
|
||||
# use for target tensor
|
||||
# you should analyse such tensor position by model, such as deepseek have 4 buffers
|
||||
# you need alloc the buffer by real input tokens when new request is in
|
||||
# notice: you should warn the dtype, because the buffer created by uint8
|
||||
def alloc_memory_buffers(self, tokens, dtype=torch.bfloat16):
|
||||
if tokens > self._max_tokens:
|
||||
print("alloc memory buffer fail, tokens is large than max_tokens.", self._max_tokens)
|
||||
return None
|
||||
|
||||
self._active_bytes = tokens * self._hiddens * self.get_dtype_bytes(dtype)
|
||||
return [sub_array[:self._active_bytes]
|
||||
for sub_array in self._src_buffer_array]
|
||||
|
||||
@property
|
||||
def memory_buffers(self):
|
||||
return self._src_buffer_array
|
||||
|
||||
# allock 1_2 buffers
|
||||
# @params tokens 待缓存的prefill tokens buffer大小
|
||||
# @params part 总共划分的区域
|
||||
# @params return_buffer_list 需要返回的区域列表,如果为空的话,返回所有
|
||||
# @params dtype 数据类型
|
||||
# 创建1/N的buffers
|
||||
# 返回[第N部分]
|
||||
def alloc_1_div_N_buffers(self, part = 2,
|
||||
return_buffer_list = [], ):
|
||||
|
||||
if not hasattr(self, "_src_buffer"):
|
||||
print("1 div N alloctor need a contiguous buffer")
|
||||
return None
|
||||
|
||||
assert isinstance(return_buffer_list, list), "return_buffer_list need List object"
|
||||
|
||||
# 数据以int8的方式,划分为part
|
||||
tmp_buffer_array = self._src_buffer.view(part, -1)
|
||||
# 如果未指定return_buffer_list, 返回所有的part buffer
|
||||
if len(return_buffer_list) == 0:
|
||||
return [tmp_buffer_array[i] for i in range(part)]
|
||||
|
||||
return [tmp_buffer_array[i] for i in return_buffer_list]
|
||||
|
||||
|
||||
def get_dtype_bytes(self, dtype):
|
||||
if isinstance(dtype, torch.dtype):
|
||||
if dtype in [torch.float16, torch.bfloat16, torch.half]:
|
||||
return 2
|
||||
elif dtype in [torch.float32, torch.float, torch.int32]:
|
||||
return 4
|
||||
elif dtype in [torch.float64, torch.double, torch.int64]:
|
||||
return 8
|
||||
elif dtype in [torch.int8, torch.uint8, torch.bool]:
|
||||
return 1
|
||||
else:
|
||||
return 1
|
||||
elif dtype == int:
|
||||
return 8
|
||||
elif dtype == float:
|
||||
return 8
|
||||
elif dtype == bool:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@property
|
||||
def enable(self):
|
||||
return self._enable
|
||||
|
||||
@property
|
||||
def blocks(self):
|
||||
return self._total_blocks
|
||||
|
||||
@property
|
||||
def max_tokens(self):
|
||||
return self._max_tokens
|
||||
|
||||
@property
|
||||
def hiddens(self):
|
||||
return self._hiddens
|
||||
|
||||
@property
|
||||
def active_bytes(self):
|
||||
return self._active_bytes
|
||||
|
||||
@property
|
||||
def block_bytes(self):
|
||||
return self._block_bytes
|
||||
|
||||
@property
|
||||
def dynamic_block_bytes(self):
|
||||
return self._dynamic_block_bytes
|
||||
|
||||
|
||||
class LLMMemoryRecycler:
|
||||
def __init__(self):
|
||||
self.count = 3
|
||||
self.embedding_output = None
|
||||
self.moe_shared_mlp_output = None
|
||||
self.mla_oproj_output = None
|
||||
#self.moe_expert_output = None
|
||||
|
||||
def clear(self):
|
||||
self.embedding_output = None
|
||||
self.moe_shared_mlp_output = None
|
||||
self.mla_oproj_output = None
|
||||
#self.moe_expert_output = None
|
||||
|
||||
@property
|
||||
def EMBEDDING_OUT_BUFFER(self):
|
||||
return self.embedding_output
|
||||
@property
|
||||
def MOE_SHARED_MLP_OUT_BUFFER(self):
|
||||
return self.moe_shared_mlp_output
|
||||
@property
|
||||
def MLA_OPROJ_OUT_BUFFER(self):
|
||||
return self.mla_oproj_output
|
||||
|
||||
def alloc_memory_recycler_llm(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
recycler:LLMMemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("llm memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal llm.")
|
||||
return False
|
||||
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
|
||||
return True
|
||||
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from .allocator import VaccHugeMemoryAllocator,LLMMemoryRecycler
|
||||
'''
|
||||
DeepSeek support 48K input tokens, there are 3 buffers can recycle:
|
||||
1. parallel_embedding output buffer
|
||||
2. mla_oproject output buffer
|
||||
3. moe_shared_mlp output buffer
|
||||
# 4. moe_expert output buffer
|
||||
each buffer size is 48K * 7168 * 2 bytes
|
||||
'''
|
||||
class DeepseekV3MemoryRecycler(LLMMemoryRecycler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
#self.moe_expert_output = None
|
||||
|
||||
# @property
|
||||
# def MOE_EXPERT_OUT_BUFFER(self):
|
||||
# return self.moe_expert_output
|
||||
|
||||
|
||||
class DeepseekMTPMemoryRecycler(LLMMemoryRecycler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dynamic_output = None
|
||||
self.deepseek_mtp_layer_input = None
|
||||
|
||||
@property
|
||||
def DYNAMIC_OUTPUT_BUFFER(self):
|
||||
return self.dynamic_output
|
||||
|
||||
@property
|
||||
def DEEPSEEK_MTP_LAYER_INPUT(self):
|
||||
return self.deepseek_mtp_layer_input
|
||||
|
||||
|
||||
def alloc_memory_recycler_deepseek_v3(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
recycler:DeepseekV3MemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("deepseek_v3 memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal deepseek_v3.")
|
||||
return False
|
||||
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
#recycler.moe_expert_output = out_buffers[3].view(dtype).view(tokens, alloctor.hiddens)
|
||||
return True
|
||||
|
||||
|
||||
def alloc_memory_recycler_deepseek_mtp(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
world_size:int,
|
||||
recycler:DeepseekMTPMemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("deepseek_mtp memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal deepseek_mtp.")
|
||||
return False
|
||||
#MTP的内存布局为
|
||||
# 1. deepseek 主模型和 草稿模型中的decoder_layer
|
||||
# a. embedding_output
|
||||
# b. mla_oproj_output
|
||||
# c. moe_shared_mlp_output
|
||||
# 因为: moe会作为previous_hidden_states,被重新组织一次,组织好的buffer会置于buffer0[a.]中
|
||||
# 如果embedding_output放在第一位也可以,相关地址需要整体往后偏移,
|
||||
# 为了理解方便把embedding置于最后,不会参与buffer的重划分复用
|
||||
# 2. deepseek_mtp 草稿模型(未启用该策略, )
|
||||
# a. dynamic_output 占用1/2
|
||||
# b. mtp_input 占用1/6,并且位于dynamic_output buffer后面
|
||||
# dynamic_buffer = alloctor.alloc_1_div_N_buffers(2, [0,])
|
||||
# mtp_input_buffer = alloctor.alloc_1_div_N_buffers(6, [3,])
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
|
||||
#1/2用broadcast的自由分配,mtp涉及到previous_hidden_states的广播
|
||||
memory_buffers = alloctor.memory_buffers
|
||||
recycler.dynamic_output = memory_buffers[1] #公共用mla_oproject
|
||||
#1/6用于mtp decoder layer的输入, 该算子特殊处理,仅用了1/tp的输出buffer即可
|
||||
# recycler.deepseek_mtp_layer_input = mtp_input_buffer[0]
|
||||
|
||||
mtp_layer_input_dims = alloctor.hiddens * 2 // world_size
|
||||
mtp_layer_input_numels = tokens * mtp_layer_input_dims
|
||||
|
||||
from vllm import envs
|
||||
if envs.VLLM_USE_V1:
|
||||
#v1 无需重新缓存previous_hidden_states,因此moe还在被占用状态,因此需要先用attention的o-buffer去暂存mtp 预处理的空间
|
||||
recycler.deepseek_mtp_layer_input = memory_buffers[1].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
|
||||
else:
|
||||
# 公共用moe output
|
||||
recycler.deepseek_mtp_layer_input = memory_buffers[2].view(dtype)[:mtp_layer_input_numels].view(tokens, mtp_layer_input_dims)
|
||||
return True
|
||||
132
vllm_vacc/vllm/model_executor/models/memory/memory_recycling.py
Normal file
132
vllm_vacc/vllm/model_executor/models/memory/memory_recycling.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import torch
|
||||
from .allocator import VaccHugeMemoryAllocator
|
||||
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler
|
||||
|
||||
VLLM_MODEL_MODE = os.environ.get("VLLM_MODEL_MODE", "deepseek")
|
||||
|
||||
global huge_memory_alloctor
|
||||
global memory_recycler
|
||||
huge_memory_alloctor = None
|
||||
memory_recycler = None
|
||||
|
||||
# you should call this function when new request is in
|
||||
def alloc_memory_recycler(tokens, dtype=torch.bfloat16, **argv):
|
||||
global huge_memory_alloctor
|
||||
global memory_recycler
|
||||
|
||||
vllm_model = argv.get('vllm_model')
|
||||
if vllm_model is None:
|
||||
print("model infos is empty, now using VLLM_MODEL_MODE")
|
||||
vllm_model = VLLM_MODEL_MODE
|
||||
|
||||
# TODO: use default memory-recycle schedule
|
||||
# if vllm_model in ['xxx']:
|
||||
# vllm_model = "llm_default"
|
||||
|
||||
memory_recycler = None
|
||||
if vllm_model == "deepseek":
|
||||
from .deepseek_v3_memory_recycler import DeepseekV3MemoryRecycler, alloc_memory_recycler_deepseek_v3
|
||||
memory_recycler = DeepseekV3MemoryRecycler()
|
||||
state = alloc_memory_recycler_deepseek_v3(tokens, huge_memory_alloctor, memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return state
|
||||
|
||||
if vllm_model == "deepseek_mtp":
|
||||
from .deepseek_v3_memory_recycler import DeepseekMTPMemoryRecycler, alloc_memory_recycler_deepseek_mtp
|
||||
memory_recycler = DeepseekMTPMemoryRecycler()
|
||||
if argv.get('world_size') is None:
|
||||
print("mtp should have TP world size, memory recycler allock fail")
|
||||
return False
|
||||
|
||||
state = alloc_memory_recycler_deepseek_mtp(tokens, huge_memory_alloctor, argv['world_size'], memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return state
|
||||
|
||||
if vllm_model == "qwen3_moe":
|
||||
from .qwen3_moe_memory_recycler import QWen3MoeMemoryRecycler, alloc_memory_recycler_qwen3_moe
|
||||
memory_recycler = QWen3MoeMemoryRecycler()
|
||||
state = alloc_memory_recycler_qwen3_moe(tokens, huge_memory_alloctor, memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return state
|
||||
|
||||
if vllm_model == "llm_default":
|
||||
from .allocator import LLMMemoryRecycler, alloc_memory_recycler_llm
|
||||
memory_recycler = LLMMemoryRecycler()
|
||||
state = alloc_memory_recycler_llm(tokens, huge_memory_alloctor, memory_recycler, dtype)
|
||||
if not state:
|
||||
del memory_recycler
|
||||
memory_recycler = None
|
||||
return False
|
||||
|
||||
# LLM pipeline parallel 方案下, 对于非stage0的PART,
|
||||
# 在接收来自BEFORE PART的hiddens, residual的时候
|
||||
# 符合内存复用规则
|
||||
# 该过程与llm forward过程相独立, 需要单独维护
|
||||
# hiddens 对应 llm forward的时候,moe mlp buffer
|
||||
# residual 对应 llm forward的时候,embedding buffer
|
||||
def alloc_pipeline_parallel_recycler_buffer(size:torch.Size, dtype:torch.dtype, key:str):
|
||||
global huge_memory_alloctor
|
||||
if huge_memory_alloctor is None:
|
||||
return None
|
||||
|
||||
intermize_tensor_dict = {
|
||||
"hidden_states": 2,
|
||||
"attention": 1,
|
||||
"residual": 0
|
||||
}
|
||||
|
||||
if not key in intermize_tensor_dict:
|
||||
return None
|
||||
|
||||
src_tensors = huge_memory_alloctor.memory_buffers[intermize_tensor_dict[key]]
|
||||
|
||||
all_bytes = size.numel() * huge_memory_alloctor.get_dtype_bytes(dtype)
|
||||
return src_tensors[:all_bytes].view(dtype).view(size)
|
||||
|
||||
# you should call this function when new server and every workers is start
|
||||
def init_huge_memory_allocator(max_tokens, hidden_size, vllm_model = None):
|
||||
global huge_memory_alloctor
|
||||
if vllm_model is None:
|
||||
print("model infos is empty, now using VLLM_MODEL_MODE")
|
||||
vllm_model = VLLM_MODEL_MODE
|
||||
|
||||
if huge_memory_alloctor is not None:
|
||||
del huge_memory_alloctor
|
||||
torch.vacc.empty_cache()
|
||||
huge_memory_alloctor = None
|
||||
|
||||
if vllm_model == "deepseek":
|
||||
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
|
||||
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
|
||||
return True
|
||||
|
||||
# deepseek_mtp set use_congituous = True
|
||||
# deepseek_mtp buffer recycler:
|
||||
# buffer[0]: normal_buffer -> embedding_output
|
||||
# buffer[1]: dynamic_buffer -> mla_oproj_output, dynamic_output
|
||||
# buffer[2]: normal_buffer -> moe_shared_mlp_output, deepseek_mtp_layer_input
|
||||
if vllm_model == "deepseek_mtp":
|
||||
# deepseek dynamic tokens last block use for mtp-weights
|
||||
# dynamic_buffer_max_tokens = max_tokens + 128, we will let mtp only support 48K now
|
||||
dynamic_buffer_max_tokens = max_tokens
|
||||
# dynamic bufffer use more 128tokens for broadcast
|
||||
# positions, input tokens
|
||||
deepseek_mtp_max_tokens = max_tokens
|
||||
|
||||
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
|
||||
# huge_memory_alloctor.init_buffers(max_tokens, 7168)
|
||||
huge_memory_alloctor.init_buffers_with_dynamic(deepseek_mtp_max_tokens, dynamic_buffer_max_tokens, hidden_size, [False, True, False])
|
||||
return True
|
||||
|
||||
if vllm_model == "qwen3_moe":
|
||||
huge_memory_alloctor = VaccHugeMemoryAllocator(3)
|
||||
huge_memory_alloctor.init_buffers(max_tokens, hidden_size)
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from .allocator import VaccHugeMemoryAllocator, LLMMemoryRecycler
|
||||
'''
|
||||
QWen3-Moe support 56K input tokens, there are 3 buffers can recycle:
|
||||
1. parallel_embedding output buffer
|
||||
2. mla_oproject output buffer
|
||||
3. moe_shared_mlp output buffer
|
||||
'''
|
||||
class QWen3MoeMemoryRecycler(LLMMemoryRecycler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def alloc_memory_recycler_qwen3_moe(tokens,
|
||||
alloctor:VaccHugeMemoryAllocator,
|
||||
recycler:QWen3MoeMemoryRecycler,
|
||||
dtype:torch.dtype = torch.bfloat16):
|
||||
|
||||
if not alloctor.enable:
|
||||
print("memory alloctor is not Init.")
|
||||
return False
|
||||
|
||||
recycler.clear()
|
||||
out_buffers = alloctor.alloc_memory_buffers(tokens, dtype)
|
||||
|
||||
if out_buffers is None:
|
||||
print("qwen3_moe memory recycler buffers alloc fail. now disable it")
|
||||
return False
|
||||
|
||||
if len(out_buffers) != recycler.count:
|
||||
print("memory recycler buffers not equal qwen3_moe.")
|
||||
return False
|
||||
|
||||
recycler.embedding_output = out_buffers[0].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.mla_oproj_output = out_buffers[1].view(dtype).view(tokens, alloctor.hiddens)
|
||||
recycler.moe_shared_mlp_output = out_buffers[2].view(dtype).view(tokens, alloctor.hiddens)
|
||||
return True
|
||||
|
||||
1457
vllm_vacc/vllm/model_executor/models/qwen2.py
Normal file
1457
vllm_vacc/vllm/model_executor/models/qwen2.py
Normal file
File diff suppressed because it is too large
Load Diff
33
vllm_vacc/vllm/model_executor/models/qwen2_5_vl.py
Normal file
33
vllm_vacc/vllm/model_executor/models/qwen2_5_vl.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
|
||||
from vacc_tools.trace_logger import get_trace_api
|
||||
trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
|
||||
get_trace_api("deepseek")
|
||||
)
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2_5_VisionAttention(nn.Module):
|
||||
|
||||
# @trace_time('Qwen2_5_VisionAttention_vacc_split_qkv')
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
q1, k1, v1 = qkv.chunk(3, dim=-1)
|
||||
q1, k1, v1 = (x.view(*new_shape) for x in (q1, k1, v1))
|
||||
return q1, k1, v1
|
||||
|
||||
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
|
||||
# convert conv3d to matmul
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.matmul(x, self.proj.weight.view(self.hidden_size, -1).T)
|
||||
|
||||
285
vllm_vacc/vllm/model_executor/models/qwen2_vl.py
Normal file
285
vllm_vacc/vllm/model_executor/models/qwen2_vl.py
Normal file
@@ -0,0 +1,285 @@
|
||||
|
||||
"""Inference-only Qwen2VL model compatible with HuggingFace weights."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce, parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention as Qwen2VisionAttentionOrg
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.logger import init_logger
|
||||
from .hf_processor.qwenvl_processor import Qwen2VLProcessorWithVacc
|
||||
from .hf_processor.qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
|
||||
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm_vacc.vllm.model_executor.models.vars import USE_FUSED_QWEN_ATTENTION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class Qwen2VisionPatchEmbed(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self.proj, 'bias') and self.proj.bias is not None:
|
||||
return torch.nn.functional.linear(x, self.proj.weight.view(self.hidden_size, -1), self.proj.bias)
|
||||
return torch.matmul(x, self.proj.weight.view(self.embed_dim, -1).T)
|
||||
|
||||
|
||||
class Qwen2VLProcessingInfo():
|
||||
def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessorWithVacc:
|
||||
return self.ctx.get_hf_processor(
|
||||
Qwen2VLProcessorWithVacc,
|
||||
use_fast=kwargs.pop("use_fast", True),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFastWithVacc:
|
||||
return self.get_hf_processor(**kwargs).image_processor
|
||||
|
||||
import torch.nn.functional as F
|
||||
class Qwen2VisionTransformer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
# patchify
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_qwen3_ops import rot_pos_emb_qwenvl
|
||||
sin_cache, cos_cache = rot_pos_emb_qwenvl(grid_thw, self.embed_dim, self.num_heads, self.spatial_merge_size, self.dtype, self.device)
|
||||
except Exception as e:
|
||||
logger.error(f"rot_pos_emb fused ops run fail, e:{e}")
|
||||
rotary_pos_emb = None
|
||||
else:
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
sin_cache, cos_cache = None, None
|
||||
|
||||
# tmp_rotary_pos_emb = self.transformer_rot_pos_emb(grid_thw)
|
||||
# qwen3_rotary_pos_emb = self.qwen3_rot_pos_emb(grid_thw)
|
||||
|
||||
# compute cu_seqlens
|
||||
grid_thw_ = torch.tensor(grid_thw)
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
|
||||
grid_thw_[:, 0]).cumsum(
|
||||
dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
cu_seqlens = cu_seqlens.tolist()
|
||||
max_seqlen, seqlens = None, None
|
||||
else:
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(
|
||||
x,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2VisionAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional["QuantizationConfig"] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super(Qwen2VisionAttentionOrg, self).__init__()
|
||||
# Per attention head and per partition values.
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
parallel_state.get_tensor_model_parallel_world_size())
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
|
||||
# self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
||||
# output_size=3 * projection_size,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.qkv",
|
||||
# disable_tp=use_data_parallel)
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
self.proj = RowParallelLinear(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-VL does not support {self.attn_backend} backend now.")
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
q1, k1, v1 = qkv.chunk(3, dim=-1)
|
||||
q1, k1, v1 = (x.view(*new_shape) for x in (q1, k1, v1))
|
||||
return q1, k1, v1
|
||||
|
||||
class Qwen2VisionBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
sin_cache: torch.Tensor,
|
||||
cos_cache: torch.Tensor,
|
||||
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
total_bytes = x.numel() * x.element_size() * get_tp_group().world_size
|
||||
reduce_result = get_tp_group().world_size > 1 and total_bytes < 4194304
|
||||
|
||||
# hidden_states = self.norm1(x)
|
||||
attn_outs = torch.vacc.fuse_atten_vit(
|
||||
hidden_states=x.view(-1, x.shape[-1]),
|
||||
hidden_states_norm_weight = self.norm1.weight,
|
||||
hidden_states_norm_bias = self.norm1.bias,
|
||||
# hidden_states_norm_weight = torch.Tensor(),
|
||||
# hidden_states_norm_bias = torch.Tensor(),
|
||||
qkv_proj_weight=self.attn.qkv.weight,
|
||||
qkv_proj_bias=self.attn.qkv.bias,
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
o_proj_weight=self.attn.proj.weight,
|
||||
o_proj_bias=self.attn.proj.bias if self.attn.proj.tp_rank == 0 else torch.Tensor(),
|
||||
seq_lens=cu_seqlens,
|
||||
sm_scale=-1,
|
||||
num_attention_heads=self.attn.num_attention_heads_per_partition * get_tp_group().world_size,
|
||||
flash_attention=True,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos
|
||||
)
|
||||
attn_out = attn_outs[0] if reduce_result else tensor_model_parallel_all_reduce(attn_outs[0])
|
||||
attn_out = attn_out.view(x.shape)
|
||||
|
||||
x = x + attn_out
|
||||
else:
|
||||
x = x + self.attn(
|
||||
self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
class Qwen2VisionMLP():
|
||||
def forward(self, x: torch.Tensor):
|
||||
try:
|
||||
from torch_vacc.vacc import fuse_mlp_vision
|
||||
hiddens_shape = x.shape
|
||||
tp_rank_id = get_tp_group().rank_in_group
|
||||
fc2_bias = None if tp_rank_id > 0 else self.fc2.bias
|
||||
hidden_states = fuse_mlp_vision(x.view(-1, hiddens_shape[-1]),
|
||||
self.fc1.weight, # nk
|
||||
self.fc2.weight, # nk
|
||||
self.fc1.bias,
|
||||
fc2_bias,
|
||||
2) # 0 is gelu, 1 is relu, 2 is quick_gelu
|
||||
vacc_res = tensor_model_parallel_all_reduce(hidden_states).view(hiddens_shape)
|
||||
return vacc_res
|
||||
except Exception as e:
|
||||
logger.error(f"mlp fused ops run fail, e:{e}")
|
||||
|
||||
x_parallel, _ = self.fc1(x)
|
||||
x_parallel = self.act(x_parallel)
|
||||
x, _ = self.fc2(x_parallel)
|
||||
return x
|
||||
|
||||
class Qwen2VisionPatchMerger():
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln_q(x)
|
||||
x = x.view(-1, self.hidden_size)
|
||||
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc import patch_merger_vision
|
||||
tp_rank_id = get_tp_group().rank_in_group
|
||||
fc2_bias = None if tp_rank_id > 0 else mlp_fc2.bias
|
||||
|
||||
hidden_states = patch_merger_vision(x,
|
||||
mlp_fc1.weight,
|
||||
mlp_fc2.weight,
|
||||
mlp_fc1.bias,
|
||||
fc2_bias,
|
||||
0) #0 is gelu, 1 is silu
|
||||
vacc_res = tensor_model_parallel_all_reduce(hidden_states)
|
||||
return vacc_res
|
||||
except Exception as e:
|
||||
logger.error(f"merge patch fused vision mlp run fail, cased by:{e}")
|
||||
|
||||
x_parallel, _ = mlp_fc1(x)
|
||||
x_parallel = mlp_act(x_parallel)
|
||||
out, _ = mlp_fc2(x_parallel)
|
||||
return out
|
||||
194
vllm_vacc/vllm/model_executor/models/qwen3.py
Normal file
194
vllm_vacc/vllm/model_executor/models/qwen3.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union, Any, Dict
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.logger import init_logger
|
||||
from .vars import *
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod as UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# uniform the params names from different quantize method
|
||||
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = torch.Tensor()
|
||||
fused_params[name + '_bias'] = None
|
||||
fused_params[name + '_qzeros'] = None
|
||||
elif isinstance(quant_method, Fp8LinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, GPTQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, AWQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_method: {quant_method}")
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None # new added params
|
||||
) -> torch.Tensor:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
# reshape kvcache
|
||||
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
|
||||
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
|
||||
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
is_decode = attn_metadata.prefill_metadata is None
|
||||
else:
|
||||
is_decode = attn_metadata_all.prefill_metadata is None
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
reduce_result = is_decode
|
||||
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
|
||||
# # only support 4M now
|
||||
# if total_bytes < 4194304:
|
||||
# reduce_result = True
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if is_decode:
|
||||
positions = [i - 1 for i in attn_metadata.seq_lens]
|
||||
cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
||||
sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
||||
else:
|
||||
cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
if residual is None:
|
||||
res_out = hidden_states
|
||||
#from torch_vacc.vacc import fuse_atten_qwen3
|
||||
attn_outs = None
|
||||
if not is_decode:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
if memory_recycler is not None:
|
||||
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
||||
|
||||
total_num_kv_heads = self.total_num_kv_heads
|
||||
if self.total_num_kv_heads < get_tp_group().world_size:
|
||||
assert get_tp_group().world_size % self.total_num_kv_heads == 0
|
||||
total_num_kv_heads = get_tp_group().world_size
|
||||
|
||||
attn_outs = torch.vacc.fuse_atten_qwen3(
|
||||
# attn_outs = vacc_fused_attn_qwen3_naive(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
||||
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
||||
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
||||
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
||||
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
|
||||
q_layernorm_weight=self.fused_params['q_norm_weight'],
|
||||
k_layernorm_weight=self.fused_params['k_norm_weight'],
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=attn_metadata.block_tables, # tensor
|
||||
block_group_size=env_blk_grp_size,
|
||||
o_proj_weight=self.fused_params['o_proj_weight'],
|
||||
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
||||
o_proj_bias=self.fused_params['o_proj_bias'],
|
||||
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
sm_scale=self.scaling,
|
||||
num_attention_heads=self.total_num_heads,
|
||||
num_key_value_heads=total_num_kv_heads,
|
||||
flash_attention=is_decode, # decode use flash_atten by default
|
||||
is_decode=is_decode,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos,
|
||||
output_opt=attn_outs,
|
||||
res_opt=residual)
|
||||
|
||||
if residual is None:
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
||||
else:
|
||||
res_out = attn_outs[1]
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
||||
|
||||
return attn_out, res_out
|
||||
else:
|
||||
# orig code
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if not hasattr(self.self_attn, "fused_params"):
|
||||
self.self_attn.fused_params = {}
|
||||
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
|
||||
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
|
||||
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
|
||||
|
||||
hidden_states, residual = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual)
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
790
vllm_vacc/vllm/model_executor/models/qwen3_moe.py
Normal file
790
vllm_vacc/vllm/model_executor/models/qwen3_moe.py
Normal file
@@ -0,0 +1,790 @@
|
||||
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union, List
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from torch_vacc.vacc.custom_ops_cpu import (
|
||||
w8a8_block_fp8_linear as w8a8_block_fp8_linear_cpu,
|
||||
)
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock, Qwen3MoeMLP
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
|
||||
|
||||
from ..ops.mrope_op import get_sin_cos_mrope
|
||||
from ..ops.qwen3_fused_moe import vacc_fused_prefill_moe_fp8, vacc_fused_decode_moe_fp8, recompute_moe_layer_blocksize
|
||||
from .vars import *
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# uniform the params names from different quantize method
|
||||
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = None
|
||||
fused_params[name + '_bias'] = None
|
||||
fused_params[name + '_qzeros'] = None
|
||||
if isinstance(quant_method, Fp8LinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, GPTQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, AWQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_method: {quant_method}")
|
||||
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear_v2(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_scale = None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
block_size = [
|
||||
weight.shape[-2] // weight_scale.shape[-2],
|
||||
weight.shape[-1] // weight_scale.shape[-1],
|
||||
]
|
||||
|
||||
if input.device.type == "vacc":
|
||||
output = torch.vacc.w8a8_block_fp8_linear(
|
||||
input_2d, weight, input_scale, weight_scale, block_size
|
||||
)
|
||||
else:
|
||||
output = w8a8_block_fp8_linear_cpu(
|
||||
input_2d, weight, input_scale, weight_scale, block_size
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def vacc_fused_attn_qwen3_naive(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
hidden_states_norm_weight: torch.Tensor,
|
||||
qkv_proj_weight: torch.Tensor,
|
||||
qkv_proj_weight_scale: torch.Tensor,
|
||||
qkv_proj_bias: Optional[torch.Tensor],
|
||||
qkv_proj_qzeros: Optional[torch.Tensor],
|
||||
q_layernorm_weight: torch.Tensor,
|
||||
k_layernorm_weight: torch.Tensor,
|
||||
sin_cache: List[torch.Tensor],
|
||||
cos_cache: List[torch.Tensor],
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
block_group_size: int,
|
||||
o_proj_weight: torch.Tensor,
|
||||
o_proj_weight_scale: torch.Tensor,
|
||||
o_proj_bias: Optional[torch.Tensor],
|
||||
o_proj_qzeros: Optional[torch.Tensor],
|
||||
seq_lens: List[int],
|
||||
sm_scale: float,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
flash_attention: bool,
|
||||
is_decode: bool,
|
||||
reduce_result: bool,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
group_id: int,
|
||||
dev_info: List[int] | Tuple[int],
|
||||
block_size: int = 16
|
||||
):
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual
|
||||
residual_out = hidden_states
|
||||
|
||||
hidden_states = torch.vacc.rms_norm(
|
||||
hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0)
|
||||
|
||||
# NOTE: for qwen3 and qwen2.5, head_dim is always 128
|
||||
head_dim = 128
|
||||
|
||||
# qkv gen
|
||||
qkv = apply_w8a8_block_fp8_linear_v2(
|
||||
input=hidden_states,
|
||||
weight=qkv_proj_weight,
|
||||
weight_scale=qkv_proj_weight_scale)
|
||||
|
||||
num_q_heads = num_attention_heads // world_size
|
||||
num_kv_heads = num_key_value_heads // world_size
|
||||
|
||||
q_size = head_dim * num_q_heads
|
||||
kv_size = head_dim * num_kv_heads
|
||||
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
# q_by_head = self.q_norm.forward_native(q_by_head)
|
||||
q_norm = torch.vacc.rms_norm(q_by_head, q_layernorm_weight, 1e-6)
|
||||
# q = q_by_head.view(q.shap
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
# k_by_head = k_norm.forward_native(k_by_head)
|
||||
k_norm = torch.vacc.rms_norm(k_by_head, k_layernorm_weight, 1e-6)
|
||||
# k = k_by_head.view(k.shap
|
||||
|
||||
v = v.view(-1, num_kv_heads, head_dim)
|
||||
|
||||
# q, k = self.rotary_emb(positions, q, k)
|
||||
start = 0
|
||||
attn_outs = []
|
||||
|
||||
if is_decode:
|
||||
# convert block_tables to 8K group index
|
||||
block_per_group = block_group_size // block_size
|
||||
block_tables = (block_tables // block_per_group).to(torch.int32)
|
||||
# logger.warning(f"decode block table: {block_tables}")
|
||||
|
||||
num_blocks = kv_cache.shape[1]
|
||||
key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim)
|
||||
value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim)
|
||||
|
||||
# bs loop
|
||||
for i in range(len(seq_lens)):
|
||||
if not is_decode:
|
||||
# prefill
|
||||
end = start + seq_lens[i]
|
||||
else:
|
||||
# decode
|
||||
end = start + 1
|
||||
|
||||
cos = cos_cache[i].unsqueeze(-2)
|
||||
sin = sin_cache[i].unsqueeze(-2)
|
||||
|
||||
q, k = torch.vacc.RotaryPosEmbedding(
|
||||
q_norm[start : end, ...], k_norm[start : end, ...], cos, sin, 0, "neox")
|
||||
|
||||
# cache concat
|
||||
torch.vacc.reshape_and_cache_attention(k, key_cache_split, slot_mapping[start : end, ...])
|
||||
torch.vacc.reshape_and_cache_attention(v[start : end, ...], value_cache_split, slot_mapping[start : end, ...])
|
||||
|
||||
# attn_output = self.attn(q, k, v)
|
||||
if not is_decode:
|
||||
# prefill
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v[start : end, ...],
|
||||
attn_mask = None,
|
||||
dropout_p = 0.0,
|
||||
is_causal = True, #causal_attn and not self.need_mask,
|
||||
is_train = False,
|
||||
recompute = False,
|
||||
flash_attention = False,
|
||||
sm_scale=sm_scale)
|
||||
else:
|
||||
# decode
|
||||
key_cache = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
||||
value_cache = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
||||
|
||||
k_slices = key_cache[block_tables[i], ...]
|
||||
k_cached = torch.cat(
|
||||
[k_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
|
||||
dim=0,
|
||||
)
|
||||
k_cached = k_cached.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
|
||||
|
||||
v_slices = value_cache[block_tables[i], ...]
|
||||
v_cached = torch.cat(
|
||||
[v_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
|
||||
dim=0,
|
||||
)
|
||||
v_cached = v_cached.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k_cached,
|
||||
value=v_cached,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,#flash_attention,
|
||||
sm_scale=sm_scale)
|
||||
|
||||
attn_outs.append(attn_out)
|
||||
# update start
|
||||
start = end
|
||||
attn_out = torch.cat(attn_outs, dim=0)
|
||||
|
||||
# output, _ = self.o_proj(attn_output)
|
||||
o_proj = apply_w8a8_block_fp8_linear_v2(
|
||||
input = attn_out.reshape(hidden_states.shape[0], -1),
|
||||
weight = o_proj_weight,
|
||||
weight_scale = o_proj_weight_scale,
|
||||
)
|
||||
|
||||
if reduce_result:
|
||||
o_proj = tensor_model_parallel_all_reduce(o_proj)
|
||||
|
||||
if residual is not None:
|
||||
return o_proj, residual_out
|
||||
return o_proj
|
||||
|
||||
def Qwen3MoeSparseMoeBlock__init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super(Qwen3MoeSparseMoeBlock, self).__init__()
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(num_experts=self.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
#patch here to transpose w2/w2_scale's data arrange , only for block quant
|
||||
if hasattr(self.experts.quant_method, 'quant_config') and hasattr(self.experts.quant_method.quant_config, 'weight_block_size'):
|
||||
self.experts.w2_weight.data = self.experts.w2_weight.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
self.experts.w2_weight_scale_inv.data = self.experts.w2_weight_scale_inv.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
if hasattr(self.experts, 'w2_weight_scale_inv_prefill'):
|
||||
self.experts.w2_weight_scale_inv_prefill.data = self.experts.w2_weight_scale_inv_prefill.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
|
||||
def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"],
|
||||
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]],
|
||||
positions: Union[torch.Tensor, list],
|
||||
is_decode: bool):
|
||||
if isinstance(rotary_emb, MRotaryEmbedding):
|
||||
# get mrope sin/cos
|
||||
cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions)
|
||||
if len(attn_metadata.seq_lens) > 1:
|
||||
if is_decode:
|
||||
cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens))
|
||||
sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens))
|
||||
else:
|
||||
cos_cache = torch.split(cos_cache, attn_metadata.seq_lens)
|
||||
sin_cache = torch.split(sin_cache, attn_metadata.seq_lens)
|
||||
else:
|
||||
cos_cache = [cos_cache]
|
||||
sin_cache = [sin_cache]
|
||||
else:
|
||||
if is_decode:
|
||||
positions = [i - 1 for i in attn_metadata.seq_lens]
|
||||
cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
||||
sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
||||
else:
|
||||
cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
return cos_cache, sin_cache
|
||||
|
||||
class Qwen3MoeDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
cos_cache: list[torch.Tensor],
|
||||
sin_cache: list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if not hasattr(self.self_attn, "fused_params"):
|
||||
self.self_attn.fused_params = {}
|
||||
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
|
||||
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
|
||||
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
|
||||
|
||||
hidden_states, residual = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
cos_cache=cos_cache,
|
||||
sin_cache=sin_cache)
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
cos_cache=cos_cache,
|
||||
sin_cache=sin_cache
|
||||
)
|
||||
|
||||
# # Fully Connected
|
||||
# hidden_states, residual = self.post_attention_layernorm(
|
||||
# hidden_states, residual)
|
||||
# hidden_states = self.mlp(hidden_states)
|
||||
# return hidden_states, residual
|
||||
|
||||
# TODO for noquant or not block_quant
|
||||
if not hasattr(self.mlp.experts.quant_method, 'quant_config') or \
|
||||
not hasattr(self.mlp.experts.quant_method.quant_config, 'weight_block_size'):
|
||||
if not isinstance(self.mlp.experts.quant_method, MoeWNA16Method):
|
||||
logger.warning('TODO for noquant or other quant')
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
# is_prefill = get_forward_context().attn_metadata['test'].prefill_metadata
|
||||
attn_metadata_0 = get_forward_context().attn_metadata.items().__iter__().__next__()[1]
|
||||
is_prefill = attn_metadata_0.prefill_metadata
|
||||
|
||||
else:
|
||||
is_prefill = get_forward_context().attn_metadata.prefill_metadata
|
||||
|
||||
quant_method = self.mlp.experts.quant_method if isinstance(self.mlp, Qwen3MoeSparseMoeBlock) \
|
||||
else self.mlp.down_proj.quant_method
|
||||
|
||||
if is_prefill is not None:
|
||||
if isinstance(quant_method, MoeWNA16Method):
|
||||
try:
|
||||
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_prefill_moe_gptq_int4
|
||||
return vacc_fused_prefill_moe_gptq_int4(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_prefill_moe_gptq_int4 fail: {e}')
|
||||
else:
|
||||
recompute_moe_layer_blocksize(self.mlp.experts)
|
||||
try:
|
||||
return vacc_fused_prefill_moe_fp8(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_prefill_moe_fp8 fail: {e}')
|
||||
else:
|
||||
if isinstance(quant_method, MoeWNA16Method):
|
||||
try:
|
||||
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_decode_moe_gptq_int4
|
||||
return vacc_fused_decode_moe_gptq_int4(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_decode_moe_gptq_int4 fail: {e}')
|
||||
else:
|
||||
try:
|
||||
return vacc_fused_decode_moe_fp8(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_decode_moe_fp8 fail: {e}')
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
class Qwen3MoeAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None, # new added params
|
||||
cos_cache: list[torch.Tensor] = None,
|
||||
sin_cache: list[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
# reshape kvcache
|
||||
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
|
||||
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
|
||||
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
is_decode = attn_metadata.prefill_metadata is None
|
||||
else:
|
||||
is_decode = attn_metadata_all.prefill_metadata is None
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
|
||||
reduce_result = is_decode
|
||||
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
|
||||
# # only support 4M now
|
||||
# if total_bytes < 4194304:
|
||||
# reduce_result = True
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if cos_cache is None or sin_cache is None:
|
||||
cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode)
|
||||
|
||||
if residual is None:
|
||||
res_out = hidden_states
|
||||
#from torch_vacc.vacc import fuse_atten_qwen3
|
||||
attn_outs = None
|
||||
if not is_decode:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
if memory_recycler is not None:
|
||||
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
||||
|
||||
total_num_kv_heads = self.total_num_kv_heads
|
||||
if self.total_num_kv_heads < get_tp_group().world_size:
|
||||
assert get_tp_group().world_size % self.total_num_kv_heads == 0
|
||||
total_num_kv_heads = get_tp_group().world_size
|
||||
attn_outs = torch.vacc.fuse_atten_qwen3(
|
||||
# attn_outs = vacc_fused_attn_qwen3_naive(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
||||
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
||||
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
||||
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
||||
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
|
||||
q_layernorm_weight=self.fused_params['q_norm_weight'],
|
||||
k_layernorm_weight=self.fused_params['k_norm_weight'],
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_group_size=env_blk_grp_size,
|
||||
o_proj_weight=self.fused_params['o_proj_weight'],
|
||||
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
||||
o_proj_bias=self.fused_params['o_proj_bias'],
|
||||
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
sm_scale=self.scaling,
|
||||
num_attention_heads=self.total_num_heads,
|
||||
num_key_value_heads=total_num_kv_heads,
|
||||
flash_attention=is_decode, # decode use flash_atten by default
|
||||
is_decode=is_decode,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos,
|
||||
output_opt=attn_outs,
|
||||
res_opt=residual)
|
||||
# debug_qwen3_moe_attention_prefill(hidden_states=hidden_states,
|
||||
# residual=residual,
|
||||
# attn_outs=attn_outs,
|
||||
# fused_params=self.fused_params,
|
||||
# attn_metadata=attn_metadata,
|
||||
# is_decode=is_decode,
|
||||
# sin_cache=sin_cache,
|
||||
# cos_cache=cos_cache,
|
||||
# kv_cache=kv_cache,
|
||||
# env_blk_grp_size=env_blk_grp_size,
|
||||
# scaling=self.scaling,
|
||||
# total_num_heads=self.total_num_heads,
|
||||
# total_num_kv_heads=self.total_num_kv_heads,
|
||||
# world_size=get_tp_group().world_size,
|
||||
# rank=get_tp_group().rank_in_group,
|
||||
# group_id=get_tp_group().group_id,
|
||||
# dev_info=get_tp_group().rank_device_infos)
|
||||
|
||||
if residual is None:
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
||||
else:
|
||||
res_out = attn_outs[1]
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
||||
|
||||
return attn_out, res_out
|
||||
else:
|
||||
# orig code
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_by_head = self.q_norm.forward_native(q_by_head)
|
||||
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = self.k_norm.forward_native(k_by_head)
|
||||
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
class Qwen3MoeModel(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
if not hasattr(self, "weight_capture"):
|
||||
from vllm_vacc.vllm.model_executor.models.weight_capture.qwen3_moe_weight_capture import Qwen3Moe_WeightCapture
|
||||
self.weight_capture = Qwen3Moe_WeightCapture(self.layers, self.start_layer, self.end_layer)
|
||||
self.layer_nums = self.end_layer - self.start_layer
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# fused layer decoder only support fp8 quant model now
|
||||
use_default_layer = self.weight_capture.support_fused_weights and USE_DECODER_LAYER_FUSE_MODE
|
||||
# print('Qwen3MoeModel attn_metadata', attn_metadata)
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
# is_decode = attn_metadata_all['test'].prefill_metadata is None
|
||||
# attn_metadata = attn_metadata_all['test']
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
is_decode = attn_metadata.prefill_metadata is None
|
||||
|
||||
else:
|
||||
is_decode = attn_metadata_all.prefill_metadata is None
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
if(use_default_layer and is_decode):
|
||||
from torch_vacc.vacc.custom_ops import qwen3_fuse_attention_moe_decode
|
||||
|
||||
layer0 = self.layers[self.start_layer]
|
||||
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode=True)
|
||||
|
||||
for i in range(0, self.layer_nums):
|
||||
layer = self.layers[i + self.start_layer]
|
||||
kv_cache = layer.self_attn.attn.kv_cache[forward_context.virtual_engine]
|
||||
num_kv_heads = max(1, layer.self_attn.total_num_kv_heads // get_tp_group().world_size)
|
||||
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, layer.self_attn.head_dim)
|
||||
total_num_kv_heads = layer.self_attn.total_num_kv_heads
|
||||
if layer.self_attn.total_num_kv_heads < get_tp_group().world_size:
|
||||
assert get_tp_group().world_size % layer.self_attn.total_num_kv_heads == 0
|
||||
total_num_kv_heads = get_tp_group().world_size
|
||||
|
||||
hidden_states, residual = qwen3_fuse_attention_moe_decode(hidden_states, residual,
|
||||
hidden_states_norm_weight=self.weight_capture.layer_mapper.attn_args._0_input_layernorm_weight[i],
|
||||
qkv_proj_weight=self.weight_capture.layer_mapper.attn_args._1_qkv_proj_weight[i],
|
||||
qkv_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._2_qkv_proj_weight_scale[i],
|
||||
qkv_proj_bias=self.weight_capture.layer_mapper.attn_args._3_qkv_proj_bias[i],
|
||||
qkv_proj_qzeros=self.weight_capture.layer_mapper.attn_args._4_qkv_proj_qzeros[i],
|
||||
q_layernorm_weight=self.weight_capture.layer_mapper.attn_args._5_q_norm_weight[i],
|
||||
k_layernorm_weight=self.weight_capture.layer_mapper.attn_args._6_k_norm_weight[i],
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_group_size=env_blk_grp_size,
|
||||
o_proj_weight=self.weight_capture.layer_mapper.attn_args._13_o_proj_weight[i],
|
||||
o_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._14_o_proj_weight_scale[i],
|
||||
o_proj_bias=self.weight_capture.layer_mapper.attn_args._15_o_proj_bias[i],
|
||||
o_proj_qzeros=self.weight_capture.layer_mapper.attn_args._16_o_proj_qzeros[i],
|
||||
seq_lens_num=attn_metadata.seq_lens,
|
||||
sm_scale=layer.self_attn.scaling,
|
||||
num_attention_heads=layer.self_attn.total_num_heads,
|
||||
num_key_value_heads=total_num_kv_heads,
|
||||
flash_attentiton=True,
|
||||
is_decode=True,
|
||||
reduce_result=True,
|
||||
# moe
|
||||
rms_weight=self.weight_capture.layer_mapper.moe_args._0_rms_norm_weight[i],
|
||||
moe_weight_13=self.weight_capture.layer_mapper.moe_args._1_w13_weight[i],
|
||||
moe_weight_2=self.weight_capture.layer_mapper.moe_args._2_w2_weight[i],
|
||||
moe_weight_13_dequat=self.weight_capture.layer_mapper.moe_args._3_w13_weight_scale_inv[i],
|
||||
moe_weight_2_dequant=self.weight_capture.layer_mapper.moe_args._4_w2_weight_scale_inv[i],
|
||||
gate_weight=self.weight_capture.layer_mapper.moe_args._5_gate_weight[i],
|
||||
block_size_13=self.weight_capture.layer_mapper.moe_args._6_w13_block_size,
|
||||
block_size_2=self.weight_capture.layer_mapper.moe_args._7_w2_block_size,
|
||||
# dist
|
||||
world_size=self.weight_capture.layer_mapper.dist_args._0_world_size,
|
||||
rank=self.weight_capture.layer_mapper.dist_args._1_rank,
|
||||
group_id=self.weight_capture.layer_mapper.dist_args._2_group_id,
|
||||
dev_info=self.weight_capture.layer_mapper.dist_args._3_dev_info)
|
||||
else:
|
||||
layer0 = self.layers[self.start_layer]
|
||||
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode)
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual, cos_cache, sin_cache )
|
||||
if deepstack_input_embeds is not None and i in range(0, len(deepstack_input_embeds)):
|
||||
if isinstance(deepstack_input_embeds, IntermediateTensors):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{i}"]
|
||||
elif isinstance(deepstack_input_embeds, torch.Tensor):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[i]
|
||||
else:
|
||||
raise ValueError(f'unsupported type: {type(deepstack_input_embeds)}')
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
if residual is not None:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
class Qwen3MoeForCausalLM(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||||
if attn_metadata.prefill_metadata is not None:
|
||||
from .memory.memory_recycling import alloc_memory_recycler
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
if hasattr(attn_metadata, 'num_prefill_tokens'):
|
||||
tokens = attn_metadata.num_prefill_tokens
|
||||
else:
|
||||
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
|
||||
|
||||
vllm_model_mode = "qwen3_moe"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
vllm_model_mode = config_infos
|
||||
|
||||
if get_tp_group().rank_in_group == 0:
|
||||
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
|
||||
logger.info(memory_infos)
|
||||
|
||||
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size, dtype=self.lm_head.weight.dtype):
|
||||
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, deepstack_input_embeds)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
from .memory.memory_recycling import init_huge_memory_allocator
|
||||
from .vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
|
||||
# default is deepseek, config can set to ['deepseek_mtp',]
|
||||
model_name = "qwen3_moe"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
model_name = config_infos
|
||||
|
||||
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
|
||||
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
|
||||
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
362
vllm_vacc/vllm/model_executor/models/qwen3_vl.py
Normal file
362
vllm_vacc/vllm/model_executor/models/qwen3_vl.py
Normal file
@@ -0,0 +1,362 @@
|
||||
|
||||
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from vllm.logger import init_logger
|
||||
from .hf_processor.qwenvl_processor import Qwen3VLProcessorWithVacc
|
||||
from .hf_processor.qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc
|
||||
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
|
||||
|
||||
from .vars import USE_FUSED_QWEN_ATTENTION
|
||||
|
||||
# from vacc_tools.trace_logger import get_trace_api
|
||||
# trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
|
||||
# get_trace_api("Qwen3vl")
|
||||
# )
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self.proj, 'bias') and self.proj.bias is not None:
|
||||
return torch.nn.functional.linear(x, self.proj.weight.view(self.hidden_size, -1), self.proj.bias)
|
||||
return torch.matmul(x, self.proj.weight.view(self.hidden_size, -1).T)
|
||||
|
||||
class Qwen3_VisionBlock(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor | list[torch.Tensor],
|
||||
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
assert isinstance(rotary_pos_emb, list), "qwen3vl vit-attention need rotary_pos_emb is list[torch.Tensor]"
|
||||
|
||||
total_bytes = x.numel() * x.element_size() * get_tp_group().world_size
|
||||
reduce_result = get_tp_group().world_size > 1 and total_bytes < 4194304
|
||||
|
||||
# hidden_states = self.norm1(x)
|
||||
attn_outs = torch.vacc.fuse_atten_vit(
|
||||
hidden_states=x.view(-1, x.shape[-1]),
|
||||
hidden_states_norm_weight = self.norm1.weight,
|
||||
hidden_states_norm_bias = self.norm1.bias,
|
||||
# hidden_states_norm_weight = torch.Tensor(),
|
||||
# hidden_states_norm_bias = torch.Tensor(),
|
||||
qkv_proj_weight=self.attn.qkv.weight,
|
||||
qkv_proj_bias=self.attn.qkv.bias,
|
||||
sin_cache=rotary_pos_emb[0],
|
||||
cos_cache=rotary_pos_emb[1],
|
||||
o_proj_weight=self.attn.proj.weight,
|
||||
o_proj_bias=self.attn.proj.bias if self.attn.proj.tp_rank == 0 else torch.Tensor(),
|
||||
seq_lens=cu_seqlens,
|
||||
sm_scale=-1,
|
||||
num_attention_heads=self.attn.num_attention_heads_per_partition * get_tp_group().world_size,
|
||||
flash_attention=True,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos
|
||||
)
|
||||
attn_out = attn_outs[0] if reduce_result else tensor_model_parallel_all_reduce(attn_outs[0])
|
||||
attn_out = attn_out.view(x.shape)
|
||||
|
||||
x = x + attn_out
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
class Qwen3_VisionTransformer(nn.Module):
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_qwen3_ops import rot_pos_emb_qwenvl
|
||||
return rot_pos_emb_qwenvl(grid_thw, self.hidden_size, self.num_heads, self.spatial_merge_size, self.dtype, self.device)
|
||||
except Exception as e:
|
||||
logger.error(f"rot_pos_emb fused ops run fail, e:{e}")
|
||||
|
||||
pos_ids = []
|
||||
# Support both Tensor and list inputs for DP path
|
||||
if isinstance(grid_thw, list):
|
||||
grid_list = grid_thw
|
||||
max_grid_size = max(max(h, w) for _, h, w in grid_list)
|
||||
else:
|
||||
grid_list = grid_thw.tolist()
|
||||
max_grid_size = int(grid_thw[:, 1:].max().item())
|
||||
for t, h, w in grid_list:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||
hpos_ids = hpos_ids.flatten()
|
||||
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
pos_ids.append(
|
||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def fast_pos_embed_interpolate(self,
|
||||
grid_thw: list[list[int]]) -> torch.Tensor:
|
||||
num_grid_per_side = self.num_grid_per_side
|
||||
m_size = self.spatial_merge_size
|
||||
hidden_dim = self.pos_embed.embedding_dim
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_qwen3_ops import fast_pos_embed_interpolate_qwenvl
|
||||
return fast_pos_embed_interpolate_qwenvl(self.pos_embed.weight, grid_thw, num_grid_per_side, m_size, hidden_dim)
|
||||
except Exception as e:
|
||||
logger.error(f"fast_pos_embed_interpolate fused ops run fail, e:{e}")
|
||||
|
||||
outputs = []
|
||||
for t, h, w in grid_thw:
|
||||
h_idxs = torch.linspace(0,
|
||||
num_grid_per_side - 1,
|
||||
h,
|
||||
dtype=torch.float32,
|
||||
device=self.device)
|
||||
w_idxs = torch.linspace(0,
|
||||
num_grid_per_side - 1,
|
||||
w,
|
||||
dtype=torch.float32,
|
||||
device=self.device)
|
||||
|
||||
h_floor = h_idxs.to(torch.long)
|
||||
w_floor = w_idxs.to(torch.long)
|
||||
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
|
||||
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
|
||||
|
||||
dh = h_idxs - h_floor
|
||||
dw = w_idxs - w_floor
|
||||
|
||||
# Create meshgrid view for all h, w vars
|
||||
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
|
||||
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor,
|
||||
w_floor,
|
||||
indexing='ij')
|
||||
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil,
|
||||
w_ceil,
|
||||
indexing='ij')
|
||||
h_floor_grid_idx = h_floor_grid * num_grid_per_side
|
||||
h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
|
||||
|
||||
# original computation of weights
|
||||
# w00 = (1 - dh_grid) * (1 - dw_grid)
|
||||
# w01 = (1 - dh_grid) * dw_grid
|
||||
# w10 = dh_grid * (1 - dw_grid)
|
||||
# w11 = dh_grid * dw_grid
|
||||
# we reuse w11 here to avoid duplicate
|
||||
# dh_grid * dw_grid computation
|
||||
w11 = dh_grid * dw_grid
|
||||
w10 = dh_grid - w11
|
||||
w01 = dw_grid - w11
|
||||
w00 = 1 - dh_grid - dw_grid + w11
|
||||
|
||||
idx00 = h_floor_grid_idx + w_floor_grid
|
||||
idx01 = h_floor_grid_idx + w_ceil_grid
|
||||
idx10 = h_ceil_grid_idx + w_floor_grid
|
||||
idx11 = h_ceil_grid_idx + w_ceil_grid
|
||||
|
||||
indices = torch.stack([idx00, idx01, idx10, idx11],
|
||||
dim=0).reshape(4, -1)
|
||||
weights = torch.stack([w00, w01, w10, w11],
|
||||
dim=0).reshape(4, -1, 1)
|
||||
weights = weights.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
embeds = self.pos_embed(indices)
|
||||
weighted_embeds = embeds * weights
|
||||
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
|
||||
combined = p0 + p1 + p2 + p3
|
||||
|
||||
combined = combined.view(h * w, hidden_dim)
|
||||
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
|
||||
repeated = repeated.view(t, h // m_size, m_size, w // m_size,
|
||||
m_size, hidden_dim)
|
||||
repeated = repeated.permute(0, 1, 3, 2, 4,
|
||||
5).reshape(-1, hidden_dim)
|
||||
outputs.append(repeated)
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
grid_thw_tensor = torch.tensor(grid_thw,
|
||||
dtype=torch.int32)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
|
||||
grid_thw_tensor[:, 0]).cumsum(
|
||||
dim=0,
|
||||
dtype=grid_thw_tensor.dtype
|
||||
if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
|
||||
if isinstance(rotary_pos_emb, torch.Tensor):
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
max_seqlen, seqlens = None, None
|
||||
cu_seqlens = cu_seqlens.tolist()
|
||||
else:
|
||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(
|
||||
layer_num)
|
||||
deepstack_feature = self.deepstack_merger_list[
|
||||
deepstack_merger_idx](hidden_states)
|
||||
deepstack_feature_lists.append(deepstack_feature)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
hidden_states = torch.cat(
|
||||
[hidden_states] + deepstack_feature_lists,
|
||||
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
|
||||
return hidden_states
|
||||
|
||||
class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP):
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
deepstack_input_embeds = None
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
if self.use_deepstack:
|
||||
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
|
||||
input_ids, inputs_embeds, multimodal_embeddings)
|
||||
self._set_deepstack_input_embeds(deepstack_input_embeds)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[self.config.image_token_id, self.config.video_token_id])
|
||||
|
||||
# commit here to remove deepstack_input_embeds copy
|
||||
# if self.use_deepstack:
|
||||
# if deepstack_input_embeds is None:
|
||||
# deepstack_input_embeds = torch.zeros_like(
|
||||
# inputs_embeds).unsqueeze(0).repeat(
|
||||
# self.deepstack_num_level, 1, 1).contiguous()
|
||||
# self._set_deepstack_input_embeds(deepstack_input_embeds)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
|
||||
return #patch here to optimize deepstack_input_embeds
|
||||
# clear deepstack_input_embeds in buffer
|
||||
if num_tokens > 0:
|
||||
for idx in range(self.deepstack_num_level):
|
||||
self.deepstack_input_embeds[idx][:num_tokens].zero_()
|
||||
class Qwen3VLProcessingInfo():
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessorWithVacc:
|
||||
processor = self.ctx.get_hf_processor(
|
||||
Qwen3VLProcessorWithVacc,
|
||||
use_fast=kwargs.pop("use_fast", True),
|
||||
**kwargs,
|
||||
)
|
||||
return processor
|
||||
|
||||
|
||||
def get_image_processor(self,
|
||||
**kwargs: object) -> Qwen2VLImageProcessorFastWithVacc:
|
||||
return self.get_hf_processor(**kwargs).image_processor
|
||||
|
||||
# def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
|
||||
# return self.get_hf_processor(**kwargs).video_processor
|
||||
|
||||
|
||||
class Qwen3_VisionPatchMerger():
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_postshuffle_norm:
|
||||
x = self.norm(x.view(-1, self.hidden_size))
|
||||
else:
|
||||
x = self.norm(x).view(-1, self.hidden_size)
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc import patch_merger_vision
|
||||
tp_rank_id = get_tp_group().rank_in_group
|
||||
fc2_bias = None if tp_rank_id > 0 else self.linear_fc2.bias
|
||||
|
||||
hidden_states = patch_merger_vision(x,
|
||||
self.linear_fc1.weight, self.linear_fc2.weight,
|
||||
self.linear_fc1.bias, fc2_bias,
|
||||
0) #0 is gelu, 1 is silu
|
||||
return tensor_model_parallel_all_reduce(hidden_states)
|
||||
except Exception as e:
|
||||
logger.error(f"merge patch fused vision mlp run fail, cased by:{e}")
|
||||
|
||||
x_parallel, _ = self.linear_fc1(x)
|
||||
x_parallel = self.act_fn(x_parallel)
|
||||
out, _ = self.linear_fc2(x_parallel)
|
||||
return out
|
||||
|
||||
class Qwen3_VisionMLP():
|
||||
def forward(self, x: torch.Tensor):
|
||||
try:
|
||||
from torch_vacc.vacc import fuse_mlp_vision
|
||||
hiddens_shape = x.shape
|
||||
tp_rank_id = get_tp_group().rank_in_group
|
||||
fc2_bias = None if tp_rank_id > 0 else self.linear_fc2.bias
|
||||
|
||||
hidden_states = fuse_mlp_vision(x.view(-1, hiddens_shape[-1]),
|
||||
self.linear_fc1.weight, self.linear_fc2.weight,
|
||||
self.linear_fc1.bias, fc2_bias,
|
||||
0) #0 is gelu, 1 is silu
|
||||
return tensor_model_parallel_all_reduce(hidden_states).view(hiddens_shape)
|
||||
except Exception as e:
|
||||
logger.error(f"qwen3vl fused vision mlp run fail, cased by:{e}")
|
||||
return self.linear_fc2(self.act_fn(self.linear_fc1(x)))
|
||||
|
||||
27
vllm_vacc/vllm/model_executor/models/roberta.py
Normal file
27
vllm_vacc/vllm/model_executor/models/roberta.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.models.bert import _decode_token_type_ids
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
token_type_ids = _decode_token_type_ids(input_ids)
|
||||
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
# position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
# token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
# embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
# embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = torch.vacc.fuse_bge_embedding_stage1(inputs_embeds, position_ids, self.position_embeddings.weight, token_type_ids, self.token_type_embeddings.weight, self.LayerNorm.weight, self.LayerNorm.bias, self.LayerNorm.eps)
|
||||
return embeddings
|
||||
58
vllm_vacc/vllm/model_executor/models/vars.py
Normal file
58
vllm_vacc/vllm/model_executor/models/vars.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
|
||||
# support Q,KV Gen with TP
|
||||
USE_PARALLEL_Q_KV_GEN = True
|
||||
|
||||
# support Merge Q,KV Gen, Q,QR weights Merge
|
||||
USE_MERGE_Q_KV_GEN_AND_Q_QR = True
|
||||
|
||||
# Support FP8 Weights for WQ,QR
|
||||
W_Q_W_QR_WUV_WUK_USE_FP8 = True
|
||||
|
||||
# fused prefill
|
||||
USE_FUSED_PREFILL = True
|
||||
|
||||
# fused prefill stage1
|
||||
USE_FUSED_PREFILL_STAGE1 = True
|
||||
|
||||
# All Request Seq Lens
|
||||
DO_SEQ_LENS = 0
|
||||
def update_seqence_length(seq_num):
|
||||
global DO_SEQ_LENS
|
||||
DO_SEQ_LENS = seq_num
|
||||
|
||||
USE_DS3_SAMPLER = int(os.getenv("USE_DS3_SAMPLER", 1))
|
||||
USE_DS3_SAMPLER_OP = int(os.getenv("USE_DS3_SAMPLER_OP", 1))
|
||||
|
||||
# cut prefill seq len
|
||||
CUT_PREFILL_SEQ_LEN = int(os.getenv("CUT_PREFILL_SEQ_LEN", -1))
|
||||
|
||||
# llm max prefill seq len
|
||||
LLM_MAX_PREFILL_SEQ_LEN = int(os.getenv("LLM_MAX_PREFILL_SEQ_LEN", 56 * 1024))
|
||||
|
||||
# All Fused Decode, default is cpu loop
|
||||
USE_DECODER_LAYER_FUSE_MODE = int(os.getenv("USE_DECODER_LAYER_FUSE_MODE", 1))
|
||||
|
||||
# Fused all layers, use cmcu loop
|
||||
FUSE_ALL_DECODER_LAYERS = int(os.getenv("FUSE_ALL_DECODER_LAYERS", 1))
|
||||
|
||||
# where to use flash attention (default: 1)
|
||||
USE_FLASH_ATTENTION = int(os.getenv("USE_FLASH_ATTENTION", 1))
|
||||
|
||||
# transpose gptq weight KN => NK
|
||||
TRANSPOSE_GPTQ_WEIGHT = True
|
||||
|
||||
# qwen fused attention
|
||||
USE_FUSED_QWEN_ATTENTION = int(os.getenv("USE_FUSED_QWEN_ATTENTION", 1))
|
||||
|
||||
# support MTP eh_proj with TP
|
||||
USE_PARALLEL_MTP_EH_PROJ = int(os.getenv("USE_PARALLEL_MTP_EH_PROJ", 1))
|
||||
|
||||
# kv_cache group size
|
||||
BLOCK_GROUP_SIZE = int(os.getenv("BLOCK_GROUP_SIZE", 8192))
|
||||
|
||||
# bert fused attention
|
||||
USE_FUSED_BERT_ATTENTION = int(os.getenv("USE_FUSED_BERT_ATTENTION", 1))
|
||||
|
||||
# fused mlp vision
|
||||
USE_FUSED_MLP_VISION = int(os.getenv("USE_FUSED_MLP_VISION", 1))
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,491 @@
|
||||
import torch
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
|
||||
DeepseekV2MLAAttention,
|
||||
DeepseekV2MLP,
|
||||
DeepseekV2MoE)
|
||||
|
||||
from ..vars import *
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
OUTPUT_ARGS_LOGS = False
|
||||
|
||||
class DistributedArgs():
|
||||
def __init__(self):
|
||||
self._0_world_size = 32
|
||||
self._1_rank = -1
|
||||
self._2_group_id = 0
|
||||
self._3_dev_info = []
|
||||
|
||||
def logs(self):
|
||||
print("dist self._0_world_size = " , self._0_world_size)
|
||||
print("dist self._1_rank = " , self._1_rank)
|
||||
print("dist self._2_group_id = " , self._2_group_id)
|
||||
print("dist self._3_dev_info = " , self._3_dev_info)
|
||||
|
||||
class AttenArgs():
|
||||
def __init__(self):
|
||||
self._a_hidden_states_norm_weight = []
|
||||
self._0_merge_q_kv_weights = [] # 融合Q,KV Weights
|
||||
self._1_merge_q_kv_scale_inv = [] # 融合Q,KV Scales
|
||||
self._2_q_a_layernorm_weight = []
|
||||
self._3_W_Q = []
|
||||
self._4_W_Q_scales = []
|
||||
self._5_W_UK = []
|
||||
self._6_W_UK_scales = []
|
||||
self._7_W_QR = []
|
||||
self._8_W_QR_scales = []
|
||||
self._9_kv_a_layernorm_weight = []
|
||||
self._10_sin_cache = None
|
||||
self._11_cos_cache = None
|
||||
self._12_slot_mapping = None
|
||||
self._13_kv_cache = None
|
||||
self._14_block_tables = None
|
||||
self._15_env_blk_grp_size = env_blk_grp_size
|
||||
self._16_W_UV = []
|
||||
self._17_W_UV_scales = []
|
||||
self._18_o_proj_weight =[]
|
||||
self._19_o_proj_weight_scale_inv = []
|
||||
# mla params
|
||||
self._20_seq_lens = []
|
||||
self._21_sm_scale = 0.0
|
||||
self._22_head_num = 128
|
||||
|
||||
def logs(self):
|
||||
print("mla _20_seq_lens block size is:", self._20_seq_lens)
|
||||
print("mla _21_sm_scale block size is:", self._21_sm_scale)
|
||||
print("mla _22_head_num block size is:", self._22_head_num)
|
||||
|
||||
|
||||
class MlpArgs():
|
||||
def __init__(self):
|
||||
#mlp params
|
||||
self._0_mlp_rms_weight = []
|
||||
self._1_mlp_w13 = []
|
||||
self._2_mlp_w2 = []
|
||||
self._3_mlp_w13_scale = []
|
||||
self._4_mlp_w2_scale = []
|
||||
self._5_mlp_w13_block_size = []
|
||||
self._6_mlp_w2_block_size = []
|
||||
|
||||
def logs(self):
|
||||
print("mlp _5_mlp_w13_block_size block size is:", self._5_mlp_w13_block_size)
|
||||
print("mlp _6_mlp_w2_block_size block size is:", self._6_mlp_w2_block_size)
|
||||
|
||||
class MoeArgs():
|
||||
def __init__(self):
|
||||
#moe params
|
||||
self._0_moe_rms_weight = []
|
||||
self._1_moe_share_mlp_w13 = []
|
||||
self._2_moe_share_mlp_w2 = []
|
||||
self._3_moe_share_mlp_w13_scale = []
|
||||
self._4_moe_share_mlp_w2_scale = []
|
||||
self._5_moe_w13 = []
|
||||
self._6_moe_w2 = []
|
||||
self._7_moe_w13_scale = []
|
||||
self._8_moe_w2_scale = []
|
||||
self._9_gate_weight = []
|
||||
self._10_moe_bias = []
|
||||
self._11_moe_mlp_w13_block_size = []
|
||||
self._12_moe_mlp_w2_block_size = []
|
||||
self._13_moe_w13_block_size = []
|
||||
self._14_moe_w2_block_size = []
|
||||
|
||||
def logs(self):
|
||||
print("moe _11_moe_mlp_w13_block_size block size is:", self._11_moe_mlp_w13_block_size)
|
||||
print("moe _12_moe_mlp_w2_block_size block size is:", self._12_moe_mlp_w2_block_size)
|
||||
print("moe _13_moe_w13_block_size block size is:", self._13_moe_w13_block_size)
|
||||
print("moe _14_moe_w2_block_size block size is:", self._14_moe_w2_block_size)
|
||||
|
||||
class WeightMapper():
|
||||
def __init__(self):
|
||||
self.attn_args = AttenArgs()
|
||||
self.mlp_args = MlpArgs() # 3 mla+mlp
|
||||
self.moe_args = MoeArgs() # 58 mla+moe
|
||||
self.dist_args = DistributedArgs()
|
||||
|
||||
# 1. weights载入
|
||||
# 2. dequant blocks 预计算
|
||||
# 3. 参数缓存&提取
|
||||
class DeepseekWeightCapture():
|
||||
def __init__(self, layer: torch.nn.ModuleList,
|
||||
start: int,
|
||||
end: int):
|
||||
|
||||
self.layer_mlp = WeightMapper()
|
||||
self.layer_moe = WeightMapper()
|
||||
|
||||
self.sin_cache_all = None
|
||||
self.cos_cache_all = None
|
||||
|
||||
self.mlp_nums = 3
|
||||
self.moe_nums = end - self.mlp_nums
|
||||
|
||||
self.start_idx = start
|
||||
self.end_idx = end
|
||||
for i in range(start, end):
|
||||
if i < self.mlp_nums:
|
||||
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_mlp.attn_args)
|
||||
self.capture_deepseek_mlp_weights(layer[i])
|
||||
else:
|
||||
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_moe.attn_args)
|
||||
self.capture_deepseek_moe_weights(layer[i])
|
||||
|
||||
if OUTPUT_ARGS_LOGS:
|
||||
self.layer_mlp.attn_args.logs()
|
||||
self.layer_mlp.mlp_args.logs()
|
||||
|
||||
self.layer_moe.attn_args.logs()
|
||||
self.layer_moe.moe_args.logs()
|
||||
|
||||
from vllm.distributed import get_tp_group
|
||||
tp_group = get_tp_group()
|
||||
self.layer_mlp.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_mlp.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_mlp.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_mlp.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
self.layer_moe.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_moe.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
|
||||
def capture_deepseek_mlp_weights(self, module: DeepseekV2DecoderLayer):
|
||||
assert isinstance(module.mlp, DeepseekV2MLP)
|
||||
|
||||
mlp = module.mlp
|
||||
rms_norm = module.post_attention_layernorm
|
||||
|
||||
w13_weight = mlp.gate_up_proj.weight
|
||||
w2_weight = mlp.down_proj.weight
|
||||
|
||||
w13_wscale = mlp.gate_up_proj.weight_scale_inv
|
||||
w13_iscale = mlp.gate_up_proj.input_scale
|
||||
|
||||
w2_wscale = mlp.down_proj.weight_scale_inv
|
||||
w2_iscale = mlp.down_proj.input_scale
|
||||
|
||||
|
||||
w13_block_size0, w13_block_size1 = mlp.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = mlp.gate_up_proj.quant_method.scale_n, mlp.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
|
||||
w2_block_size0, w2_block_size1 = mlp.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = mlp.down_proj.quant_method.scale_n, mlp.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
|
||||
|
||||
self.layer_mlp.mlp_args._0_mlp_rms_weight.append(rms_norm.weight)
|
||||
self.layer_mlp.mlp_args._1_mlp_w13.append(w13_weight)
|
||||
self.layer_mlp.mlp_args._2_mlp_w2.append(w2_weight)
|
||||
self.layer_mlp.mlp_args._3_mlp_w13_scale.append(w13_wscale)
|
||||
self.layer_mlp.mlp_args._4_mlp_w2_scale.append(w2_wscale)
|
||||
self.layer_mlp.mlp_args._5_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
self.layer_mlp.mlp_args._6_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
|
||||
|
||||
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
|
||||
assert isinstance(module.mlp, DeepseekV2MoE)
|
||||
|
||||
share_expert_layer = module.mlp.shared_experts
|
||||
experts_layer = module.mlp.experts
|
||||
rms_norm = module.post_attention_layernorm
|
||||
gate = module.mlp.gate
|
||||
|
||||
w13_weight = share_expert_layer.gate_up_proj.weight
|
||||
w2_weight = share_expert_layer.down_proj.weight
|
||||
|
||||
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
|
||||
# w13_iscale = share_expert_layer.gate_up_proj.input_scale
|
||||
|
||||
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
|
||||
# w2_iscale = share_expert_layer.down_proj.input_scale
|
||||
|
||||
|
||||
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
|
||||
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
|
||||
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
|
||||
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
|
||||
block_size0, block_size1 = (
|
||||
hidden_dims // hidden_blocks,
|
||||
inter_dims // inter_blocks,
|
||||
)
|
||||
|
||||
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
|
||||
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
|
||||
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
|
||||
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
|
||||
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
|
||||
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
|
||||
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
|
||||
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
|
||||
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
|
||||
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
|
||||
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
|
||||
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
|
||||
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
|
||||
|
||||
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
|
||||
weight_mapper: AttenArgs):
|
||||
if(self.sin_cache_all is None):
|
||||
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
|
||||
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
|
||||
|
||||
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
|
||||
|
||||
fused_params = {}
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.q_a_proj.named_parameters():
|
||||
fused_params['q_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.q_a_layernorm.named_parameters():
|
||||
fused_params['q_a_layernorm_' + name] = param
|
||||
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
|
||||
fused_params['kv_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
|
||||
fused_params['kv_a_layernorm_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.o_proj.named_parameters():
|
||||
fused_params['o_proj_' + name] = param
|
||||
|
||||
import os
|
||||
self._15_env_blk_grp_size = env_blk_grp_size
|
||||
# init sin,cos cache
|
||||
|
||||
mla_params = module.self_attn.mla_attn.impl.extract_weights()
|
||||
fused_params = {**fused_params, **mla_params}
|
||||
|
||||
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
|
||||
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
|
||||
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
|
||||
weight_mapper._3_W_Q.append(fused_params['W_Q'])
|
||||
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
|
||||
weight_mapper._5_W_UK.append(fused_params['W_UK'])
|
||||
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
|
||||
weight_mapper._7_W_QR.append(fused_params['W_QR'])
|
||||
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
|
||||
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
|
||||
#weight_mapper._10_sin_cache.append(None)
|
||||
#weight_mapper._11_cos_cache.append(None)
|
||||
#weight_mapper._12_slot_mapping.append(None)
|
||||
#weight_mapper._13_kv_cache.append(None)
|
||||
#weight_mapper._14_block_tables.append(None)
|
||||
# weight_mapper._15_env_blk_grp_size.append(None)
|
||||
weight_mapper._16_W_UV.append(fused_params['W_UV'])
|
||||
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
|
||||
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
|
||||
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
|
||||
weight_mapper._20_seq_lens = None
|
||||
weight_mapper._21_sm_scale = module.self_attn.scaling
|
||||
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
|
||||
|
||||
# 可优化,在c++里面只用Tensor即可
|
||||
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
|
||||
positions = [i - 1 for i in seq_lens]
|
||||
cos_cache = [self.cos_cache_all[i] for i in positions]
|
||||
sin_cache = [self.sin_cache_all[i] for i in positions]
|
||||
|
||||
self.layer_mlp.attn_args._10_sin_cache = sin_cache
|
||||
self.layer_mlp.attn_args._11_cos_cache = cos_cache
|
||||
|
||||
self.layer_moe.attn_args._10_sin_cache = sin_cache
|
||||
self.layer_moe.attn_args._11_cos_cache = cos_cache
|
||||
|
||||
self.layer_mlp.attn_args._20_seq_lens = seq_lens
|
||||
self.layer_moe.attn_args._20_seq_lens = seq_lens
|
||||
|
||||
|
||||
self.layer_mlp.attn_args._13_kv_cache = kv_caches_dense_layer
|
||||
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
|
||||
|
||||
self.layer_mlp.attn_args._12_slot_mapping = slot_mapping
|
||||
self.layer_mlp.attn_args._14_block_tables = block_tables
|
||||
|
||||
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
|
||||
self.layer_moe.attn_args._14_block_tables = block_tables
|
||||
|
||||
# for i in range(self.mlp_nums):
|
||||
# if i < self.end_idx:
|
||||
# self.layer_mlp.attn_args._12_slot_mapping[i] = slot_mapping
|
||||
# self.layer_mlp.attn_args._14_block_tables[i] = block_tables
|
||||
|
||||
# for i in range(self.moe_nums):
|
||||
# if i < self.end_idx:
|
||||
# self.layer_moe.attn_args._12_slot_mapping[i] = slot_mapping
|
||||
# self.layer_moe.attn_args._14_block_tables[i] = block_tables
|
||||
|
||||
def logs(self):
|
||||
print("current layer mlp attn: \n")
|
||||
self.layer_mlp.attn_args.logs()
|
||||
self.layer_mlp.dist_args.logs()
|
||||
|
||||
class DeepseekMTPWegitCapture():
|
||||
# 相比DeepSeek Weight Capture, MTP只有1层DeepseekDecoderLayer, 且是MOE的layer
|
||||
def __init__(self, layer: torch.nn.Module):
|
||||
|
||||
self.layer_moe = WeightMapper()
|
||||
|
||||
self.sin_cache_all = None
|
||||
self.cos_cache_all = None
|
||||
|
||||
self.capture_deepseek_mla_attn_weights(layer, self.layer_moe.attn_args)
|
||||
self.capture_deepseek_moe_weights(layer)
|
||||
|
||||
if OUTPUT_ARGS_LOGS:
|
||||
self.layer_moe.attn_args.logs()
|
||||
self.layer_moe.moe_args.logs()
|
||||
|
||||
from vllm.distributed import get_tp_group
|
||||
tp_group = get_tp_group()
|
||||
|
||||
self.layer_moe.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_moe.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
|
||||
assert isinstance(module.mlp, DeepseekV2MoE)
|
||||
|
||||
share_expert_layer = module.mlp.shared_experts
|
||||
experts_layer = module.mlp.experts
|
||||
rms_norm = module.post_attention_layernorm
|
||||
gate = module.mlp.gate
|
||||
|
||||
w13_weight = share_expert_layer.gate_up_proj.weight
|
||||
w2_weight = share_expert_layer.down_proj.weight
|
||||
|
||||
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
|
||||
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
|
||||
|
||||
|
||||
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
|
||||
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
|
||||
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
|
||||
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
|
||||
block_size0, block_size1 = (
|
||||
hidden_dims // hidden_blocks,
|
||||
inter_dims // inter_blocks,
|
||||
)
|
||||
|
||||
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
|
||||
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
|
||||
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
|
||||
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
|
||||
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
|
||||
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
|
||||
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
|
||||
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
|
||||
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
|
||||
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
|
||||
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
|
||||
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
|
||||
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
|
||||
|
||||
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
|
||||
weight_mapper: AttenArgs):
|
||||
if(self.sin_cache_all is None):
|
||||
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
|
||||
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
|
||||
|
||||
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
|
||||
|
||||
fused_params = {}
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.q_a_proj.named_parameters():
|
||||
fused_params['q_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.q_a_layernorm.named_parameters():
|
||||
fused_params['q_a_layernorm_' + name] = param
|
||||
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
|
||||
fused_params['kv_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
|
||||
fused_params['kv_a_layernorm_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.o_proj.named_parameters():
|
||||
fused_params['o_proj_' + name] = param
|
||||
|
||||
import os
|
||||
self._15_env_blk_grp_size = env_blk_grp_size
|
||||
# init sin,cos cache
|
||||
|
||||
mla_params = module.self_attn.mla_attn.impl.extract_weights()
|
||||
fused_params = {**fused_params, **mla_params}
|
||||
|
||||
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
|
||||
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
|
||||
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
|
||||
weight_mapper._3_W_Q.append(fused_params['W_Q'])
|
||||
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
|
||||
weight_mapper._5_W_UK.append(fused_params['W_UK'])
|
||||
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
|
||||
weight_mapper._7_W_QR.append(fused_params['W_QR'])
|
||||
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
|
||||
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
|
||||
#weight_mapper._10_sin_cache.append(None)
|
||||
#weight_mapper._11_cos_cache.append(None)
|
||||
#weight_mapper._12_slot_mapping.append(None)
|
||||
#weight_mapper._13_kv_cache.append(None)
|
||||
#weight_mapper._14_block_tables.append(None)
|
||||
# weight_mapper._15_env_blk_grp_size.append(None)
|
||||
weight_mapper._16_W_UV.append(fused_params['W_UV'])
|
||||
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
|
||||
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
|
||||
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
|
||||
weight_mapper._20_seq_lens = None
|
||||
weight_mapper._21_sm_scale = module.self_attn.scaling
|
||||
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
|
||||
|
||||
# 可优化,在c++里面只用Tensor即可
|
||||
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
|
||||
positions = [i - 1 for i in seq_lens]
|
||||
cos_cache = [self.cos_cache_all[i] for i in positions]
|
||||
sin_cache = [self.sin_cache_all[i] for i in positions]
|
||||
|
||||
self.layer_moe.attn_args._10_sin_cache = sin_cache
|
||||
self.layer_moe.attn_args._11_cos_cache = cos_cache
|
||||
|
||||
self.layer_moe.attn_args._20_seq_lens = seq_lens
|
||||
|
||||
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
|
||||
|
||||
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
|
||||
self.layer_moe.attn_args._14_block_tables = block_tables
|
||||
|
||||
def logs(self):
|
||||
print("current layer mlp attn: \n")
|
||||
self.layer_mlp.attn_args.logs()
|
||||
self.layer_mlp.dist_args.logs()
|
||||
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
|
||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeMLP)
|
||||
|
||||
class Qwen3Moe_DistributedArgs():
|
||||
def __init__(self):
|
||||
self._0_world_size = 32
|
||||
self._1_rank = -1
|
||||
self._2_group_id = 0
|
||||
self._3_dev_info = []
|
||||
|
||||
def __repr__(self):
|
||||
dist_infos = f"[dist] world_size = {self._0_world_size} \n" \
|
||||
+ f"[dist] rank = {self._1_rank} \n" \
|
||||
+ f"[dist] group_id = {self._2_group_id} \n" \
|
||||
+ f"[dist] dev_info = {self._3_dev_info}"
|
||||
return dist_infos
|
||||
|
||||
class Qwen3Moe_AttenArgs():
|
||||
def __init__(self):
|
||||
self._0_input_layernorm_weight = []
|
||||
self._1_qkv_proj_weight = [] #
|
||||
self._2_qkv_proj_weight_scale = []
|
||||
self._3_qkv_proj_bias = []
|
||||
self._4_qkv_proj_qzeros = []
|
||||
self._5_q_norm_weight = []
|
||||
self._6_k_norm_weight = []
|
||||
self._7_sin_cache = None
|
||||
self._8_cos_cache = None
|
||||
self._9_slot_mapping = None
|
||||
self._10_kv_cache = None
|
||||
self._11_block_tables = None
|
||||
self._12_block_group_size = None
|
||||
self._13_o_proj_weight = []
|
||||
self._14_o_proj_weight_scale = []
|
||||
self._15_o_proj_bias = []
|
||||
self._16_o_proj_qzeros = []
|
||||
self._17_seq_lens = None
|
||||
self._18_sm_scale =None
|
||||
self._19_num_attention_heads = None
|
||||
self._20_num_key_value_heads = None
|
||||
|
||||
def __repr__(self):
|
||||
attn_infos = "[qwen attn] 21 args" \
|
||||
+ f"[qwen attn] weight counts: {len(self._0_input_layernorm_weight)}"
|
||||
return attn_infos
|
||||
|
||||
class Qwen3Moe_MoeArgs():
|
||||
def __init__(self):
|
||||
#moe params
|
||||
self._0_rms_norm_weight = []
|
||||
self._1_w13_weight = []
|
||||
self._2_w2_weight = []
|
||||
self._3_w13_weight_scale_inv = []
|
||||
self._4_w2_weight_scale_inv = []
|
||||
self._5_gate_weight = []
|
||||
self._6_w13_block_size = None
|
||||
self._7_w2_block_size = None
|
||||
|
||||
def __repr__(self):
|
||||
moe_infos = f"[moe] w13_block_size: {self._6_w13_block_size}" \
|
||||
+ f"[moe] w2_block_size: {self._7_w2_block_size}" \
|
||||
+ f"[moe] weight counts: {len(self._1_w13_weight)}"
|
||||
return moe_infos
|
||||
|
||||
class Qwen3Moe_WeightMapper:
|
||||
def __init__(self):
|
||||
self.attn_args = Qwen3Moe_AttenArgs()
|
||||
self.moe_args = Qwen3Moe_MoeArgs()
|
||||
self.dist_args = Qwen3Moe_DistributedArgs()
|
||||
|
||||
class Qwen3Moe_WeightCapture():
|
||||
def __init__(self, layers: torch.nn.ModuleList,
|
||||
start: int,
|
||||
end: int):
|
||||
self.layer_mapper = Qwen3Moe_WeightMapper()
|
||||
# qwen3 only support fp8 now
|
||||
self.support_fused_weights = False
|
||||
for i in range(start, end):
|
||||
layer = layers[i]
|
||||
self.capture_attn_weights(layer)
|
||||
self.capture_moe_weights(layer)
|
||||
|
||||
# 注册 多卡环境信息
|
||||
from vllm.distributed import get_tp_group
|
||||
tp_group = get_tp_group()
|
||||
self.layer_mapper.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_mapper.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_mapper.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_mapper.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
def capture_attn_weights(self, layer):
|
||||
from vllm_vacc.vllm.model_executor.models.qwen3_moe import set_fused_params
|
||||
# 注册融合算子
|
||||
fused_params = {}
|
||||
fused_params['input_layernorm_weight'] = layer.input_layernorm.weight
|
||||
fused_params['q_norm_weight'] = layer.self_attn.q_norm.weight
|
||||
fused_params['k_norm_weight'] = layer.self_attn.k_norm.weight
|
||||
set_fused_params(fused_params, layer.self_attn.qkv_proj.quant_method, layer.self_attn.qkv_proj, 'qkv_proj')
|
||||
set_fused_params(fused_params, layer.self_attn.o_proj.quant_method, layer.self_attn.o_proj, 'o_proj')
|
||||
|
||||
self.support_fused_weights = hasattr(layer.mlp.experts.quant_method, 'quant_config') and hasattr(layer.mlp.experts.quant_method.quant_config, 'weight_block_size')
|
||||
if not hasattr(layer.self_attn, "fused_params"):
|
||||
layer.self_attn.fused_params = fused_params
|
||||
|
||||
self.layer_mapper.attn_args._0_input_layernorm_weight.append(fused_params['input_layernorm_weight'])
|
||||
self.layer_mapper.attn_args._1_qkv_proj_weight.append(fused_params['qkv_proj_weight'])
|
||||
self.layer_mapper.attn_args._2_qkv_proj_weight_scale.append(fused_params['qkv_proj_weight_scale'])
|
||||
self.layer_mapper.attn_args._3_qkv_proj_bias.append(fused_params['qkv_proj_bias'])
|
||||
self.layer_mapper.attn_args._4_qkv_proj_qzeros.append(fused_params['qkv_proj_qzeros'])
|
||||
self.layer_mapper.attn_args._5_q_norm_weight.append(fused_params['q_norm_weight'])
|
||||
self.layer_mapper.attn_args._6_k_norm_weight.append(fused_params['k_norm_weight'])
|
||||
# self.layer_mapper.attn_args._7_sin_cache
|
||||
# self.layer_mapper.attn_args._8_cos_cache
|
||||
# self.layer_mapper.attn_args._9_slot_mapping
|
||||
# self.layer_mapper.attn_args._10_kv_cache
|
||||
# self.layer_mapper.attn_args._11_block_tables
|
||||
# self.layer_mapper.attn_args._12_block_group_size
|
||||
self.layer_mapper.attn_args._13_o_proj_weight.append(fused_params['o_proj_weight'])
|
||||
self.layer_mapper.attn_args._14_o_proj_weight_scale.append(fused_params['o_proj_weight_scale'])
|
||||
self.layer_mapper.attn_args._15_o_proj_bias.append(fused_params['o_proj_bias'])
|
||||
self.layer_mapper.attn_args._16_o_proj_qzeros.append(fused_params['o_proj_qzeros'])
|
||||
# self.layer_mapper.attn_args._17_seq_lens
|
||||
self.layer_mapper.attn_args._18_sm_scale = layer.self_attn.scaling
|
||||
self.layer_mapper.attn_args._19_num_attention_heads = layer.self_attn.total_num_heads
|
||||
self.layer_mapper.attn_args._20_num_key_value_heads = layer.self_attn.total_num_kv_heads
|
||||
|
||||
def capture_moe_weights(self, layer: Qwen3MoeDecoderLayer):
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
quant_method = layer.mlp.experts.quant_method if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock) \
|
||||
else layer.mlp.down_proj.quant_method
|
||||
|
||||
if not isinstance(quant_method, MoeWNA16Method):
|
||||
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import recompute_moe_layer_blocksize
|
||||
recompute_moe_layer_blocksize(layer.mlp.experts)
|
||||
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
|
||||
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_weight)
|
||||
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_weight)
|
||||
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_weight_scale_inv)
|
||||
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_weight_scale_inv)
|
||||
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
|
||||
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
|
||||
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size
|
||||
else:
|
||||
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
|
||||
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_qweight)
|
||||
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_qweight)
|
||||
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_scales)
|
||||
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_scales)
|
||||
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
|
||||
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
|
||||
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size
|
||||
|
||||
Reference in New Issue
Block a user