init
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,491 @@
|
||||
import torch
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
|
||||
DeepseekV2MLAAttention,
|
||||
DeepseekV2MLP,
|
||||
DeepseekV2MoE)
|
||||
|
||||
from ..vars import *
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
OUTPUT_ARGS_LOGS = False
|
||||
|
||||
class DistributedArgs():
|
||||
def __init__(self):
|
||||
self._0_world_size = 32
|
||||
self._1_rank = -1
|
||||
self._2_group_id = 0
|
||||
self._3_dev_info = []
|
||||
|
||||
def logs(self):
|
||||
print("dist self._0_world_size = " , self._0_world_size)
|
||||
print("dist self._1_rank = " , self._1_rank)
|
||||
print("dist self._2_group_id = " , self._2_group_id)
|
||||
print("dist self._3_dev_info = " , self._3_dev_info)
|
||||
|
||||
class AttenArgs():
|
||||
def __init__(self):
|
||||
self._a_hidden_states_norm_weight = []
|
||||
self._0_merge_q_kv_weights = [] # 融合Q,KV Weights
|
||||
self._1_merge_q_kv_scale_inv = [] # 融合Q,KV Scales
|
||||
self._2_q_a_layernorm_weight = []
|
||||
self._3_W_Q = []
|
||||
self._4_W_Q_scales = []
|
||||
self._5_W_UK = []
|
||||
self._6_W_UK_scales = []
|
||||
self._7_W_QR = []
|
||||
self._8_W_QR_scales = []
|
||||
self._9_kv_a_layernorm_weight = []
|
||||
self._10_sin_cache = None
|
||||
self._11_cos_cache = None
|
||||
self._12_slot_mapping = None
|
||||
self._13_kv_cache = None
|
||||
self._14_block_tables = None
|
||||
self._15_env_blk_grp_size = env_blk_grp_size
|
||||
self._16_W_UV = []
|
||||
self._17_W_UV_scales = []
|
||||
self._18_o_proj_weight =[]
|
||||
self._19_o_proj_weight_scale_inv = []
|
||||
# mla params
|
||||
self._20_seq_lens = []
|
||||
self._21_sm_scale = 0.0
|
||||
self._22_head_num = 128
|
||||
|
||||
def logs(self):
|
||||
print("mla _20_seq_lens block size is:", self._20_seq_lens)
|
||||
print("mla _21_sm_scale block size is:", self._21_sm_scale)
|
||||
print("mla _22_head_num block size is:", self._22_head_num)
|
||||
|
||||
|
||||
class MlpArgs():
|
||||
def __init__(self):
|
||||
#mlp params
|
||||
self._0_mlp_rms_weight = []
|
||||
self._1_mlp_w13 = []
|
||||
self._2_mlp_w2 = []
|
||||
self._3_mlp_w13_scale = []
|
||||
self._4_mlp_w2_scale = []
|
||||
self._5_mlp_w13_block_size = []
|
||||
self._6_mlp_w2_block_size = []
|
||||
|
||||
def logs(self):
|
||||
print("mlp _5_mlp_w13_block_size block size is:", self._5_mlp_w13_block_size)
|
||||
print("mlp _6_mlp_w2_block_size block size is:", self._6_mlp_w2_block_size)
|
||||
|
||||
class MoeArgs():
|
||||
def __init__(self):
|
||||
#moe params
|
||||
self._0_moe_rms_weight = []
|
||||
self._1_moe_share_mlp_w13 = []
|
||||
self._2_moe_share_mlp_w2 = []
|
||||
self._3_moe_share_mlp_w13_scale = []
|
||||
self._4_moe_share_mlp_w2_scale = []
|
||||
self._5_moe_w13 = []
|
||||
self._6_moe_w2 = []
|
||||
self._7_moe_w13_scale = []
|
||||
self._8_moe_w2_scale = []
|
||||
self._9_gate_weight = []
|
||||
self._10_moe_bias = []
|
||||
self._11_moe_mlp_w13_block_size = []
|
||||
self._12_moe_mlp_w2_block_size = []
|
||||
self._13_moe_w13_block_size = []
|
||||
self._14_moe_w2_block_size = []
|
||||
|
||||
def logs(self):
|
||||
print("moe _11_moe_mlp_w13_block_size block size is:", self._11_moe_mlp_w13_block_size)
|
||||
print("moe _12_moe_mlp_w2_block_size block size is:", self._12_moe_mlp_w2_block_size)
|
||||
print("moe _13_moe_w13_block_size block size is:", self._13_moe_w13_block_size)
|
||||
print("moe _14_moe_w2_block_size block size is:", self._14_moe_w2_block_size)
|
||||
|
||||
class WeightMapper():
|
||||
def __init__(self):
|
||||
self.attn_args = AttenArgs()
|
||||
self.mlp_args = MlpArgs() # 3 mla+mlp
|
||||
self.moe_args = MoeArgs() # 58 mla+moe
|
||||
self.dist_args = DistributedArgs()
|
||||
|
||||
# 1. weights载入
|
||||
# 2. dequant blocks 预计算
|
||||
# 3. 参数缓存&提取
|
||||
class DeepseekWeightCapture():
|
||||
def __init__(self, layer: torch.nn.ModuleList,
|
||||
start: int,
|
||||
end: int):
|
||||
|
||||
self.layer_mlp = WeightMapper()
|
||||
self.layer_moe = WeightMapper()
|
||||
|
||||
self.sin_cache_all = None
|
||||
self.cos_cache_all = None
|
||||
|
||||
self.mlp_nums = 3
|
||||
self.moe_nums = end - self.mlp_nums
|
||||
|
||||
self.start_idx = start
|
||||
self.end_idx = end
|
||||
for i in range(start, end):
|
||||
if i < self.mlp_nums:
|
||||
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_mlp.attn_args)
|
||||
self.capture_deepseek_mlp_weights(layer[i])
|
||||
else:
|
||||
self.capture_deepseek_mla_attn_weights(layer[i], self.layer_moe.attn_args)
|
||||
self.capture_deepseek_moe_weights(layer[i])
|
||||
|
||||
if OUTPUT_ARGS_LOGS:
|
||||
self.layer_mlp.attn_args.logs()
|
||||
self.layer_mlp.mlp_args.logs()
|
||||
|
||||
self.layer_moe.attn_args.logs()
|
||||
self.layer_moe.moe_args.logs()
|
||||
|
||||
from vllm.distributed import get_tp_group
|
||||
tp_group = get_tp_group()
|
||||
self.layer_mlp.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_mlp.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_mlp.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_mlp.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
self.layer_moe.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_moe.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
|
||||
def capture_deepseek_mlp_weights(self, module: DeepseekV2DecoderLayer):
|
||||
assert isinstance(module.mlp, DeepseekV2MLP)
|
||||
|
||||
mlp = module.mlp
|
||||
rms_norm = module.post_attention_layernorm
|
||||
|
||||
w13_weight = mlp.gate_up_proj.weight
|
||||
w2_weight = mlp.down_proj.weight
|
||||
|
||||
w13_wscale = mlp.gate_up_proj.weight_scale_inv
|
||||
w13_iscale = mlp.gate_up_proj.input_scale
|
||||
|
||||
w2_wscale = mlp.down_proj.weight_scale_inv
|
||||
w2_iscale = mlp.down_proj.input_scale
|
||||
|
||||
|
||||
w13_block_size0, w13_block_size1 = mlp.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = mlp.gate_up_proj.quant_method.scale_n, mlp.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
|
||||
w2_block_size0, w2_block_size1 = mlp.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = mlp.down_proj.quant_method.scale_n, mlp.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
|
||||
|
||||
self.layer_mlp.mlp_args._0_mlp_rms_weight.append(rms_norm.weight)
|
||||
self.layer_mlp.mlp_args._1_mlp_w13.append(w13_weight)
|
||||
self.layer_mlp.mlp_args._2_mlp_w2.append(w2_weight)
|
||||
self.layer_mlp.mlp_args._3_mlp_w13_scale.append(w13_wscale)
|
||||
self.layer_mlp.mlp_args._4_mlp_w2_scale.append(w2_wscale)
|
||||
self.layer_mlp.mlp_args._5_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
self.layer_mlp.mlp_args._6_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
|
||||
|
||||
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
|
||||
assert isinstance(module.mlp, DeepseekV2MoE)
|
||||
|
||||
share_expert_layer = module.mlp.shared_experts
|
||||
experts_layer = module.mlp.experts
|
||||
rms_norm = module.post_attention_layernorm
|
||||
gate = module.mlp.gate
|
||||
|
||||
w13_weight = share_expert_layer.gate_up_proj.weight
|
||||
w2_weight = share_expert_layer.down_proj.weight
|
||||
|
||||
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
|
||||
# w13_iscale = share_expert_layer.gate_up_proj.input_scale
|
||||
|
||||
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
|
||||
# w2_iscale = share_expert_layer.down_proj.input_scale
|
||||
|
||||
|
||||
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
|
||||
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
|
||||
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
|
||||
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
|
||||
block_size0, block_size1 = (
|
||||
hidden_dims // hidden_blocks,
|
||||
inter_dims // inter_blocks,
|
||||
)
|
||||
|
||||
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
|
||||
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
|
||||
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
|
||||
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
|
||||
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
|
||||
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
|
||||
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
|
||||
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
|
||||
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
|
||||
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
|
||||
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
|
||||
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
|
||||
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
|
||||
|
||||
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
|
||||
weight_mapper: AttenArgs):
|
||||
if(self.sin_cache_all is None):
|
||||
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
|
||||
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
|
||||
|
||||
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
|
||||
|
||||
fused_params = {}
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.q_a_proj.named_parameters():
|
||||
fused_params['q_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.q_a_layernorm.named_parameters():
|
||||
fused_params['q_a_layernorm_' + name] = param
|
||||
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
|
||||
fused_params['kv_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
|
||||
fused_params['kv_a_layernorm_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.o_proj.named_parameters():
|
||||
fused_params['o_proj_' + name] = param
|
||||
|
||||
import os
|
||||
self._15_env_blk_grp_size = env_blk_grp_size
|
||||
# init sin,cos cache
|
||||
|
||||
mla_params = module.self_attn.mla_attn.impl.extract_weights()
|
||||
fused_params = {**fused_params, **mla_params}
|
||||
|
||||
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
|
||||
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
|
||||
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
|
||||
weight_mapper._3_W_Q.append(fused_params['W_Q'])
|
||||
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
|
||||
weight_mapper._5_W_UK.append(fused_params['W_UK'])
|
||||
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
|
||||
weight_mapper._7_W_QR.append(fused_params['W_QR'])
|
||||
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
|
||||
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
|
||||
#weight_mapper._10_sin_cache.append(None)
|
||||
#weight_mapper._11_cos_cache.append(None)
|
||||
#weight_mapper._12_slot_mapping.append(None)
|
||||
#weight_mapper._13_kv_cache.append(None)
|
||||
#weight_mapper._14_block_tables.append(None)
|
||||
# weight_mapper._15_env_blk_grp_size.append(None)
|
||||
weight_mapper._16_W_UV.append(fused_params['W_UV'])
|
||||
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
|
||||
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
|
||||
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
|
||||
weight_mapper._20_seq_lens = None
|
||||
weight_mapper._21_sm_scale = module.self_attn.scaling
|
||||
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
|
||||
|
||||
# 可优化,在c++里面只用Tensor即可
|
||||
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
|
||||
positions = [i - 1 for i in seq_lens]
|
||||
cos_cache = [self.cos_cache_all[i] for i in positions]
|
||||
sin_cache = [self.sin_cache_all[i] for i in positions]
|
||||
|
||||
self.layer_mlp.attn_args._10_sin_cache = sin_cache
|
||||
self.layer_mlp.attn_args._11_cos_cache = cos_cache
|
||||
|
||||
self.layer_moe.attn_args._10_sin_cache = sin_cache
|
||||
self.layer_moe.attn_args._11_cos_cache = cos_cache
|
||||
|
||||
self.layer_mlp.attn_args._20_seq_lens = seq_lens
|
||||
self.layer_moe.attn_args._20_seq_lens = seq_lens
|
||||
|
||||
|
||||
self.layer_mlp.attn_args._13_kv_cache = kv_caches_dense_layer
|
||||
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
|
||||
|
||||
self.layer_mlp.attn_args._12_slot_mapping = slot_mapping
|
||||
self.layer_mlp.attn_args._14_block_tables = block_tables
|
||||
|
||||
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
|
||||
self.layer_moe.attn_args._14_block_tables = block_tables
|
||||
|
||||
# for i in range(self.mlp_nums):
|
||||
# if i < self.end_idx:
|
||||
# self.layer_mlp.attn_args._12_slot_mapping[i] = slot_mapping
|
||||
# self.layer_mlp.attn_args._14_block_tables[i] = block_tables
|
||||
|
||||
# for i in range(self.moe_nums):
|
||||
# if i < self.end_idx:
|
||||
# self.layer_moe.attn_args._12_slot_mapping[i] = slot_mapping
|
||||
# self.layer_moe.attn_args._14_block_tables[i] = block_tables
|
||||
|
||||
def logs(self):
|
||||
print("current layer mlp attn: \n")
|
||||
self.layer_mlp.attn_args.logs()
|
||||
self.layer_mlp.dist_args.logs()
|
||||
|
||||
class DeepseekMTPWegitCapture():
|
||||
# 相比DeepSeek Weight Capture, MTP只有1层DeepseekDecoderLayer, 且是MOE的layer
|
||||
def __init__(self, layer: torch.nn.Module):
|
||||
|
||||
self.layer_moe = WeightMapper()
|
||||
|
||||
self.sin_cache_all = None
|
||||
self.cos_cache_all = None
|
||||
|
||||
self.capture_deepseek_mla_attn_weights(layer, self.layer_moe.attn_args)
|
||||
self.capture_deepseek_moe_weights(layer)
|
||||
|
||||
if OUTPUT_ARGS_LOGS:
|
||||
self.layer_moe.attn_args.logs()
|
||||
self.layer_moe.moe_args.logs()
|
||||
|
||||
from vllm.distributed import get_tp_group
|
||||
tp_group = get_tp_group()
|
||||
|
||||
self.layer_moe.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_moe.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_moe.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_moe.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
def capture_deepseek_moe_weights(self, module: DeepseekV2DecoderLayer):
|
||||
assert isinstance(module.mlp, DeepseekV2MoE)
|
||||
|
||||
share_expert_layer = module.mlp.shared_experts
|
||||
experts_layer = module.mlp.experts
|
||||
rms_norm = module.post_attention_layernorm
|
||||
gate = module.mlp.gate
|
||||
|
||||
w13_weight = share_expert_layer.gate_up_proj.weight
|
||||
w2_weight = share_expert_layer.down_proj.weight
|
||||
|
||||
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
|
||||
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
|
||||
|
||||
|
||||
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
|
||||
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
|
||||
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
|
||||
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
|
||||
block_size0, block_size1 = (
|
||||
hidden_dims // hidden_blocks,
|
||||
inter_dims // inter_blocks,
|
||||
)
|
||||
|
||||
self.layer_moe.moe_args._0_moe_rms_weight.append(rms_norm.weight)
|
||||
self.layer_moe.moe_args._1_moe_share_mlp_w13.append(w13_weight)
|
||||
self.layer_moe.moe_args._2_moe_share_mlp_w2.append(w2_weight)
|
||||
self.layer_moe.moe_args._3_moe_share_mlp_w13_scale.append(w13_wscale)
|
||||
self.layer_moe.moe_args._4_moe_share_mlp_w2_scale.append(w2_wscale)
|
||||
self.layer_moe.moe_args._5_moe_w13.append(experts_layer.w13_weight)
|
||||
self.layer_moe.moe_args._6_moe_w2.append(experts_layer.w2_weight)
|
||||
self.layer_moe.moe_args._7_moe_w13_scale.append(experts_layer.w13_weight_scale_inv)
|
||||
self.layer_moe.moe_args._8_moe_w2_scale.append(experts_layer.w2_weight_scale_inv)
|
||||
self.layer_moe.moe_args._9_gate_weight.append(gate.weight)
|
||||
self.layer_moe.moe_args._10_moe_bias.append(gate.e_score_correction_bias)
|
||||
self.layer_moe.moe_args._11_moe_mlp_w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
self.layer_moe.moe_args._12_moe_mlp_w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
self.layer_moe.moe_args._13_moe_w13_block_size = [block_size0, block_size1]
|
||||
self.layer_moe.moe_args._14_moe_w2_block_size = [block_size0, block_size1]
|
||||
|
||||
def capture_deepseek_mla_attn_weights(self, module: DeepseekV2DecoderLayer,
|
||||
weight_mapper: AttenArgs):
|
||||
if(self.sin_cache_all is None):
|
||||
self.sin_cache_all = module.self_attn.mla_attn.impl.rotary_emb.sin_cache
|
||||
self.cos_cache_all = module.self_attn.mla_attn.impl.rotary_emb.cos_cache
|
||||
|
||||
weight_mapper._a_hidden_states_norm_weight.append(module.input_layernorm.weight)
|
||||
|
||||
fused_params = {}
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.q_a_proj.named_parameters():
|
||||
fused_params['q_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.q_a_layernorm.named_parameters():
|
||||
fused_params['q_a_layernorm_' + name] = param
|
||||
|
||||
if not USE_MERGE_Q_KV_GEN_AND_Q_QR:
|
||||
for name, param in module.self_attn.kv_a_proj_with_mqa.named_parameters():
|
||||
fused_params['kv_a_proj_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.kv_a_layernorm.named_parameters():
|
||||
fused_params['kv_a_layernorm_' + name] = param
|
||||
|
||||
for name, param in module.self_attn.o_proj.named_parameters():
|
||||
fused_params['o_proj_' + name] = param
|
||||
|
||||
import os
|
||||
self._15_env_blk_grp_size = env_blk_grp_size
|
||||
# init sin,cos cache
|
||||
|
||||
mla_params = module.self_attn.mla_attn.impl.extract_weights()
|
||||
fused_params = {**fused_params, **mla_params}
|
||||
|
||||
weight_mapper._0_merge_q_kv_weights.append(module.self_attn.merge_q_kv_weights)
|
||||
weight_mapper._1_merge_q_kv_scale_inv.append(module.self_attn.merge_q_kv_scale_inv)
|
||||
weight_mapper._2_q_a_layernorm_weight.append(fused_params['q_a_layernorm_weight'])
|
||||
weight_mapper._3_W_Q.append(fused_params['W_Q'])
|
||||
weight_mapper._4_W_Q_scales.append(fused_params['W_Q_scales'])
|
||||
weight_mapper._5_W_UK.append(fused_params['W_UK'])
|
||||
weight_mapper._6_W_UK_scales.append(fused_params['W_UK_scales'])
|
||||
weight_mapper._7_W_QR.append(fused_params['W_QR'])
|
||||
weight_mapper._8_W_QR_scales.append(fused_params['W_QR_scales'])
|
||||
weight_mapper._9_kv_a_layernorm_weight.append(fused_params['kv_a_layernorm_weight'])
|
||||
#weight_mapper._10_sin_cache.append(None)
|
||||
#weight_mapper._11_cos_cache.append(None)
|
||||
#weight_mapper._12_slot_mapping.append(None)
|
||||
#weight_mapper._13_kv_cache.append(None)
|
||||
#weight_mapper._14_block_tables.append(None)
|
||||
# weight_mapper._15_env_blk_grp_size.append(None)
|
||||
weight_mapper._16_W_UV.append(fused_params['W_UV'])
|
||||
weight_mapper._17_W_UV_scales.append(fused_params['W_UV_scales'])
|
||||
weight_mapper._18_o_proj_weight.append(fused_params['o_proj_weight'])
|
||||
weight_mapper._19_o_proj_weight_scale_inv.append(fused_params['o_proj_weight_scale_inv'])
|
||||
weight_mapper._20_seq_lens = None
|
||||
weight_mapper._21_sm_scale = module.self_attn.scaling
|
||||
weight_mapper._22_head_num = module.self_attn.num_heads // module.self_attn.o_proj.tp_size
|
||||
|
||||
# 可优化,在c++里面只用Tensor即可
|
||||
def update_attn_args(self, seq_lens, slot_mapping, kv_caches_dense_layer, kv_caches_moe_layer, block_tables):
|
||||
positions = [i - 1 for i in seq_lens]
|
||||
cos_cache = [self.cos_cache_all[i] for i in positions]
|
||||
sin_cache = [self.sin_cache_all[i] for i in positions]
|
||||
|
||||
self.layer_moe.attn_args._10_sin_cache = sin_cache
|
||||
self.layer_moe.attn_args._11_cos_cache = cos_cache
|
||||
|
||||
self.layer_moe.attn_args._20_seq_lens = seq_lens
|
||||
|
||||
self.layer_moe.attn_args._13_kv_cache = kv_caches_moe_layer
|
||||
|
||||
self.layer_moe.attn_args._12_slot_mapping = slot_mapping
|
||||
self.layer_moe.attn_args._14_block_tables = block_tables
|
||||
|
||||
def logs(self):
|
||||
print("current layer mlp attn: \n")
|
||||
self.layer_mlp.attn_args.logs()
|
||||
self.layer_mlp.dist_args.logs()
|
||||
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
|
||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeMLP)
|
||||
|
||||
class Qwen3Moe_DistributedArgs():
|
||||
def __init__(self):
|
||||
self._0_world_size = 32
|
||||
self._1_rank = -1
|
||||
self._2_group_id = 0
|
||||
self._3_dev_info = []
|
||||
|
||||
def __repr__(self):
|
||||
dist_infos = f"[dist] world_size = {self._0_world_size} \n" \
|
||||
+ f"[dist] rank = {self._1_rank} \n" \
|
||||
+ f"[dist] group_id = {self._2_group_id} \n" \
|
||||
+ f"[dist] dev_info = {self._3_dev_info}"
|
||||
return dist_infos
|
||||
|
||||
class Qwen3Moe_AttenArgs():
|
||||
def __init__(self):
|
||||
self._0_input_layernorm_weight = []
|
||||
self._1_qkv_proj_weight = [] #
|
||||
self._2_qkv_proj_weight_scale = []
|
||||
self._3_qkv_proj_bias = []
|
||||
self._4_qkv_proj_qzeros = []
|
||||
self._5_q_norm_weight = []
|
||||
self._6_k_norm_weight = []
|
||||
self._7_sin_cache = None
|
||||
self._8_cos_cache = None
|
||||
self._9_slot_mapping = None
|
||||
self._10_kv_cache = None
|
||||
self._11_block_tables = None
|
||||
self._12_block_group_size = None
|
||||
self._13_o_proj_weight = []
|
||||
self._14_o_proj_weight_scale = []
|
||||
self._15_o_proj_bias = []
|
||||
self._16_o_proj_qzeros = []
|
||||
self._17_seq_lens = None
|
||||
self._18_sm_scale =None
|
||||
self._19_num_attention_heads = None
|
||||
self._20_num_key_value_heads = None
|
||||
|
||||
def __repr__(self):
|
||||
attn_infos = "[qwen attn] 21 args" \
|
||||
+ f"[qwen attn] weight counts: {len(self._0_input_layernorm_weight)}"
|
||||
return attn_infos
|
||||
|
||||
class Qwen3Moe_MoeArgs():
|
||||
def __init__(self):
|
||||
#moe params
|
||||
self._0_rms_norm_weight = []
|
||||
self._1_w13_weight = []
|
||||
self._2_w2_weight = []
|
||||
self._3_w13_weight_scale_inv = []
|
||||
self._4_w2_weight_scale_inv = []
|
||||
self._5_gate_weight = []
|
||||
self._6_w13_block_size = None
|
||||
self._7_w2_block_size = None
|
||||
|
||||
def __repr__(self):
|
||||
moe_infos = f"[moe] w13_block_size: {self._6_w13_block_size}" \
|
||||
+ f"[moe] w2_block_size: {self._7_w2_block_size}" \
|
||||
+ f"[moe] weight counts: {len(self._1_w13_weight)}"
|
||||
return moe_infos
|
||||
|
||||
class Qwen3Moe_WeightMapper:
|
||||
def __init__(self):
|
||||
self.attn_args = Qwen3Moe_AttenArgs()
|
||||
self.moe_args = Qwen3Moe_MoeArgs()
|
||||
self.dist_args = Qwen3Moe_DistributedArgs()
|
||||
|
||||
class Qwen3Moe_WeightCapture():
|
||||
def __init__(self, layers: torch.nn.ModuleList,
|
||||
start: int,
|
||||
end: int):
|
||||
self.layer_mapper = Qwen3Moe_WeightMapper()
|
||||
# qwen3 only support fp8 now
|
||||
self.support_fused_weights = False
|
||||
for i in range(start, end):
|
||||
layer = layers[i]
|
||||
self.capture_attn_weights(layer)
|
||||
self.capture_moe_weights(layer)
|
||||
|
||||
# 注册 多卡环境信息
|
||||
from vllm.distributed import get_tp_group
|
||||
tp_group = get_tp_group()
|
||||
self.layer_mapper.dist_args._0_world_size = tp_group.world_size
|
||||
self.layer_mapper.dist_args._1_rank = tp_group.rank_in_group
|
||||
self.layer_mapper.dist_args._2_group_id = tp_group.group_id
|
||||
self.layer_mapper.dist_args._3_dev_info = tp_group.rank_device_infos
|
||||
|
||||
def capture_attn_weights(self, layer):
|
||||
from vllm_vacc.vllm.model_executor.models.qwen3_moe import set_fused_params
|
||||
# 注册融合算子
|
||||
fused_params = {}
|
||||
fused_params['input_layernorm_weight'] = layer.input_layernorm.weight
|
||||
fused_params['q_norm_weight'] = layer.self_attn.q_norm.weight
|
||||
fused_params['k_norm_weight'] = layer.self_attn.k_norm.weight
|
||||
set_fused_params(fused_params, layer.self_attn.qkv_proj.quant_method, layer.self_attn.qkv_proj, 'qkv_proj')
|
||||
set_fused_params(fused_params, layer.self_attn.o_proj.quant_method, layer.self_attn.o_proj, 'o_proj')
|
||||
|
||||
self.support_fused_weights = hasattr(layer.mlp.experts.quant_method, 'quant_config') and hasattr(layer.mlp.experts.quant_method.quant_config, 'weight_block_size')
|
||||
if not hasattr(layer.self_attn, "fused_params"):
|
||||
layer.self_attn.fused_params = fused_params
|
||||
|
||||
self.layer_mapper.attn_args._0_input_layernorm_weight.append(fused_params['input_layernorm_weight'])
|
||||
self.layer_mapper.attn_args._1_qkv_proj_weight.append(fused_params['qkv_proj_weight'])
|
||||
self.layer_mapper.attn_args._2_qkv_proj_weight_scale.append(fused_params['qkv_proj_weight_scale'])
|
||||
self.layer_mapper.attn_args._3_qkv_proj_bias.append(fused_params['qkv_proj_bias'])
|
||||
self.layer_mapper.attn_args._4_qkv_proj_qzeros.append(fused_params['qkv_proj_qzeros'])
|
||||
self.layer_mapper.attn_args._5_q_norm_weight.append(fused_params['q_norm_weight'])
|
||||
self.layer_mapper.attn_args._6_k_norm_weight.append(fused_params['k_norm_weight'])
|
||||
# self.layer_mapper.attn_args._7_sin_cache
|
||||
# self.layer_mapper.attn_args._8_cos_cache
|
||||
# self.layer_mapper.attn_args._9_slot_mapping
|
||||
# self.layer_mapper.attn_args._10_kv_cache
|
||||
# self.layer_mapper.attn_args._11_block_tables
|
||||
# self.layer_mapper.attn_args._12_block_group_size
|
||||
self.layer_mapper.attn_args._13_o_proj_weight.append(fused_params['o_proj_weight'])
|
||||
self.layer_mapper.attn_args._14_o_proj_weight_scale.append(fused_params['o_proj_weight_scale'])
|
||||
self.layer_mapper.attn_args._15_o_proj_bias.append(fused_params['o_proj_bias'])
|
||||
self.layer_mapper.attn_args._16_o_proj_qzeros.append(fused_params['o_proj_qzeros'])
|
||||
# self.layer_mapper.attn_args._17_seq_lens
|
||||
self.layer_mapper.attn_args._18_sm_scale = layer.self_attn.scaling
|
||||
self.layer_mapper.attn_args._19_num_attention_heads = layer.self_attn.total_num_heads
|
||||
self.layer_mapper.attn_args._20_num_key_value_heads = layer.self_attn.total_num_kv_heads
|
||||
|
||||
def capture_moe_weights(self, layer: Qwen3MoeDecoderLayer):
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
quant_method = layer.mlp.experts.quant_method if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock) \
|
||||
else layer.mlp.down_proj.quant_method
|
||||
|
||||
if not isinstance(quant_method, MoeWNA16Method):
|
||||
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import recompute_moe_layer_blocksize
|
||||
recompute_moe_layer_blocksize(layer.mlp.experts)
|
||||
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
|
||||
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_weight)
|
||||
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_weight)
|
||||
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_weight_scale_inv)
|
||||
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_weight_scale_inv)
|
||||
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
|
||||
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
|
||||
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size
|
||||
else:
|
||||
self.layer_mapper.moe_args._0_rms_norm_weight.append(layer.post_attention_layernorm.weight)
|
||||
self.layer_mapper.moe_args._1_w13_weight.append(layer.mlp.experts.w13_qweight)
|
||||
self.layer_mapper.moe_args._2_w2_weight.append(layer.mlp.experts.w2_qweight)
|
||||
self.layer_mapper.moe_args._3_w13_weight_scale_inv.append(layer.mlp.experts.w13_scales)
|
||||
self.layer_mapper.moe_args._4_w2_weight_scale_inv.append(layer.mlp.experts.w2_scales)
|
||||
self.layer_mapper.moe_args._5_gate_weight.append(layer.mlp.gate.weight)
|
||||
self.layer_mapper.moe_args._6_w13_block_size = layer.mlp.experts.w13_block_size
|
||||
self.layer_mapper.moe_args._7_w2_block_size = layer.mlp.experts.w2_block_size
|
||||
|
||||
Reference in New Issue
Block a user