fix: disable dsv3_router_gemm in dsv3_nextn (#7793)
This commit is contained in:
@@ -210,8 +210,10 @@ class MoEGate(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
is_nextn: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.is_nextn = is_nextn
|
||||||
self.weight = nn.Parameter(
|
self.weight = nn.Parameter(
|
||||||
torch.empty((config.n_routed_experts, config.hidden_size))
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
||||||
)
|
)
|
||||||
@@ -233,8 +235,10 @@ class MoEGate(nn.Module):
|
|||||||
True, # is_vnni
|
True, # is_vnni
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
|
||||||
if (
|
if (
|
||||||
_is_cuda
|
_is_cuda
|
||||||
|
and not self.is_nextn
|
||||||
and hidden_states.shape[0] < 4
|
and hidden_states.shape[0] < 4
|
||||||
and hidden_states.shape[1] == 7168
|
and hidden_states.shape[1] == 7168
|
||||||
and self.weight.shape[0] == 256
|
and self.weight.shape[0] == 256
|
||||||
@@ -258,6 +262,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
|
is_nextn: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
@@ -284,7 +289,9 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
"Only silu is supported for now."
|
"Only silu is supported for now."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
self.gate = MoEGate(
|
||||||
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||||
|
)
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class()(
|
||||||
num_experts=config.n_routed_experts
|
num_experts=config.n_routed_experts
|
||||||
@@ -1776,6 +1783,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
prefix=add_prefix("mlp", prefix),
|
prefix=add_prefix("mlp", prefix),
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
alt_stream=alt_stream,
|
alt_stream=alt_stream,
|
||||||
|
is_nextn=is_nextn,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if enable_moe_dense_fully_dp():
|
if enable_moe_dense_fully_dp():
|
||||||
|
|||||||
Reference in New Issue
Block a user