[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)
Co-authored-by: yuhyao <827623970@qq.com>
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, EPMoE):
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return W4AFp8MoEMethod(self)
|
||||
return None
|
||||
|
||||
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
|
||||
return []
|
||||
|
||||
|
||||
class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
|
||||
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
||||
s_shape = scales.shape
|
||||
# Reshape to separate groups of 4
|
||||
alignment = 4 if s_shape[2] % 4 == 0 else 1
|
||||
scales_interleaved = scales.reshape(
|
||||
s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
|
||||
)
|
||||
# Permute dimensions to interleave
|
||||
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
||||
# Reshape back to original dimensions but with interleaved values
|
||||
scales_interleaved = scales_interleaved.reshape(
|
||||
s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
|
||||
)
|
||||
return scales_interleaved.contiguous()
|
||||
|
||||
|
||||
class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(self, quant_config: W4AFp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return
|
||||
|
||||
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
|
||||
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
||||
s_shape = scales.shape
|
||||
# Reshape to separate groups of 4
|
||||
scales_interleaved = scales.reshape(
|
||||
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
|
||||
)
|
||||
# Permute dimensions to interleave
|
||||
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
||||
# Reshape back to original dimensions but with interleaved values
|
||||
scales_interleaved = scales_interleaved.reshape(
|
||||
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
|
||||
)
|
||||
return scales_interleaved.contiguous()
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
dtype = torch.bfloat16
|
||||
device = layer.w2_weight.device
|
||||
|
||||
# Interleave w13_weight_scale (gate_up_proj)
|
||||
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
|
||||
w13_weight_scale = self._interleave_scales(w13_weight_scale)
|
||||
w13_weight_scale = interleave_scales(w13_weight_scale)
|
||||
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
|
||||
|
||||
# Interleave w2_weight_scale (down_proj)
|
||||
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
|
||||
w2_weight_scale = self._interleave_scales(w2_weight_scale)
|
||||
w2_weight_scale = interleave_scales(w2_weight_scale)
|
||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
|
||||
|
||||
# Process input scales
|
||||
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
local_topk_ids = topk_ids
|
||||
local_topk_ids = torch.where(
|
||||
topk_ids == -1,
|
||||
layer.num_experts,
|
||||
topk_ids,
|
||||
)
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
local_topk_ids = torch.where(
|
||||
topk_ids == -1,
|
||||
layer.num_experts,
|
||||
topk_ids,
|
||||
)
|
||||
|
||||
output = cutlass_w4a8_moe(
|
||||
layer.start_expert_id,
|
||||
|
||||
Reference in New Issue
Block a user