[XPU][CPU] Enable the native path of DeepSeek (#4086)
Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
@@ -57,7 +57,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1
|
||||
|
||||
# xpu is not enabled in public vllm and torch whl,
|
||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||
srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
|
||||
srt_xpu = ["sglang[runtime_common]", "vllm>=0.6.4.post1,<=0.7.2", "outlines>=0.0.44,<=0.1.11"]
|
||||
|
||||
# For Intel Gaudi(device : hpu) follow the installation guide
|
||||
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
|
||||
|
||||
@@ -54,8 +54,18 @@ class QuantizationConfig(ABC):
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
This requirement is due to the custom kernels used by the
|
||||
quantization method or the stock pytorch capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_availability(cls) -> bool:
|
||||
"""Whether the quantization config is available on current device.
|
||||
|
||||
This requirement is due to the custom kernels used by the
|
||||
quantization method or the stock pytorch capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -19,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import get_device_capability, set_weight_attrs
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -71,7 +72,20 @@ class BlockInt8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -36,7 +37,11 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_fp8_matmul,
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
apply_w8a8_block_fp8_linear,
|
||||
@@ -46,6 +51,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_device_capability,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
permute_weight,
|
||||
print_warning_once,
|
||||
@@ -55,6 +62,7 @@ from sglang.srt.utils import (
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_hip:
|
||||
from aiter.fused_moe_bf16_asm import asm_moe
|
||||
@@ -108,7 +116,24 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return 0
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return True
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -850,6 +875,52 @@ class Fp8MoEMethod:
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def torch_w8a8_block_fp8_moe(
|
||||
self, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape
|
||||
):
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
|
||||
"""This function performs fused moe with block-wise quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
topk = topk_ids.shape[-1]
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||
# NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``.
|
||||
a_q = a_q.to(torch.float32)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_fp8_matmul(
|
||||
a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype,
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||
act_out, block_k
|
||||
)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_fp8_matmul(
|
||||
act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype,
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -923,7 +994,7 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
)
|
||||
else:
|
||||
elif _is_cuda:
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
@@ -950,6 +1021,24 @@ class Fp8MoEMethod:
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
# for CPU and other accelerators, fallback to native path
|
||||
return self.torch_w8a8_block_fp8_moe(
|
||||
a=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_s=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_s=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
topk_weight=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
|
||||
@@ -28,10 +28,13 @@ from sglang.srt.utils import (
|
||||
get_device_name,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_triton_available,
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_triton = is_triton_available()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -162,6 +165,34 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
tl.store(y_s_ptr, y_s)
|
||||
|
||||
|
||||
def native_per_token_group_quant_fp8(
|
||||
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
|
||||
):
|
||||
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
|
||||
|
||||
It converts the tensor values into float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
Note that only `torch.float8_e4m3fn` is supported for now.
|
||||
"""
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
@@ -232,19 +263,22 @@ def per_token_group_quant_fp8(
|
||||
num_stages=num_stages,
|
||||
)
|
||||
else:
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
N,
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
if _is_triton:
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
N,
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
else:
|
||||
x_q, x_s = native_per_token_group_quant_fp8(x, group_size)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
@@ -691,6 +725,61 @@ def get_w8a8_block_fp8_configs(
|
||||
return None
|
||||
|
||||
|
||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (N,)
|
||||
A = A.reshape(M, A.shape[-1])
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0]
|
||||
assert k_tiles == Bs.shape[1]
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||
|
||||
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
|
||||
B_tiles = [
|
||||
[
|
||||
B[
|
||||
j * block_n : min((j + 1) * block_n, N),
|
||||
i * block_k : min((i + 1) * block_k, K),
|
||||
]
|
||||
for i in range(k_tiles)
|
||||
]
|
||||
for j in range(n_tiles)
|
||||
]
|
||||
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
|
||||
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
a = A_tiles[i]
|
||||
b = B_tiles[j][i]
|
||||
c = C_tiles[j]
|
||||
s = As_tiles[i] * Bs[j][i]
|
||||
c[:, :] += torch.matmul(a, b.t()) * s
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@@ -715,86 +804,90 @@ def w8a8_block_fp8_matmul(
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
if _is_triton: # pragma: no cover
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||
# compute units available on the device.
|
||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, config["BLOCK_SIZE_N"]
|
||||
)
|
||||
|
||||
# deepgemm only support bf16
|
||||
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||
if supports_custom_op():
|
||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
else:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||
# compute units available on the device.
|
||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, config["BLOCK_SIZE_N"]
|
||||
)
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
# deepgemm only support bf16
|
||||
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||
if supports_custom_op():
|
||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||
else:
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||
else:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
else:
|
||||
C = native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
|
||||
|
||||
return C
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import sys
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@@ -8,6 +9,7 @@ from vllm.scalar_type import scalar_types
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,7 +92,20 @@ class GPTQConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 60
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 60
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -209,7 +224,20 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -371,7 +399,20 @@ class MarlinConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 80
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 80
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -20,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -52,7 +54,20 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 89
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 89
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
|
||||
@@ -14,7 +14,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import get_device_capability, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -35,7 +35,20 @@ class W8A8Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 89
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> bool:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 89
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -18,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
|
||||
class W8A8Int8Config(QuantizationConfig):
|
||||
@@ -36,7 +38,20 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return 75
|
||||
|
||||
# Vendors can update
|
||||
return sys.maxsize
|
||||
|
||||
@classmethod
|
||||
def get_availability(cls) -> int:
|
||||
major, minor = get_device_capability()
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return major * 10 + minor > 75
|
||||
|
||||
# Vendors can update
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
|
||||
@@ -69,6 +69,7 @@ class RotaryEmbedding(CustomOp):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
@@ -77,6 +78,7 @@ class RotaryEmbedding(CustomOp):
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
||||
@@ -283,12 +285,19 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factors: Union[List[float], float],
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
if isinstance(scaling_factors, float):
|
||||
scaling_factors = [scaling_factors]
|
||||
self.scaling_factors: List[float] = scaling_factors # noqa
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
# Lazy initialized.
|
||||
self._scaling_factor_to_offset: Dict[float, int]
|
||||
@@ -347,10 +356,17 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
@@ -434,6 +450,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
@@ -448,7 +465,13 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
@@ -645,6 +668,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
@@ -652,7 +676,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
@@ -665,9 +688,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
||||
* attn_factor
|
||||
)
|
||||
self.device = device
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
@@ -762,6 +790,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
scaling_factor: float,
|
||||
low_freq_factor: float,
|
||||
high_freq_factor: float,
|
||||
@@ -772,7 +801,13 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
self.high_freq_factor = high_freq_factor
|
||||
self.orig_max_position = orig_max_position
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
str,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
@@ -810,10 +845,17 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
mrope_section: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
self.mrope_section = mrope_section
|
||||
@@ -1003,9 +1045,14 @@ def get_rope(
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: str = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if device is None:
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
device = global_server_args_dict["device"]
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
@@ -1030,7 +1077,7 @@ def get_rope(
|
||||
|
||||
if rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype, device
|
||||
)
|
||||
else:
|
||||
if "rope_type" in rope_scaling:
|
||||
@@ -1052,6 +1099,7 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
scaling_factor,
|
||||
low_freq_factor,
|
||||
high_freq_factor,
|
||||
@@ -1066,6 +1114,7 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
)
|
||||
else:
|
||||
@@ -1076,6 +1125,7 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@@ -1087,6 +1137,7 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@@ -1098,6 +1149,7 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@@ -1116,6 +1168,7 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
@@ -1143,6 +1196,7 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "longrope":
|
||||
@@ -1253,21 +1307,8 @@ def get_rope_wrapper(
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
if device != "cpu":
|
||||
return get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
)
|
||||
|
||||
return get_rope_cpu(
|
||||
return get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
@@ -1276,5 +1317,4 @@ def get_rope_wrapper(
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
device,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,8 @@ from sglang.srt.distributed import (
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_group,
|
||||
get_attention_tp_size,
|
||||
|
||||
@@ -108,15 +108,25 @@ def _get_quantization_config(
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
major, minor = get_device_capability()
|
||||
|
||||
if major is not None and minor is not None:
|
||||
assert 0 <= minor < 10
|
||||
capability = major * 10 + minor
|
||||
if capability < quant_config.get_min_capability():
|
||||
if not hasattr(quant_config, "get_availability"):
|
||||
# Update VLLM to support get_available
|
||||
if major is not None and minor is not None:
|
||||
assert 0 <= minor < 10
|
||||
capability = major * 10 + minor
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} "
|
||||
"is not supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}."
|
||||
)
|
||||
else:
|
||||
if not quant_config.get_availability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} "
|
||||
"is not supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}."
|
||||
f"Current capability: {major, minor}."
|
||||
)
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
|
||||
@@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import (
|
||||
block_dequant as int8_block_dequant,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -305,7 +305,6 @@ class DeepseekV2Attention(nn.Module):
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=global_server_args_dict["device"],
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
@@ -501,7 +500,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
@@ -646,19 +645,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
||||
if self.w_kc.dtype == torch.float8_e4m3fnuz: # hip only
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn and is_cuda_available():
|
||||
q_nope_val, q_nope_scale = input_to_float8(
|
||||
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
)
|
||||
|
||||
else:
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||
@@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
||||
if self.w_vc.dtype == torch.float8_e4m3fnuz or not is_cuda_available():
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
attn_bmm_output = torch.bmm(
|
||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn and is_cuda_available():
|
||||
attn_output_val, attn_output_scale = input_to_float8(
|
||||
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
@@ -694,6 +694,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
else:
|
||||
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
||||
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||
|
||||
@@ -126,6 +126,14 @@ def is_cuda_available():
|
||||
return is_cuda()
|
||||
|
||||
|
||||
def is_triton_available():
|
||||
if is_cuda() or is_xpu() or is_hip():
|
||||
return get_bool_env_var("TRITON_AVAILABLE", default="true")
|
||||
else:
|
||||
# update once CPU/HPU supports triton
|
||||
return False
|
||||
|
||||
|
||||
def enable_show_time_cost():
|
||||
global show_time_cost
|
||||
show_time_cost = True
|
||||
@@ -1136,6 +1144,7 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
major, minor = None, None
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
assert 0 <= minor < 10
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
|
||||
|
||||
@@ -7,6 +7,8 @@ import torch
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_fp8_matmul,
|
||||
per_token_group_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
@@ -15,35 +17,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
|
||||
|
||||
# For test
|
||||
def native_per_token_group_quant_fp8(
|
||||
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
|
||||
):
|
||||
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
|
||||
|
||||
It converts the tensor values into float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
Note that only `torch.float8_e4m3fn` is supported for now.
|
||||
"""
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
@@ -154,62 +127,6 @@ class TestStaticQuantFP8(unittest.TestCase):
|
||||
self._static_quant_fp8(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (N,)
|
||||
A = A.reshape(M, A.shape[-1])
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0]
|
||||
assert k_tiles == Bs.shape[1]
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||
|
||||
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
|
||||
B_tiles = [
|
||||
[
|
||||
B[
|
||||
j * block_n : min((j + 1) * block_n, N),
|
||||
i * block_k : min((i + 1) * block_k, K),
|
||||
]
|
||||
for i in range(k_tiles)
|
||||
]
|
||||
for j in range(n_tiles)
|
||||
]
|
||||
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
|
||||
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
a = A_tiles[i]
|
||||
b = B_tiles[j][i]
|
||||
c = C_tiles[j]
|
||||
s = As_tiles[i] * Bs[j][i]
|
||||
c[:, :] += torch.matmul(a, b.t()) * s
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
|
||||
class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||
|
||||
if not _is_cuda:
|
||||
|
||||
Reference in New Issue
Block a user