Share target model embed and head weights for nextn (#4033)
This commit is contained in:
@@ -83,23 +83,16 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.server_args = server_args
|
||||
|
||||
# Share the embedding and lm_head
|
||||
if not self.speculative_algorithm.is_nextn():
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
if server_args.speculative_token_map is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = torch.tensor(
|
||||
self.hot_token_id, dtype=torch.int32, device=head.device
|
||||
)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
else:
|
||||
self.hot_token_id = None
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
if server_args.speculative_token_map is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = torch.tensor(
|
||||
self.hot_token_id, dtype=torch.int32, device=head.device
|
||||
)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
else:
|
||||
if server_args.speculative_token_map is not None:
|
||||
raise NotImplementedError(
|
||||
"NEXTN does not support speculative-token-map now"
|
||||
)
|
||||
self.hot_token_id = None
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
|
||||
Reference in New Issue
Block a user