[Feature] Add test for speculative_token_map (#4016)
This commit is contained in:
@@ -31,6 +31,16 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_token_map(token_map_path: str) -> List[int]:
|
||||
if not os.path.exists(token_map_path):
|
||||
cache_dir = snapshot_download(
|
||||
os.path.dirname(token_map_path),
|
||||
ignore_patterns=["*.bin", "*.safetensors"],
|
||||
)
|
||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||
return torch.load(token_map_path)
|
||||
|
||||
|
||||
class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def __init__(
|
||||
@@ -48,20 +58,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
if server_args.speculative_token_map is not None:
|
||||
if os.path.exists(server_args.speculative_token_map):
|
||||
self.hot_token_id = torch.load(server_args.speculative_token_map)
|
||||
else:
|
||||
cache_dir = snapshot_download(
|
||||
os.path.dirname(server_args.speculative_token_map),
|
||||
ignore_patterns=["*.bin", "*.safetensors"],
|
||||
)
|
||||
file_path = os.path.join(
|
||||
cache_dir, os.path.basename(server_args.speculative_token_map)
|
||||
)
|
||||
self.hot_token_id = torch.load(file_path)
|
||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||
server_args.json_model_override_args = (
|
||||
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
||||
)
|
||||
else:
|
||||
self.hot_token_id = None
|
||||
|
||||
super().__init__(
|
||||
gpu_id=gpu_id,
|
||||
@@ -84,14 +86,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Share the embedding and lm_head
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
if server_args.speculative_token_map is not None:
|
||||
if self.hot_token_id is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = torch.tensor(
|
||||
self.hot_token_id, dtype=torch.int32, device=head.device
|
||||
)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
else:
|
||||
self.hot_token_id = None
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
|
||||
|
||||
Reference in New Issue
Block a user