[Feature] DeepSeek V3/R1 INT8 Quantization (channel-wise) (#3888)

Co-authored-by: yych0745 <1398089567@qq.com>
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
Co-authored-by: b0urnee <2769086541@qq.com>
This commit is contained in:
HandH1998
2025-03-07 12:54:52 +08:00
committed by GitHub
parent 63ee26d162
commit c7f254468f
5 changed files with 369 additions and 21 deletions

View File

@@ -15,7 +15,10 @@ from vllm import _custom_ops as ops
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
)
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
@@ -117,6 +120,7 @@ def fused_moe_kernel(
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
@@ -167,17 +171,38 @@ def fused_moe_kernel(
)
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8 or use_int8_w8a8:
if use_fp8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
if use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
# channel-wise
else:
# Load per-column scale for weights
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
@@ -217,7 +242,11 @@ def fused_moe_kernel(
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
# fix out of shared memory issue
if use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
@@ -497,9 +526,11 @@ def invoke_fused_moe_kernel(
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
if _is_cuda:
@@ -512,9 +543,10 @@ def invoke_fused_moe_kernel(
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
padded_size = padding_size
A, A_scale = ops.scaled_int8_quant(A, A_scale)
# activation channel-wise int8 quantization
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
@@ -1060,7 +1092,6 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape,
)
if activation == "silu":
if _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)

View File

@@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
import torch
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
@@ -10,6 +10,7 @@ if is_cuda:
from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
@@ -55,9 +56,12 @@ class W8A8Int8Config(QuantizationConfig):
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
@@ -81,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
@@ -115,3 +119,148 @@ class W8A8Int8LinearMethod(LinearMethodBase):
return int8_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)
class W8A8Int8MoEMethod:
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)

View File

@@ -1202,18 +1202,22 @@ class DeepseekV2ForCausalLM(nn.Module):
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
if (
hasattr(self.quant_config, "weight_block_size")
and w.dtype == torch.int8
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w = int8_block_dequant(
weight, weight_scale, weight_block_size
).to(torch.bfloat16)
if w.dtype == torch.int8:
if hasattr(self.quant_config, "weight_block_size"):
# block-wise int8 need it
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w = int8_block_dequant(
weight, weight_scale, weight_block_size
).to(torch.bfloat16)
else:
# channel-wise int8 need it
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
torch.bfloat16
)
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)