[XPU][CPU] Enable the native path of DeepSeek (#4086)

Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
Meng, Hengyu
2025-03-13 13:26:29 +08:00
committed by GitHub
parent c76040e31b
commit 71046fcd71
16 changed files with 501 additions and 223 deletions

View File

@@ -54,8 +54,18 @@ class QuantizationConfig(ABC):
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
This requirement is due to the custom kernels used by the
quantization method or the stock pytorch capability.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def get_availability(cls) -> bool:
"""Whether the quantization config is available on current device.
This requirement is due to the custom kernels used by the
quantization method or the stock pytorch capability.
"""
raise NotImplementedError

View File

@@ -1,6 +1,7 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging
import sys
from typing import Any, Callable, Dict, List, Optional
import torch
@@ -19,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import get_device_capability, set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -71,7 +72,20 @@ class BlockInt8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:

View File

@@ -1,6 +1,7 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging
import sys
from typing import Any, Callable, Dict, List, Optional
import torch
@@ -36,7 +37,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_kernel import (
native_per_token_group_quant_fp8,
native_w8a8_block_fp8_matmul,
per_token_group_quant_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
apply_w8a8_block_fp8_linear,
@@ -46,6 +51,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
)
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
is_cuda,
is_hip,
permute_weight,
print_warning_once,
@@ -55,6 +62,7 @@ from sglang.srt.utils import (
ACTIVATION_SCHEMES = ["static", "dynamic"]
_is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe
@@ -108,7 +116,24 @@ class Fp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80
if hasattr(torch, "xpu") and torch.xpu.is_available():
return 0
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
if hasattr(torch, "xpu") and torch.xpu.is_available():
return True
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:
@@ -850,6 +875,52 @@ class Fp8MoEMethod:
)
torch.cuda.empty_cache()
def torch_w8a8_block_fp8_moe(
self, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape
):
from sglang.srt.layers.activation import SiluAndMul
"""This function performs fused moe with block-wise quantization using native torch."""
B, D = a.shape
topk = topk_ids.shape[-1]
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
# NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``.
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_fp8_matmul(
a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype,
)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k
)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_fp8_matmul(
act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype,
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def apply(
self,
layer: torch.nn.Module,
@@ -923,7 +994,7 @@ class Fp8MoEMethod:
layer.w13_weight_scale1,
layer.w2_weight_scale1,
)
else:
elif _is_cuda:
# Expert fusion with FP8 quantization
return fused_experts(
x,
@@ -950,6 +1021,24 @@ class Fp8MoEMethod:
no_combine=no_combine,
)
# for CPU and other accelerators, fallback to native path
return self.torch_w8a8_block_fp8_moe(
a=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_s=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_s=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
topk_weight=topk_weights,
topk_ids=topk_ids,
block_shape=self.quant_config.weight_block_size,
)
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""

View File

@@ -28,10 +28,13 @@ from sglang.srt.utils import (
get_device_name,
is_cuda,
is_hip,
is_triton_available,
supports_custom_op,
)
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_triton = is_triton_available()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = is_cuda()
@@ -162,6 +165,34 @@ def _per_token_group_quant_fp8_colmajor(
tl.store(y_s_ptr, y_s)
def native_per_token_group_quant_fp8(
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
):
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Note that only `torch.float8_e4m3fn` is supported for now.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
@@ -232,19 +263,22 @@ def per_token_group_quant_fp8(
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
if _is_triton:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
x_q, x_s = native_per_token_group_quant_fp8(x, group_size)
return x_q, x_s
@@ -691,6 +725,61 @@ def get_w8a8_block_fp8_configs(
return None
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
@@ -715,86 +804,90 @@ def w8a8_block_fp8_matmul(
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
if _is_triton: # pragma: no cover
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
# deepgemm only support bf16
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
else:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
# deepgemm only support bf16
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
else:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
else:
C = native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
return C

View File

@@ -1,4 +1,5 @@
import logging
import sys
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
@@ -8,6 +9,7 @@ from vllm.scalar_type import scalar_types
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import get_device_capability
logger = logging.getLogger(__name__)
@@ -90,7 +92,20 @@ class GPTQConfig(QuantizationConfig):
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 60
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 60
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 60
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:
@@ -209,7 +224,20 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:
@@ -371,7 +399,20 @@ class MarlinConfig(QuantizationConfig):
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:

View File

@@ -1,6 +1,7 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
import logging
import sys
from typing import Any, Dict, List, Optional
import torch
@@ -20,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
from sglang.srt.utils import get_device_capability
# Initialize logger for the module
logger = logging.getLogger(__name__)
@@ -52,7 +54,20 @@ class ModelOptFp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 89
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 89
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:

View File

@@ -14,7 +14,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import get_device_capability, is_hip
_is_hip = is_hip()
@@ -35,7 +35,20 @@ class W8A8Fp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 89
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 89
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 89
# Vendors can update
return False
@classmethod
def get_name(self) -> str:

View File

@@ -1,3 +1,4 @@
import sys
from typing import Any, Callable, Dict, List, Optional
import torch
@@ -18,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import get_device_capability
class W8A8Int8Config(QuantizationConfig):
@@ -36,7 +38,20 @@ class W8A8Int8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 75
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 75
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> int:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 75
# Vendors can update
return False
@classmethod
def get_name(self) -> str: