Feature DeepSeek V3/R1 INT8 Quantization (block-wise) (#3730)
Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
@@ -15,7 +15,13 @@ from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_device_name,
|
||||
is_cuda_available,
|
||||
is_hip,
|
||||
)
|
||||
|
||||
is_hip_flag = is_hip()
|
||||
|
||||
@@ -86,6 +92,7 @@ def fused_moe_kernel(
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
even_Ks: tl.constexpr,
|
||||
):
|
||||
@@ -159,7 +166,7 @@ def fused_moe_kernel(
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
@@ -198,7 +205,7 @@ def fused_moe_kernel(
|
||||
# We accumulate along the K dimension.
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8:
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
@@ -221,7 +228,7 @@ def fused_moe_kernel(
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8:
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
@@ -477,6 +484,7 @@ def invoke_fused_moe_kernel(
|
||||
config: Dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
@@ -499,6 +507,18 @@ def invoke_fused_moe_kernel(
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
elif use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
padded_size = padding_size
|
||||
A, A_scale = ops.scaled_int8_quant(A, A_scale)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
elif use_int8_w8a16:
|
||||
assert B_scale is not None
|
||||
else:
|
||||
@@ -548,6 +568,7 @@ def invoke_fused_moe_kernel(
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
even_Ks=even_Ks,
|
||||
**config,
|
||||
@@ -701,9 +722,12 @@ def get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False,
|
||||
use_int8_w8a8: Optional[bool] = False,
|
||||
):
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a8:
|
||||
return "int8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif dtype == torch.float:
|
||||
@@ -721,6 +745,7 @@ def inplace_fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
@@ -737,6 +762,7 @@ def inplace_fused_experts(
|
||||
True,
|
||||
activation,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
@@ -754,6 +780,7 @@ def inplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
@@ -780,6 +807,7 @@ def outplace_fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
@@ -796,6 +824,7 @@ def outplace_fused_experts(
|
||||
False,
|
||||
activation,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
@@ -813,6 +842,7 @@ def outplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
@@ -840,6 +870,7 @@ def fused_experts(
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
@@ -856,6 +887,7 @@ def fused_experts(
|
||||
topk_ids,
|
||||
activation,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
@@ -873,6 +905,7 @@ def fused_experts(
|
||||
topk_ids,
|
||||
activation,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
@@ -891,6 +924,7 @@ def fused_experts_impl(
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
@@ -899,7 +933,7 @@ def fused_experts_impl(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
padded_size = padding_size
|
||||
if not use_fp8_w8a8 or block_shape is not None:
|
||||
if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None:
|
||||
padded_size = 0
|
||||
|
||||
# Check constraints.
|
||||
@@ -918,6 +952,7 @@ def fused_experts_impl(
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
config_dtype = get_config_dtype_str(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
@@ -1001,6 +1036,7 @@ def fused_experts_impl(
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
@@ -1034,6 +1070,7 @@ def fused_experts_impl(
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
@@ -1078,6 +1115,7 @@ def fused_moe(
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
@@ -1105,6 +1143,8 @@ def fused_moe(
|
||||
note: Deepseek V2/V3/R1 series models use grouped_topk
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
@@ -1144,6 +1184,7 @@ def fused_moe(
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
|
||||
Reference in New Issue
Block a user