Support casting bf16 NextN moe to fp8 (#11613)
This commit is contained in:
@@ -25,13 +25,18 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
|
|||||||
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.quantization import Fp8Config
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
from sglang.srt.models.deepseek_v2 import (
|
||||||
|
DeepseekV2DecoderLayer,
|
||||||
|
DeepseekV3ForCausalLM,
|
||||||
|
enable_nextn_moe_bf16_cast_to_fp8,
|
||||||
|
)
|
||||||
from sglang.srt.server_args import get_global_server_args
|
from sglang.srt.server_args import get_global_server_args
|
||||||
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
|
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
|
||||||
|
|
||||||
@@ -49,6 +54,16 @@ class DeepseekModelNextN(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if enable_nextn_moe_bf16_cast_to_fp8(quant_config):
|
||||||
|
# refer to real DeepSeek V3 quant config
|
||||||
|
moe_quant_config = Fp8Config(
|
||||||
|
is_checkpoint_fp8_serialized=True,
|
||||||
|
weight_block_size=[128, 128],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
moe_quant_config = None
|
||||||
|
|
||||||
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
|
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
|
||||||
@@ -74,6 +89,7 @@ class DeepseekModelNextN(nn.Module):
|
|||||||
config,
|
config,
|
||||||
0,
|
0,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
moe_quant_config=moe_quant_config,
|
||||||
is_nextn=True,
|
is_nextn=True,
|
||||||
prefix=add_prefix("decoder", prefix),
|
prefix=add_prefix("decoder", prefix),
|
||||||
alt_stream=self.alt_stream,
|
alt_stream=self.alt_stream,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from tqdm import tqdm, trange
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt import single_batch_overlap
|
from sglang.srt import single_batch_overlap
|
||||||
@@ -82,7 +83,7 @@ from sglang.srt.layers.moe import (
|
|||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import Fp8Config, deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
is_fp8_fnuz,
|
is_fp8_fnuz,
|
||||||
@@ -196,6 +197,15 @@ _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_nextn_moe_bf16_cast_to_fp8(quant_config):
|
||||||
|
return (
|
||||||
|
quant_config is not None
|
||||||
|
and quant_config.get_name() == "modelopt_fp4"
|
||||||
|
and get_moe_a2a_backend().is_deepep()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
||||||
"fa3",
|
"fa3",
|
||||||
"nsa",
|
"nsa",
|
||||||
@@ -526,6 +536,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.alt_stream = alt_stream
|
self.alt_stream = alt_stream
|
||||||
|
self.is_nextn = is_nextn
|
||||||
|
|
||||||
if self.tp_size > config.n_routed_experts:
|
if self.tp_size > config.n_routed_experts:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -2381,6 +2392,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
moe_quant_config: Optional[QuantizationConfig] = None,
|
||||||
is_nextn: bool = False,
|
is_nextn: bool = False,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
@@ -2430,7 +2442,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
if self.is_layer_sparse:
|
if self.is_layer_sparse:
|
||||||
self.mlp = DeepseekV2MoE(
|
self.mlp = DeepseekV2MoE(
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=moe_quant_config or quant_config,
|
||||||
prefix=add_prefix("mlp", prefix),
|
prefix=add_prefix("mlp", prefix),
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
alt_stream=alt_stream,
|
alt_stream=alt_stream,
|
||||||
@@ -3109,6 +3121,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
):
|
):
|
||||||
self._weight_requant_ue8m0(is_nextn)
|
self._weight_requant_ue8m0(is_nextn)
|
||||||
|
|
||||||
|
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
|
||||||
|
self._transform_scale_nextn_moe_ue8m0()
|
||||||
|
|
||||||
def _weight_requant_ue8m0(self, is_nextn=False):
|
def _weight_requant_ue8m0(self, is_nextn=False):
|
||||||
weight_block_size = self.quant_config.weight_block_size
|
weight_block_size = self.quant_config.weight_block_size
|
||||||
|
|
||||||
@@ -3174,6 +3189,28 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
module.weight, module.weight_scale_inv, weight_block_size
|
module.weight, module.weight_scale_inv, weight_block_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
|
||||||
|
def _transform_scale_nextn_moe_ue8m0(self):
|
||||||
|
layer = self.model.decoder
|
||||||
|
|
||||||
|
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
||||||
|
if shared_experts is not None:
|
||||||
|
for module in [
|
||||||
|
shared_experts.gate_up_proj,
|
||||||
|
shared_experts.down_proj,
|
||||||
|
]:
|
||||||
|
transform_scale_ue8m0_inplace(
|
||||||
|
module.weight_scale_inv, mn=module.weight.shape[-2]
|
||||||
|
)
|
||||||
|
|
||||||
|
experts = layer.mlp.experts
|
||||||
|
if isinstance(experts, DeepEPMoE):
|
||||||
|
for w in [
|
||||||
|
experts.w13_weight_fp8,
|
||||||
|
experts.w2_weight_fp8,
|
||||||
|
]:
|
||||||
|
transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||||
|
|
||||||
if is_nextn:
|
if is_nextn:
|
||||||
@@ -3189,6 +3226,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("num_nextn_predict_layers is not in the config")
|
raise ValueError("num_nextn_predict_layers is not in the config")
|
||||||
|
|
||||||
|
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
|
||||||
|
weights = self._quant_nextn_moe_to_fp8_ue8m0(
|
||||||
|
weights, nextn_layer_id=nextn_layer_id
|
||||||
|
)
|
||||||
|
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
@@ -3418,6 +3460,38 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
||||||
|
|
||||||
|
# TODO avoid code dup
|
||||||
|
def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int):
|
||||||
|
weights_dict = dict(weights)
|
||||||
|
|
||||||
|
# temporarily only support DeepSeek V3/R1
|
||||||
|
weight_block_size = [128, 128]
|
||||||
|
|
||||||
|
for layer_id in [nextn_layer_id]:
|
||||||
|
for expert_sub_name in [
|
||||||
|
"shared_experts",
|
||||||
|
*[
|
||||||
|
f"experts.{expert_id}"
|
||||||
|
for expert_id in range(self.config.n_routed_experts)
|
||||||
|
],
|
||||||
|
]:
|
||||||
|
for stem in [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]:
|
||||||
|
partial_name = (
|
||||||
|
f"model.layers.{layer_id}.mlp.{expert_sub_name}.{stem}"
|
||||||
|
)
|
||||||
|
original_weight = weights_dict[f"{partial_name}.weight"]
|
||||||
|
out_w, out_s = quant_weight_ue8m0(
|
||||||
|
original_weight, weight_block_size=weight_block_size
|
||||||
|
)
|
||||||
|
weights_dict[f"{partial_name}.weight"] = out_w
|
||||||
|
weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
|
||||||
|
|
||||||
|
return list(weights_dict.items())
|
||||||
|
|
||||||
def get_embed_and_head(self):
|
def get_embed_and_head(self):
|
||||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user