sglang quant module remove vllm dependency (#4507)
This commit is contained in:
@@ -6,21 +6,41 @@ from copy import deepcopy
|
|||||||
from typing import Callable, Dict, Optional, Type, Union
|
from typing import Callable, Dict, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
|
||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
try:
|
||||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||||
CompressedTensorsConfig,
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||||
)
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
CompressedTensorsConfig,
|
||||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
)
|
||||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
GPTQMarlin24Config,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||||
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||||
|
|
||||||
|
VLLM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
|
# Define empty classes as placeholders when vllm is not available
|
||||||
|
class DummyConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = (
|
||||||
|
CompressedTensorsConfig
|
||||||
|
) = DummyConfig
|
||||||
|
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
|
||||||
|
GPTQMarlin24Config
|
||||||
|
) = DummyConfig
|
||||||
|
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||||
@@ -30,29 +50,37 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
|||||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||||
|
|
||||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
# Base quantization methods that don't depend on vllm
|
||||||
"aqlm": AQLMConfig,
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
"awq": AWQConfig,
|
|
||||||
"deepspeedfp": DeepSpeedFPConfig,
|
|
||||||
"tpu_int8": Int8TpuConfig,
|
|
||||||
"fp8": Fp8Config,
|
"fp8": Fp8Config,
|
||||||
"blockwise_int8": BlockInt8Config,
|
"blockwise_int8": BlockInt8Config,
|
||||||
"fbgemm_fp8": FBGEMMFp8Config,
|
|
||||||
"marlin": MarlinConfig,
|
|
||||||
"modelopt": ModelOptFp8Config,
|
"modelopt": ModelOptFp8Config,
|
||||||
"gguf": GGUFConfig,
|
|
||||||
"gptq_marlin_24": GPTQMarlin24Config,
|
|
||||||
"gptq_marlin": GPTQMarlinConfig,
|
"gptq_marlin": GPTQMarlinConfig,
|
||||||
"awq_marlin": AWQMarlinConfig,
|
|
||||||
"gptq": GPTQConfig,
|
"gptq": GPTQConfig,
|
||||||
"compressed-tensors": CompressedTensorsConfig,
|
|
||||||
"bitsandbytes": BitsAndBytesConfig,
|
|
||||||
"qqq": QQQConfig,
|
|
||||||
"experts_int8": ExpertsInt8Config,
|
|
||||||
"w8a8_int8": W8A8Int8Config,
|
"w8a8_int8": W8A8Int8Config,
|
||||||
"w8a8_fp8": W8A8Fp8Config,
|
"w8a8_fp8": W8A8Fp8Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add vllm-dependent methods if available
|
||||||
|
QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
|
||||||
|
if VLLM_AVAILABLE:
|
||||||
|
VLLM_QUANTIZATION_METHODS = {
|
||||||
|
"aqlm": AQLMConfig,
|
||||||
|
"awq": AWQConfig,
|
||||||
|
"deepspeedfp": DeepSpeedFPConfig,
|
||||||
|
"tpu_int8": Int8TpuConfig,
|
||||||
|
"fbgemm_fp8": FBGEMMFp8Config,
|
||||||
|
"marlin": MarlinConfig,
|
||||||
|
"gguf": GGUFConfig,
|
||||||
|
"gptq_marlin_24": GPTQMarlin24Config,
|
||||||
|
"awq_marlin": AWQMarlinConfig,
|
||||||
|
"compressed-tensors": CompressedTensorsConfig,
|
||||||
|
"bitsandbytes": BitsAndBytesConfig,
|
||||||
|
"qqq": QQQConfig,
|
||||||
|
"experts_int8": ExpertsInt8Config,
|
||||||
|
}
|
||||||
|
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
|
||||||
|
|
||||||
|
|
||||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||||
if quantization not in QUANTIZATION_METHODS:
|
if quantization not in QUANTIZATION_METHODS:
|
||||||
@@ -157,25 +185,31 @@ def get_linear_quant_method(
|
|||||||
|
|
||||||
|
|
||||||
def gptq_get_quant_method(self, layer, prefix):
|
def gptq_get_quant_method(self, layer, prefix):
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
if not VLLM_AVAILABLE:
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
return None
|
||||||
GPTQMarlinLinearMethod,
|
|
||||||
GPTQMarlinMoEMethod,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
try:
|
||||||
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||||
if isinstance(layer, FusedMoE):
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
return GPTQMarlinMoEMethod(self)
|
GPTQMarlinLinearMethod,
|
||||||
|
GPTQMarlinMoEMethod,
|
||||||
if isinstance(self, GPTQConfig):
|
|
||||||
return get_linear_quant_method(
|
|
||||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
|
||||||
)
|
|
||||||
elif isinstance(self, GPTQMarlinConfig):
|
|
||||||
return get_linear_quant_method(
|
|
||||||
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
|
if isinstance(layer, FusedMoE):
|
||||||
|
return GPTQMarlinMoEMethod(self)
|
||||||
|
|
||||||
|
if isinstance(self, GPTQConfig):
|
||||||
|
return get_linear_quant_method(
|
||||||
|
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||||
|
)
|
||||||
|
elif isinstance(self, GPTQMarlinConfig):
|
||||||
|
return get_linear_quant_method(
|
||||||
|
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -187,33 +221,40 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|||||||
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
|
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
|
||||||
can recognize sglang layers
|
can recognize sglang layers
|
||||||
"""
|
"""
|
||||||
|
if not VLLM_AVAILABLE:
|
||||||
|
return
|
||||||
|
|
||||||
if reverse:
|
if reverse:
|
||||||
builtins.isinstance = original_isinstance
|
builtins.isinstance = original_isinstance
|
||||||
return
|
return
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
try:
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
VocabParallelEmbedding,
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
)
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
FusedMoE as PatchedFusedMoE,
|
||||||
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
)
|
||||||
)
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
def patched_isinstance(obj, classinfo):
|
def patched_isinstance(obj, classinfo):
|
||||||
if classinfo is LinearBase:
|
if classinfo is LinearBase:
|
||||||
return original_isinstance(obj, PatchedLinearBase)
|
return original_isinstance(obj, PatchedLinearBase)
|
||||||
if classinfo is FusedMoE:
|
if classinfo is FusedMoE:
|
||||||
return original_isinstance(obj, PatchedFusedMoE)
|
return original_isinstance(obj, PatchedFusedMoE)
|
||||||
if classinfo is VocabParallelEmbedding:
|
if classinfo is VocabParallelEmbedding:
|
||||||
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
||||||
return original_isinstance(obj, classinfo)
|
return original_isinstance(obj, classinfo)
|
||||||
|
|
||||||
builtins.isinstance = patched_isinstance
|
builtins.isinstance = patched_isinstance
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||||
@@ -221,72 +262,88 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|||||||
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
||||||
Convert sglang arguments to vllm arguments.
|
Convert sglang arguments to vllm arguments.
|
||||||
"""
|
"""
|
||||||
original_apply = class_obj.apply
|
if not VLLM_AVAILABLE:
|
||||||
sig = inspect.signature(original_apply)
|
return
|
||||||
param_names = list(sig.parameters.keys())
|
|
||||||
has_correction_bias = "e_score_correction_bias" in param_names
|
|
||||||
|
|
||||||
def new_apply(
|
try:
|
||||||
self,
|
original_apply = class_obj.apply
|
||||||
layer: torch.nn.Module,
|
sig = inspect.signature(original_apply)
|
||||||
x: torch.Tensor,
|
param_names = list(sig.parameters.keys())
|
||||||
router_logits: torch.Tensor,
|
has_correction_bias = "e_score_correction_bias" in param_names
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
activation: str = "silu",
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
):
|
|
||||||
assert activation == "silu"
|
|
||||||
assert inplace and not no_combine
|
|
||||||
|
|
||||||
kwargs = {
|
def new_apply(
|
||||||
"self": self,
|
self,
|
||||||
"layer": layer,
|
layer: torch.nn.Module,
|
||||||
"x": x,
|
x: torch.Tensor,
|
||||||
"router_logits": router_logits,
|
router_logits: torch.Tensor,
|
||||||
"top_k": top_k,
|
top_k: int,
|
||||||
"renormalize": renormalize,
|
renormalize: bool,
|
||||||
"use_grouped_topk": use_grouped_topk,
|
use_grouped_topk: bool,
|
||||||
"topk_group": topk_group,
|
topk_group: Optional[int] = None,
|
||||||
"num_expert_group": num_expert_group,
|
num_expert_group: Optional[int] = None,
|
||||||
"custom_routing_function": custom_routing_function,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
}
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
if correction_bias is not None:
|
activation: str = "silu",
|
||||||
if not has_correction_bias:
|
inplace: bool = True,
|
||||||
raise ValueError(
|
no_combine: bool = False,
|
||||||
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
):
|
||||||
)
|
assert activation == "silu"
|
||||||
kwargs["e_score_correction_bias"] = correction_bias
|
assert inplace and not no_combine
|
||||||
return original_apply(**kwargs)
|
|
||||||
|
|
||||||
setattr(class_obj, "apply", new_apply)
|
kwargs = {
|
||||||
|
"self": self,
|
||||||
|
"layer": layer,
|
||||||
|
"x": x,
|
||||||
|
"router_logits": router_logits,
|
||||||
|
"top_k": top_k,
|
||||||
|
"renormalize": renormalize,
|
||||||
|
"use_grouped_topk": use_grouped_topk,
|
||||||
|
"topk_group": topk_group,
|
||||||
|
"num_expert_group": num_expert_group,
|
||||||
|
"custom_routing_function": custom_routing_function,
|
||||||
|
}
|
||||||
|
if correction_bias is not None:
|
||||||
|
if not has_correction_bias:
|
||||||
|
raise ValueError(
|
||||||
|
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
||||||
|
)
|
||||||
|
kwargs["e_score_correction_bias"] = correction_bias
|
||||||
|
return original_apply(**kwargs)
|
||||||
|
|
||||||
|
setattr(class_obj, "apply", new_apply)
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_quant_configs():
|
def monkey_patch_quant_configs():
|
||||||
"""Apply all monkey patches in one place."""
|
"""Apply all monkey patches in one place."""
|
||||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
if not VLLM_AVAILABLE:
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
return
|
||||||
CompressedTensorsW8A8Fp8MoEMethod,
|
|
||||||
CompressedTensorsWNA16MoEMethod,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
|
|
||||||
|
|
||||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
try:
|
||||||
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||||
|
CompressedTensorsW8A8Fp8MoEMethod,
|
||||||
|
CompressedTensorsWNA16MoEMethod,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQMarlinMoEMethod,
|
||||||
|
)
|
||||||
|
|
||||||
monkey_patch_moe_apply(AWQMoEMethod)
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||||
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
||||||
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
|
||||||
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
monkey_patch_moe_apply(AWQMoEMethod)
|
||||||
|
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
||||||
|
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
||||||
|
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
monkey_patch_quant_configs()
|
# Only apply monkey patches if vllm is available
|
||||||
|
if VLLM_AVAILABLE:
|
||||||
|
monkey_patch_quant_configs()
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
@@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
||||||
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|||||||
@@ -7,20 +7,33 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
apply_fp8_marlin_linear,
|
|
||||||
prepare_fp8_layer_for_marlin,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
|
is_layer_skipped,
|
||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
|
apply_fp8_marlin_linear,
|
||||||
|
prepare_fp8_layer_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
|
MARLIN_FP8_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MARLIN_FP8_AVAILABLE = False
|
||||||
|
|
||||||
|
def apply_fp8_marlin_linear(*args, **kwargs):
|
||||||
|
raise ImportError("vllm is not installed")
|
||||||
|
|
||||||
|
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
||||||
|
raise ImportError("vllm is not installed")
|
||||||
|
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
LinearBase,
|
LinearBase,
|
||||||
@@ -46,6 +59,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
permute_weight,
|
permute_weight,
|
||||||
print_warning_once,
|
print_warning_once,
|
||||||
@@ -60,6 +74,13 @@ if _is_hip:
|
|||||||
from aiter.fused_moe_bf16_asm import asm_moe
|
from aiter.fused_moe_bf16_asm import asm_moe
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||||
|
else:
|
||||||
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -173,7 +194,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
# kernel for fast weight-only FP8 quantization
|
# kernel for fast weight-only FP8 quantization
|
||||||
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
self.use_marlin = (
|
||||||
|
get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE
|
||||||
|
)
|
||||||
# Disable marlin for ROCm
|
# Disable marlin for ROCm
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
@@ -371,9 +394,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_fp8_layer_for_marlin(layer)
|
try:
|
||||||
# Activations not quantized for marlin.
|
prepare_fp8_layer_for_marlin(layer)
|
||||||
del layer.input_scale
|
# Activations not quantized for marlin.
|
||||||
|
del layer.input_scale
|
||||||
|
except ImportError:
|
||||||
|
self.use_marlin = False
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -383,15 +409,18 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
return apply_fp8_marlin_linear(
|
try:
|
||||||
input=x,
|
return apply_fp8_marlin_linear(
|
||||||
weight=layer.weight,
|
input=x,
|
||||||
weight_scale=layer.weight_scale,
|
weight=layer.weight,
|
||||||
workspace=layer.workspace,
|
weight_scale=layer.weight_scale,
|
||||||
size_n=layer.output_size_per_partition,
|
workspace=layer.workspace,
|
||||||
size_k=layer.input_size_per_partition,
|
size_n=layer.output_size_per_partition,
|
||||||
bias=bias,
|
size_k=layer.input_size_per_partition,
|
||||||
)
|
bias=bias,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
self.use_marlin = False
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
return apply_w8a8_block_fp8_linear(
|
return apply_w8a8_block_fp8_linear(
|
||||||
@@ -680,12 +709,20 @@ class Fp8MoEMethod:
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
for expert in range(layer.num_experts):
|
for expert in range(layer.num_experts):
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
if _is_cuda:
|
||||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
)
|
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
)
|
||||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||||
)
|
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
|
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||||
|
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,12 @@ if _is_cuda:
|
|||||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
||||||
|
|
||||||
if use_vllm_cutlass_w8a8_fp8_kernel:
|
if use_vllm_cutlass_w8a8_fp8_kernel:
|
||||||
from vllm import _custom_ops as ops
|
try:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
VLLM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VLLM_AVAILABLE = False
|
||||||
else:
|
else:
|
||||||
from sgl_kernel import fp8_scaled_mm
|
from sgl_kernel import fp8_scaled_mm
|
||||||
|
|
||||||
@@ -219,90 +224,97 @@ def apply_fp8_linear(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cutlass_fp8_supported:
|
if cutlass_fp8_supported:
|
||||||
if use_vllm_cutlass_w8a8_fp8_kernel:
|
try:
|
||||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||||
output = ops.cutlass_scaled_mm(
|
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||||
qinput,
|
output = ops.cutlass_scaled_mm(
|
||||||
weight,
|
qinput,
|
||||||
out_dtype=input.dtype,
|
weight,
|
||||||
scale_a=x_scale,
|
out_dtype=input.dtype,
|
||||||
scale_b=weight_scale,
|
scale_a=x_scale,
|
||||||
bias=bias,
|
scale_b=weight_scale,
|
||||||
)
|
bias=bias,
|
||||||
else:
|
)
|
||||||
assert (
|
else:
|
||||||
weight_scale.numel() == weight.shape[1]
|
assert (
|
||||||
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
weight_scale.numel() == weight.shape[1]
|
||||||
output = fp8_scaled_mm(
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||||
qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias
|
output = fp8_scaled_mm(
|
||||||
)
|
qinput,
|
||||||
return output.view(*output_shape)
|
weight,
|
||||||
|
x_scale,
|
||||||
|
weight_scale,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
return output.view(*output_shape)
|
||||||
|
except (ImportError, NameError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
# torch.scaled_mm supports per tensor weights + activations only
|
# torch.scaled_mm supports per tensor weights + activations only
|
||||||
# so fallback to naive if per channel or per token
|
# so fallback to naive if per channel or per token
|
||||||
|
per_tensor_weights = weight_scale.numel() == 1
|
||||||
|
per_tensor_activations = x_scale.numel() == 1
|
||||||
|
|
||||||
|
if per_tensor_weights and per_tensor_activations:
|
||||||
|
# Fused GEMM_DQ
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
|
if type(output) is tuple and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
per_tensor_weights = weight_scale.numel() == 1
|
# Fallback for channelwise case, where we use unfused DQ
|
||||||
per_tensor_activations = x_scale.numel() == 1
|
# due to limitations with scaled_mm
|
||||||
|
|
||||||
if per_tensor_weights and per_tensor_activations:
|
# Symmetric quantized GEMM by definition computes the following:
|
||||||
# Fused GEMM_DQ
|
# C = (s_x * X) (s_w * W) + bias
|
||||||
output = torch._scaled_mm(
|
# This is equivalent to dequantizing the weights and activations
|
||||||
qinput,
|
# before applying a GEMM.
|
||||||
weight,
|
#
|
||||||
out_dtype=input.dtype,
|
# In order to compute quantized operands, a quantized kernel
|
||||||
scale_a=x_scale,
|
# will rewrite the above like so:
|
||||||
scale_b=weight_scale,
|
# C = s_w * s_x * (X * W) + bias
|
||||||
bias=bias,
|
#
|
||||||
)
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
# does not support s_w being a vector.
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
|
||||||
if type(output) is tuple and len(output) == 2:
|
|
||||||
output = output[0]
|
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
# Making sure the dummy tensor is on the same device as the weight
|
||||||
|
global TORCH_DEVICE_IDENTITY
|
||||||
|
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||||
|
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||||
|
|
||||||
else:
|
# GEMM
|
||||||
# Fallback for channelwise case, where we use unfused DQ
|
# This computes C = (X * W).
|
||||||
# due to limitations with scaled_mm
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
scale_a=TORCH_DEVICE_IDENTITY,
|
||||||
|
scale_b=TORCH_DEVICE_IDENTITY,
|
||||||
|
out_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
|
if type(output) is tuple and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
# Unpad (undo num_token_padding)
|
||||||
|
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||||
|
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
||||||
|
|
||||||
# Symmetric quantized GEMM by definition computes the following:
|
# DQ
|
||||||
# C = (s_x * X) (s_w * W) + bias
|
# C = sw * sx * (X * W) + bias
|
||||||
# This is equivalent to dequantizing the weights and activations
|
output = output * x_scale * weight_scale.t()
|
||||||
# before applying a GEMM.
|
if bias is not None:
|
||||||
#
|
output = output + bias
|
||||||
# In order to compute quantized operands, a quantized kernel
|
return output.to(dtype=input.dtype).view(*output_shape)
|
||||||
# will rewrite the above like so:
|
|
||||||
# C = s_w * s_x * (X * W) + bias
|
|
||||||
#
|
|
||||||
# For the scaled_mm fallback case, we break this down, since it
|
|
||||||
# does not support s_w being a vector.
|
|
||||||
|
|
||||||
# Making sure the dummy tensor is on the same device as the weight
|
|
||||||
global TORCH_DEVICE_IDENTITY
|
|
||||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
|
||||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
|
||||||
|
|
||||||
# GEMM
|
|
||||||
# This computes C = (X * W).
|
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
qinput,
|
|
||||||
weight,
|
|
||||||
scale_a=TORCH_DEVICE_IDENTITY,
|
|
||||||
scale_b=TORCH_DEVICE_IDENTITY,
|
|
||||||
out_dtype=torch.float32,
|
|
||||||
)
|
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
|
||||||
if type(output) is tuple and len(output) == 2:
|
|
||||||
output = output[0]
|
|
||||||
# Unpad (undo num_token_padding)
|
|
||||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
|
||||||
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
|
||||||
|
|
||||||
# DQ
|
|
||||||
# C = sw * sx * (X * W) + bias
|
|
||||||
output = output * x_scale * weight_scale.t()
|
|
||||||
if bias is not None:
|
|
||||||
output = output + bias
|
|
||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
|
||||||
|
|||||||
@@ -3,11 +3,21 @@ from fractions import Fraction
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.scalar_type import scalar_types
|
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.quantization.utils import scalar_types
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import vllm
|
||||||
|
|
||||||
|
VLLM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -110,6 +120,9 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["GPTQLinearMethod"]:
|
) -> Optional["GPTQLinearMethod"]:
|
||||||
|
if not VLLM_AVAILABLE:
|
||||||
|
raise ImportError("vllm is not installed")
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||||
|
|
||||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||||
@@ -263,6 +276,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
if not VLLM_AVAILABLE:
|
||||||
|
raise ImportError("vllm is not installed")
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
GPTQMarlinLinearMethod,
|
GPTQMarlinLinearMethod,
|
||||||
GPTQMarlinMoEMethod,
|
GPTQMarlinMoEMethod,
|
||||||
@@ -285,6 +301,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||||
|
if not VLLM_AVAILABLE:
|
||||||
|
return False
|
||||||
|
|
||||||
quant_method = quant_config.get("quant_method", "").lower()
|
quant_method = quant_config.get("quant_method", "").lower()
|
||||||
num_bits = quant_config.get("bits")
|
num_bits = quant_config.get("bits")
|
||||||
group_size = quant_config.get("group_size")
|
group_size = quant_config.get("group_size")
|
||||||
@@ -294,9 +313,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
check_marlin_supported,
|
check_marlin_supported,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not _is_cuda:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if quant_method != "gptq":
|
if quant_method != "gptq":
|
||||||
@@ -407,6 +425,9 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["MarlinLinearMethod"]:
|
) -> Optional["MarlinLinearMethod"]:
|
||||||
|
if not VLLM_AVAILABLE:
|
||||||
|
raise ImportError("vllm is not installed")
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||||
|
|
||||||
if isinstance(layer, LinearBase) or (
|
if isinstance(layer, LinearBase) or (
|
||||||
|
|||||||
98
python/sglang/srt/layers/quantization/kv_cache.py
Normal file
98
python/sglang/srt/layers/quantization/kv_cache.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kv_cache.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||||
|
"""
|
||||||
|
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
||||||
|
Attention layer to support loading those scaling factors from checkpoints.
|
||||||
|
The k/v_scale will be used to:
|
||||||
|
- quantize k/v_cache entries before saving them to the cache
|
||||||
|
- dequantize k/v_cache entries before fetching them from the cache
|
||||||
|
|
||||||
|
:param quant_config: the appropriate QuantizationConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: QuantizationConfig):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Create "weight" (aka k_scale and v_scale) for an attention layer.
|
||||||
|
"""
|
||||||
|
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
||||||
|
# If the k/v_scale appears in the checkpoint, it will be
|
||||||
|
# overwritten when loading weights.
|
||||||
|
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||||
|
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_fp8_fnuz(cls) -> bool:
|
||||||
|
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||||
|
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||||
|
|
||||||
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||||
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||||
|
# regardless whether the kv-scale is available in the checkpoint.
|
||||||
|
# No need to process kv scales after loading if we are going to
|
||||||
|
# calculate them on the fly.
|
||||||
|
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
|
||||||
|
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||||
|
# We prefer to use separate k_scale and v_scale if present
|
||||||
|
k_scale = layer.k_scale.to("cpu").tolist()
|
||||||
|
v_scale = layer.v_scale.to("cpu").tolist()
|
||||||
|
if _is_hip and self.is_fp8_fnuz():
|
||||||
|
k_scale *= 2
|
||||||
|
v_scale *= 2
|
||||||
|
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||||
|
# If no scales were loaded (both scales are invalid negative
|
||||||
|
# values), use the default value of 1.0
|
||||||
|
k_scale = 1.0
|
||||||
|
v_scale = 1.0
|
||||||
|
else:
|
||||||
|
# If we find a single kv_scale in the checkpoint, we remap
|
||||||
|
# kv_scale to k_scale during weight loading, and duplicate
|
||||||
|
# k_scale to v_scale here
|
||||||
|
assert layer.k_scale > 0.0
|
||||||
|
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||||
|
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||||
|
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||||
|
if _is_hip and self.is_fp8_fnuz():
|
||||||
|
k_scale *= 2
|
||||||
|
v_scale *= 2
|
||||||
|
|
||||||
|
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
||||||
|
raise ValueError(
|
||||||
|
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
# These are used in the final Attention.forward()
|
||||||
|
layer._k_scale.copy_(k_scale)
|
||||||
|
layer._v_scale.copy_(v_scale)
|
||||||
|
layer._k_scale_float = k_scale
|
||||||
|
layer._v_scale_float = v_scale
|
||||||
|
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
||||||
|
logger.warning(
|
||||||
|
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||||
|
"may cause accuracy issues. Please make sure k/v_scale "
|
||||||
|
"scaling factors are available in the fp8 checkpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
del layer.k_scale
|
||||||
|
del layer.v_scale
|
||||||
@@ -5,12 +5,6 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
||||||
convert_to_channelwise,
|
|
||||||
cutlass_fp8_supported,
|
|
||||||
requantize_with_max_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||||
@@ -19,7 +13,15 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
|
apply_fp8_linear,
|
||||||
|
cutlass_fp8_supported,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from sglang.srt.layers.quantization.utils import (
|
||||||
|
convert_to_channelwise,
|
||||||
|
requantize_with_max_scale,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize logger for the module
|
# Initialize logger for the module
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
442
python/sglang/srt/layers/quantization/utils.py
Normal file
442
python/sglang/srt/layers/quantization/utils.py
Normal file
@@ -0,0 +1,442 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/scalar_type.py
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import struct
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from types import MappingProxyType
|
||||||
|
from typing import List, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def is_layer_skipped(
|
||||||
|
prefix: str,
|
||||||
|
ignored_layers: List[str],
|
||||||
|
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
|
||||||
|
) -> bool:
|
||||||
|
# prefix: model.layers.0.self_attn.q_proj
|
||||||
|
# proj_name: q_proj
|
||||||
|
proj_name = prefix.split(".")[-1]
|
||||||
|
|
||||||
|
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||||
|
# in the safetensors checkpoint. So, we convert the name
|
||||||
|
# from the fused version to unfused + check to make sure that
|
||||||
|
# each shard of the fused layer has the same scheme.
|
||||||
|
if proj_name in fused_mapping:
|
||||||
|
shard_prefixes = [
|
||||||
|
prefix.replace(proj_name, shard_proj_name)
|
||||||
|
for shard_proj_name in fused_mapping[proj_name]
|
||||||
|
]
|
||||||
|
|
||||||
|
is_skipped = None
|
||||||
|
for shard_prefix in shard_prefixes:
|
||||||
|
is_shard_skipped = shard_prefix in ignored_layers
|
||||||
|
|
||||||
|
if is_skipped is None:
|
||||||
|
is_skipped = is_shard_skipped
|
||||||
|
elif is_shard_skipped != is_skipped:
|
||||||
|
raise ValueError(
|
||||||
|
f"Detected some but not all shards of {prefix} "
|
||||||
|
"are quantized. All shards of fused layers "
|
||||||
|
"to have the same precision."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
is_skipped = prefix in ignored_layers
|
||||||
|
|
||||||
|
assert is_skipped is not None
|
||||||
|
return is_skipped
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_dequantize(
|
||||||
|
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
fake_qweight = tensor.to(torch.float16)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
|
def all_close_1d(x: torch.Tensor) -> bool:
|
||||||
|
assert len(x.shape) == 1
|
||||||
|
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_channelwise(
|
||||||
|
weight_scale: torch.Tensor, logical_widths: List[int]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Create channelwise buffer
|
||||||
|
weight_scale_channel = torch.empty(
|
||||||
|
(sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expand each scale to match the size of each logical matrix.
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_scale_channel[start:end, :] = weight_scale[idx]
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return weight_scale_channel
|
||||||
|
|
||||||
|
|
||||||
|
def requantize_with_max_scale(
|
||||||
|
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Max scale to be used for requanitzation.
|
||||||
|
max_w_scale = weight_scale.max()
|
||||||
|
|
||||||
|
# QKV / MLP is fused in the on disk checkpoint if any of the
|
||||||
|
# weight scales are still set to the default since we initialize
|
||||||
|
# N weight scales for N shards but we only load 1 weight scale
|
||||||
|
# from disk in this case. Skip requantization in this case (since)
|
||||||
|
# we already are quantized with the single scale.
|
||||||
|
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||||
|
unfused_module_in_checkpoint = (
|
||||||
|
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
||||||
|
)
|
||||||
|
|
||||||
|
# If unfused checkpoint, need requanize with the single scale.
|
||||||
|
if unfused_module_in_checkpoint:
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
||||||
|
weight[start:end, :], _ = ops.scaled_fp8_quant(weight_dq, max_w_scale)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return max_w_scale, weight
|
||||||
|
|
||||||
|
|
||||||
|
# Mirrors enum in `core/scalar_type.hpp`
|
||||||
|
class NanRepr(Enum):
|
||||||
|
NONE = 0 # nans are not supported
|
||||||
|
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||||
|
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||||
|
|
||||||
|
|
||||||
|
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
||||||
|
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
||||||
|
# in sync until the inductor fully supports custom C++ classes.
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ScalarType:
|
||||||
|
"""
|
||||||
|
ScalarType can represent a wide range of floating point and integer
|
||||||
|
types, in particular it can be used to represent sub-byte data types
|
||||||
|
(something that torch.dtype currently does not support). It is also
|
||||||
|
capable of representing types with a bias, i.e.:
|
||||||
|
`stored_value = value + bias`,
|
||||||
|
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||||
|
of 8). The implementation for this class can be found in
|
||||||
|
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||||
|
with that file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
exponent: int
|
||||||
|
"""
|
||||||
|
Number of bits in the exponent if this is a floating point type
|
||||||
|
(zero if this an integer type)
|
||||||
|
"""
|
||||||
|
|
||||||
|
mantissa: int
|
||||||
|
"""
|
||||||
|
Number of bits in the mantissa if this is a floating point type,
|
||||||
|
or the number bits representing an integer excluding the sign bit if
|
||||||
|
this an integer type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
signed: bool
|
||||||
|
"If the type is signed (i.e. has a sign bit)"
|
||||||
|
|
||||||
|
bias: int
|
||||||
|
"""
|
||||||
|
bias used to encode the values in this scalar type
|
||||||
|
(value = stored_value - bias, default 0) for example if we store the
|
||||||
|
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||||
|
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_finite_values_only: bool = False
|
||||||
|
"""
|
||||||
|
Private: if infs are supported, used `has_infs()` instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nan_repr: NanRepr = NanRepr.IEEE_754
|
||||||
|
"""
|
||||||
|
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||||
|
(not applicable for integer types)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _floating_point_max_int(self) -> int:
|
||||||
|
assert (
|
||||||
|
self.mantissa <= 52 and self.exponent <= 11
|
||||||
|
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||||
|
|
||||||
|
max_mantissa = (1 << self.mantissa) - 1
|
||||||
|
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
||||||
|
max_mantissa = max_mantissa - 1
|
||||||
|
|
||||||
|
max_exponent = (1 << self.exponent) - 2
|
||||||
|
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
|
||||||
|
assert (
|
||||||
|
self.exponent < 11
|
||||||
|
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||||
|
max_exponent = max_exponent + 1
|
||||||
|
|
||||||
|
# adjust the exponent to match that of a double
|
||||||
|
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
||||||
|
# e is the exponent bits), there is some precedent for non-standard
|
||||||
|
# biases, example `float8_e4m3b11fnuz` here:
|
||||||
|
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
||||||
|
# complication we are just assuming the standard exponent bias until
|
||||||
|
# there is a need to support non-standard biases
|
||||||
|
exponent_bias = (1 << (self.exponent - 1)) - 1
|
||||||
|
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
||||||
|
|
||||||
|
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
|
||||||
|
|
||||||
|
# shift the mantissa and exponent into the proper positions for an
|
||||||
|
# IEEE double and bitwise-or them together.
|
||||||
|
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
|
||||||
|
|
||||||
|
def _floating_point_max(self) -> float:
|
||||||
|
double_raw = self._floating_point_max_int()
|
||||||
|
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
|
||||||
|
|
||||||
|
def _raw_max(self) -> Union[int, float]:
|
||||||
|
if self.is_floating_point():
|
||||||
|
return self._floating_point_max()
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
|
||||||
|
), "Cannot represent max as an int"
|
||||||
|
return (1 << self.mantissa) - 1
|
||||||
|
|
||||||
|
def _raw_min(self) -> Union[int, float]:
|
||||||
|
if self.is_floating_point():
|
||||||
|
assert (
|
||||||
|
self.is_signed()
|
||||||
|
), "We currently assume all floating point types are signed"
|
||||||
|
sign_bit_double = 1 << 63
|
||||||
|
|
||||||
|
max_raw = self._floating_point_max_int()
|
||||||
|
min_raw = max_raw | sign_bit_double
|
||||||
|
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not self.is_signed() or self.size_bits <= 64
|
||||||
|
), "Cannot represent min as a int64_t"
|
||||||
|
|
||||||
|
if self.is_signed():
|
||||||
|
return -(1 << (self.size_bits - 1))
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def id(self) -> int:
|
||||||
|
"""
|
||||||
|
Convert the ScalarType to an int which can be passed to pytorch custom
|
||||||
|
ops. This layout of the int must be kept in sync with the C++
|
||||||
|
ScalarType's from_id method.
|
||||||
|
"""
|
||||||
|
val = 0
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
def or_and_advance(member, bit_width):
|
||||||
|
nonlocal val
|
||||||
|
nonlocal offset
|
||||||
|
bit_mask = (1 << bit_width) - 1
|
||||||
|
val = val | (int(member) & bit_mask) << offset
|
||||||
|
offset = offset + bit_width
|
||||||
|
|
||||||
|
or_and_advance(self.exponent, 8)
|
||||||
|
or_and_advance(self.mantissa, 8)
|
||||||
|
or_and_advance(self.signed, 1)
|
||||||
|
or_and_advance(self.bias, 32)
|
||||||
|
or_and_advance(self._finite_values_only, 1)
|
||||||
|
or_and_advance(self.nan_repr.value, 8)
|
||||||
|
|
||||||
|
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size_bits(self) -> int:
|
||||||
|
return self.exponent + self.mantissa + int(self.signed)
|
||||||
|
|
||||||
|
def min(self) -> Union[int, float]:
|
||||||
|
"""
|
||||||
|
Min representable value for this scalar type.
|
||||||
|
(accounting for bias if there is one)
|
||||||
|
"""
|
||||||
|
return self._raw_min() - self.bias
|
||||||
|
|
||||||
|
def max(self) -> Union[int, float]:
|
||||||
|
"""
|
||||||
|
Max representable value for this scalar type.
|
||||||
|
(accounting for bias if there is one)
|
||||||
|
"""
|
||||||
|
return self._raw_max() - self.bias
|
||||||
|
|
||||||
|
def is_signed(self) -> bool:
|
||||||
|
"""
|
||||||
|
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||||
|
added for consistency with:
|
||||||
|
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||||
|
"""
|
||||||
|
return self.signed
|
||||||
|
|
||||||
|
def is_floating_point(self) -> bool:
|
||||||
|
"If the type is a floating point type"
|
||||||
|
return self.exponent != 0
|
||||||
|
|
||||||
|
def is_integer(self) -> bool:
|
||||||
|
"If the type is an integer type"
|
||||||
|
return self.exponent == 0
|
||||||
|
|
||||||
|
def has_bias(self) -> bool:
|
||||||
|
"If the type has a non-zero bias"
|
||||||
|
return self.bias != 0
|
||||||
|
|
||||||
|
def has_infs(self) -> bool:
|
||||||
|
"If the type is floating point and supports infinity"
|
||||||
|
return not self._finite_values_only
|
||||||
|
|
||||||
|
def has_nans(self) -> bool:
|
||||||
|
return self.nan_repr != NanRepr.NONE.value
|
||||||
|
|
||||||
|
def is_ieee_754(self) -> bool:
|
||||||
|
"""
|
||||||
|
If the type is a floating point type that follows IEEE 754
|
||||||
|
conventions
|
||||||
|
"""
|
||||||
|
return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""
|
||||||
|
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||||
|
for floating point types (leading f) the scheme is:
|
||||||
|
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||||
|
flags:
|
||||||
|
- no-flags: means it follows IEEE 754 conventions
|
||||||
|
- f: means finite values only (no infinities)
|
||||||
|
- n: means nans are supported (non-standard encoding)
|
||||||
|
for integer types the scheme is:
|
||||||
|
`[u]int<size_bits>[b<bias>]`
|
||||||
|
- if bias is not present it means its zero
|
||||||
|
"""
|
||||||
|
if self.is_floating_point():
|
||||||
|
ret = (
|
||||||
|
"float"
|
||||||
|
+ str(self.size_bits)
|
||||||
|
+ "_e"
|
||||||
|
+ str(self.exponent)
|
||||||
|
+ "m"
|
||||||
|
+ str(self.mantissa)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.is_ieee_754():
|
||||||
|
if self._finite_values_only:
|
||||||
|
ret = ret + "f"
|
||||||
|
if self.nan_repr != NanRepr.NONE:
|
||||||
|
ret = ret + "n"
|
||||||
|
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
||||||
|
if self.has_bias():
|
||||||
|
ret = ret + "b" + str(self.bias)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "ScalarType." + self.__str__()
|
||||||
|
|
||||||
|
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||||
|
# opcheck to work.
|
||||||
|
def __len__(self) -> int:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
#
|
||||||
|
# Convenience Constructors
|
||||||
|
#
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||||
|
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||||
|
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
||||||
|
ret.id # noqa B018: make sure the id is cached
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||||
|
"""Create a unsigned integer scalar type."""
|
||||||
|
ret = cls(0, size_bits, False, bias if bias else 0)
|
||||||
|
ret.id # noqa B018: make sure the id is cached
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
|
||||||
|
"""
|
||||||
|
Create a standard floating point type
|
||||||
|
(i.e. follows IEEE 754 conventions).
|
||||||
|
"""
|
||||||
|
assert mantissa > 0 and exponent > 0
|
||||||
|
ret = cls(exponent, mantissa, True, 0)
|
||||||
|
ret.id # noqa B018: make sure the id is cached
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def float_(
|
||||||
|
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
|
||||||
|
) -> "ScalarType":
|
||||||
|
"""
|
||||||
|
Create a non-standard floating point type
|
||||||
|
(i.e. does not follow IEEE 754 conventions).
|
||||||
|
"""
|
||||||
|
assert mantissa > 0 and exponent > 0
|
||||||
|
assert nan_repr != NanRepr.IEEE_754, (
|
||||||
|
"use `float_IEEE754` constructor for floating point types that "
|
||||||
|
"follow IEEE 754 conventions"
|
||||||
|
)
|
||||||
|
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
||||||
|
ret.id # noqa B018: make sure the id is cached
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||||
|
# for floating point types (leading f) the scheme is:
|
||||||
|
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||||
|
# flags:
|
||||||
|
# - no-flags: means it follows IEEE 754 conventions
|
||||||
|
# - f: means finite values only (no infinities)
|
||||||
|
# - n: means nans are supported (non-standard encoding)
|
||||||
|
# for integer types the scheme is:
|
||||||
|
# `[u]int<size_bits>[b<bias>]`
|
||||||
|
# - if bias is not present it means its zero
|
||||||
|
|
||||||
|
|
||||||
|
class scalar_types:
|
||||||
|
int4 = ScalarType.int_(4, None)
|
||||||
|
uint4 = ScalarType.uint(4, None)
|
||||||
|
int8 = ScalarType.int_(8, None)
|
||||||
|
uint8 = ScalarType.uint(8, None)
|
||||||
|
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||||
|
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||||
|
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||||
|
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||||
|
|
||||||
|
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||||
|
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||||
|
|
||||||
|
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||||
|
float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
||||||
|
|
||||||
|
# "gptq" types
|
||||||
|
uint2b2 = ScalarType.uint(2, 2)
|
||||||
|
uint3b4 = ScalarType.uint(3, 4)
|
||||||
|
uint4b8 = ScalarType.uint(4, 8)
|
||||||
|
uint8b128 = ScalarType.uint(8, 128)
|
||||||
|
|
||||||
|
# colloquial names
|
||||||
|
bfloat16 = float16_e8m7
|
||||||
|
float16 = float16_e5m10
|
||||||
Reference in New Issue
Block a user