[DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (#6853)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -556,7 +556,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
loaded_weight = loaded_weight.to(param.data.device)
|
loaded_weight = loaded_weight.to(param.data.device)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
param.data[expert_id] != 1
|
"compressed" in self.quant_method.__class__.__name__.lower()
|
||||||
|
and param.data[expert_id] != 1
|
||||||
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
|
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -580,6 +581,23 @@ class FusedMoE(torch.nn.Module):
|
|||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
if "ModelOpt" in self.quant_method.__class__.__name__:
|
||||||
|
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
|
||||||
|
self._load_per_tensor_weight_scale(
|
||||||
|
shard_id=shard_id,
|
||||||
|
param=param,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
elif "weight" in weight_name:
|
||||||
|
self._load_model_weight_or_group_weight_scale(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Case weight scales and zero_points
|
# Case weight scales and zero_points
|
||||||
if "scale" in weight_name or "zero" in weight_name:
|
if "scale" in weight_name or "zero" in weight_name:
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
from sglang.srt.layers.linear import (
|
||||||
|
LinearBase,
|
||||||
|
LinearMethodBase,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
@@ -15,10 +20,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
apply_fp8_linear,
|
apply_fp8_linear,
|
||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
|
is_sm100_supported,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
|
is_layer_skipped,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -270,9 +277,16 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
)
|
)
|
||||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||||
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
||||||
|
if not kv_cache_quant_algo:
|
||||||
|
kv_cache_quant_algo = "auto"
|
||||||
group_size = quant_config["group_size"]
|
group_size = quant_config["group_size"]
|
||||||
exclude_modules = quant_config["exclude_modules"]
|
exclude_modules = quant_config["exclude_modules"]
|
||||||
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
||||||
|
logger.warning(
|
||||||
|
f"group_size: {group_size},"
|
||||||
|
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
||||||
|
f"exclude_modules: {exclude_modules}"
|
||||||
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"NVFP4 quantization requires group size and "
|
"NVFP4 quantization requires group size and "
|
||||||
"kv_cache_quant_algo specified in "
|
"kv_cache_quant_algo specified in "
|
||||||
@@ -285,19 +299,30 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
exclude_modules,
|
exclude_modules,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||||
|
import regex as re
|
||||||
|
|
||||||
|
for pattern in exclude_modules:
|
||||||
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||||
|
if re.fullmatch(regex_str, prefix):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
if self.exclude_modules and any(
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
module in prefix for module in self.exclude_modules
|
|
||||||
):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
|
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
|
||||||
|
prefix, self.exclude_modules
|
||||||
|
):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
return ModelOptFp4LinearMethod(self)
|
return ModelOptFp4LinearMethod(self)
|
||||||
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return ModelOptNvFp4FusedMoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
@@ -461,3 +486,305 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
return out.view(*output_shape)
|
return out.view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOptNvFp4FusedMoEMethod:
|
||||||
|
"""
|
||||||
|
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
|
||||||
|
Args:
|
||||||
|
quant_config: NVFP4 Quant Config
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||||
|
|
||||||
|
if not hasattr(cls, "_initialized"):
|
||||||
|
original_init = cls.__init__
|
||||||
|
new_cls = type(
|
||||||
|
cls.__name__,
|
||||||
|
(FusedMoEMethodBase,),
|
||||||
|
{
|
||||||
|
"__init__": original_init,
|
||||||
|
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||||
|
obj.__init__(*args, **kwargs)
|
||||||
|
return obj
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def __init__(self, quant_config: ModelOptFp4Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
if not is_sm100_supported():
|
||||||
|
raise ValueError(
|
||||||
|
"Current platform does not support NVFP4"
|
||||||
|
" quantization. Please use Blackwell and"
|
||||||
|
" above."
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||||
|
raise ValueError(
|
||||||
|
"NVFP4 quantization was selected, "
|
||||||
|
" dynamic quantization is not supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.num_experts = num_experts
|
||||||
|
layer.params_dtype = params_dtype
|
||||||
|
layer.quant_config = self.quant_config
|
||||||
|
weight_dtype = torch.uint8
|
||||||
|
weight_scale_dtype = torch.float8_e4m3fn
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
# GEMM 1
|
||||||
|
w13_weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
hidden_size // 2,
|
||||||
|
dtype=weight_dtype,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=2,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
|
||||||
|
# GEMM 2
|
||||||
|
w2_weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
intermediate_size_per_partition // 2,
|
||||||
|
dtype=weight_dtype,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=2,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
|
||||||
|
w13_weight_scale = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
hidden_size // self.quant_config.group_size,
|
||||||
|
dtype=weight_scale_dtype,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=2,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
|
||||||
|
w2_weight_scale = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
intermediate_size_per_partition // self.quant_config.group_size,
|
||||||
|
dtype=weight_scale_dtype,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=2,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
w13_weight_scale_2 = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||||
|
|
||||||
|
w2_weight_scale_2 = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(num_experts, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
||||||
|
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
w13_input_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
|
||||||
|
w2_input_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(num_experts, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
|
def swizzle_blockscale(self, scale: torch.tensor):
|
||||||
|
assert scale.dtype == torch.float8_e4m3fn
|
||||||
|
# Pad and blockwise interleave weight_scale
|
||||||
|
scale_ndim = scale.ndim
|
||||||
|
if scale.ndim == 2:
|
||||||
|
scale = scale.unsqueeze(0)
|
||||||
|
assert scale.ndim == 3
|
||||||
|
B, M, K = scale.shape
|
||||||
|
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||||
|
M_padded = round_up_multiple(M, 128)
|
||||||
|
K_padded = round_up_multiple(K, 4)
|
||||||
|
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||||
|
padded_scale[:B, :M, :K] = scale
|
||||||
|
batches, rows, cols = padded_scale.shape
|
||||||
|
assert rows % 128 == 0
|
||||||
|
assert cols % 4 == 0
|
||||||
|
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
|
||||||
|
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||||
|
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||||
|
return (
|
||||||
|
swizzled_scale.reshape(M, K)
|
||||||
|
if scale_ndim == 2
|
||||||
|
else swizzled_scale.reshape(B, M, K)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|
||||||
|
# GEMM 1
|
||||||
|
if not torch.allclose(
|
||||||
|
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
||||||
|
):
|
||||||
|
logger.warning_once(
|
||||||
|
"w1_weight_scale_2 must match w3_weight_scale_2. "
|
||||||
|
"Accuracy may be affected."
|
||||||
|
)
|
||||||
|
|
||||||
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||||
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
||||||
|
|
||||||
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||||
|
layer.g1_alphas = Parameter(
|
||||||
|
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
layer.w13_weight_scale.shape[2] % 16 == 0
|
||||||
|
), "Expected weight_scale.dim(1) to be divisible by 16"
|
||||||
|
assert (
|
||||||
|
layer.w13_weight_scale.dtype == torch.float8_e4m3fn
|
||||||
|
), "Weight Blockscale must be represented as FP8-E4M3"
|
||||||
|
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
||||||
|
|
||||||
|
layer.w13_blockscale_swizzled = Parameter(
|
||||||
|
w13_blockscale_swizzled, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is for quantization, so we need to invert it.
|
||||||
|
layer.w13_input_scale_quant = Parameter(
|
||||||
|
(1 / w13_input_scale).to(torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||||
|
|
||||||
|
# GEMM 2
|
||||||
|
layer.g2_alphas = Parameter(
|
||||||
|
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is for quantization, so we need to invert it.
|
||||||
|
layer.w2_input_scale_quant = Parameter(
|
||||||
|
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
layer.w2_weight_scale.shape[2] % 16 == 0
|
||||||
|
), "Expected weight_scale.dim(1) to be divisible by 16"
|
||||||
|
assert (
|
||||||
|
layer.w2_weight_scale.dtype == torch.float8_e4m3fn
|
||||||
|
), "Weight Blockscale must be represented as FP8-E4M3"
|
||||||
|
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
||||||
|
|
||||||
|
layer.w2_blockscale_swizzled = Parameter(
|
||||||
|
w2_blockscale_swizzled, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||||
|
|
||||||
|
device = layer.w13_weight.device
|
||||||
|
layer.cutlass_moe_params = CutlassMoEParams(
|
||||||
|
CutlassMoEType.BlockscaledFP4,
|
||||||
|
device,
|
||||||
|
num_experts=layer.num_experts,
|
||||||
|
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
||||||
|
hidden_size=layer.w13_weight.shape[2] * 2,
|
||||||
|
) # k
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||||
|
|
||||||
|
return cutlass_moe_fp4(
|
||||||
|
a=x,
|
||||||
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
|
w1_fp4=layer.w13_weight,
|
||||||
|
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||||
|
w1_alphas=layer.g1_alphas,
|
||||||
|
a2_gscale=layer.w2_input_scale_quant,
|
||||||
|
w2_fp4=layer.w2_weight,
|
||||||
|
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||||
|
w2_alphas=layer.g2_alphas,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
params=layer.cutlass_moe_params,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
).to(x.dtype)
|
||||||
|
|||||||
@@ -1746,7 +1746,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
global_server_args_dict["disable_shared_experts_fusion"] = False
|
global_server_args_dict["disable_shared_experts_fusion"] = False
|
||||||
log_info_on_rank0(
|
log_info_on_rank0(
|
||||||
logger,
|
logger,
|
||||||
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
"Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
@@ -1926,6 +1926,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self_attn.use_deep_gemm_bmm = True
|
self_attn.use_deep_gemm_bmm = True
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
||||||
|
|
||||||
if is_nextn:
|
if is_nextn:
|
||||||
if hasattr(self.config, "num_nextn_predict_layers"):
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
||||||
num_nextn_layers = self.config.num_nextn_predict_layers
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
||||||
@@ -1982,6 +1983,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
"up_proj.qzeros",
|
"up_proj.qzeros",
|
||||||
"up_proj.scales",
|
"up_proj.scales",
|
||||||
]
|
]
|
||||||
|
elif self.quant_config.get_name() == "modelopt_fp4":
|
||||||
|
suffix_list = [
|
||||||
|
"down_proj.weight",
|
||||||
|
"down_proj.weight_scale",
|
||||||
|
"down_proj.weight_scale_2",
|
||||||
|
"down_proj.input_scale",
|
||||||
|
"gate_proj.weight",
|
||||||
|
"gate_proj.weight_scale",
|
||||||
|
"gate_proj.weight_scale_2",
|
||||||
|
"gate_proj.input_scale",
|
||||||
|
"up_proj.weight",
|
||||||
|
"up_proj.weight_scale",
|
||||||
|
"up_proj.weight_scale_2",
|
||||||
|
"up_proj.input_scale",
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
||||||
@@ -2125,7 +2141,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if fuse_qkv_a_proj and (
|
if fuse_qkv_a_proj and (
|
||||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||||
):
|
):
|
||||||
@@ -2151,9 +2166,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
fused_weight = torch.cat(
|
fused_weight = torch.cat(
|
||||||
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
||||||
)
|
)
|
||||||
|
param_name = (
|
||||||
param_name = name.replace(
|
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
||||||
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
if "q_a_proj" in name
|
||||||
|
else name.replace(
|
||||||
|
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
param = params_dict[param_name]
|
param = params_dict[param_name]
|
||||||
|
|
||||||
@@ -2164,6 +2182,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
cached_a_proj.pop(q_a_proj_name)
|
cached_a_proj.pop(q_a_proj_name)
|
||||||
cached_a_proj.pop(kv_a_proj_name)
|
cached_a_proj.pop(kv_a_proj_name)
|
||||||
else:
|
else:
|
||||||
|
if (
|
||||||
|
"k_scale" in name or "v_scale" in name
|
||||||
|
) and name not in params_dict:
|
||||||
|
# modelopt attn kv scale is named differently
|
||||||
|
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
||||||
|
name = name.replace("_proj", "attn_mqa")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Unknown scale found in checkpoint: {name}"
|
||||||
|
)
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(
|
weight_loader = getattr(
|
||||||
param, "weight_loader", default_weight_loader
|
param, "weight_loader", default_weight_loader
|
||||||
|
|||||||
Reference in New Issue
Block a user