[XPU][CPU] Enable the native path of DeepSeek (#4086)
Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
@@ -7,6 +7,8 @@ import torch
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_fp8_matmul,
|
||||
per_token_group_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
@@ -15,35 +17,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
|
||||
|
||||
# For test
|
||||
def native_per_token_group_quant_fp8(
|
||||
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
|
||||
):
|
||||
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
|
||||
|
||||
It converts the tensor values into float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
Note that only `torch.float8_e4m3fn` is supported for now.
|
||||
"""
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
@@ -154,62 +127,6 @@ class TestStaticQuantFP8(unittest.TestCase):
|
||||
self._static_quant_fp8(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (N,)
|
||||
A = A.reshape(M, A.shape[-1])
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0]
|
||||
assert k_tiles == Bs.shape[1]
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||
|
||||
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
|
||||
B_tiles = [
|
||||
[
|
||||
B[
|
||||
j * block_n : min((j + 1) * block_n, N),
|
||||
i * block_k : min((i + 1) * block_k, K),
|
||||
]
|
||||
for i in range(k_tiles)
|
||||
]
|
||||
for j in range(n_tiles)
|
||||
]
|
||||
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
|
||||
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
a = A_tiles[i]
|
||||
b = B_tiles[j][i]
|
||||
c = C_tiles[j]
|
||||
s = As_tiles[i] * Bs[j][i]
|
||||
c[:, :] += torch.matmul(a, b.t()) * s
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
|
||||
class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||
|
||||
if not _is_cuda:
|
||||
|
||||
Reference in New Issue
Block a user