Support NextN (MTP) speculative decoding for DeepSeek-V3/R1 (#3582)
This commit is contained in:
@@ -519,6 +519,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
):
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
@@ -680,6 +682,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_nextn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -731,7 +734,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
if (
|
||||
if is_nextn or (
|
||||
config.n_routed_experts is not None
|
||||
and layer_id >= config.first_k_dense_replace
|
||||
and layer_id % config.moe_layer_freq == 0
|
||||
|
||||
Reference in New Issue
Block a user