Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -59,9 +59,164 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils.collection_utils import is_list_of
import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
#[B,K//8,N] ->[B,K,N]
# less memmory
def unpack_k_batch_opt(packed_w: torch.Tensor, num_bits: int = 4, chunk_size: int = 2) -> torch.Tensor:
"""
Memory-efficient unpacking for 3D tensor.
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor,
without broadcasting huge intermediate tensors (avoids OOM).
Args:
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
num_bits: Number of bits per packed element (e.g., 4 or 2).
chunk_size: How many bit groups to unpack at once (tradeoff between speed and memory).
Returns:
unpacked: torch.int8 tensor of shape [B, K, N].
"""
B, k_packed, N = packed_w.shape
pack_factor = 32 // num_bits
K = k_packed * pack_factor
mask = (1 << num_bits) - 1
# Allocate output tensor once
unpacked = torch.empty((B, K, N), dtype=torch.int8, device=packed_w.device)
# Process bit chunks iteratively to save memory
for i in range(0, pack_factor, chunk_size):
# Precompute shifts for this chunk
shift_vals = num_bits * torch.arange(i, min(i + chunk_size, pack_factor), device=packed_w.device)
# [chunk_size, 1, 1, 1]
shifts = shift_vals.view(-1, 1, 1, 1)
# Compute small chunk only
chunk = ((packed_w.unsqueeze(0) >> shifts) & mask).to(torch.int8)
# chunk: [chunk_size, B, k_packed, N]
# write into output
for j in range(chunk.shape[0]):
unpacked[:, (i + j)::pack_factor, :] = chunk[j]
del chunk # release memory early
return unpacked
# more memmory
def unpack_k_batch(packed_w: torch.Tensor, num_bits: int = 4) -> torch.Tensor:
"""
Efficient vectorized unpacking for 3D tensor (batch version).
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor.
Args:
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
num_bits: Number of bits per packed element (e.g., 4).
Returns:
unpacked: torch.int8 tensor of shape [B, K, N].
"""
B, k_packed, n = packed_w.shape
pack_factor = 32 // num_bits
k = k_packed * pack_factor
mask = (1 << num_bits) - 1
# [pack_factor, 1, 1, 1]
shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1, 1)
# [1, B, k_packed, N]
packed_expanded = packed_w.unsqueeze(0)
# Extract each group of num_bits using bitwise ops
unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
# [pack_factor, B, k_packed, N] → [B, K, N]
unpacked = unpacked_groups.permute(1, 2, 0, 3).reshape(B, k, n)
return unpacked
#[B,K,N] ->[B,K,N//8]
# less memmory
def pack_n_batch_opt(x: torch.Tensor, pack_num: int = 8, order_map=None, chunk_size: int = 2) -> torch.Tensor:
"""
Memory-efficient batch packing with correct bit order.
[B, K, N] int4 -> [B, K, N//pack_num] int32.
"""
B, K, N = x.shape
assert N % pack_num == 0, "N must be divisible by pack_num"
cols = N // pack_num
unit = 32 // pack_num
if order_map is None:
order_map = list(range(pack_num))
order_map = torch.tensor(order_map, device=x.device)
shifts = unit * torch.arange(pack_num, device=x.device) # always 0..unit*(pack_num-1)
packed = torch.zeros((B, K, cols), dtype=torch.int32, device=x.device)
x_reshape = x.view(B, K, cols, pack_num) & 0xF
# process in chunks for memory efficiency
for start in range(0, pack_num, chunk_size):
end = min(start + chunk_size, pack_num)
idx_chunk = order_map[start:end]
shift_chunk = shifts[start:end]
vals = torch.gather(x_reshape, 3, idx_chunk.view(1,1,1,-1).expand(B,K,cols,-1)).to(torch.int32)
for j in range(vals.shape[-1]):
packed.add_(vals[..., j] << shift_chunk[j])
return packed
## more memmory
def pack_n_batch(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
"""
Efficient vectorized batch packing: [B, K, N] int4 -> [B, K, N//pack_num] int32.
Args:
x: torch.int32 tensor of shape [B, K, N], each element 0-15 (int4).
pack_num: Number of 4-bit elements per packed int32 (default=8).
order_map: Optional order of elements within each packed int32.
Returns:
torch.int32 tensor of shape [B, K, N//pack_num].
"""
B, K, N = x.shape
assert N % pack_num == 0, "N must be divisible by pack_num"
cols = N // pack_num
if order_map is None:
order_map = list(range(pack_num))
order_map = torch.tensor(order_map, device=x.device)
unit = 32 // pack_num # number of bits per element
# reshape to [B, K, cols, pack_num]
pack_num_int = int(pack_num)
x_reshape = x.view(B, K, cols, pack_num_int)
# reorder according to order_map
x_reorder = torch.gather(
x_reshape, 3, order_map.view(1, 1, 1, -1).expand(B, K, cols, -1)
)
# mask low 4 bits
x_reorder = x_reorder & 0xF
# bit shifts [pack_num] -> [1,1,1,pack_num] broadcastable
shifts = (unit * torch.arange(pack_num_int, device=x.device)).view(1, 1, 1, -1)
# shift and sum along last dimension to combine bits
packed = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
return packed
def get_moe_quant_method(
config: "GPTQMarlinConfig",
@@ -495,8 +650,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
# elif self.quant_config.quant_type.size_bits == 8:
# self.quant_type = scalar_types.uint8b128
else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
@@ -594,7 +749,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_experts,
scales_size13,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=params_dtype,
dtype=torch.int32,
),
requires_grad=False,
)
@@ -606,7 +761,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype,
dtype=torch.int32,
),
requires_grad=False,
)
@@ -656,7 +811,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace_new(device, 4)
# layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
@@ -673,119 +828,111 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order
if self.quant_config.desc_act:
# if self.quant_config.desc_act:
# Get sorting based on g_idx
num_experts = layer.w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
torch.int32
)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
torch.int32
)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# num_experts = layer.w13_g_idx.shape[0]
# w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
# w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
# w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
# w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
# for e in range(num_experts):
# w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
# torch.int32
# )
# w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
# torch.int32
# )
# w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
# w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
# replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
# replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
# replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
# replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
# else:
# # Reset g_idx related tensors
# num_experts = layer.w13_g_idx.shape[0]
# device = layer.w13_g_idx.device
# layer.w13_g_idx = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# layer.w2_g_idx = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# layer.w13_g_idx_sort_indices = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# layer.w2_g_idx_sort_indices = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# # Repack weights
# marlin_w13_qweight = ops.gptq_marlin_moe_repack(
# layer.w13_qweight,
# layer.w13_g_idx_sort_indices,
# layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
# layer.w13_qweight.shape[2],
# self.quant_config.quant_type.size_bits,
# )
# replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
# marlin_w2_qweight = ops.gptq_marlin_moe_repack(
# layer.w2_qweight,
# layer.w2_g_idx_sort_indices,
# layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
# layer.w2_qweight.shape[2],
# self.quant_config.quant_type.size_bits,
# )
# replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# # Repack scales
# marlin_w13_scales = marlin_moe_permute_scales(
# s=layer.w13_scales,
# size_k=layer.intermediate_size_per_partition,
# size_n=layer.w13_scales.shape[2],
# group_size=self.quant_config.group_size,
# )
# replace_parameter(layer, "w13_scales", marlin_w13_scales)
# marlin_w2_scales = marlin_moe_permute_scales(
# s=layer.w2_scales,
# size_k=layer.w2_scales.shape[1]
# * (
# self.quant_config.group_size
# if self.quant_config.group_size != -1
# else self.quant_config.pack_factor
# ),
# size_n=layer.w2_scales.shape[2],
# group_size=self.quant_config.group_size,
# )
# replace_parameter(layer, "w2_scales", marlin_w2_scales)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
# layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1]
* (
self.quant_config.group_size
if self.quant_config.group_size != -1
else self.quant_config.pack_factor
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
# if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
# layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
if self.quant_config.desc_act:
raise NotImplementedError(
"GPTQMarlinMoEMethod now not support desc_act. please fix it")
w13_qweight_unpacked = unpack_k_batch(layer.w13_qweight)
w13_qweight_repacked = pack_n_batch(w13_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
replace_parameter(layer, "w13_qweight", w13_qweight_repacked)
# quant vllm/model_executor/layers/quantization/utils/quant_utils.py#quantize_weights
# if quant_type.has_bias():
# w_q += quant_type.bias
# use quant_type.bias as zp,(ixformer support)
w13_zp = torch.full_like(layer.w13_scales, self.quant_type.bias, dtype=torch.int32)
w13_zp_pack = pack_n_batch(w13_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
replace_parameter(layer, "w13_qzeros", w13_zp_pack)
w2_qweight_unpacked = unpack_k_batch(layer.w2_qweight)
w2_qweight_repacked = pack_n_batch(w2_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
replace_parameter(layer, "w2_qweight", w2_qweight_repacked)
w2_zp = torch.full_like(layer.w2_scales, self.quant_type.bias, dtype=torch.int32)
w2_zp_pack = pack_n_batch(w2_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
replace_parameter(layer, "w2_qzeros", w2_zp_pack)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -900,30 +1047,165 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
getattr(layer, "w13_bias", None),
getattr(layer, "w2_bias", None),
layer.w13_scales,
layer.w2_scales,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
assert layer.activation.value == "silu", "Only SiLU activation is supported."
use_ep = layer.expert_map is not None
if use_ep:
start_eid = layer.ep_rank * layer.local_num_experts
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
if layer.apply_router_weight_on_input:
raise NotImplementedError(
"GPTQMarlinMoEMethod Apply router weight on input is not supported for"
"fused Marlin MoE method.")
if (hasattr(layer, "w13_bias") and layer.w13_bias is not None) or (hasattr(layer, "w2_bias") and layer.w2_bias is not None):
raise NotImplementedError(
"GPTQMarlinMoEMethod moe_w4a16_group_gemm not supported bias, please fix this")
num_tokens = topk_ids.shape[0]
num_experts = layer.global_num_experts
if use_ep:
hidden_size = x.shape[1]
(
src_to_dst,
sorted_token_ids,
expert_sizes_gpu,
expert_sizes_cpu,
expand_tokens,
) = ixfops.moe_compute_token_index_ep(
topk_ids=topk_ids,
num_experts=num_experts,
start_expert_id=start_eid,
end_expert_id=end_eid,
)
if expert_sizes_cpu.sum() == 0:
return torch.zeros(
(num_tokens, hidden_size),
device=x.device,
dtype=x.dtype,
)
else:
expand_tokens = num_tokens * layer.top_k
(
src_to_dst,
sorted_token_ids,
expert_sizes_gpu,
expert_sizes_cpu,
) = ixfops.moe_compute_token_index(
topk_ids=topk_ids,
num_experts=num_experts,
)
expert_sizes_cpu = expert_sizes_gpu.cpu()
# expand + reorder
# TODO use kernel
expand_hidden_states = ixfops.moe_expand_input(
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
topk=layer.top_k,
src_to_dst=src_to_dst,
)
# w4a16 group gemm 1
# pt_output_1: (expand_tokens, 2n) dtype
pt_output_1 = ixfops.moe_w4a16_group_gemm(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
# w4a16 group gemm 2 + reorder
# pt_output_3: (expand_tokens, k) dtype
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * layer.top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
output=pt_output_3,
tokens_per_experts_gpu=expert_sizes_gpu,
)
reduce_mask = src_to_dst == -1
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
mask=reduce_mask,
)
else:
pt_output_3 = ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
extra_residual=shared_experts_input,
)
return final_hidden_states
# return torch.ops.vllm.fused_marlin_moe(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# getattr(layer, "w13_bias", None),
# getattr(layer, "w2_bias", None),
# layer.w13_scales,
# layer.w2_scales,
# router_logits,
# topk_weights,
# topk_ids,
# quant_type_id=self.quant_type.id,
# apply_router_weight_on_input=apply_router_weight_on_input,
# global_num_experts=global_num_experts,
# expert_map=expert_map,
# g_idx1=layer.w13_g_idx,
# g_idx2=layer.w2_g_idx,
# sort_indices1=layer.w13_g_idx_sort_indices,
# sort_indices2=layer.w2_g_idx_sort_indices,
# workspace=layer.workspace,
# is_k_full=self.is_k_full)