This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,250 @@
import torch
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
# moe layer blocksize
# 1. shared_expert_blocksize (w13, w2)
# 2. expert_blocksize (w)
def recompute_moe_layer_blocksize(moe_layer, share_expert_layer, experts_layer):
'''
if not hasattr(share_expert_layer, "w13_block_size"):
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
w13_block_size0 = w13_block_size0 // scale_n
w13_block_size1 = w13_block_size1 // scale_k
share_expert_layer.w13_block_size = [w13_block_size0, w13_block_size1]
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
w2_block_size0 = w2_block_size0 // scale_n
w2_block_size1 = w2_block_size1 // scale_k
share_expert_layer.w2_block_size = [w2_block_size0, w2_block_size1]
if not hasattr(moe_layer, "block_size"):
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
block_size0, block_size1 = (
hidden_dims // hidden_blocks,
inter_dims // inter_blocks,
)
experts_layer.block_size = [block_size0, block_size1]
'''
if not hasattr(share_expert_layer, "w13_block_size"):
w13_block_size0, w13_block_size1 = share_expert_layer.gate_up_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.gate_up_proj.quant_method.scale_n, share_expert_layer.gate_up_proj.quant_method.scale_k
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
w13_block_size0 = w13_block_size0 // scale_n
w13_block_size1 = w13_block_size1 // scale_k
share_expert_layer.w13_block_size = [w13_block_size0, w13_block_size1]
w2_block_size0, w2_block_size1 = share_expert_layer.down_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = share_expert_layer.down_proj.quant_method.scale_n, share_expert_layer.down_proj.quant_method.scale_k
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
w2_block_size0 = w2_block_size0 // scale_n
w2_block_size1 = w2_block_size1 // scale_k
share_expert_layer.w2_block_size = [w2_block_size0, w2_block_size1]
if not hasattr(experts_layer, "w13w2_block_size"):
hidden_dims, inter_dims = experts_layer.w13_weight.shape[1], experts_layer.w13_weight.shape[2]
hidden_blocks, inter_blocks = experts_layer.w13_weight_scale_inv.shape[1], experts_layer.w13_weight_scale_inv.shape[2]
block_size0, block_size1 = (
hidden_dims // hidden_blocks,
inter_dims // inter_blocks,
)
experts_layer.w13w2_block_size = [block_size0, block_size1]
def vacc_fused_mlp_fp8(mlp, hidden_states, residual = None, rms_norm = None, reduce_result = False, moe_share = False):
if not hasattr(mlp, "w13_block_size"):
w13_block_size0, w13_block_size1 = mlp.gate_up_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = mlp.gate_up_proj.quant_method.scale_n, mlp.gate_up_proj.quant_method.scale_k
assert w13_block_size0 % scale_n == 0 and w13_block_size1 % scale_k == 0
w13_block_size0 = w13_block_size0 // scale_n
w13_block_size1 = w13_block_size1 // scale_k
mlp.w13_block_size = [w13_block_size0, w13_block_size1]
w2_block_size0, w2_block_size1 = mlp.down_proj.quant_method.quant_config.weight_block_size
scale_n, scale_k = mlp.down_proj.quant_method.scale_n, mlp.down_proj.quant_method.scale_k
assert w2_block_size0 % scale_n == 0 and w2_block_size1 % scale_k == 0
w2_block_size0 = w2_block_size0 // scale_n
w2_block_size1 = w2_block_size1 // scale_k
mlp.w2_block_size = [w2_block_size0, w2_block_size1]
# start mlp fused ops
w13_weight = mlp.gate_up_proj.weight
w2_weight = mlp.down_proj.weight
w13_wscale = mlp.gate_up_proj.weight_scale_inv
w13_iscale = mlp.gate_up_proj.input_scale
w2_wscale = mlp.down_proj.weight_scale_inv
w2_iscale = mlp.down_proj.input_scale
if (residual == None):
from torch_vacc.vacc.custom_ops import fused_mlp_fp8 as fused_share_expert
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
mlp_output = hidden_states
# moe share experts out, can't be inplaced
if moe_share and memory_recycler is not None:
mlp_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
mlp_output = fused_share_expert(hidden_states, w13_weight, w2_weight,
True,
w13_wscale, w2_wscale,
w13_iscale, w2_iscale,
mlp.w13_block_size,
mlp.w2_block_size,
output=mlp_output)
return mlp_output
else:
try:
from torch_vacc.vacc.custom_ops import fused_mlp_allreduce
tp_group = get_tp_group()
final_hidden_states, residual = fused_mlp_allreduce(hidden_states, residual, rms_norm.weight,
w13_weight, w2_weight,
w13_wscale, w2_wscale,
mlp.w13_block_size,
mlp.w2_block_size,
red_op_type=0,
world_size=tp_group.world_size,
rank=tp_group.rank_in_group,
root_rank=0,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
pass
except Exception as e:
print("fused mlp_fp8_allreduce run false, now use unfused mlp and allreduce", e)
final_hidden_states, residual = torch.vacc.fused_mlp_with_rmsnorm(hidden_states, residual, rms_norm.weight,
w13_weight, w2_weight,
w13_wscale, w2_wscale,
mlp.w13_block_size,
mlp.w2_block_size,)
if reduce_result:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states, residual
def vacc_fused_decode_moe_fp8(moe_layer, share_expert_layer, hidden_states, residual, rms_norm, gate, experts_layer, routed_scaling_factor, reduce_result):
w13_weight = share_expert_layer.gate_up_proj.weight
w2_weight = share_expert_layer.down_proj.weight
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
# w13_iscale = share_expert_layer.gate_up_proj.input_scale
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
# w2_iscale = share_expert_layer.down_proj.input_scale
recompute_moe_layer_blocksize(moe_layer, share_expert_layer, experts_layer)
if reduce_result:
tp_group = get_tp_group()
total_bytes = hidden_states.numel() * hidden_states.element_size() * tp_group.world_size
# only support 4M now
if total_bytes < 4194304:
try:
from torch_vacc.vacc.custom_ops import fuse_moe_decode_v2_allreduce
final_hidden_states, residual = fuse_moe_decode_v2_allreduce(hidden_states, residual, rms_norm.weight,
w13_weight, w2_weight,
w13_wscale, w2_wscale,
experts_layer.w13_weight, experts_layer.w2_weight,
experts_layer.w13_weight_scale_inv, experts_layer.w2_weight_scale_inv,
gate.weight, gate.e_score_correction_bias,
share_expert_layer.w13_block_size,
share_expert_layer.w2_block_size,
experts_layer.w13w2_block_size,
experts_layer.w13w2_block_size,
red_op_type=0,
world_size=tp_group.world_size,
rank = tp_group.rank_in_group,
root_rank=0,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
return final_hidden_states, residual
except Exception as e:
print("fuse_moe with all_reduce run Fail, now use unfused moe&all_reduce", e)
# 1. >4M
# 2. fuse_moe_decode_v2_allreduce run Fail
final_hidden_states, residual = torch.vacc.fused_mlp_moe_with_rmsnorm(hidden_states, residual, rms_norm.weight,
w13_weight, w2_weight,
w13_wscale, w2_wscale,
experts_layer.w13_weight, experts_layer.w2_weight,
experts_layer.w13_weight_scale_inv, experts_layer.w2_weight_scale_inv,
gate.weight, gate.e_score_correction_bias,
share_expert_layer.w13_block_size, share_expert_layer.w2_block_size,
experts_layer.w13w2_block_size, experts_layer.w13w2_block_size)
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
else:
final_hidden_states, residual = torch.vacc.fused_mlp_moe_with_rmsnorm(hidden_states, residual, rms_norm.weight,
w13_weight, w2_weight,
w13_wscale, w2_wscale,
experts_layer.w13_weight, experts_layer.w2_weight,
experts_layer.w13_weight_scale_inv, experts_layer.w2_weight_scale_inv,
gate.weight, gate.e_score_correction_bias,
share_expert_layer.w13_block_size, share_expert_layer.w2_block_size,
experts_layer.w13w2_block_size, experts_layer.w13w2_block_size)
return final_hidden_states, residual
def vacc_fused_prefill_moe_fp8(moe_layer, share_expert_layer, hidden_states, residual, rms_norm, gate, experts_layer, routed_scaling_factor, reduce_result):
from torch_vacc.vacc.custom_ops import fused_experts, fuse_moe_prefill_stage0
w13_weight = share_expert_layer.gate_up_proj.weight
w2_weight = share_expert_layer.down_proj.weight
w13_wscale = share_expert_layer.gate_up_proj.weight_scale_inv
# w13_iscale = share_expert_layer.gate_up_proj.input_scale
w2_wscale = share_expert_layer.down_proj.weight_scale_inv
# w2_iscale = share_expert_layer.down_proj.input_scale
recompute_moe_layer_blocksize(moe_layer, share_expert_layer, experts_layer)
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
share_experts_output = None
#experts_output = None
if memory_recycler is not None:
share_experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
#experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER
hidden_states, shared_output, topk_weights, topk_ids, residual = fuse_moe_prefill_stage0(hidden_states,
residual,
rms_norm.weight,
w13_weight,
w2_weight,
w13_wscale,
w2_wscale,
gate.weight,
gate.e_score_correction_bias,
share_expert_layer.w13_block_size,
share_expert_layer.w2_block_size,
rms_hidden_state_opt=hidden_states,
mlp_hidden_state_opt=share_experts_output)
# output_opt must be used
final_hidden_states = fused_experts(
hidden_states,
experts_layer.w13_weight,
experts_layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=experts_layer.w13_weight_scale_inv,
w2_scale=experts_layer.w2_weight_scale_inv,
a13_scale=None,
a2_scale=None,
block_shape=experts_layer.w13w2_block_size,
decode_with_batch=False,
output_opt=shared_output,
)
# final_hidden_states = torch.add(shared_output, final_hidden_states, alpha=routed_scaling_factor)
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states, residual
# shared_output.add_(final_hidden_states, alpha=routed_scaling_factor)
# shared_output = tensor_model_parallel_all_reduce(
# shared_output)
# return shared_output, residual

View File

@@ -0,0 +1,41 @@
from typing import Union
import torch
from torch import nn
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
def get_sin_cos_mrope(rotary_emb: MRotaryEmbedding, positions: Union[torch.Tensor], use_fuse: bool = True):
#ref: MRotaryEmbedding forward_native
# odsp
if use_fuse:
return torch.vacc.mrope_get_sin_cos(
rotary_emb.sin_cache,
rotary_emb.cos_cache,
positions,
rotary_emb.mrope_section,
rotary_emb.mrope_interleaved
)
cos = rotary_emb.cos_cache[positions]
sin = rotary_emb.sin_cache[positions]
if rotary_emb.mrope_interleaved:
cos_cache = apply_interleaved_rope(cos, rotary_emb.mrope_section)
sin_cache = apply_interleaved_rope(sin, rotary_emb.mrope_section)
else:
cos_cache = torch.cat([
m[i] for i, m in enumerate(
cos.split(rotary_emb.mrope_section, dim=-1))
],
dim=-1)
sin_cache = torch.cat([
m[i] for i, m in enumerate(
sin.split(rotary_emb.mrope_section, dim=-1))
],
dim=-1)
return cos_cache, sin_cache

View File

@@ -0,0 +1,222 @@
import torch
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
def recompute_moe_layer_blocksize(experts_layer):
if not hasattr(experts_layer, "w13_block_size"):
block_size_13 = [experts_layer.w13_weight.shape[1] // experts_layer.w13_weight_scale_inv.shape[1],
experts_layer.w13_weight.shape[2] // experts_layer.w13_weight_scale_inv.shape[2]]
experts_layer.w13_block_size = block_size_13
if not hasattr(experts_layer, "w2_block_size"):
# block_size_2 = [experts_layer.w2_weight.shape[2] // experts_layer.w2_weight_scale_inv.shape[2],
# experts_layer.w2_weight.shape[1] // experts_layer.w2_weight_scale_inv.shape[1]]
block_size_2 = [experts_layer.w2_weight.shape[1] // experts_layer.w2_weight_scale_inv.shape[1],
experts_layer.w2_weight.shape[2] // experts_layer.w2_weight_scale_inv.shape[2]]
experts_layer.w2_block_size = block_size_2
if not hasattr(experts_layer, "w13_block_size_prefill") and hasattr(experts_layer, "w13_weight_scale_inv_prefill"):
w13_block_size_prefill = [experts_layer.w13_weight.shape[1] // experts_layer.w13_weight_scale_inv_prefill.shape[1],
experts_layer.w13_weight.shape[2] // experts_layer.w13_weight_scale_inv_prefill.shape[2]]
experts_layer.w13_block_size_prefill = w13_block_size_prefill
if not hasattr(experts_layer, "w2_block_size_prefill") and hasattr(experts_layer, "w2_weight_scale_inv_prefill"):
w2_block_size_prefill = [experts_layer.w2_weight.shape[1] // experts_layer.w2_weight_scale_inv_prefill.shape[1],
experts_layer.w2_weight.shape[2] // experts_layer.w2_weight_scale_inv_prefill.shape[2]]
experts_layer.w2_block_size_prefill = w2_block_size_prefill
def vacc_fused_decode_moe_fp8(hidden_states, residual, rms_norm, gate, experts_layer, reduce_result=True):
if reduce_result:
tp_group = get_tp_group()
total_bytes = hidden_states.numel() * hidden_states.element_size() * tp_group.world_size
# only support 4M now
if total_bytes < 4194304:
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_decode_qwen
final_hidden_states, residual = fuse_moe_decode_qwen(hidden_states,
residual,
rms_norm.weight,
experts_layer.w13_weight,
experts_layer.w2_weight,
experts_layer.w13_weight_scale_inv,
experts_layer.w2_weight_scale_inv,
gate.weight,
experts_layer.w13_block_size,
experts_layer.w2_block_size,
world_size=tp_group.world_size,
rank = tp_group.rank_in_group,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos
)
return final_hidden_states, residual
# 1. >4M
# 2. fuse_moe_decode_v2_allreduce run Fail
raise ValueError('total_bytes >= 4M')
else:
## recompute_moe_layer_blocksize(experts_layer)
# print('hidden_states', hidden_states.shape, hidden_states.dtype)
# print('residual', residual.shape, residual.dtype)
# print('rms_norm.weight', rms_norm.weight.shape, rms_norm.weight.dtype)
# print('w13', experts_layer.w13_weight.shape, experts_layer.w13_weight.dtype, experts_layer.w13_weight.stride())
# print('w2', experts_layer.w2_weight.shape, experts_layer.w2_weight.dtype, experts_layer.w2_weight.stride())
# print('w13_scale', experts_layer.w13_weight_scale_inv.shape, experts_layer.w13_weight_scale_inv.dtype, experts_layer.w13_weight_scale_inv.stride())
# print('w2_scale', experts_layer.w2_weight_scale_inv.shape, experts_layer.w2_weight_scale_inv.dtype, experts_layer.w2_weight_scale_inv.stride())
# print('gate.weight', gate.weight.shape, gate.weight.dtype, gate.weight.stride())
# print('block_size_13', experts_layer.w13_block_size)
# print('block_size_2', experts_layer.w2_block_size)
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_decode_qwen
final_hidden_states, residual = fuse_moe_decode_qwen(hidden_states,
residual,
rms_norm.weight,
experts_layer.w13_weight,
experts_layer.w2_weight,
experts_layer.w13_weight_scale_inv,
experts_layer.w2_weight_scale_inv,
gate.weight,
experts_layer.w13_block_size,
experts_layer.w2_block_size,
)
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states, residual
def vacc_fused_prefill_moe_fp8(hidden_states, residual, rms_norm, gate, experts_layer):
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_prefill_stage0_qwen
from torch_vacc.vacc.custom_ops import fused_experts
# fuse_moe_prefill_stage0 -> new api
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
zero_moe_hidden_state_tensor = None
if memory_recycler is not None:
zero_moe_hidden_state_tensor = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
hidden_states, zero_moe_hidden_state_opt, topk_weights, topk_ids, residual = fuse_moe_prefill_stage0_qwen(hidden_states,
residual,
rms_norm.weight,
gate.weight,
rms_hidden_state_opt=hidden_states,
zero_moe_hidden_state_opt=zero_moe_hidden_state_tensor)
# recompute_moe_layer_blocksize(experts_layer)
# block_size_2 = experts_layer.w2_block_size
# experts_num = experts_layer.w13_weight.shape[0]
w13_scale=experts_layer.w13_weight_scale_inv
w2_scale=experts_layer.w2_weight_scale_inv
block_size_13 = experts_layer.w13_block_size
if hasattr(experts_layer, "w13_weight_scale_inv_prefill") and hasattr(experts_layer, "w2_weight_scale_inv_prefill"):
w13_scale = experts_layer.w13_weight_scale_inv_prefill
w2_scale=experts_layer.w2_weight_scale_inv_prefill
block_size_13 = experts_layer.w13_block_size_prefill
# output_opt must be used
final_hidden_states = fused_experts(
hidden_states,
experts_layer.w13_weight,
experts_layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=w13_scale,
w2_scale=w2_scale,
a13_scale=None,
a2_scale=None,
block_shape=block_size_13,
decode_with_batch=False,
output_opt=zero_moe_hidden_state_opt,
)
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states, residual
def vacc_fused_prefill_moe_gptq_int4(hidden_states, residual, rms_norm, gate, experts_layer):
from torch_vacc.vacc.custom_ops import fused_experts_int4_prefill
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_prefill_stage0_qwen
# fuse_moe_prefill_stage0 -> new api
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
zero_moe_hidden_state_tensor = None
if memory_recycler is not None:
zero_moe_hidden_state_tensor = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
hidden_states, zero_moe_hidden_state_opt, topk_weights, topk_ids, residual = fuse_moe_prefill_stage0_qwen(hidden_states,
residual,
rms_norm.weight,
gate.weight,
rms_hidden_state_opt=hidden_states,
zero_moe_hidden_state_opt=zero_moe_hidden_state_tensor)
w13_scale=experts_layer.w13_scales # w13_weight_scale_inv
w2_scale=experts_layer.w2_scales # w2_weight_scale_inv
block_size_13 = (1, experts_layer.w13_block_size)
block_size_2 = (1, experts_layer.w2_block_size)
# output_opt must be used
final_hidden_states = fused_experts_int4_prefill(
hidden_states,
experts_layer.w13_qweight,
experts_layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
w13_scale=w13_scale,
w2_scale=w2_scale,
a13_scale=None,
a2_scale=None,
w13_block_shape=block_size_13,
w2_block_shape =block_size_2,
output_opt=zero_moe_hidden_state_opt,)
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states, residual
def vacc_fused_decode_moe_gptq_int4(hidden_states, residual, rms_norm, gate, experts_layer, reduce_result=True):
if reduce_result:
tp_group = get_tp_group()
total_bytes = hidden_states.numel() * hidden_states.element_size() * tp_group.world_size
# only support 4M now
if total_bytes < 4194304:
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_decode_qwen
block_size_13 = (1, experts_layer.w13_block_size)
block_size_2 = (1, experts_layer.w2_block_size)
final_hidden_states, residual = fuse_moe_decode_qwen(hidden_states,
residual,
rms_norm.weight,
experts_layer.w13_qweight,
experts_layer.w2_qweight,
experts_layer.w13_scales,
experts_layer.w2_scales,
gate.weight,
block_size_13,
block_size_2,
world_size=tp_group.world_size,
rank = tp_group.rank_in_group,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos
)
return final_hidden_states, residual
# 1. >4M
# 2. fuse_moe_decode_v2_allreduce run Fail
raise ValueError('total_bytes >= 4M')
else:
from torch_vacc.vacc.custom_qwen3_ops import fuse_moe_decode_qwen
final_hidden_states, residual = fuse_moe_decode_qwen(hidden_states,
residual,
rms_norm.weight,
experts_layer.w13_qweight,
experts_layer.w2_qweight,
experts_layer.w13_scales,
experts_layer.w2_scales,
gate.weight,
experts_layer.w13_block_size,
experts_layer.w2_block_size,
)
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states, residual