Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -59,9 +59,164 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
#[B,K//8,N] ->[B,K,N]
|
||||
# less memmory
|
||||
def unpack_k_batch_opt(packed_w: torch.Tensor, num_bits: int = 4, chunk_size: int = 2) -> torch.Tensor:
|
||||
"""
|
||||
Memory-efficient unpacking for 3D tensor.
|
||||
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor,
|
||||
without broadcasting huge intermediate tensors (avoids OOM).
|
||||
|
||||
Args:
|
||||
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
|
||||
num_bits: Number of bits per packed element (e.g., 4 or 2).
|
||||
chunk_size: How many bit groups to unpack at once (tradeoff between speed and memory).
|
||||
|
||||
Returns:
|
||||
unpacked: torch.int8 tensor of shape [B, K, N].
|
||||
"""
|
||||
B, k_packed, N = packed_w.shape
|
||||
pack_factor = 32 // num_bits
|
||||
K = k_packed * pack_factor
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
# Allocate output tensor once
|
||||
unpacked = torch.empty((B, K, N), dtype=torch.int8, device=packed_w.device)
|
||||
|
||||
# Process bit chunks iteratively to save memory
|
||||
for i in range(0, pack_factor, chunk_size):
|
||||
# Precompute shifts for this chunk
|
||||
shift_vals = num_bits * torch.arange(i, min(i + chunk_size, pack_factor), device=packed_w.device)
|
||||
# [chunk_size, 1, 1, 1]
|
||||
shifts = shift_vals.view(-1, 1, 1, 1)
|
||||
# Compute small chunk only
|
||||
chunk = ((packed_w.unsqueeze(0) >> shifts) & mask).to(torch.int8)
|
||||
|
||||
# chunk: [chunk_size, B, k_packed, N]
|
||||
# write into output
|
||||
for j in range(chunk.shape[0]):
|
||||
unpacked[:, (i + j)::pack_factor, :] = chunk[j]
|
||||
|
||||
del chunk # release memory early
|
||||
|
||||
return unpacked
|
||||
|
||||
# more memmory
|
||||
def unpack_k_batch(packed_w: torch.Tensor, num_bits: int = 4) -> torch.Tensor:
|
||||
"""
|
||||
Efficient vectorized unpacking for 3D tensor (batch version).
|
||||
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor.
|
||||
|
||||
Args:
|
||||
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
|
||||
num_bits: Number of bits per packed element (e.g., 4).
|
||||
|
||||
Returns:
|
||||
unpacked: torch.int8 tensor of shape [B, K, N].
|
||||
"""
|
||||
B, k_packed, n = packed_w.shape
|
||||
pack_factor = 32 // num_bits
|
||||
k = k_packed * pack_factor
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
# [pack_factor, 1, 1, 1]
|
||||
shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1, 1)
|
||||
|
||||
# [1, B, k_packed, N]
|
||||
packed_expanded = packed_w.unsqueeze(0)
|
||||
|
||||
# Extract each group of num_bits using bitwise ops
|
||||
unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
|
||||
|
||||
# [pack_factor, B, k_packed, N] → [B, K, N]
|
||||
unpacked = unpacked_groups.permute(1, 2, 0, 3).reshape(B, k, n)
|
||||
|
||||
return unpacked
|
||||
|
||||
|
||||
#[B,K,N] ->[B,K,N//8]
|
||||
# less memmory
|
||||
def pack_n_batch_opt(x: torch.Tensor, pack_num: int = 8, order_map=None, chunk_size: int = 2) -> torch.Tensor:
|
||||
"""
|
||||
Memory-efficient batch packing with correct bit order.
|
||||
[B, K, N] int4 -> [B, K, N//pack_num] int32.
|
||||
"""
|
||||
B, K, N = x.shape
|
||||
assert N % pack_num == 0, "N must be divisible by pack_num"
|
||||
cols = N // pack_num
|
||||
unit = 32 // pack_num
|
||||
|
||||
if order_map is None:
|
||||
order_map = list(range(pack_num))
|
||||
order_map = torch.tensor(order_map, device=x.device)
|
||||
|
||||
shifts = unit * torch.arange(pack_num, device=x.device) # always 0..unit*(pack_num-1)
|
||||
packed = torch.zeros((B, K, cols), dtype=torch.int32, device=x.device)
|
||||
x_reshape = x.view(B, K, cols, pack_num) & 0xF
|
||||
|
||||
# process in chunks for memory efficiency
|
||||
for start in range(0, pack_num, chunk_size):
|
||||
end = min(start + chunk_size, pack_num)
|
||||
idx_chunk = order_map[start:end]
|
||||
shift_chunk = shifts[start:end]
|
||||
|
||||
vals = torch.gather(x_reshape, 3, idx_chunk.view(1,1,1,-1).expand(B,K,cols,-1)).to(torch.int32)
|
||||
for j in range(vals.shape[-1]):
|
||||
packed.add_(vals[..., j] << shift_chunk[j])
|
||||
|
||||
return packed
|
||||
|
||||
## more memmory
|
||||
def pack_n_batch(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
|
||||
"""
|
||||
Efficient vectorized batch packing: [B, K, N] int4 -> [B, K, N//pack_num] int32.
|
||||
|
||||
Args:
|
||||
x: torch.int32 tensor of shape [B, K, N], each element 0-15 (int4).
|
||||
pack_num: Number of 4-bit elements per packed int32 (default=8).
|
||||
order_map: Optional order of elements within each packed int32.
|
||||
|
||||
Returns:
|
||||
torch.int32 tensor of shape [B, K, N//pack_num].
|
||||
"""
|
||||
|
||||
B, K, N = x.shape
|
||||
assert N % pack_num == 0, "N must be divisible by pack_num"
|
||||
cols = N // pack_num
|
||||
|
||||
if order_map is None:
|
||||
order_map = list(range(pack_num))
|
||||
order_map = torch.tensor(order_map, device=x.device)
|
||||
|
||||
unit = 32 // pack_num # number of bits per element
|
||||
|
||||
# reshape to [B, K, cols, pack_num]
|
||||
pack_num_int = int(pack_num)
|
||||
|
||||
x_reshape = x.view(B, K, cols, pack_num_int)
|
||||
|
||||
# reorder according to order_map
|
||||
x_reorder = torch.gather(
|
||||
x_reshape, 3, order_map.view(1, 1, 1, -1).expand(B, K, cols, -1)
|
||||
)
|
||||
|
||||
# mask low 4 bits
|
||||
x_reorder = x_reorder & 0xF
|
||||
|
||||
# bit shifts [pack_num] -> [1,1,1,pack_num] broadcastable
|
||||
shifts = (unit * torch.arange(pack_num_int, device=x.device)).view(1, 1, 1, -1)
|
||||
|
||||
# shift and sum along last dimension to combine bits
|
||||
packed = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
|
||||
|
||||
return packed
|
||||
|
||||
|
||||
|
||||
def get_moe_quant_method(
|
||||
config: "GPTQMarlinConfig",
|
||||
@@ -495,8 +650,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.quant_type.size_bits == 4:
|
||||
self.quant_type = scalar_types.uint4b8
|
||||
elif self.quant_config.quant_type.size_bits == 8:
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
# elif self.quant_config.quant_type.size_bits == 8:
|
||||
# self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
self.input_dtype = None
|
||||
@@ -594,7 +749,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -606,7 +761,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
num_experts,
|
||||
scales_size2,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -656,7 +811,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
# layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
@@ -673,119 +828,111 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_scales.data = layer.w2_scales.data * 512
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
||||
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Repack weights
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# num_experts = layer.w13_g_idx.shape[0]
|
||||
# w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
# w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
# w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
# w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
# for e in range(num_experts):
|
||||
# w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
||||
# torch.int32
|
||||
# )
|
||||
# w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
# torch.int32
|
||||
# )
|
||||
# w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
||||
# w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
||||
# replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
# replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
# replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
# replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
# else:
|
||||
# # Reset g_idx related tensors
|
||||
# num_experts = layer.w13_g_idx.shape[0]
|
||||
# device = layer.w13_g_idx.device
|
||||
# layer.w13_g_idx = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# layer.w2_g_idx = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# # Repack weights
|
||||
# marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
# layer.w13_qweight,
|
||||
# layer.w13_g_idx_sort_indices,
|
||||
# layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
# layer.w13_qweight.shape[2],
|
||||
# self.quant_config.quant_type.size_bits,
|
||||
# )
|
||||
# replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
# marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
# layer.w2_qweight,
|
||||
# layer.w2_g_idx_sort_indices,
|
||||
# layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
# layer.w2_qweight.shape[2],
|
||||
# self.quant_config.quant_type.size_bits,
|
||||
# )
|
||||
# replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# # Repack scales
|
||||
# marlin_w13_scales = marlin_moe_permute_scales(
|
||||
# s=layer.w13_scales,
|
||||
# size_k=layer.intermediate_size_per_partition,
|
||||
# size_n=layer.w13_scales.shape[2],
|
||||
# group_size=self.quant_config.group_size,
|
||||
# )
|
||||
# replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
# marlin_w2_scales = marlin_moe_permute_scales(
|
||||
# s=layer.w2_scales,
|
||||
# size_k=layer.w2_scales.shape[1]
|
||||
# * (
|
||||
# self.quant_config.group_size
|
||||
# if self.quant_config.group_size != -1
|
||||
# else self.quant_config.pack_factor
|
||||
# ),
|
||||
# size_n=layer.w2_scales.shape[2],
|
||||
# group_size=self.quant_config.group_size,
|
||||
# )
|
||||
# replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
# The modular kernel expects w13_weight and w2_weight,
|
||||
# but GPTQ uses w13_qweight and w2_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w13_weight = layer.w13_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w2_weight = layer.w2_qweight
|
||||
# if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
# layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
|
||||
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w13_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_input_global_scale",
|
||||
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.w2_scales.shape[1]
|
||||
* (
|
||||
self.quant_config.group_size
|
||||
if self.quant_config.group_size != -1
|
||||
else self.quant_config.pack_factor
|
||||
),
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
|
||||
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w2_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_global_scale",
|
||||
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
# if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
# layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
if self.quant_config.desc_act:
|
||||
raise NotImplementedError(
|
||||
"GPTQMarlinMoEMethod now not support desc_act. please fix it")
|
||||
w13_qweight_unpacked = unpack_k_batch(layer.w13_qweight)
|
||||
w13_qweight_repacked = pack_n_batch(w13_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
replace_parameter(layer, "w13_qweight", w13_qweight_repacked)
|
||||
|
||||
# quant vllm/model_executor/layers/quantization/utils/quant_utils.py#quantize_weights
|
||||
# if quant_type.has_bias():
|
||||
# w_q += quant_type.bias
|
||||
# use quant_type.bias as zp,(ixformer support)
|
||||
w13_zp = torch.full_like(layer.w13_scales, self.quant_type.bias, dtype=torch.int32)
|
||||
w13_zp_pack = pack_n_batch(w13_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
|
||||
replace_parameter(layer, "w13_qzeros", w13_zp_pack)
|
||||
|
||||
w2_qweight_unpacked = unpack_k_batch(layer.w2_qweight)
|
||||
w2_qweight_repacked = pack_n_batch(w2_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
replace_parameter(layer, "w2_qweight", w2_qweight_repacked)
|
||||
|
||||
w2_zp = torch.full_like(layer.w2_scales, self.quant_type.bias, dtype=torch.int32)
|
||||
w2_zp_pack = pack_n_batch(w2_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
|
||||
replace_parameter(layer, "w2_qzeros", w2_zp_pack)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@@ -900,30 +1047,165 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
getattr(layer, "w13_bias", None),
|
||||
getattr(layer, "w2_bias", None),
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
assert layer.activation.value == "silu", "Only SiLU activation is supported."
|
||||
use_ep = layer.expert_map is not None
|
||||
|
||||
if use_ep:
|
||||
start_eid = layer.ep_rank * layer.local_num_experts
|
||||
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
|
||||
|
||||
if layer.apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"GPTQMarlinMoEMethod Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
if (hasattr(layer, "w13_bias") and layer.w13_bias is not None) or (hasattr(layer, "w2_bias") and layer.w2_bias is not None):
|
||||
raise NotImplementedError(
|
||||
"GPTQMarlinMoEMethod moe_w4a16_group_gemm not supported bias, please fix this")
|
||||
|
||||
num_tokens = topk_ids.shape[0]
|
||||
num_experts = layer.global_num_experts
|
||||
|
||||
if use_ep:
|
||||
hidden_size = x.shape[1]
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
expand_tokens,
|
||||
) = ixfops.moe_compute_token_index_ep(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
start_expert_id=start_eid,
|
||||
end_expert_id=end_eid,
|
||||
)
|
||||
if expert_sizes_cpu.sum() == 0:
|
||||
return torch.zeros(
|
||||
(num_tokens, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
else:
|
||||
expand_tokens = num_tokens * layer.top_k
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
) = ixfops.moe_compute_token_index(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
|
||||
# expand + reorder
|
||||
# TODO use kernel
|
||||
expand_hidden_states = ixfops.moe_expand_input(
|
||||
hidden_states=x,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=layer.top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
|
||||
# w4a16 group gemm 1
|
||||
# pt_output_1: (expand_tokens, 2n) dtype
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * layer.top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
extra_residual=shared_experts_input,
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# return torch.ops.vllm.fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# getattr(layer, "w13_bias", None),
|
||||
# getattr(layer, "w2_bias", None),
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# router_logits,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# quant_type_id=self.quant_type.id,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
# global_num_experts=global_num_experts,
|
||||
# expert_map=expert_map,
|
||||
# g_idx1=layer.w13_g_idx,
|
||||
# g_idx2=layer.w2_g_idx,
|
||||
# sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
# sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
# workspace=layer.workspace,
|
||||
# is_k_full=self.is_k_full)
|
||||
|
||||
Reference in New Issue
Block a user