[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)

Co-authored-by: yuhyao <827623970@qq.com>
This commit is contained in:
chenxj
2025-09-02 13:17:26 +08:00
committed by GitHub
parent 21e1bc475c
commit d4a938417d
11 changed files with 291 additions and 120 deletions

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.managers.schedule_batch import global_server_args_dict
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, EPMoE):
elif isinstance(layer, FusedMoE):
return W4AFp8MoEMethod(self)
return None
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
return []
class W4AFp8MoEMethod(FusedMoEMethodBase):
def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape = scales.shape
# Reshape to separate groups of 4
alignment = 4 if s_shape[2] % 4 == 0 else 1
scales_interleaved = scales.reshape(
s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
)
# Permute dimensions to interleave
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
# Reshape back to original dimensions but with interleaved values
scales_interleaved = scales_interleaved.reshape(
s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
)
return scales_interleaved.contiguous()
class W4AFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: W4AFp8Config):
self.quant_config = quant_config
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
return
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape = scales.shape
# Reshape to separate groups of 4
scales_interleaved = scales.reshape(
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
)
# Permute dimensions to interleave
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
# Reshape back to original dimensions but with interleaved values
scales_interleaved = scales_interleaved.reshape(
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
)
return scales_interleaved.contiguous()
def process_weights_after_loading(self, layer: Module) -> None:
dtype = torch.bfloat16
device = layer.w2_weight.device
# Interleave w13_weight_scale (gate_up_proj)
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
w13_weight_scale = self._interleave_scales(w13_weight_scale)
w13_weight_scale = interleave_scales(w13_weight_scale)
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
# Interleave w2_weight_scale (down_proj)
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
w2_weight_scale = self._interleave_scales(w2_weight_scale)
w2_weight_scale = interleave_scales(w2_weight_scale)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
# Process input scales
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids
local_topk_ids = torch.where(
topk_ids == -1,
layer.num_experts,
topk_ids,
)
if get_moe_expert_parallel_world_size() > 1:
local_topk_ids = torch.where(
topk_ids == -1,
layer.num_experts,
topk_ids,
)
output = cutlass_w4a8_moe(
layer.start_expert_id,