[1/n] apply wna16marlin kernel in moe weight only quantization (#7683)
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com> Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: HandH1998 <1335248067@qq.com> Co-authored-by: 弋云 <yiyun.wyt@antgroup.com> Co-authored-by: walker-ai <2398833647@qq.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from sgl_kernel.elementwise import (
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
@@ -55,6 +56,11 @@ from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_mla,
|
||||
)
|
||||
from sgl_kernel.marlin import (
|
||||
awq_marlin_moe_repack,
|
||||
awq_marlin_repack,
|
||||
gptq_marlin_repack,
|
||||
)
|
||||
from sgl_kernel.moe import (
|
||||
apply_shuffle_mul_sum,
|
||||
cutlass_fp4_group_mm,
|
||||
|
||||
223
sgl-kernel/python/sgl_kernel/fused_moe.py
Normal file
223
sgl-kernel/python/sgl_kernel/fused_moe.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
if has_zp:
|
||||
assert num_bits == 4
|
||||
return scalar_types.uint4
|
||||
else:
|
||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- w1_scale (torch.Tensor): Scale to be used for w1.
|
||||
- w2_scale (torch.Tensor): Scale to be used for w2.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
||||
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
||||
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
||||
permutation.
|
||||
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
|
||||
permutation.
|
||||
- topk_weights (torch.Tensor): Top-k weights.
|
||||
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
moe_align_block_size,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[1] == w2.shape[2] // (
|
||||
num_bits // 2
|
||||
), "Hidden size mismatch w2"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = w2.shape[1] * 16
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
None,
|
||||
is_marlin=True,
|
||||
)
|
||||
config = get_config_func(M)
|
||||
|
||||
block_size_m = config["BLOCK_SIZE_M"]
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, block_size_m, global_num_experts
|
||||
)
|
||||
|
||||
if workspace is None:
|
||||
max_workspace_size = (max(2 * N, K) // 64) * (
|
||||
sorted_token_ids.size(0) // block_size_m
|
||||
)
|
||||
device = hidden_states.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
max_workspace_size = min(max_workspace_size, sms * 4)
|
||||
workspace = torch.zeros(
|
||||
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
|
||||
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
|
||||
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * topk_ids.shape[1] * max(2 * N, K),),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = intermediate_cache13[: M * topk_ids.shape[1] * 2 * N]
|
||||
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
|
||||
intermediate_cache3 = intermediate_cache13[: M * topk_ids.shape[1] * K]
|
||||
intermediate_cache3 = intermediate_cache3.view(-1, K)
|
||||
|
||||
use_atomic_add = (
|
||||
hidden_states.dtype == torch.half
|
||||
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
)
|
||||
|
||||
intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
|
||||
hidden_states,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
w1_scale,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type_id=scalar_type1.id,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_full_k=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
intermediate_cache3 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
|
||||
intermediate_cache2,
|
||||
intermediate_cache3,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=True,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type_id=scalar_type2.id,
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_full_k=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
).view(-1, topk, K)
|
||||
|
||||
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
return torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output
|
||||
)
|
||||
|
||||
|
||||
def fused_marlin_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
44
sgl-kernel/python/sgl_kernel/marlin.py
Normal file
44
sgl-kernel/python/sgl_kernel/marlin.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
|
||||
|
||||
def gptq_marlin_repack(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
):
|
||||
torch.ops.sgl_kernel.gptq_marlin_repack.default(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
)
|
||||
|
||||
|
||||
def awq_marlin_repack(
|
||||
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
def awq_marlin_moe_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
output = torch.empty(
|
||||
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
||||
device=b_q_weight.device,
|
||||
dtype=b_q_weight.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
|
||||
b_q_weight[e], size_k, size_n, num_bits
|
||||
)
|
||||
return output
|
||||
352
sgl-kernel/python/sgl_kernel/scalar_type.py
Normal file
352
sgl-kernel/python/sgl_kernel/scalar_type.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
_SCALAR_TYPES_ID_MAP = {}
|
||||
|
||||
|
||||
# Mirrors enum in `core/scalar_type.hpp`
|
||||
class NanRepr(Enum):
|
||||
NONE = 0 # nans are not supported
|
||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||
|
||||
|
||||
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
||||
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
||||
# in sync until the inductor fully supports custom C++ classes.
|
||||
@dataclass(frozen=True)
|
||||
class ScalarType:
|
||||
"""
|
||||
ScalarType can represent a wide range of floating point and integer
|
||||
types, in particular it can be used to represent sub-byte data types
|
||||
(something that torch.dtype currently does not support). It is also
|
||||
capable of representing types with a bias, i.e.:
|
||||
`stored_value = value + bias`,
|
||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||
of 8). The implementation for this class can be found in
|
||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||
with that file.
|
||||
"""
|
||||
|
||||
exponent: int
|
||||
"""
|
||||
Number of bits in the exponent if this is a floating point type
|
||||
(zero if this an integer type)
|
||||
"""
|
||||
|
||||
mantissa: int
|
||||
"""
|
||||
Number of bits in the mantissa if this is a floating point type,
|
||||
or the number bits representing an integer excluding the sign bit if
|
||||
this an integer type.
|
||||
"""
|
||||
|
||||
signed: bool
|
||||
"If the type is signed (i.e. has a sign bit)"
|
||||
|
||||
bias: int
|
||||
"""
|
||||
bias used to encode the values in this scalar type
|
||||
(value = stored_value - bias, default 0) for example if we store the
|
||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||
"""
|
||||
|
||||
_finite_values_only: bool = False
|
||||
"""
|
||||
Private: if infs are supported, used `has_infs()` instead.
|
||||
"""
|
||||
|
||||
nan_repr: NanRepr = NanRepr.IEEE_754
|
||||
"""
|
||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||
(not applicable for integer types)
|
||||
"""
|
||||
|
||||
def _floating_point_max_int(self) -> int:
|
||||
assert (
|
||||
self.mantissa <= 52 and self.exponent <= 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
|
||||
max_mantissa = (1 << self.mantissa) - 1
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
||||
max_mantissa = max_mantissa - 1
|
||||
|
||||
max_exponent = (1 << self.exponent) - 2
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
|
||||
assert (
|
||||
self.exponent < 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
max_exponent = max_exponent + 1
|
||||
|
||||
# adjust the exponent to match that of a double
|
||||
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
||||
# e is the exponent bits), there is some precedent for non-standard
|
||||
# biases, example `float8_e4m3b11fnuz` here:
|
||||
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
||||
# complication we are just assuming the standard exponent bias until
|
||||
# there is a need to support non-standard biases
|
||||
exponent_bias = (1 << (self.exponent - 1)) - 1
|
||||
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
||||
|
||||
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
|
||||
|
||||
# shift the mantissa and exponent into the proper positions for an
|
||||
# IEEE double and bitwise-or them together.
|
||||
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
|
||||
|
||||
def _floating_point_max(self) -> float:
|
||||
double_raw = self._floating_point_max_int()
|
||||
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
|
||||
|
||||
def _raw_max(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
return self._floating_point_max()
|
||||
else:
|
||||
assert (
|
||||
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
|
||||
), "Cannot represent max as an int"
|
||||
return (1 << self.mantissa) - 1
|
||||
|
||||
def _raw_min(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
assert (
|
||||
self.is_signed()
|
||||
), "We currently assume all floating point types are signed"
|
||||
sign_bit_double = 1 << 63
|
||||
|
||||
max_raw = self._floating_point_max_int()
|
||||
min_raw = max_raw | sign_bit_double
|
||||
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
|
||||
else:
|
||||
assert (
|
||||
not self.is_signed() or self.size_bits <= 64
|
||||
), "Cannot represent min as a int64_t"
|
||||
|
||||
if self.is_signed():
|
||||
return -(1 << (self.size_bits - 1))
|
||||
else:
|
||||
return 0
|
||||
|
||||
@functools.cached_property
|
||||
def id(self) -> int:
|
||||
"""
|
||||
Convert the ScalarType to an int which can be passed to pytorch custom
|
||||
ops. This layout of the int must be kept in sync with the C++
|
||||
ScalarType's from_id method.
|
||||
"""
|
||||
val = 0
|
||||
offset = 0
|
||||
|
||||
def or_and_advance(member, bit_width):
|
||||
nonlocal val
|
||||
nonlocal offset
|
||||
bit_mask = (1 << bit_width) - 1
|
||||
val = val | (int(member) & bit_mask) << offset
|
||||
offset = offset + bit_width
|
||||
|
||||
or_and_advance(self.exponent, 8)
|
||||
or_and_advance(self.mantissa, 8)
|
||||
or_and_advance(self.signed, 1)
|
||||
or_and_advance(self.bias, 32)
|
||||
or_and_advance(self._finite_values_only, 1)
|
||||
or_and_advance(self.nan_repr.value, 8)
|
||||
|
||||
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
|
||||
|
||||
_SCALAR_TYPES_ID_MAP[val] = self
|
||||
|
||||
return val
|
||||
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.exponent + self.mantissa + int(self.signed)
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
"""
|
||||
Min representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_min() - self.bias
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
"""
|
||||
Max representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_max() - self.bias
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
"""
|
||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||
added for consistency with:
|
||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||
"""
|
||||
return self.signed
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"If the type is a floating point type"
|
||||
return self.exponent != 0
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
"If the type is an integer type"
|
||||
return self.exponent == 0
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
"If the type has a non-zero bias"
|
||||
return self.bias != 0
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
"If the type is floating point and supports infinity"
|
||||
return not self._finite_values_only
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.nan_repr != NanRepr.NONE.value
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
"""
|
||||
If the type is a floating point type that follows IEEE 754
|
||||
conventions
|
||||
"""
|
||||
return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
for floating point types (leading f) the scheme is:
|
||||
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
flags:
|
||||
- no-flags: means it follows IEEE 754 conventions
|
||||
- f: means finite values only (no infinities)
|
||||
- n: means nans are supported (non-standard encoding)
|
||||
for integer types the scheme is:
|
||||
`[u]int<size_bits>[b<bias>]`
|
||||
- if bias is not present it means its zero
|
||||
"""
|
||||
if self.is_floating_point():
|
||||
ret = (
|
||||
"float"
|
||||
+ str(self.size_bits)
|
||||
+ "_e"
|
||||
+ str(self.exponent)
|
||||
+ "m"
|
||||
+ str(self.mantissa)
|
||||
)
|
||||
|
||||
if not self.is_ieee_754():
|
||||
if self._finite_values_only:
|
||||
ret = ret + "f"
|
||||
if self.nan_repr != NanRepr.NONE:
|
||||
ret = ret + "n"
|
||||
|
||||
return ret
|
||||
else:
|
||||
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
||||
if self.has_bias():
|
||||
ret = ret + "b" + str(self.bias)
|
||||
return ret
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "ScalarType." + self.__str__()
|
||||
|
||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||
# opcheck to work.
|
||||
def __len__(self) -> int:
|
||||
raise TypeError
|
||||
|
||||
#
|
||||
# Convenience Constructors
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"""Create a unsigned integer scalar type."""
|
||||
ret = cls(0, size_bits, False, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
|
||||
"""
|
||||
Create a standard floating point type
|
||||
(i.e. follows IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
ret = cls(exponent, mantissa, True, 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_(
|
||||
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
|
||||
) -> "ScalarType":
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
assert nan_repr != NanRepr.IEEE_754, (
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions"
|
||||
)
|
||||
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_id(cls, scalar_type_id: int):
|
||||
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
||||
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
|
||||
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
||||
|
||||
|
||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
# for floating point types (leading f) the scheme is:
|
||||
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
# flags:
|
||||
# - no-flags: means it follows IEEE 754 conventions
|
||||
# - f: means finite values only (no infinities)
|
||||
# - n: means nans are supported (non-standard encoding)
|
||||
# for integer types the scheme is:
|
||||
# `[u]int<size_bits>[b<bias>]`
|
||||
# - if bias is not present it means its zero
|
||||
|
||||
|
||||
class scalar_types:
|
||||
int4 = ScalarType.int_(4, None)
|
||||
uint4 = ScalarType.uint(4, None)
|
||||
int8 = ScalarType.int_(8, None)
|
||||
uint8 = ScalarType.uint(8, None)
|
||||
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||
|
||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||
|
||||
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
||||
|
||||
# "gptq" types
|
||||
uint2b2 = ScalarType.uint(2, 2)
|
||||
uint3b4 = ScalarType.uint(3, 4)
|
||||
uint4b8 = ScalarType.uint(4, 8)
|
||||
uint8b128 = ScalarType.uint(8, 128)
|
||||
|
||||
# colloquial names
|
||||
bfloat16 = float16_e8m7
|
||||
float16 = float16_e5m10
|
||||
Reference in New Issue
Block a user