fix: disable dsv3_router_gemm in dsv3_nextn (#7793)
This commit is contained in:
@@ -210,8 +210,10 @@ class MoEGate(nn.Module):
|
||||
self,
|
||||
config,
|
||||
prefix: str = "",
|
||||
is_nextn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.is_nextn = is_nextn
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((config.n_routed_experts, config.hidden_size))
|
||||
)
|
||||
@@ -233,8 +235,10 @@ class MoEGate(nn.Module):
|
||||
True, # is_vnni
|
||||
)
|
||||
|
||||
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
|
||||
if (
|
||||
_is_cuda
|
||||
and not self.is_nextn
|
||||
and hidden_states.shape[0] < 4
|
||||
and hidden_states.shape[1] == 7168
|
||||
and self.weight.shape[0] == 256
|
||||
@@ -258,6 +262,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
is_nextn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -284,7 +289,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
||||
self.gate = MoEGate(
|
||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
num_experts=config.n_routed_experts
|
||||
@@ -1776,6 +1783,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
layer_id=self.layer_id,
|
||||
alt_stream=alt_stream,
|
||||
is_nextn=is_nextn,
|
||||
)
|
||||
else:
|
||||
if enable_moe_dense_fully_dp():
|
||||
|
||||
Reference in New Issue
Block a user