init
This commit is contained in:
0
vllm_vacc/vllm/model_executor/ops/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/ops/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
250
vllm_vacc/vllm/model_executor/ops/deepseek_fused_mlp_moe.py
Normal file
250
vllm_vacc/vllm/model_executor/ops/deepseek_fused_mlp_moe.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import torch
|
||||
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
|
||||
|
||||
# moe layer blocksize
|
||||
# 1. shared_expert_blocksize (w13, w2)
|
||||
# 2. expert_blocksize (w)
|
||||
def recompute_moe_layer_blocksize(moe_layer, share_expert_layer, experts_layer):
|
||||
'''
|
||||
if not hasattr(share_expert_layer, "w13_block_size"):
|
||||
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
share_expert_layer.w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
|
||||
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
share_expert_layer.w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
|
||||
if not hasattr(moe_layer, "block_size"):
|
||||
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
|
||||
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
|
||||
block_size0, block_size1 = (
|
||||
hidden_dims // hidden_blocks,
|
||||
inter_dims // inter_blocks,
|
||||
)
|
||||
experts_layer.block_size = [block_size0, block_size1]
|
||||
'''
|
||||
if not hasattr(share_expert_layer, "w13_block_size"):
|
||||
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
share_expert_layer.w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
|
||||
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
share_expert_layer.w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
|
||||
if not hasattr(experts_layer, "w13w2_block_size"):
|
||||
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
|
||||
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
|
||||
block_size0, block_size1 = (
|
||||
hidden_dims // hidden_blocks,
|
||||
inter_dims // inter_blocks,
|
||||
)
|
||||
experts_layer.w13w2_block_size = [block_size0, block_size1]
|
||||
|
||||
|
||||
def vacc_fused_mlp_fp8(mlp, hidden_states, residual = None, rms_norm = None, reduce_result = False, moe_share = False):
|
||||
if not hasattr(mlp, "w13_block_size"):
|
||||
w13_block_size0, w13_block_size1 = mlp.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = mlp.gate_up_proj.quant_method.scale_n, mlp.gate_up_proj.quant_method.scale_k
|
||||
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
|
||||
w13_block_size0 = w13_block_size0 // scale_n
|
||||
w13_block_size1 = w13_block_size1 // scale_k
|
||||
mlp.w13_block_size = [w13_block_size0, w13_block_size1]
|
||||
|
||||
w2_block_size0, w2_block_size1 = mlp.down_proj.quant_method.quant_config.weight_block_size
|
||||
scale_n, scale_k = mlp.down_proj.quant_method.scale_n, mlp.down_proj.quant_method.scale_k
|
||||
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
|
||||
w2_block_size0 = w2_block_size0 // scale_n
|
||||
w2_block_size1 = w2_block_size1 // scale_k
|
||||
mlp.w2_block_size = [w2_block_size0, w2_block_size1]
|
||||
|
||||
# start mlp fused ops
|
||||
w13_weight = mlp.gate_up_proj.weight
|
||||
w2_weight = mlp.down_proj.weight
|
||||
|
||||
w13_wscale = mlp.gate_up_proj.weight_scale_inv
|
||||
w13_iscale = mlp.gate_up_proj.input_scale
|
||||
|
||||
w2_wscale = mlp.down_proj.weight_scale_inv
|
||||
w2_iscale = mlp.down_proj.input_scale
|
||||
|
||||
if (residual == None):
|
||||
from torch_vacc.vacc.custom_ops import fused_mlp_fp8 as fused_share_expert
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
mlp_output = hidden_states
|
||||
# moe share experts out, can't be inplaced
|
||||
if moe_share and memory_recycler is not None:
|
||||
mlp_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
|
||||
|
||||
mlp_output = fused_share_expert(hidden_states, w13_weight, w2_weight,
|
||||
True,
|
||||
w13_wscale, w2_wscale,
|
||||
w13_iscale, w2_iscale,
|
||||
mlp.w13_block_size,
|
||||
mlp.w2_block_size,
|
||||
output=mlp_output)
|
||||
return mlp_output
|
||||
else:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fused_mlp_allreduce
|
||||
tp_group = get_tp_group()
|
||||
final_hidden_states, residual = fused_mlp_allreduce(hidden_states, residual, rms_norm.weight,
|
||||
w13_weight, w2_weight,
|
||||
w13_wscale, w2_wscale,
|
||||
mlp.w13_block_size,
|
||||
mlp.w2_block_size,
|
||||
red_op_type=0,
|
||||
world_size=tp_group.world_size,
|
||||
rank=tp_group.rank_in_group,
|
||||
root_rank=0,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos)
|
||||
|
||||
pass
|
||||
except Exception as e:
|
||||
print("fused mlp_fp8_allreduce run false, now use unfused mlp and allreduce", e)
|
||||
final_hidden_states, residual = torch.vacc.fused_mlp_with_rmsnorm(hidden_states, residual, rms_norm.weight,
|
||||
w13_weight, w2_weight,
|
||||
w13_wscale, w2_wscale,
|
||||
mlp.w13_block_size,
|
||||
mlp.w2_block_size,)
|
||||
if reduce_result:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
return final_hidden_states, residual
|
||||
|
||||
def vacc_fused_decode_moe_fp8(moe_layer, share_expert_layer, hidden_states, residual, rms_norm, gate, experts_layer, routed_scaling_factor, reduce_result):
|
||||
|
||||
w13_weight = share_expert_layer.gate_up_proj.weight
|
||||
w2_weight = share_expert_layer.down_proj.weight
|
||||
|
||||
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
|
||||
# w13_iscale = share_expert_layer.gate_up_proj.input_scale
|
||||
|
||||
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
|
||||
# w2_iscale = share_expert_layer.down_proj.input_scale
|
||||
|
||||
recompute_moe_layer_blocksize(moe_layer, share_expert_layer, experts_layer)
|
||||
if reduce_result:
|
||||
tp_group = get_tp_group()
|
||||
total_bytes = hidden_states.numel() * hidden_states.element_size() * tp_group.world_size
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fuse_moe_decode_v2_allreduce
|
||||
final_hidden_states, residual = fuse_moe_decode_v2_allreduce(hidden_states, residual, rms_norm.weight,
|
||||
w13_weight, w2_weight,
|
||||
w13_wscale, w2_wscale,
|
||||
experts_layer.w13_weight, experts_layer.w2_weight,
|
||||
experts_layer.w13_weight_scale_inv, experts_layer.w2_weight_scale_inv,
|
||||
gate.weight, gate.e_score_correction_bias,
|
||||
share_expert_layer.w13_block_size,
|
||||
share_expert_layer.w2_block_size,
|
||||
experts_layer.w13w2_block_size,
|
||||
experts_layer.w13w2_block_size,
|
||||
red_op_type=0,
|
||||
world_size=tp_group.world_size,
|
||||
rank = tp_group.rank_in_group,
|
||||
root_rank=0,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos)
|
||||
return final_hidden_states, residual
|
||||
except Exception as e:
|
||||
print("fuse_moe with all_reduce run Fail, now use unfused moe&all_reduce", e)
|
||||
# 1. >4M
|
||||
# 2. fuse_moe_decode_v2_allreduce run Fail
|
||||
final_hidden_states, residual = torch.vacc.fused_mlp_moe_with_rmsnorm(hidden_states, residual, rms_norm.weight,
|
||||
w13_weight, w2_weight,
|
||||
w13_wscale, w2_wscale,
|
||||
experts_layer.w13_weight, experts_layer.w2_weight,
|
||||
experts_layer.w13_weight_scale_inv, experts_layer.w2_weight_scale_inv,
|
||||
gate.weight, gate.e_score_correction_bias,
|
||||
share_expert_layer.w13_block_size, share_expert_layer.w2_block_size,
|
||||
experts_layer.w13w2_block_size, experts_layer.w13w2_block_size)
|
||||
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
else:
|
||||
final_hidden_states, residual = torch.vacc.fused_mlp_moe_with_rmsnorm(hidden_states, residual, rms_norm.weight,
|
||||
w13_weight, w2_weight,
|
||||
w13_wscale, w2_wscale,
|
||||
experts_layer.w13_weight, experts_layer.w2_weight,
|
||||
experts_layer.w13_weight_scale_inv, experts_layer.w2_weight_scale_inv,
|
||||
gate.weight, gate.e_score_correction_bias,
|
||||
share_expert_layer.w13_block_size, share_expert_layer.w2_block_size,
|
||||
experts_layer.w13w2_block_size, experts_layer.w13w2_block_size)
|
||||
|
||||
return final_hidden_states, residual
|
||||
|
||||
def vacc_fused_prefill_moe_fp8(moe_layer, share_expert_layer, hidden_states, residual, rms_norm, gate, experts_layer, routed_scaling_factor, reduce_result):
|
||||
from torch_vacc.vacc.custom_ops import fused_experts, fuse_moe_prefill_stage0
|
||||
|
||||
w13_weight = share_expert_layer.gate_up_proj.weight
|
||||
w2_weight = share_expert_layer.down_proj.weight
|
||||
|
||||
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
|
||||
# w13_iscale = share_expert_layer.gate_up_proj.input_scale
|
||||
|
||||
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
|
||||
# w2_iscale = share_expert_layer.down_proj.input_scale
|
||||
|
||||
recompute_moe_layer_blocksize(moe_layer, share_expert_layer, experts_layer)
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
share_experts_output = None
|
||||
#experts_output = None
|
||||
if memory_recycler is not None:
|
||||
share_experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
|
||||
#experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER
|
||||
|
||||
hidden_states, shared_output, topk_weights, topk_ids, residual = fuse_moe_prefill_stage0(hidden_states,
|
||||
residual,
|
||||
rms_norm.weight,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_wscale,
|
||||
w2_wscale,
|
||||
gate.weight,
|
||||
gate.e_score_correction_bias,
|
||||
share_expert_layer.w13_block_size,
|
||||
share_expert_layer.w2_block_size,
|
||||
rms_hidden_state_opt=hidden_states,
|
||||
mlp_hidden_state_opt=share_experts_output)
|
||||
# output_opt must be used
|
||||
final_hidden_states = fused_experts(
|
||||
hidden_states,
|
||||
experts_layer.w13_weight,
|
||||
experts_layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=experts_layer.w13_weight_scale_inv,
|
||||
w2_scale=experts_layer.w2_weight_scale_inv,
|
||||
a13_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=experts_layer.w13w2_block_size,
|
||||
decode_with_batch=False,
|
||||
output_opt=shared_output,
|
||||
)
|
||||
# final_hidden_states = torch.add(shared_output, final_hidden_states, alpha=routed_scaling_factor)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
return final_hidden_states, residual
|
||||
# shared_output.add_(final_hidden_states, alpha=routed_scaling_factor)
|
||||
# shared_output = tensor_model_parallel_all_reduce(
|
||||
# shared_output)
|
||||
# return shared_output, residual
|
||||
41
vllm_vacc/vllm/model_executor/ops/mrope_op.py
Normal file
41
vllm_vacc/vllm/model_executor/ops/mrope_op.py
Normal file
@@ -0,0 +1,41 @@
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
|
||||
|
||||
|
||||
def get_sin_cos_mrope(rotary_emb: MRotaryEmbedding, positions: Union[torch.Tensor], use_fuse: bool = True):
|
||||
#ref: MRotaryEmbedding forward_native
|
||||
# odsp
|
||||
if use_fuse:
|
||||
return torch.vacc.mrope_get_sin_cos(
|
||||
rotary_emb.sin_cache,
|
||||
rotary_emb.cos_cache,
|
||||
positions,
|
||||
rotary_emb.mrope_section,
|
||||
rotary_emb.mrope_interleaved
|
||||
)
|
||||
|
||||
cos = rotary_emb.cos_cache[positions]
|
||||
sin = rotary_emb.sin_cache[positions]
|
||||
|
||||
if rotary_emb.mrope_interleaved:
|
||||
cos_cache = apply_interleaved_rope(cos, rotary_emb.mrope_section)
|
||||
sin_cache = apply_interleaved_rope(sin, rotary_emb.mrope_section)
|
||||
else:
|
||||
cos_cache = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
cos.split(rotary_emb.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin_cache = torch.cat([
|
||||
m[i] for i, m in enumerate(
|
||||
sin.split(rotary_emb.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
return cos_cache, sin_cache
|
||||
|
||||
222
vllm_vacc/vllm/model_executor/ops/qwen3_fused_moe.py
Normal file
222
vllm_vacc/vllm/model_executor/ops/qwen3_fused_moe.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import torch
|
||||
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
|
||||
|
||||
|
||||
def recompute_moe_layer_blocksize(experts_layer):
|
||||
if not hasattr(experts_layer, "w13_block_size"):
|
||||
block_size_13 = [experts_layer.w13_weight.shape[1] // experts_layer.w13_weight_scale_inv.shape[1],
|
||||
experts_layer.w13_weight.shape[2] // experts_layer.w13_weight_scale_inv.shape[2]]
|
||||
experts_layer.w13_block_size = block_size_13
|
||||
if not hasattr(experts_layer, "w2_block_size"):
|
||||
# block_size_2 = [experts_layer.w2_weight.shape[2] // experts_layer.w2_weight_scale_inv.shape[2],
|
||||
# experts_layer.w2_weight.shape[1] // experts_layer.w2_weight_scale_inv.shape[1]]
|
||||
block_size_2 = [experts_layer.w2_weight.shape[1] // experts_layer.w2_weight_scale_inv.shape[1],
|
||||
experts_layer.w2_weight.shape[2] // experts_layer.w2_weight_scale_inv.shape[2]]
|
||||
|
||||
experts_layer.w2_block_size = block_size_2
|
||||
|
||||
if not hasattr(experts_layer, "w13_block_size_prefill") and hasattr(experts_layer, "w13_weight_scale_inv_prefill"):
|
||||
w13_block_size_prefill = [experts_layer.w13_weight.shape[1] // experts_layer.w13_weight_scale_inv_prefill.shape[1],
|
||||
experts_layer.w13_weight.shape[2] // experts_layer.w13_weight_scale_inv_prefill.shape[2]]
|
||||
experts_layer.w13_block_size_prefill = w13_block_size_prefill
|
||||
if not hasattr(experts_layer, "w2_block_size_prefill") and hasattr(experts_layer, "w2_weight_scale_inv_prefill"):
|
||||
w2_block_size_prefill = [experts_layer.w2_weight.shape[1] // experts_layer.w2_weight_scale_inv_prefill.shape[1],
|
||||
experts_layer.w2_weight.shape[2] // experts_layer.w2_weight_scale_inv_prefill.shape[2]]
|
||||
|
||||
experts_layer.w2_block_size_prefill = w2_block_size_prefill
|
||||
|
||||
|
||||
def vacc_fused_decode_moe_fp8(hidden_states, residual, rms_norm, gate, experts_layer, reduce_result=True):
|
||||
if reduce_result:
|
||||
tp_group = get_tp_group()
|
||||
total_bytes = hidden_states.numel() * hidden_states.element_size() * tp_group.world_size
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_decode_qwen
|
||||
final_hidden_states, residual = fuse_moe_decode_qwen(hidden_states,
|
||||
residual,
|
||||
rms_norm.weight,
|
||||
experts_layer.w13_weight,
|
||||
experts_layer.w2_weight,
|
||||
experts_layer.w13_weight_scale_inv,
|
||||
experts_layer.w2_weight_scale_inv,
|
||||
gate.weight,
|
||||
experts_layer.w13_block_size,
|
||||
experts_layer.w2_block_size,
|
||||
world_size=tp_group.world_size,
|
||||
rank = tp_group.rank_in_group,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos
|
||||
)
|
||||
return final_hidden_states, residual
|
||||
# 1. >4M
|
||||
# 2. fuse_moe_decode_v2_allreduce run Fail
|
||||
raise ValueError('total_bytes >= 4M')
|
||||
else:
|
||||
|
||||
## recompute_moe_layer_blocksize(experts_layer)
|
||||
|
||||
# print('hidden_states', hidden_states.shape, hidden_states.dtype)
|
||||
# print('residual', residual.shape, residual.dtype)
|
||||
# print('rms_norm.weight', rms_norm.weight.shape, rms_norm.weight.dtype)
|
||||
# print('w13', experts_layer.w13_weight.shape, experts_layer.w13_weight.dtype, experts_layer.w13_weight.stride())
|
||||
# print('w2', experts_layer.w2_weight.shape, experts_layer.w2_weight.dtype, experts_layer.w2_weight.stride())
|
||||
# print('w13_scale', experts_layer.w13_weight_scale_inv.shape, experts_layer.w13_weight_scale_inv.dtype, experts_layer.w13_weight_scale_inv.stride())
|
||||
# print('w2_scale', experts_layer.w2_weight_scale_inv.shape, experts_layer.w2_weight_scale_inv.dtype, experts_layer.w2_weight_scale_inv.stride())
|
||||
# print('gate.weight', gate.weight.shape, gate.weight.dtype, gate.weight.stride())
|
||||
# print('block_size_13', experts_layer.w13_block_size)
|
||||
# print('block_size_2', experts_layer.w2_block_size)
|
||||
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_decode_qwen
|
||||
final_hidden_states, residual = fuse_moe_decode_qwen(hidden_states,
|
||||
residual,
|
||||
rms_norm.weight,
|
||||
experts_layer.w13_weight,
|
||||
experts_layer.w2_weight,
|
||||
experts_layer.w13_weight_scale_inv,
|
||||
experts_layer.w2_weight_scale_inv,
|
||||
gate.weight,
|
||||
experts_layer.w13_block_size,
|
||||
experts_layer.w2_block_size,
|
||||
)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states, residual
|
||||
|
||||
|
||||
def vacc_fused_prefill_moe_fp8(hidden_states, residual, rms_norm, gate, experts_layer):
|
||||
|
||||
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_prefill_stage0_qwen
|
||||
from torch_vacc.vacc.custom_ops import fused_experts
|
||||
# fuse_moe_prefill_stage0 -> new api
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
|
||||
zero_moe_hidden_state_tensor = None
|
||||
if memory_recycler is not None:
|
||||
zero_moe_hidden_state_tensor = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
|
||||
|
||||
hidden_states, zero_moe_hidden_state_opt, topk_weights, topk_ids, residual = fuse_moe_prefill_stage0_qwen(hidden_states,
|
||||
residual,
|
||||
rms_norm.weight,
|
||||
gate.weight,
|
||||
rms_hidden_state_opt=hidden_states,
|
||||
zero_moe_hidden_state_opt=zero_moe_hidden_state_tensor)
|
||||
|
||||
# recompute_moe_layer_blocksize(experts_layer)
|
||||
# block_size_2 = experts_layer.w2_block_size
|
||||
# experts_num = experts_layer.w13_weight.shape[0]
|
||||
|
||||
w13_scale=experts_layer.w13_weight_scale_inv
|
||||
w2_scale=experts_layer.w2_weight_scale_inv
|
||||
block_size_13 = experts_layer.w13_block_size
|
||||
if hasattr(experts_layer, "w13_weight_scale_inv_prefill") and hasattr(experts_layer, "w2_weight_scale_inv_prefill"):
|
||||
w13_scale = experts_layer.w13_weight_scale_inv_prefill
|
||||
w2_scale=experts_layer.w2_weight_scale_inv_prefill
|
||||
block_size_13 = experts_layer.w13_block_size_prefill
|
||||
|
||||
# output_opt must be used
|
||||
final_hidden_states = fused_experts(
|
||||
hidden_states,
|
||||
experts_layer.w13_weight,
|
||||
experts_layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
a13_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=block_size_13,
|
||||
decode_with_batch=False,
|
||||
output_opt=zero_moe_hidden_state_opt,
|
||||
)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
return final_hidden_states, residual
|
||||
|
||||
|
||||
def vacc_fused_prefill_moe_gptq_int4(hidden_states, residual, rms_norm, gate, experts_layer):
|
||||
from torch_vacc.vacc.custom_ops import fused_experts_int4_prefill
|
||||
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_prefill_stage0_qwen
|
||||
|
||||
# fuse_moe_prefill_stage0 -> new api
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
|
||||
zero_moe_hidden_state_tensor = None
|
||||
if memory_recycler is not None:
|
||||
zero_moe_hidden_state_tensor = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
|
||||
|
||||
hidden_states, zero_moe_hidden_state_opt, topk_weights, topk_ids, residual = fuse_moe_prefill_stage0_qwen(hidden_states,
|
||||
residual,
|
||||
rms_norm.weight,
|
||||
gate.weight,
|
||||
rms_hidden_state_opt=hidden_states,
|
||||
zero_moe_hidden_state_opt=zero_moe_hidden_state_tensor)
|
||||
|
||||
|
||||
w13_scale=experts_layer.w13_scales # w13_weight_scale_inv
|
||||
w2_scale=experts_layer.w2_scales # w2_weight_scale_inv
|
||||
block_size_13 = (1, experts_layer.w13_block_size)
|
||||
block_size_2 = (1, experts_layer.w2_block_size)
|
||||
|
||||
# output_opt must be used
|
||||
final_hidden_states = fused_experts_int4_prefill(
|
||||
hidden_states,
|
||||
experts_layer.w13_qweight,
|
||||
experts_layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w13_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
a13_scale=None,
|
||||
a2_scale=None,
|
||||
w13_block_shape=block_size_13,
|
||||
w2_block_shape =block_size_2,
|
||||
output_opt=zero_moe_hidden_state_opt,)
|
||||
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states, residual
|
||||
|
||||
|
||||
def vacc_fused_decode_moe_gptq_int4(hidden_states, residual, rms_norm, gate, experts_layer, reduce_result=True):
|
||||
if reduce_result:
|
||||
tp_group = get_tp_group()
|
||||
total_bytes = hidden_states.numel() * hidden_states.element_size() * tp_group.world_size
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_decode_qwen
|
||||
block_size_13 = (1, experts_layer.w13_block_size)
|
||||
block_size_2 = (1, experts_layer.w2_block_size)
|
||||
final_hidden_states, residual = fuse_moe_decode_qwen(hidden_states,
|
||||
residual,
|
||||
rms_norm.weight,
|
||||
experts_layer.w13_qweight,
|
||||
experts_layer.w2_qweight,
|
||||
experts_layer.w13_scales,
|
||||
experts_layer.w2_scales,
|
||||
gate.weight,
|
||||
block_size_13,
|
||||
block_size_2,
|
||||
world_size=tp_group.world_size,
|
||||
rank = tp_group.rank_in_group,
|
||||
group_id=tp_group.group_id,
|
||||
dev_info=tp_group.rank_device_infos
|
||||
)
|
||||
return final_hidden_states, residual
|
||||
# 1. >4M
|
||||
# 2. fuse_moe_decode_v2_allreduce run Fail
|
||||
raise ValueError('total_bytes >= 4M')
|
||||
else:
|
||||
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_decode_qwen
|
||||
final_hidden_states, residual = fuse_moe_decode_qwen(hidden_states,
|
||||
residual,
|
||||
rms_norm.weight,
|
||||
experts_layer.w13_qweight,
|
||||
experts_layer.w2_qweight,
|
||||
experts_layer.w13_scales,
|
||||
experts_layer.w2_scales,
|
||||
gate.weight,
|
||||
experts_layer.w13_block_size,
|
||||
experts_layer.w2_block_size,
|
||||
)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states, residual
|
||||
|
||||
Reference in New Issue
Block a user