Share target model embed and head weights for nextn (#4033)

This commit is contained in:
Ke Bao
2025-03-04 05:30:04 +08:00
committed by GitHub
parent 146ac8df07
commit 9fafa62db7
7 changed files with 47 additions and 45 deletions

View File

@@ -83,23 +83,16 @@ class EAGLEWorker(TpModelWorker):
self.server_args = server_args
# Share the embedding and lm_head
if not self.speculative_algorithm.is_nextn():
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if server_args.speculative_token_map is not None:
head = head.clone()
self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device
)
head.data = head.data[self.hot_token_id]
else:
self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head)
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if server_args.speculative_token_map is not None:
head = head.clone()
self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device
)
head.data = head.data[self.hot_token_id]
else:
if server_args.speculative_token_map is not None:
raise NotImplementedError(
"NEXTN does not support speculative-token-map now"
)
self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners