Clean up imports (#5467)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
is_activation_quantization_format,
|
||||
should_ignore_layer,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
if not _is_cuda:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
try:
|
||||
import vllm
|
||||
@@ -58,8 +51,6 @@ __all__ = [
|
||||
|
||||
class CompressedTensorsMoEMethod:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if cls is CompressedTensorsMoEMethod:
|
||||
return super().__new__(cls)
|
||||
return super().__new__(cls)
|
||||
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
|
||||
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
|
||||
)
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
|
||||
if _is_cuda:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
else:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = vllm_ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id]
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
||||
layer.w13_weight_packed,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
||||
@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
|
||||
layer.w2_weight_packed,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
||||
@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use fused_marlin_moe, please install vllm"
|
||||
)
|
||||
if expert_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
||||
|
||||
@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
Fp8LinearOp,
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
||||
@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
|
||||
@@ -8,15 +8,6 @@ import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
convert_to_channelwise,
|
||||
is_layer_skipped,
|
||||
per_tensor_dequantize,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
@@ -27,11 +18,12 @@ try:
|
||||
except ImportError:
|
||||
MARLIN_FP8_AVAILABLE = False
|
||||
|
||||
def apply_fp8_marlin_linear(*args, **kwargs):
|
||||
raise ImportError("vllm is not installed")
|
||||
def dummy_func(*args, **kwargs):
|
||||
raise ImportError(
|
||||
"marlin FP8 requires some operators from vllm. Please install vllm."
|
||||
)
|
||||
|
||||
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
||||
raise ImportError("vllm is not installed")
|
||||
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
|
||||
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
scaled_fp8_quant,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
apply_w8a8_block_fp8_linear,
|
||||
@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
convert_to_channelwise,
|
||||
is_layer_skipped,
|
||||
per_tensor_dequantize,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
permute_weight,
|
||||
print_warning_once,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_hip:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if not _is_cuda:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight_scale_inv.data, requires_grad=False
|
||||
)
|
||||
return
|
||||
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
try:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
except ImportError:
|
||||
self.use_marlin = False
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
try:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
except ImportError:
|
||||
self.use_marlin = False
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
@@ -516,7 +511,7 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# INT4 MoE weight - INT32 packed
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@@ -617,7 +612,7 @@ class Fp8MoEMethod:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
@@ -649,7 +644,7 @@ class Fp8MoEMethod:
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
self.process_weights_hip_int4(layer)
|
||||
return
|
||||
|
||||
@@ -706,20 +701,12 @@ class Fp8MoEMethod:
|
||||
requires_grad=False,
|
||||
)
|
||||
for expert in range(layer.num_experts):
|
||||
if _is_cuda:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
else:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
@@ -796,18 +783,10 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
if _is_cuda:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
else:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = vllm_ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id]
|
||||
)
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
@@ -930,41 +909,11 @@ class Fp8MoEMethod:
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return ck_moe_2stages_win4(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=(
|
||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
if self.block_quant:
|
||||
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
||||
return asm_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale_inv,
|
||||
layer.w2_weight_scale_inv,
|
||||
block_shape=tuple(self.quant_config.weight_block_size),
|
||||
expert_mask=None,
|
||||
)
|
||||
else:
|
||||
return ck_moe_2stages(
|
||||
if _is_hip:
|
||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return ck_moe_2stages_win4(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -978,33 +927,65 @@ class Fp8MoEMethod:
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
if get_bool_env_var("CK_MOE"):
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
if self.block_quant:
|
||||
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
||||
return asm_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale_inv,
|
||||
layer.w2_weight_scale_inv,
|
||||
block_shape=tuple(self.quant_config.weight_block_size),
|
||||
expert_mask=None,
|
||||
)
|
||||
else:
|
||||
return ck_moe_2stages(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
@@ -34,15 +34,23 @@ from sglang.srt.utils import (
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
_enable_jit_deepgemm = False
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
else:
|
||||
fp8_max = torch.finfo(_fp8_type).max
|
||||
fp8_min = -fp8_max
|
||||
|
||||
_enable_jit_deepgemm = False
|
||||
if _is_cuda:
|
||||
import deep_gemm
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||
from sgl_kernel import (
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
|
||||
sm_version = get_device_sm()
|
||||
if sm_version == 90 and get_bool_env_var(
|
||||
@@ -53,6 +61,7 @@ if _is_cuda:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
|
||||
def deep_gemm_fp8_fp8_bf16_nt(
|
||||
@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
|
||||
x: The input tenosr with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
||||
@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
|
||||
fp8_min = -fp8_max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
if column_major_scales:
|
||||
@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
):
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
fp8_min = -fp8_max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||
x_s = torch.empty(
|
||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
||||
|
||||
return x_q, x_s
|
||||
@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8(
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
dtype: torch.dtype = _fp8_type,
|
||||
):
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
@@ -368,7 +358,6 @@ def static_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
repeat_scale: bool = False,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
||||
|
||||
@@ -386,15 +375,8 @@ def static_quant_fp8(
|
||||
"""
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
|
||||
fp8_min = -fp8_max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||
M = x.numel() // x.shape[-1]
|
||||
N = x.shape[-1]
|
||||
if repeat_scale:
|
||||
@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2(
|
||||
|
||||
|
||||
def per_tensor_quant_mla_fp8(
|
||||
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
x: torch.Tensor, eps: float = 1e-12
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function quantizes input values to float8 values with tensor-wise quantization
|
||||
@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8(
|
||||
"""
|
||||
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
if _is_hip:
|
||||
dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
|
||||
x_q = x.new_empty(x.size(), dtype=dtype)
|
||||
x_q = x.new_empty(x.size(), dtype=_fp8_type)
|
||||
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
||||
|
||||
num_head, num_seq, head_size = x.shape
|
||||
@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8(
|
||||
head_size,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
-fp8_max,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 (8-bit floating point) format.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor to be quantized
|
||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||
If None, scales will be computed dynamically.
|
||||
num_token_padding (Optional[int]): If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||
determines the quantization granularity:
|
||||
- True: compute scale per token
|
||||
- False: compute single scale per tensor
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- quantized_tensor: The FP8 quantized version of input
|
||||
- scale_tensor: The scaling factors used for quantization
|
||||
|
||||
Raises:
|
||||
AssertionError: If input is not 2D or if static scale's numel != 1
|
||||
"""
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
|
||||
if scale is None:
|
||||
# Dynamic scaling
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=False
|
||||
) # False for dynamic
|
||||
else:
|
||||
# Static scaling
|
||||
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=True
|
||||
) # True for static
|
||||
|
||||
return output, scale
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_enable_jit_deepgemm,
|
||||
per_token_group_quant_fp8,
|
||||
scaled_fp8_quant,
|
||||
sglang_per_token_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
@@ -17,30 +25,20 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
)
|
||||
|
||||
try:
|
||||
import vllm
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
from aiter import gemm_a8w8_blockscale
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
||||
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
|
||||
_TORCH_VERSION = torch.__version__.split("+")[0]
|
||||
try:
|
||||
@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant(
|
||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
||||
|
||||
x_q_tensor, scale = (
|
||||
sgl_scaled_fp8_quant(x_dq_block)
|
||||
scaled_fp8_quant(x_dq_block)
|
||||
if _is_cuda
|
||||
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||
)
|
||||
@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant(
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
||||
x_q_tensor, scale = (
|
||||
sgl_scaled_fp8_quant(x_dq_channel)
|
||||
scaled_fp8_quant(x_dq_channel)
|
||||
if _is_cuda
|
||||
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
||||
)
|
||||
@@ -264,7 +262,7 @@ def apply_fp8_linear(
|
||||
# final solution should be: 1. add support to per-tensor activation scaling.
|
||||
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
||||
if _is_hip and weight_scale.numel() == 1:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
@@ -275,32 +273,29 @@ def apply_fp8_linear(
|
||||
)
|
||||
|
||||
if cutlass_fp8_supported:
|
||||
try:
|
||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
weight_scale.numel() == weight.shape[1]
|
||||
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||
output = fp8_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
except (ImportError, NameError, AttributeError):
|
||||
pass
|
||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||
output = vllm_ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
weight_scale.numel() == weight.shape[1]
|
||||
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||
output = fp8_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
@@ -343,8 +338,10 @@ def apply_fp8_linear(
|
||||
|
||||
# Making sure the dummy tensor is on the same device as the weight
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(
|
||||
1, dtype=torch.float32, device=weight.device
|
||||
)
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
@@ -372,13 +369,6 @@ def apply_fp8_linear(
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def maybe_create_device_identity():
|
||||
# Allocate dummy ones tensor for torch._scaled_mm
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
||||
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
@@ -405,9 +395,7 @@ class Fp8LinearOp:
|
||||
# We also don't pad when using torch.compile,
|
||||
# as it breaks with dynamic shapes.
|
||||
if pad_output is None:
|
||||
enable_torch_compile = os.environ.get(
|
||||
"SGLANG_ENABLE_TORCH_COMPILE", "0"
|
||||
).lower() in ("1", "true", "yes")
|
||||
enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
|
||||
pad_output = not enable_torch_compile
|
||||
self.output_padding = 17 if pad_output else None
|
||||
|
||||
@@ -439,13 +427,13 @@ class Fp8LinearOp:
|
||||
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
||||
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
||||
if _is_cuda:
|
||||
qinput, x_scale = sgl_scaled_fp8_quant(
|
||||
qinput, x_scale = scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
)
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
scale_ub=input_scale_ub,
|
||||
@@ -455,7 +443,7 @@ class Fp8LinearOp:
|
||||
# Fused GEMM_DQ
|
||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||
output = ops.cutlass_scaled_mm(
|
||||
output = vllm_ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
@@ -482,14 +470,14 @@ class Fp8LinearOp:
|
||||
else:
|
||||
# Maybe apply padding to output, see comment in __init__
|
||||
if _is_cuda:
|
||||
qinput, x_scale = sgl_scaled_fp8_quant(
|
||||
qinput, x_scale = scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=self.output_padding,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
)
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=self.output_padding,
|
||||
@@ -562,9 +550,12 @@ class Fp8LinearOp:
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
|
||||
# Making sure the dummy tensor is on the same device as the weight
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(
|
||||
1, dtype=torch.float32, device=weight.device
|
||||
)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
from typing import List, Mapping, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
if not _is_cuda:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
def is_fp8_fnuz() -> bool:
|
||||
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
||||
if _is_cuda:
|
||||
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
|
||||
else:
|
||||
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
|
||||
weight_dq, max_w_scale
|
||||
)
|
||||
weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
|
||||
start = end
|
||||
|
||||
return max_w_scale, weight
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
||||
|
||||
is_cuda = is_cuda_available()
|
||||
if is_cuda:
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
||||
|
||||
is_cuda = is_cuda_available()
|
||||
if is_cuda:
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
|
||||
|
||||
class W8A8Int8Config(QuantizationConfig):
|
||||
|
||||
Reference in New Issue
Block a user