[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)

Co-authored-by: yuhyao <827623970@qq.com>
This commit is contained in:
chenxj
2025-09-02 13:17:26 +08:00
committed by GitHub
parent 21e1bc475c
commit d4a938417d
11 changed files with 291 additions and 120 deletions

View File

@@ -405,9 +405,10 @@ class ModelConfig:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
if quant_cfg is None:
# check if is modelopt model -- modelopt doesn't have corresponding field
# check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
is_local = os.path.exists(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local:

View File

@@ -91,18 +91,10 @@ def cutlass_w4a8_moe(
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert (
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
and w1_scale.shape[2] == w1_q.shape[1] * 4
), "W1 scale shape mismatch"
assert (
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
and w2_scale.shape[2] == w2_q.shape[1] * 4
), "W2 scale shape mismatch"
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
num_experts = w1_q.size(0)
m = a.size(0)

View File

@@ -114,9 +114,6 @@ class EPMoE(FusedMoE):
with_bias=with_bias,
)
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config):

View File

@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1
@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module):
if (
"compressed" in self.quant_method.__class__.__name__.lower()
and param.data[expert_id] != 1
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
or "w4afp8" in self.quant_config.get_name()
and (param.data[expert_id] != 1).any()
and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
):
raise ValueError(
"input_scales of w1 and w3 of a layer "

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.managers.schedule_batch import global_server_args_dict
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, EPMoE):
elif isinstance(layer, FusedMoE):
return W4AFp8MoEMethod(self)
return None
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
return []
class W4AFp8MoEMethod(FusedMoEMethodBase):
def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape = scales.shape
# Reshape to separate groups of 4
alignment = 4 if s_shape[2] % 4 == 0 else 1
scales_interleaved = scales.reshape(
s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
)
# Permute dimensions to interleave
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
# Reshape back to original dimensions but with interleaved values
scales_interleaved = scales_interleaved.reshape(
s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
)
return scales_interleaved.contiguous()
class W4AFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: W4AFp8Config):
self.quant_config = quant_config
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
return
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape = scales.shape
# Reshape to separate groups of 4
scales_interleaved = scales.reshape(
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
)
# Permute dimensions to interleave
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
# Reshape back to original dimensions but with interleaved values
scales_interleaved = scales_interleaved.reshape(
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
)
return scales_interleaved.contiguous()
def process_weights_after_loading(self, layer: Module) -> None:
dtype = torch.bfloat16
device = layer.w2_weight.device
# Interleave w13_weight_scale (gate_up_proj)
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
w13_weight_scale = self._interleave_scales(w13_weight_scale)
w13_weight_scale = interleave_scales(w13_weight_scale)
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
# Interleave w2_weight_scale (down_proj)
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
w2_weight_scale = self._interleave_scales(w2_weight_scale)
w2_weight_scale = interleave_scales(w2_weight_scale)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
# Process input scales
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids
local_topk_ids = torch.where(
topk_ids == -1,
layer.num_experts,
topk_ids,
)
if get_moe_expert_parallel_world_size() > 1:
local_topk_ids = torch.where(
topk_ids == -1,
layer.num_experts,
topk_ids,
)
output = cutlass_w4a8_moe(
layer.start_expert_id,

View File

@@ -2185,6 +2185,8 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
elif self.quant_config.get_name() == "w4afp8":
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
@@ -2496,6 +2498,9 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
)
# Params for special naming rules in mixed-precision models, for example:
# model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
# see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
if self.quant_config and self.quant_config.get_name() == "w4afp8":
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from typing import Literal, Optional
import pytest
import torch
@@ -25,7 +25,7 @@ def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Ten
return packed_tensor.to(torch.int8)
def pack_interleave(num_experts, ref_weight, ref_scale):
def pack_interleave(num_experts, ref_weight, ref_scale, alignment=4):
n, k = ref_weight.shape[1], ref_weight.shape[2]
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
@@ -33,11 +33,16 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
w_q = w_q.contiguous()
scale_interleaved = ref_scale.reshape(
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
ref_scale.shape[0],
ref_scale.shape[1],
(ref_scale.shape[2] // alignment),
alignment,
) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
ref_scale.shape[0],
ref_scale.shape[2] // alignment,
ref_scale.shape[1] * alignment,
) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous()
@@ -48,12 +53,17 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
@pytest.mark.parametrize("N", [2048])
@pytest.mark.parametrize("K", [7168])
@pytest.mark.parametrize("E", [256])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.parametrize("tp_size", [8])
@pytest.mark.parametrize("use_ep_moe", [True, False])
@pytest.mark.parametrize("topk", [8])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
local_e = E // ep_size
def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dtype):
if use_ep_moe:
local_e = E // tp_size
else: # tp mode
local_e = E
N = N // tp_size
debug = False
if debug:
@@ -87,7 +97,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
)
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
if use_ep_moe:
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
else:
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2, 1)
device = "cuda"
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
@@ -265,7 +278,9 @@ def ref(
gate, fc1 = fc1.chunk(2, dim=-1)
fc1 = fc1 * torch.nn.functional.silu(gate)
act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
act = torch.clamp((fc1 / pre_quant_scale_2.float()), -448.0, 448.0).to(
torch.float8_e4m3fn
)
act = act.to(dtype)
w2 = ref_weight_2[e_idx]