[XPU][CPU] Enable the native path of DeepSeek (#4086)
Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
@@ -52,6 +52,15 @@ cd ..
|
|||||||
pip install -e "python[all_hip]"
|
pip install -e "python[all_hip]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note: To Intel GPU, do following instead:
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/sgl-project/sglang.git
|
||||||
|
cd sglang
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install -e "python[all_xpu]"
|
||||||
|
```
|
||||||
|
|
||||||
## Method 3: Using docker
|
## Method 3: Using docker
|
||||||
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
|
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
|
||||||
Replace `<secret>` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens).
|
Replace `<secret>` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens).
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1
|
|||||||
|
|
||||||
# xpu is not enabled in public vllm and torch whl,
|
# xpu is not enabled in public vllm and torch whl,
|
||||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||||
srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
|
srt_xpu = ["sglang[runtime_common]", "vllm>=0.6.4.post1,<=0.7.2", "outlines>=0.0.44,<=0.1.11"]
|
||||||
|
|
||||||
# For Intel Gaudi(device : hpu) follow the installation guide
|
# For Intel Gaudi(device : hpu) follow the installation guide
|
||||||
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
|
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
|
||||||
|
|||||||
@@ -54,8 +54,18 @@ class QuantizationConfig(ABC):
|
|||||||
"""Minimum GPU capability to support the quantization method.
|
"""Minimum GPU capability to support the quantization method.
|
||||||
|
|
||||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||||
This requirement is due to the custom CUDA kernels used by the
|
This requirement is due to the custom kernels used by the
|
||||||
quantization method.
|
quantization method or the stock pytorch capability.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_availability(cls) -> bool:
|
||||||
|
"""Whether the quantization config is available on current device.
|
||||||
|
|
||||||
|
This requirement is due to the custom kernels used by the
|
||||||
|
quantization method or the stock pytorch capability.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -19,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import get_device_capability, set_weight_attrs
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
@@ -71,7 +72,20 @@ class BlockInt8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 80
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return 80
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return sys.maxsize
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_availability(cls) -> bool:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return major * 10 + minor > 80
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -36,7 +37,11 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
native_per_token_group_quant_fp8,
|
||||||
|
native_w8a8_block_fp8_matmul,
|
||||||
|
per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
apply_fp8_linear,
|
apply_fp8_linear,
|
||||||
apply_w8a8_block_fp8_linear,
|
apply_w8a8_block_fp8_linear,
|
||||||
@@ -46,6 +51,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
get_device_capability,
|
||||||
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
permute_weight,
|
permute_weight,
|
||||||
print_warning_once,
|
print_warning_once,
|
||||||
@@ -55,6 +62,7 @@ from sglang.srt.utils import (
|
|||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from aiter.fused_moe_bf16_asm import asm_moe
|
from aiter.fused_moe_bf16_asm import asm_moe
|
||||||
@@ -108,7 +116,24 @@ class Fp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 80
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return 80
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return sys.maxsize
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_availability(cls) -> bool:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return major * 10 + minor > 80
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
@@ -850,6 +875,52 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def torch_w8a8_block_fp8_moe(
|
||||||
|
self, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
|
|
||||||
|
"""This function performs fused moe with block-wise quantization using native torch."""
|
||||||
|
|
||||||
|
B, D = a.shape
|
||||||
|
topk = topk_ids.shape[-1]
|
||||||
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
topk_weight = topk_weight.view(-1)
|
||||||
|
topk_ids = topk_ids.view(-1)
|
||||||
|
|
||||||
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
|
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||||
|
# NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``.
|
||||||
|
a_q = a_q.to(torch.float32)
|
||||||
|
for i in range(w1.shape[0]):
|
||||||
|
mask = topk_ids == i
|
||||||
|
if mask.sum():
|
||||||
|
inter_out = native_w8a8_block_fp8_matmul(
|
||||||
|
a_q[mask],
|
||||||
|
w1[i],
|
||||||
|
a_s[mask],
|
||||||
|
w1_s[i],
|
||||||
|
block_shape,
|
||||||
|
output_dtype=a.dtype,
|
||||||
|
)
|
||||||
|
act_out = SiluAndMul().forward_native(inter_out)
|
||||||
|
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||||
|
act_out, block_k
|
||||||
|
)
|
||||||
|
act_out = act_out.to(torch.float32)
|
||||||
|
out[mask] = native_w8a8_block_fp8_matmul(
|
||||||
|
act_out_q,
|
||||||
|
w2[i],
|
||||||
|
act_out_s,
|
||||||
|
w2_s[i],
|
||||||
|
block_shape,
|
||||||
|
output_dtype=a.dtype,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||||
|
).sum(dim=1)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -923,7 +994,7 @@ class Fp8MoEMethod:
|
|||||||
layer.w13_weight_scale1,
|
layer.w13_weight_scale1,
|
||||||
layer.w2_weight_scale1,
|
layer.w2_weight_scale1,
|
||||||
)
|
)
|
||||||
else:
|
elif _is_cuda:
|
||||||
# Expert fusion with FP8 quantization
|
# Expert fusion with FP8 quantization
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
@@ -950,6 +1021,24 @@ class Fp8MoEMethod:
|
|||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# for CPU and other accelerators, fallback to native path
|
||||||
|
return self.torch_w8a8_block_fp8_moe(
|
||||||
|
a=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
w1_s=(
|
||||||
|
layer.w13_weight_scale_inv
|
||||||
|
if self.block_quant
|
||||||
|
else layer.w13_weight_scale
|
||||||
|
),
|
||||||
|
w2_s=(
|
||||||
|
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||||
|
),
|
||||||
|
topk_weight=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -28,10 +28,13 @@ from sglang.srt.utils import (
|
|||||||
get_device_name,
|
get_device_name,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
|
is_triton_available,
|
||||||
supports_custom_op,
|
supports_custom_op,
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
_is_triton = is_triton_available()
|
||||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -162,6 +165,34 @@ def _per_token_group_quant_fp8_colmajor(
|
|||||||
tl.store(y_s_ptr, y_s)
|
tl.store(y_s_ptr, y_s)
|
||||||
|
|
||||||
|
|
||||||
|
def native_per_token_group_quant_fp8(
|
||||||
|
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
|
||||||
|
):
|
||||||
|
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
|
||||||
|
|
||||||
|
It converts the tensor values into float8 values and returns the
|
||||||
|
quantized tensor along with the scaling factor used for quantization.
|
||||||
|
Note that only `torch.float8_e4m3fn` is supported for now.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
x.shape[-1] % group_size == 0
|
||||||
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||||
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||||
|
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||||
|
x_s = amax / fp8_max
|
||||||
|
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||||
|
x_q = x_q.reshape(x.shape)
|
||||||
|
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||||
|
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
def per_token_group_quant_fp8(
|
def per_token_group_quant_fp8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@@ -232,19 +263,22 @@ def per_token_group_quant_fp8(
|
|||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_per_token_group_quant_fp8[(M,)](
|
if _is_triton:
|
||||||
x,
|
_per_token_group_quant_fp8[(M,)](
|
||||||
x_q,
|
x,
|
||||||
x_s,
|
x_q,
|
||||||
group_size,
|
x_s,
|
||||||
N,
|
group_size,
|
||||||
eps,
|
N,
|
||||||
fp8_min=fp8_min,
|
eps,
|
||||||
fp8_max=fp8_max,
|
fp8_min=fp8_min,
|
||||||
BLOCK=BLOCK,
|
fp8_max=fp8_max,
|
||||||
num_warps=num_warps,
|
BLOCK=BLOCK,
|
||||||
num_stages=num_stages,
|
num_warps=num_warps,
|
||||||
)
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x_q, x_s = native_per_token_group_quant_fp8(x, group_size)
|
||||||
|
|
||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
@@ -691,6 +725,61 @@ def get_w8a8_block_fp8_configs(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||||
|
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||||
|
|
||||||
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||||
|
The output is returned in the specified `output_dtype`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
A = A.to(torch.float32)
|
||||||
|
B = B.to(torch.float32)
|
||||||
|
assert A.shape[-1] == B.shape[-1]
|
||||||
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||||
|
assert A.shape[:-1] == As.shape[:-1]
|
||||||
|
|
||||||
|
M = A.numel() // A.shape[-1]
|
||||||
|
N, K = B.shape
|
||||||
|
origin_C_shape = A.shape[:-1] + (N,)
|
||||||
|
A = A.reshape(M, A.shape[-1])
|
||||||
|
As = As.reshape(M, As.shape[-1])
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
assert n_tiles == Bs.shape[0]
|
||||||
|
assert k_tiles == Bs.shape[1]
|
||||||
|
|
||||||
|
C_shape = (M, N)
|
||||||
|
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||||
|
|
||||||
|
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
|
||||||
|
B_tiles = [
|
||||||
|
[
|
||||||
|
B[
|
||||||
|
j * block_n : min((j + 1) * block_n, N),
|
||||||
|
i * block_k : min((i + 1) * block_k, K),
|
||||||
|
]
|
||||||
|
for i in range(k_tiles)
|
||||||
|
]
|
||||||
|
for j in range(n_tiles)
|
||||||
|
]
|
||||||
|
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
|
||||||
|
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
|
||||||
|
|
||||||
|
for i in range(k_tiles):
|
||||||
|
for j in range(n_tiles):
|
||||||
|
a = A_tiles[i]
|
||||||
|
b = B_tiles[j][i]
|
||||||
|
c = C_tiles[j]
|
||||||
|
s = As_tiles[i] * Bs[j][i]
|
||||||
|
c[:, :] += torch.matmul(a, b.t()) * s
|
||||||
|
|
||||||
|
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
def w8a8_block_fp8_matmul(
|
def w8a8_block_fp8_matmul(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@@ -715,86 +804,90 @@ def w8a8_block_fp8_matmul(
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The result of matmul.
|
torch.Tensor: The result of matmul.
|
||||||
"""
|
"""
|
||||||
assert len(block_size) == 2
|
if _is_triton: # pragma: no cover
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
|
||||||
assert A.shape[-1] == B.shape[-1]
|
assert A.shape[-1] == B.shape[-1]
|
||||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||||
M = A.numel() // A.shape[-1]
|
M = A.numel() // A.shape[-1]
|
||||||
|
|
||||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
N, K = B.shape
|
N, K = B.shape
|
||||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||||
|
|
||||||
C_shape = A.shape[:-1] + (N,)
|
C_shape = A.shape[:-1] + (N,)
|
||||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||||
|
|
||||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||||
if configs:
|
if configs:
|
||||||
# If an optimal configuration map has been found, look up the
|
# If an optimal configuration map has been found, look up the
|
||||||
# optimal config
|
# optimal config
|
||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||||
else:
|
|
||||||
# Default config
|
|
||||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": block_size[0],
|
|
||||||
"BLOCK_SIZE_K": block_size[1],
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3,
|
|
||||||
}
|
|
||||||
|
|
||||||
def grid(META):
|
|
||||||
return (
|
|
||||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
|
||||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
|
||||||
# compute units available on the device.
|
|
||||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
|
||||||
N, config["BLOCK_SIZE_N"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# deepgemm only support bf16
|
|
||||||
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
|
||||||
if supports_custom_op():
|
|
||||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
|
||||||
else:
|
else:
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
# Default config
|
||||||
else:
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
||||||
kernel = (
|
config = {
|
||||||
_w8a8_block_fp8_matmul_unrolledx4
|
"BLOCK_SIZE_M": 64,
|
||||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
"BLOCK_SIZE_N": block_size[0],
|
||||||
else _w8a8_block_fp8_matmul
|
"BLOCK_SIZE_K": block_size[1],
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
def grid(META):
|
||||||
|
return (
|
||||||
|
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
||||||
|
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||||
|
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||||
|
# compute units available on the device.
|
||||||
|
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||||
|
N, config["BLOCK_SIZE_N"]
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel[grid](
|
# deepgemm only support bf16
|
||||||
A,
|
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||||
B,
|
if supports_custom_op():
|
||||||
C,
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||||
As,
|
else:
|
||||||
Bs,
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||||
M,
|
else:
|
||||||
N,
|
kernel = (
|
||||||
K,
|
_w8a8_block_fp8_matmul_unrolledx4
|
||||||
block_n,
|
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||||
block_k,
|
else _w8a8_block_fp8_matmul
|
||||||
A.stride(-2),
|
)
|
||||||
A.stride(-1),
|
|
||||||
B.stride(1),
|
kernel[grid](
|
||||||
B.stride(0),
|
A,
|
||||||
C.stride(-2),
|
B,
|
||||||
C.stride(-1),
|
C,
|
||||||
As.stride(-2),
|
As,
|
||||||
As.stride(-1),
|
Bs,
|
||||||
Bs.stride(1),
|
M,
|
||||||
Bs.stride(0),
|
N,
|
||||||
**config,
|
K,
|
||||||
)
|
block_n,
|
||||||
|
block_k,
|
||||||
|
A.stride(-2),
|
||||||
|
A.stride(-1),
|
||||||
|
B.stride(1),
|
||||||
|
B.stride(0),
|
||||||
|
C.stride(-2),
|
||||||
|
C.stride(-1),
|
||||||
|
As.stride(-2),
|
||||||
|
As.stride(-1),
|
||||||
|
Bs.stride(1),
|
||||||
|
Bs.stride(0),
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
C = native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
|
||||||
|
|
||||||
return C
|
return C
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
@@ -8,6 +9,7 @@ from vllm.scalar_type import scalar_types
|
|||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from sglang.srt.utils import get_device_capability
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -90,7 +92,20 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
@classmethod
|
@classmethod
|
||||||
# Need to figure it out
|
# Need to figure it out
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 60
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return 60
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return sys.maxsize
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_availability(cls) -> bool:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return major * 10 + minor > 60
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
@@ -209,7 +224,20 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 80
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return 80
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return sys.maxsize
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_availability(cls) -> bool:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return major * 10 + minor > 80
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
@@ -371,7 +399,20 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
@classmethod
|
@classmethod
|
||||||
# Need to figure it out
|
# Need to figure it out
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 80
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return 80
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return sys.maxsize
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_availability(cls) -> bool:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return major * 10 + minor > 80
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -20,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
|
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
|
||||||
|
from sglang.srt.utils import get_device_capability
|
||||||
|
|
||||||
# Initialize logger for the module
|
# Initialize logger for the module
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -52,7 +54,20 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return 89
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return sys.maxsize
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_availability(cls) -> bool:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return major * 10 + minor > 89
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import get_device_capability, is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
@@ -35,7 +35,20 @@ class W8A8Fp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 89
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return 89
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return sys.maxsize
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_availability(cls) -> bool:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return major * 10 + minor > 89
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -18,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
|
from sglang.srt.utils import get_device_capability
|
||||||
|
|
||||||
|
|
||||||
class W8A8Int8Config(QuantizationConfig):
|
class W8A8Int8Config(QuantizationConfig):
|
||||||
@@ -36,7 +38,20 @@ class W8A8Int8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 75
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return 75
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return sys.maxsize
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_availability(cls) -> int:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return major * 10 + minor > 75
|
||||||
|
|
||||||
|
# Vendors can update
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@@ -77,6 +78,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.base = base
|
self.base = base
|
||||||
self.is_neox_style = is_neox_style
|
self.is_neox_style = is_neox_style
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
cache = self._compute_cos_sin_cache()
|
cache = self._compute_cos_sin_cache()
|
||||||
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
||||||
@@ -283,12 +285,19 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factors: Union[List[float], float],
|
scaling_factors: Union[List[float], float],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(scaling_factors, float):
|
if isinstance(scaling_factors, float):
|
||||||
scaling_factors = [scaling_factors]
|
scaling_factors = [scaling_factors]
|
||||||
self.scaling_factors: List[float] = scaling_factors # noqa
|
self.scaling_factors: List[float] = scaling_factors # noqa
|
||||||
super().__init__(
|
super().__init__(
|
||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
# Lazy initialized.
|
# Lazy initialized.
|
||||||
self._scaling_factor_to_offset: Dict[float, int]
|
self._scaling_factor_to_offset: Dict[float, int]
|
||||||
@@ -347,10 +356,17 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
super().__init__(
|
super().__init__(
|
||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
@@ -434,6 +450,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
*,
|
*,
|
||||||
extrapolation_factor: float = 1,
|
extrapolation_factor: float = 1,
|
||||||
attn_factor: float = 1,
|
attn_factor: float = 1,
|
||||||
@@ -448,7 +465,13 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
# Get n-d magnitude scaling corrected for interpolation
|
# Get n-d magnitude scaling corrected for interpolation
|
||||||
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||||
@@ -645,6 +668,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
*,
|
*,
|
||||||
extrapolation_factor: float = 1,
|
extrapolation_factor: float = 1,
|
||||||
attn_factor: float = 1,
|
attn_factor: float = 1,
|
||||||
@@ -652,7 +676,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
beta_slow: int = 1,
|
beta_slow: int = 1,
|
||||||
mscale: float = 1,
|
mscale: float = 1,
|
||||||
mscale_all_dim: float = 0,
|
mscale_all_dim: float = 0,
|
||||||
device: Optional[str] = "cuda",
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.extrapolation_factor = extrapolation_factor
|
self.extrapolation_factor = extrapolation_factor
|
||||||
@@ -665,9 +688,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
||||||
* attn_factor
|
* attn_factor
|
||||||
)
|
)
|
||||||
self.device = device
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||||
@@ -762,6 +790,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
|||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
low_freq_factor: float,
|
low_freq_factor: float,
|
||||||
high_freq_factor: float,
|
high_freq_factor: float,
|
||||||
@@ -772,7 +801,13 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
|||||||
self.high_freq_factor = high_freq_factor
|
self.high_freq_factor = high_freq_factor
|
||||||
self.orig_max_position = orig_max_position
|
self.orig_max_position = orig_max_position
|
||||||
super().__init__(
|
super().__init__(
|
||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
str,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
@@ -810,10 +845,17 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
mrope_section: Optional[List[int]] = None,
|
mrope_section: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mrope_section = mrope_section
|
self.mrope_section = mrope_section
|
||||||
@@ -1003,9 +1045,14 @@ def get_rope(
|
|||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
partial_rotary_factor: float = 1.0,
|
partial_rotary_factor: float = 1.0,
|
||||||
|
device: str = None,
|
||||||
) -> RotaryEmbedding:
|
) -> RotaryEmbedding:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
|
if device is None:
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
|
||||||
|
device = global_server_args_dict["device"]
|
||||||
if rope_scaling is not None:
|
if rope_scaling is not None:
|
||||||
# Transforms every value that is a list into a tuple for caching calls
|
# Transforms every value that is a list into a tuple for caching calls
|
||||||
rope_scaling_tuple = {
|
rope_scaling_tuple = {
|
||||||
@@ -1030,7 +1077,7 @@ def get_rope(
|
|||||||
|
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
rotary_emb = RotaryEmbedding(
|
rotary_emb = RotaryEmbedding(
|
||||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
head_size, rotary_dim, max_position, base, is_neox_style, dtype, device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if "rope_type" in rope_scaling:
|
if "rope_type" in rope_scaling:
|
||||||
@@ -1052,6 +1099,7 @@ def get_rope(
|
|||||||
base,
|
base,
|
||||||
is_neox_style,
|
is_neox_style,
|
||||||
dtype,
|
dtype,
|
||||||
|
device,
|
||||||
scaling_factor,
|
scaling_factor,
|
||||||
low_freq_factor,
|
low_freq_factor,
|
||||||
high_freq_factor,
|
high_freq_factor,
|
||||||
@@ -1066,6 +1114,7 @@ def get_rope(
|
|||||||
base,
|
base,
|
||||||
is_neox_style,
|
is_neox_style,
|
||||||
dtype,
|
dtype,
|
||||||
|
device,
|
||||||
mrope_section=rope_scaling["mrope_section"],
|
mrope_section=rope_scaling["mrope_section"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1076,6 +1125,7 @@ def get_rope(
|
|||||||
base,
|
base,
|
||||||
is_neox_style,
|
is_neox_style,
|
||||||
dtype,
|
dtype,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
elif scaling_type == "linear":
|
elif scaling_type == "linear":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
@@ -1087,6 +1137,7 @@ def get_rope(
|
|||||||
is_neox_style,
|
is_neox_style,
|
||||||
scaling_factor,
|
scaling_factor,
|
||||||
dtype,
|
dtype,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
elif scaling_type == "dynamic":
|
elif scaling_type == "dynamic":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
@@ -1098,6 +1149,7 @@ def get_rope(
|
|||||||
is_neox_style,
|
is_neox_style,
|
||||||
scaling_factor,
|
scaling_factor,
|
||||||
dtype,
|
dtype,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
elif scaling_type == "yarn":
|
elif scaling_type == "yarn":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
@@ -1116,6 +1168,7 @@ def get_rope(
|
|||||||
is_neox_style,
|
is_neox_style,
|
||||||
scaling_factor,
|
scaling_factor,
|
||||||
dtype,
|
dtype,
|
||||||
|
device,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
elif scaling_type == "deepseek_yarn":
|
elif scaling_type == "deepseek_yarn":
|
||||||
@@ -1143,6 +1196,7 @@ def get_rope(
|
|||||||
is_neox_style,
|
is_neox_style,
|
||||||
scaling_factor,
|
scaling_factor,
|
||||||
dtype,
|
dtype,
|
||||||
|
device,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
elif scaling_type == "longrope":
|
elif scaling_type == "longrope":
|
||||||
@@ -1253,21 +1307,8 @@ def get_rope_wrapper(
|
|||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
partial_rotary_factor: float = 1.0,
|
partial_rotary_factor: float = 1.0,
|
||||||
device: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
if device != "cpu":
|
return get_rope(
|
||||||
return get_rope(
|
|
||||||
head_size,
|
|
||||||
rotary_dim,
|
|
||||||
max_position,
|
|
||||||
base,
|
|
||||||
is_neox_style,
|
|
||||||
rope_scaling,
|
|
||||||
dtype,
|
|
||||||
partial_rotary_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
return get_rope_cpu(
|
|
||||||
head_size,
|
head_size,
|
||||||
rotary_dim,
|
rotary_dim,
|
||||||
max_position,
|
max_position,
|
||||||
@@ -1276,5 +1317,4 @@ def get_rope_wrapper(
|
|||||||
rope_scaling,
|
rope_scaling,
|
||||||
dtype,
|
dtype,
|
||||||
partial_rotary_factor,
|
partial_rotary_factor,
|
||||||
device,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ from sglang.srt.distributed import (
|
|||||||
set_custom_all_reduce,
|
set_custom_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||||
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||||
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
get_attention_tp_group,
|
get_attention_tp_group,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
|
|||||||
@@ -108,15 +108,25 @@ def _get_quantization_config(
|
|||||||
quant_config = get_quant_config(model_config, load_config)
|
quant_config = get_quant_config(model_config, load_config)
|
||||||
major, minor = get_device_capability()
|
major, minor = get_device_capability()
|
||||||
|
|
||||||
if major is not None and minor is not None:
|
if not hasattr(quant_config, "get_availability"):
|
||||||
assert 0 <= minor < 10
|
# Update VLLM to support get_available
|
||||||
capability = major * 10 + minor
|
if major is not None and minor is not None:
|
||||||
if capability < quant_config.get_min_capability():
|
assert 0 <= minor < 10
|
||||||
|
capability = major * 10 + minor
|
||||||
|
if capability < quant_config.get_min_capability():
|
||||||
|
raise ValueError(
|
||||||
|
f"The quantization method {model_config.quantization} "
|
||||||
|
"is not supported for the current GPU. "
|
||||||
|
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||||
|
f"Current capability: {capability}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not quant_config.get_availability():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The quantization method {model_config.quantization} "
|
f"The quantization method {model_config.quantization} "
|
||||||
"is not supported for the current GPU. "
|
"is not supported for the current GPU. "
|
||||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||||
f"Current capability: {capability}."
|
f"Current capability: {major, minor}."
|
||||||
)
|
)
|
||||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||||
if model_config.dtype not in supported_dtypes:
|
if model_config.dtype not in supported_dtypes:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|||||||
block_dequant as int8_block_dequant,
|
block_dequant as int8_block_dequant,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
@@ -305,7 +305,6 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
is_neox_style=False,
|
is_neox_style=False,
|
||||||
device=global_server_args_dict["device"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if rope_scaling:
|
if rope_scaling:
|
||||||
@@ -501,7 +500,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
if rope_scaling:
|
if rope_scaling:
|
||||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope_wrapper(
|
||||||
qk_rope_head_dim,
|
qk_rope_head_dim,
|
||||||
rotary_dim=qk_rope_head_dim,
|
rotary_dim=qk_rope_head_dim,
|
||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
@@ -646,19 +645,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
if self.w_kc.dtype == torch.float8_e4m3fnuz: # hip only
|
||||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||||
q_nope_out = torch.bmm(
|
q_nope_out = torch.bmm(
|
||||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||||
)
|
)
|
||||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
elif self.w_kc.dtype == torch.float8_e4m3fn and is_cuda_available():
|
||||||
q_nope_val, q_nope_scale = input_to_float8(
|
q_nope_val, q_nope_scale = input_to_float8(
|
||||||
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
q_nope_out = bmm_fp8(
|
q_nope_out = bmm_fp8(
|
||||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||||
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||||
@@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
||||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
if self.w_vc.dtype == torch.float8_e4m3fnuz or not is_cuda_available():
|
||||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||||
attn_bmm_output = torch.bmm(
|
attn_bmm_output = torch.bmm(
|
||||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||||
)
|
)
|
||||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
elif self.w_vc.dtype == torch.float8_e4m3fn and is_cuda_available():
|
||||||
attn_output_val, attn_output_scale = input_to_float8(
|
attn_output_val, attn_output_scale = input_to_float8(
|
||||||
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
@@ -694,6 +694,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.w_scale,
|
self.w_scale,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
||||||
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||||
|
|||||||
@@ -126,6 +126,14 @@ def is_cuda_available():
|
|||||||
return is_cuda()
|
return is_cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def is_triton_available():
|
||||||
|
if is_cuda() or is_xpu() or is_hip():
|
||||||
|
return get_bool_env_var("TRITON_AVAILABLE", default="true")
|
||||||
|
else:
|
||||||
|
# update once CPU/HPU supports triton
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def enable_show_time_cost():
|
def enable_show_time_cost():
|
||||||
global show_time_cost
|
global show_time_cost
|
||||||
show_time_cost = True
|
show_time_cost = True
|
||||||
@@ -1136,6 +1144,7 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
|||||||
major, minor = None, None
|
major, minor = None, None
|
||||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
major, minor = torch.cuda.get_device_capability(device_id)
|
major, minor = torch.cuda.get_device_capability(device_id)
|
||||||
|
assert 0 <= minor < 10
|
||||||
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
|
major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import torch
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
native_per_token_group_quant_fp8,
|
||||||
|
native_w8a8_block_fp8_matmul,
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
static_quant_fp8,
|
static_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
@@ -15,35 +17,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||||
|
|
||||||
|
|
||||||
# For test
|
|
||||||
def native_per_token_group_quant_fp8(
|
|
||||||
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
|
|
||||||
):
|
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
|
|
||||||
|
|
||||||
It converts the tensor values into float8 values and returns the
|
|
||||||
quantized tensor along with the scaling factor used for quantization.
|
|
||||||
Note that only `torch.float8_e4m3fn` is supported for now.
|
|
||||||
"""
|
|
||||||
assert (
|
|
||||||
x.shape[-1] % group_size == 0
|
|
||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
|
||||||
|
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
fp8_min = finfo.min
|
|
||||||
fp8_max = finfo.max
|
|
||||||
|
|
||||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
|
||||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
|
||||||
x_s = amax / fp8_max
|
|
||||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
|
||||||
x_q = x_q.reshape(x.shape)
|
|
||||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
|
||||||
|
|
||||||
return x_q, x_s
|
|
||||||
|
|
||||||
|
|
||||||
class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||||
NUM_TOKENS = [7, 83, 2048]
|
NUM_TOKENS = [7, 83, 2048]
|
||||||
@@ -154,62 +127,6 @@ class TestStaticQuantFP8(unittest.TestCase):
|
|||||||
self._static_quant_fp8(*params)
|
self._static_quant_fp8(*params)
|
||||||
|
|
||||||
|
|
||||||
# For test
|
|
||||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
|
||||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
|
||||||
|
|
||||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
|
||||||
The output is returned in the specified `output_dtype`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
A = A.to(torch.float32)
|
|
||||||
B = B.to(torch.float32)
|
|
||||||
assert A.shape[-1] == B.shape[-1]
|
|
||||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
|
||||||
assert len(block_size) == 2
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
|
||||||
assert A.shape[:-1] == As.shape[:-1]
|
|
||||||
|
|
||||||
M = A.numel() // A.shape[-1]
|
|
||||||
N, K = B.shape
|
|
||||||
origin_C_shape = A.shape[:-1] + (N,)
|
|
||||||
A = A.reshape(M, A.shape[-1])
|
|
||||||
As = As.reshape(M, As.shape[-1])
|
|
||||||
n_tiles = (N + block_n - 1) // block_n
|
|
||||||
k_tiles = (K + block_k - 1) // block_k
|
|
||||||
assert n_tiles == Bs.shape[0]
|
|
||||||
assert k_tiles == Bs.shape[1]
|
|
||||||
|
|
||||||
C_shape = (M, N)
|
|
||||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
|
||||||
|
|
||||||
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
|
|
||||||
B_tiles = [
|
|
||||||
[
|
|
||||||
B[
|
|
||||||
j * block_n : min((j + 1) * block_n, N),
|
|
||||||
i * block_k : min((i + 1) * block_k, K),
|
|
||||||
]
|
|
||||||
for i in range(k_tiles)
|
|
||||||
]
|
|
||||||
for j in range(n_tiles)
|
|
||||||
]
|
|
||||||
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
|
|
||||||
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
|
|
||||||
|
|
||||||
for i in range(k_tiles):
|
|
||||||
for j in range(n_tiles):
|
|
||||||
a = A_tiles[i]
|
|
||||||
b = B_tiles[j][i]
|
|
||||||
c = C_tiles[j]
|
|
||||||
s = As_tiles[i] * Bs[j][i]
|
|
||||||
c[:, :] += torch.matmul(a, b.t()) * s
|
|
||||||
|
|
||||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
|
||||||
return C
|
|
||||||
|
|
||||||
|
|
||||||
class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||||
|
|
||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
|
|||||||
Reference in New Issue
Block a user