add flashinfer mxfp4 (#8847)
This commit is contained in:
@@ -38,6 +38,7 @@ from sglang.srt.utils import (
|
|||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
next_power_of_2,
|
next_power_of_2,
|
||||||
|
round_up,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
@@ -146,7 +147,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.num_fused_shared_experts = num_fused_shared_experts
|
self.num_fused_shared_experts = num_fused_shared_experts
|
||||||
self.expert_map_cpu = None
|
self.expert_map_cpu = None
|
||||||
@@ -206,6 +206,16 @@ class FusedMoE(torch.nn.Module):
|
|||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
if (
|
||||||
|
self.quant_config is not None
|
||||||
|
and self.quant_config.get_name() == "mxfp4"
|
||||||
|
and (
|
||||||
|
get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE")
|
||||||
|
or get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE")
|
||||||
|
)
|
||||||
|
):
|
||||||
|
hidden_size = round_up(hidden_size, 256)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
num_experts=self.num_local_experts,
|
num_experts=self.num_local_experts,
|
||||||
@@ -784,6 +794,14 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||||
|
origin_hidden_states_dim = hidden_states.shape[-1]
|
||||||
|
if self.hidden_size != origin_hidden_states_dim:
|
||||||
|
hidden_states = torch.nn.functional.pad(
|
||||||
|
hidden_states,
|
||||||
|
(0, self.hidden_size - origin_hidden_states_dim),
|
||||||
|
mode="constant",
|
||||||
|
value=0.0,
|
||||||
|
)
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
||||||
@@ -829,7 +847,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_expert_params_mapping(
|
def make_expert_params_mapping(
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
|
get_bool_env_var,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
@@ -31,6 +32,12 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||||
|
|
||||||
|
# Environment variables for FlashInfer MXFP4 MoE backend
|
||||||
|
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false")
|
||||||
|
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
|
||||||
|
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
|
||||||
|
)
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
@@ -228,16 +235,28 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
weight_dtype = torch.uint8
|
weight_dtype = torch.uint8
|
||||||
scale_dtype = torch.uint8
|
scale_dtype = torch.uint8
|
||||||
|
|
||||||
intermediate_size *= 2
|
|
||||||
mxfp4_block = 32
|
mxfp4_block = 32
|
||||||
|
|
||||||
self.intermediate_size = intermediate_size
|
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||||
|
# for to hold non-uniform sharded tensor as well as swizzling
|
||||||
|
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||||
|
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
|
||||||
|
hidden_size = round_up(hidden_size, 256)
|
||||||
|
elif is_hip():
|
||||||
|
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 128)
|
||||||
|
else:
|
||||||
|
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 64)
|
||||||
|
|
||||||
|
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
# Fused gate_up_proj (column parallel)
|
# Fused gate_up_proj (column parallel)
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts, 2 * intermediate_size, hidden_size // 2, dtype=weight_dtype
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition_after_pad,
|
||||||
|
hidden_size // 2,
|
||||||
|
dtype=weight_dtype,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
@@ -247,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
w13_weight_scale = torch.nn.Parameter(
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts,
|
num_experts,
|
||||||
2 * intermediate_size,
|
2 * intermediate_size_per_partition_after_pad,
|
||||||
hidden_size // mxfp4_block,
|
hidden_size // mxfp4_block,
|
||||||
dtype=scale_dtype,
|
dtype=scale_dtype,
|
||||||
),
|
),
|
||||||
@@ -257,7 +276,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
w13_weight_bias = torch.nn.Parameter(
|
w13_weight_bias = torch.nn.Parameter(
|
||||||
torch.zeros(num_experts, 2 * intermediate_size, dtype=torch.bfloat16),
|
torch.zeros(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition_after_pad,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
||||||
@@ -266,7 +289,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
# down_proj (row parallel)
|
# down_proj (row parallel)
|
||||||
w2_weight = torch.nn.Parameter(
|
w2_weight = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts, hidden_size, intermediate_size // 2, dtype=weight_dtype
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition_after_pad // 2,
|
||||||
|
dtype=weight_dtype,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
@@ -277,7 +303,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts,
|
num_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
intermediate_size // mxfp4_block,
|
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||||
dtype=scale_dtype,
|
dtype=scale_dtype,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -293,6 +319,158 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
|
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||||
|
logger.info(
|
||||||
|
"Shuffling MoE weights for FlashInfer, it might take a while..."
|
||||||
|
)
|
||||||
|
layer.gemm1_alpha = Parameter(
|
||||||
|
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.gemm1_beta = Parameter(
|
||||||
|
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.gemm1_clamp_limit = Parameter(
|
||||||
|
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
sf_block_size = 32 # mxfp4 block size
|
||||||
|
|
||||||
|
assert (
|
||||||
|
layer.w13_weight.dim() == 3
|
||||||
|
and layer.w13_weight.shape[0] == self.num_experts
|
||||||
|
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||||
|
and layer.w13_weight.shape[2] == self.hidden_size // 2
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
layer.w13_weight_scale.dim() == 3
|
||||||
|
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||||
|
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
|
||||||
|
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
layer.w2_weight.dim() == 3
|
||||||
|
and layer.w2_weight.shape[0] == self.num_experts
|
||||||
|
and layer.w2_weight.shape[1] == self.hidden_size
|
||||||
|
and layer.w2_weight.shape[2] == self.intermediate_size // 2
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
layer.w2_weight_scale.dim() == 3
|
||||||
|
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||||
|
and layer.w2_weight_scale.shape[2]
|
||||||
|
== self.intermediate_size // sf_block_size
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
layer.w13_weight_bias.dim() == 2
|
||||||
|
and layer.w13_weight_bias.shape[0] == self.num_experts
|
||||||
|
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
layer.w2_weight_bias.dim() == 2
|
||||||
|
and layer.w2_weight_bias.shape[0] == self.num_experts
|
||||||
|
and layer.w2_weight_bias.shape[1] == self.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
w13_weight_scale = layer.w13_weight_scale.data
|
||||||
|
w2_weight_scale = layer.w2_weight_scale.data
|
||||||
|
w13_weight = layer.w13_weight.data
|
||||||
|
w2_weight = layer.w2_weight.data
|
||||||
|
w13_bias = layer.w13_weight_bias.data.to(torch.float32)
|
||||||
|
w2_bias = layer.w2_weight_bias.data.to(torch.float32)
|
||||||
|
|
||||||
|
# Swap w1 and w3 as the definition of
|
||||||
|
# swiglu is different in the trtllm-gen
|
||||||
|
def swap_every_two_rows(x, axis=-1):
|
||||||
|
shape = x.shape
|
||||||
|
if axis < 0:
|
||||||
|
axis = len(shape) + axis
|
||||||
|
|
||||||
|
# Create a new shape with pairs swapped along specified axis
|
||||||
|
new_shape = list(shape)
|
||||||
|
new_shape[axis] = shape[axis] // 2
|
||||||
|
new_shape.insert(axis + 1, 2)
|
||||||
|
|
||||||
|
# Reshape to expose pairs, swap them, and reshape back
|
||||||
|
x = x.reshape(*new_shape)
|
||||||
|
x = x.flip(axis + 1)
|
||||||
|
new_shape = list(shape)
|
||||||
|
return x.reshape(*new_shape)
|
||||||
|
|
||||||
|
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
|
||||||
|
w13_weight = swap_every_two_rows(w13_weight, -2)
|
||||||
|
w13_bias = swap_every_two_rows(w13_bias, -1)
|
||||||
|
|
||||||
|
# Shuffle weights and scaling factors for transposed mma output
|
||||||
|
gemm1_weights_mxfp4_shuffled = []
|
||||||
|
gemm1_scales_mxfp4_shuffled = []
|
||||||
|
gemm2_weights_mxfp4_shuffled = []
|
||||||
|
gemm2_scales_mxfp4_shuffled = []
|
||||||
|
gemm1_bias_shuffled = []
|
||||||
|
gemm2_bias_shuffled = []
|
||||||
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||||
|
for i in range(self.num_experts):
|
||||||
|
gemm1_weights_mxfp4_shuffled.append(
|
||||||
|
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||||
|
)
|
||||||
|
gemm1_scales_mxfp4_shuffled.append(
|
||||||
|
shuffle_matrix_sf_a(
|
||||||
|
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||||
|
)
|
||||||
|
)
|
||||||
|
gemm1_bias_shuffled.append(
|
||||||
|
shuffle_matrix_a(
|
||||||
|
w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
gemm2_weights_mxfp4_shuffled.append(
|
||||||
|
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||||
|
)
|
||||||
|
gemm2_scales_mxfp4_shuffled.append(
|
||||||
|
shuffle_matrix_sf_a(
|
||||||
|
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||||
|
)
|
||||||
|
)
|
||||||
|
gemm2_bias_shuffled.append(
|
||||||
|
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m)
|
||||||
|
)
|
||||||
|
|
||||||
|
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
||||||
|
w13_weight_scale = (
|
||||||
|
torch.stack(gemm1_scales_mxfp4_shuffled)
|
||||||
|
.reshape(
|
||||||
|
self.num_experts,
|
||||||
|
2 * self.intermediate_size,
|
||||||
|
self.hidden_size // sf_block_size,
|
||||||
|
)
|
||||||
|
.view(torch.float8_e4m3fn)
|
||||||
|
)
|
||||||
|
|
||||||
|
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
|
||||||
|
w2_weight_scale = (
|
||||||
|
torch.stack(gemm2_scales_mxfp4_shuffled)
|
||||||
|
.reshape(
|
||||||
|
self.num_experts,
|
||||||
|
self.hidden_size,
|
||||||
|
self.intermediate_size // sf_block_size,
|
||||||
|
)
|
||||||
|
.view(torch.float8_e4m3fn)
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||||
|
layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
|
||||||
|
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||||
|
layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
|
||||||
|
layer.w13_weight_bias = Parameter(
|
||||||
|
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.w2_weight_bias = Parameter(
|
||||||
|
torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||||
|
|
||||||
@@ -366,22 +544,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
activation_alpha: Optional[float] = None,
|
activation_alpha: Optional[float] = None,
|
||||||
swiglu_limit: Optional[float] = None,
|
swiglu_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# avoid import error when triton_kernel is not installed
|
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||||
# from vllm.model_executor.layers.fused_moe.triton_kernels_moe import (
|
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
|
||||||
# triton_kernel_moe_forward)
|
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
||||||
|
# which can theoretically improve performance
|
||||||
"""
|
if USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||||
if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE
|
|
||||||
or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE):
|
|
||||||
assert not self.moe.use_ep, (
|
|
||||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
|
||||||
if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE:
|
|
||||||
assert x.dtype == torch.bfloat16
|
assert x.dtype == torch.bfloat16
|
||||||
x_quant = x
|
x_quant = x
|
||||||
x_scale = None
|
x_scale = None
|
||||||
else:
|
else:
|
||||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||||
|
|
||||||
|
topk_weights, topk_ids, router_logits = topk_output
|
||||||
|
top_k = topk_weights.shape[-1]
|
||||||
|
|
||||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||||
router_logits.to(torch.bfloat16),
|
router_logits.to(torch.bfloat16),
|
||||||
None, # routing_bias
|
None, # routing_bias
|
||||||
@@ -412,7 +589,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
True, # do finalize
|
True, # do finalize
|
||||||
)[0]
|
)[0]
|
||||||
return trtllm_gen_output
|
return trtllm_gen_output
|
||||||
"""
|
|
||||||
|
|
||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
if self.with_bias:
|
if self.with_bias:
|
||||||
|
|||||||
@@ -464,7 +464,21 @@ class ServerArgs:
|
|||||||
model_arch = self.get_hf_config().architectures[0]
|
model_arch = self.get_hf_config().architectures[0]
|
||||||
if model_arch in ["GptOssForCausalLM"]:
|
if model_arch in ["GptOssForCausalLM"]:
|
||||||
self.attention_backend = "triton"
|
self.attention_backend = "triton"
|
||||||
self.enable_triton_kernel_moe = True
|
|
||||||
|
# Check if FlashInfer MXFP4 MoE is enabled
|
||||||
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
|
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var(
|
||||||
|
"SGLANG_USE_FLASHINFER_MXFP4_MOE", "false"
|
||||||
|
)
|
||||||
|
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
|
||||||
|
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only enable Triton kernel MoE if FlashInfer is not enabled
|
||||||
|
if not (USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE):
|
||||||
|
self.enable_triton_kernel_moe = True
|
||||||
|
|
||||||
self.disable_hybrid_swa_memory = True
|
self.disable_hybrid_swa_memory = True
|
||||||
|
|
||||||
quantization_config = getattr(
|
quantization_config = getattr(
|
||||||
|
|||||||
Reference in New Issue
Block a user