diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index b67dfc741..6b81b7ba1 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -280,7 +280,8 @@ class ForwardBatch: ).to(device, non_blocking=True) if ( model_runner.server_args.attention_backend != "torch_native" - and model_runner.server_args.speculative_algorithm != "NEXTN" + # TODO: Fix triton kernel illegal memory access for EAGLE + and model_runner.server_args.speculative_algorithm != "EAGLE" ): ret.extend_num_tokens = batch.extend_num_tokens positions, ret.extend_start_loc = compute_position_triton( diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 9588eb87e..0dfa69a2e 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -116,14 +116,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): self.model = DeepseekModelNextN(config, quant_config) if global_server_args_dict["enable_dp_attention"]: - self.model.shared_head.head = ReplicatedLinear( + self.lm_head = ReplicatedLinear( config.hidden_size, config.vocab_size, bias=False, ) self.logits_processor = LogitsProcessor(config, skip_all_gather=True) else: - self.model.shared_head.head = ParallelLMHead( + self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, @@ -139,7 +139,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.model.shared_head.head, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -168,10 +168,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): nextn_layer_prefix = "model.layers.0" nextn_spec_weight_names = [ - "shared_head.head", "shared_head.norm", "eh_proj", - "embed_tokens", "enorm", "hnorm", ] @@ -180,17 +178,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): for name, loaded_weight in weights: if not name.startswith(nextn_layer_prefix): continue - else: - is_decoder = True - # For nextn specific weights - for weight_name in nextn_spec_weight_names: - if weight_name in name: - name = name.replace(nextn_layer_prefix, "model") - is_decoder = False - break - # For decoder layer weights - if is_decoder: - name = name.replace(nextn_layer_prefix, "model.decoder") + + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue + + is_decoder = True + # For nextn specific weights + for weight_name in nextn_spec_weight_names: + if weight_name in name: + name = name.replace(nextn_layer_prefix, "model") + is_decoder = False + break + # For decoder layer weights + if is_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") if "rotary_emb.inv_freq" in name: continue diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c68403ea9..8d7562120 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module): if is_hip_: self_attn.w_scale *= 2.0 + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 32c3dbd7c..5833e1266 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -270,10 +270,11 @@ class ServerArgs: ) # Speculative Decoding - if ( - self.speculative_algorithm == "EAGLE" - or self.speculative_algorithm == "NEXTN" - ): + if self.speculative_algorithm == "NEXTN": + # NEXTN shares the same implementation of EAGLE + self.speculative_algorithm = "EAGLE" + + if self.speculative_algorithm == "EAGLE": self.disable_overlap_schedule = True self.prefill_only_one_req = True self.disable_cuda_graph_padding = True diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 7639bd999..810514429 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -83,23 +83,16 @@ class EAGLEWorker(TpModelWorker): self.server_args = server_args # Share the embedding and lm_head - if not self.speculative_algorithm.is_nextn(): - embed, head = self.target_worker.model_runner.model.get_embed_and_head() - if server_args.speculative_token_map is not None: - head = head.clone() - self.hot_token_id = torch.tensor( - self.hot_token_id, dtype=torch.int32, device=head.device - ) - head.data = head.data[self.hot_token_id] - else: - self.hot_token_id = None - self.model_runner.model.set_embed_and_head(embed, head) + embed, head = self.target_worker.model_runner.model.get_embed_and_head() + if server_args.speculative_token_map is not None: + head = head.clone() + self.hot_token_id = torch.tensor( + self.hot_token_id, dtype=torch.int32, device=head.device + ) + head.data = head.data[self.hot_token_id] else: - if server_args.speculative_token_map is not None: - raise NotImplementedError( - "NEXTN does not support speculative-token-map now" - ) self.hot_token_id = None + self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph # Create multi-step attn backends and cuda graph runners diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 3d3520018..af45ac423 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -5,24 +5,16 @@ class SpeculativeAlgorithm(IntEnum): NONE = auto() EAGLE = auto() - # NEXTN spec decoding is for DeepSeek V3/R1 - # currently it's implemented based on EAGLE - NEXTN = auto() - def is_none(self): return self == SpeculativeAlgorithm.NONE def is_eagle(self): - return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.NEXTN - - def is_nextn(self): - return self == SpeculativeAlgorithm.NEXTN + return self == SpeculativeAlgorithm.EAGLE @staticmethod def from_string(name: str): name_map = { "EAGLE": SpeculativeAlgorithm.EAGLE, - "NEXTN": SpeculativeAlgorithm.NEXTN, None: SpeculativeAlgorithm.NONE, } if name is not None: diff --git a/scripts/export_deepseek_nextn.py b/scripts/export_deepseek_nextn.py index ad6bb0406..35a06b645 100644 --- a/scripts/export_deepseek_nextn.py +++ b/scripts/export_deepseek_nextn.py @@ -62,6 +62,8 @@ def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id): continue for key in matching_keys: + if "embed_tokens" in key or "shared_head.head" in key: + continue new_key = key.replace(prefix, "model.layers.0") params[new_key] = f.get_tensor(key)