[DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (#6853)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -556,7 +556,8 @@ class FusedMoE(torch.nn.Module):
|
||||
loaded_weight = loaded_weight.to(param.data.device)
|
||||
|
||||
if (
|
||||
param.data[expert_id] != 1
|
||||
"compressed" in self.quant_method.__class__.__name__.lower()
|
||||
and param.data[expert_id] != 1
|
||||
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -580,6 +581,23 @@ class FusedMoE(torch.nn.Module):
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
return
|
||||
if "ModelOpt" in self.quant_method.__class__.__name__:
|
||||
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
|
||||
self._load_per_tensor_weight_scale(
|
||||
shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
elif "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
return
|
||||
|
||||
# Case weight scales and zero_points
|
||||
if "scale" in weight_name or "zero" in weight_name:
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -15,10 +20,12 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
cutlass_fp8_supported,
|
||||
is_sm100_supported,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
convert_to_channelwise,
|
||||
is_layer_skipped,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -270,9 +277,16 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
)
|
||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
||||
if not kv_cache_quant_algo:
|
||||
kv_cache_quant_algo = "auto"
|
||||
group_size = quant_config["group_size"]
|
||||
exclude_modules = quant_config["exclude_modules"]
|
||||
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
||||
logger.warning(
|
||||
f"group_size: {group_size},"
|
||||
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
||||
f"exclude_modules: {exclude_modules}"
|
||||
)
|
||||
raise ValueError(
|
||||
"NVFP4 quantization requires group size and "
|
||||
"kv_cache_quant_algo specified in "
|
||||
@@ -285,19 +299,30 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
exclude_modules,
|
||||
)
|
||||
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||
import regex as re
|
||||
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if self.exclude_modules and any(
|
||||
module in prefix for module in self.exclude_modules
|
||||
):
|
||||
return None
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
|
||||
prefix, self.exclude_modules
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptFp4LinearMethod(self)
|
||||
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptNvFp4FusedMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
@@ -461,3 +486,305 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
|
||||
class ModelOptNvFp4FusedMoEMethod:
|
||||
"""
|
||||
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
|
||||
Args:
|
||||
quant_config: NVFP4 Quant Config
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp4Config):
|
||||
self.quant_config = quant_config
|
||||
if not is_sm100_supported():
|
||||
raise ValueError(
|
||||
"Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above."
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError(
|
||||
"NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported."
|
||||
)
|
||||
|
||||
layer.num_experts = num_experts
|
||||
layer.params_dtype = params_dtype
|
||||
layer.quant_config = self.quant_config
|
||||
weight_dtype = torch.uint8
|
||||
weight_scale_dtype = torch.float8_e4m3fn
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
# GEMM 1
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
# GEMM 2
|
||||
w2_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
|
||||
w13_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||
)
|
||||
|
||||
w13_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||
|
||||
w2_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
|
||||
w13_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert scale.dtype == torch.float8_e4m3fn
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (
|
||||
swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2
|
||||
else swizzled_scale.reshape(B, M, K)
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# GEMM 1
|
||||
if not torch.allclose(
|
||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
||||
):
|
||||
logger.warning_once(
|
||||
"w1_weight_scale_2 must match w3_weight_scale_2. "
|
||||
"Accuracy may be affected."
|
||||
)
|
||||
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
||||
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
layer.g1_alphas = Parameter(
|
||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
assert (
|
||||
layer.w13_weight_scale.shape[2] % 16 == 0
|
||||
), "Expected weight_scale.dim(1) to be divisible by 16"
|
||||
assert (
|
||||
layer.w13_weight_scale.dtype == torch.float8_e4m3fn
|
||||
), "Weight Blockscale must be represented as FP8-E4M3"
|
||||
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
||||
|
||||
layer.w13_blockscale_swizzled = Parameter(
|
||||
w13_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w13_input_scale_quant = Parameter(
|
||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
layer.g2_alphas = Parameter(
|
||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w2_input_scale_quant = Parameter(
|
||||
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
assert (
|
||||
layer.w2_weight_scale.shape[2] % 16 == 0
|
||||
), "Expected weight_scale.dim(1) to be divisible by 16"
|
||||
assert (
|
||||
layer.w2_weight_scale.dtype == torch.float8_e4m3fn
|
||||
), "Weight Blockscale must be represented as FP8-E4M3"
|
||||
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
|
||||
device = layer.w13_weight.device
|
||||
layer.cutlass_moe_params = CutlassMoEParams(
|
||||
CutlassMoEType.BlockscaledFP4,
|
||||
device,
|
||||
num_experts=layer.num_experts,
|
||||
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
||||
hidden_size=layer.w13_weight.shape[2] * 2,
|
||||
) # k
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
params=layer.cutlass_moe_params,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
).to(x.dtype)
|
||||
|
||||
@@ -1746,7 +1746,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = False
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
||||
"Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
||||
)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
@@ -1926,6 +1926,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self_attn.use_deep_gemm_bmm = True
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||
|
||||
if is_nextn:
|
||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||
@@ -1982,6 +1983,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
"up_proj.qzeros",
|
||||
"up_proj.scales",
|
||||
]
|
||||
elif self.quant_config.get_name() == "modelopt_fp4":
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale",
|
||||
"down_proj.weight_scale_2",
|
||||
"down_proj.input_scale",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale",
|
||||
"gate_proj.weight_scale_2",
|
||||
"gate_proj.input_scale",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale",
|
||||
"up_proj.weight_scale_2",
|
||||
"up_proj.input_scale",
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
||||
@@ -2125,7 +2141,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if fuse_qkv_a_proj and (
|
||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||
):
|
||||
@@ -2151,9 +2166,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
fused_weight = torch.cat(
|
||||
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
||||
)
|
||||
|
||||
param_name = name.replace(
|
||||
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
||||
param_name = (
|
||||
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
||||
if "q_a_proj" in name
|
||||
else name.replace(
|
||||
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
||||
)
|
||||
)
|
||||
param = params_dict[param_name]
|
||||
|
||||
@@ -2164,6 +2182,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
cached_a_proj.pop(q_a_proj_name)
|
||||
cached_a_proj.pop(kv_a_proj_name)
|
||||
else:
|
||||
if (
|
||||
"k_scale" in name or "v_scale" in name
|
||||
) and name not in params_dict:
|
||||
# modelopt attn kv scale is named differently
|
||||
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
||||
name = name.replace("_proj", "attn_mqa")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown scale found in checkpoint: {name}"
|
||||
)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
|
||||
Reference in New Issue
Block a user