[Feature] Add test for speculative_token_map (#4016)

This commit is contained in:
William
2025-03-04 20:26:24 +08:00
committed by GitHub
parent 926f8efc0c
commit 0d4e3228cf
2 changed files with 75 additions and 14 deletions

View File

@@ -31,6 +31,16 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
logger = logging.getLogger(__name__)
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
return torch.load(token_map_path)
class EAGLEWorker(TpModelWorker):
def __init__(
@@ -48,20 +58,12 @@ class EAGLEWorker(TpModelWorker):
server_args.disable_cuda_graph = True
if server_args.speculative_token_map is not None:
if os.path.exists(server_args.speculative_token_map):
self.hot_token_id = torch.load(server_args.speculative_token_map)
else:
cache_dir = snapshot_download(
os.path.dirname(server_args.speculative_token_map),
ignore_patterns=["*.bin", "*.safetensors"],
)
file_path = os.path.join(
cache_dir, os.path.basename(server_args.speculative_token_map)
)
self.hot_token_id = torch.load(file_path)
self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)
else:
self.hot_token_id = None
super().__init__(
gpu_id=gpu_id,
@@ -84,14 +86,12 @@ class EAGLEWorker(TpModelWorker):
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if server_args.speculative_token_map is not None:
if self.hot_token_id is not None:
head = head.clone()
self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device
)
head.data = head.data[self.hot_token_id]
else:
self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph