[XPU][CPU] Enable the native path of DeepSeek (#4086)
Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
@@ -54,8 +54,18 @@ class QuantizationConfig(ABC):
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
This requirement is due to the custom kernels used by the
|
||||
quantization method or the stock pytorch capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_availability(cls) -> bool:
|
||||
"""Whether the quantization config is available on current device.
|
||||
|
||||
This requirement is due to the custom kernels used by the
|
||||
quantization method or the stock pytorch capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -19,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import get_device_capability, set_weight_attrs
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -71,7 +72,20 @@ class BlockInt8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -36,7 +37,11 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_fp8_matmul,
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
apply_w8a8_block_fp8_linear,
|
||||
@@ -46,6 +51,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_device_capability,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
permute_weight,
|
||||
print_warning_once,
|
||||
@@ -55,6 +62,7 @@ from sglang.srt.utils import (
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_hip:
|
||||
from aiter.fused_moe_bf16_asm import asm_moe
|
||||
@@ -108,7 +116,24 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return 0
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return True
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -850,6 +875,52 @@ class Fp8MoEMethod:
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def torch_w8a8_block_fp8_moe(
|
||||
self, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape
|
||||
):
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
|
||||
"""This function performs fused moe with block-wise quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
topk = topk_ids.shape[-1]
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||
# NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``.
|
||||
a_q = a_q.to(torch.float32)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_fp8_matmul(
|
||||
a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype,
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||
act_out, block_k
|
||||
)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_fp8_matmul(
|
||||
act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype,
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -923,7 +994,7 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
)
|
||||
else:
|
||||
elif _is_cuda:
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
@@ -950,6 +1021,24 @@ class Fp8MoEMethod:
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
# for CPU and other accelerators, fallback to native path
|
||||
return self.torch_w8a8_block_fp8_moe(
|
||||
a=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_s=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_s=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
topk_weight=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
|
||||
@@ -28,10 +28,13 @@ from sglang.srt.utils import (
|
||||
get_device_name,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_triton_available,
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_triton = is_triton_available()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -162,6 +165,34 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
tl.store(y_s_ptr, y_s)
|
||||
|
||||
|
||||
def native_per_token_group_quant_fp8(
|
||||
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
|
||||
):
|
||||
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
|
||||
|
||||
It converts the tensor values into float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
Note that only `torch.float8_e4m3fn` is supported for now.
|
||||
"""
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
@@ -232,19 +263,22 @@ def per_token_group_quant_fp8(
|
||||
num_stages=num_stages,
|
||||
)
|
||||
else:
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
N,
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
if _is_triton:
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
N,
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
else:
|
||||
x_q, x_s = native_per_token_group_quant_fp8(x, group_size)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
@@ -691,6 +725,61 @@ def get_w8a8_block_fp8_configs(
|
||||
return None
|
||||
|
||||
|
||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (N,)
|
||||
A = A.reshape(M, A.shape[-1])
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0]
|
||||
assert k_tiles == Bs.shape[1]
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||
|
||||
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
|
||||
B_tiles = [
|
||||
[
|
||||
B[
|
||||
j * block_n : min((j + 1) * block_n, N),
|
||||
i * block_k : min((i + 1) * block_k, K),
|
||||
]
|
||||
for i in range(k_tiles)
|
||||
]
|
||||
for j in range(n_tiles)
|
||||
]
|
||||
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
|
||||
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
a = A_tiles[i]
|
||||
b = B_tiles[j][i]
|
||||
c = C_tiles[j]
|
||||
s = As_tiles[i] * Bs[j][i]
|
||||
c[:, :] += torch.matmul(a, b.t()) * s
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@@ -715,86 +804,90 @@ def w8a8_block_fp8_matmul(
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
if _is_triton: # pragma: no cover
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||
# compute units available on the device.
|
||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, config["BLOCK_SIZE_N"]
|
||||
)
|
||||
|
||||
# deepgemm only support bf16
|
||||
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||
if supports_custom_op():
|
||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
else:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||
# compute units available on the device.
|
||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, config["BLOCK_SIZE_N"]
|
||||
)
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
# deepgemm only support bf16
|
||||
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||
if supports_custom_op():
|
||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||
else:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
else:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
else:
|
||||
C = native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
|
||||
|
||||
return C
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import sys
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@@ -8,6 +9,7 @@ from vllm.scalar_type import scalar_types
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,7 +92,20 @@ class GPTQConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 60
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 60
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -209,7 +224,20 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -371,7 +399,20 @@ class MarlinConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -20,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -52,7 +54,20 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 89
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 89
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
|
||||
@@ -14,7 +14,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import get_device_capability, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -35,7 +35,20 @@ class W8A8Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 89
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 89
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -18,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
|
||||
class W8A8Int8Config(QuantizationConfig):
|
||||
@@ -36,7 +38,20 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 75
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> int:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 75
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user