Simplify eagle tests and TP sync in grammar backend (#4066)

This commit is contained in:
Lianmin Zheng
2025-03-04 13:40:40 -08:00
committed by GitHub
parent 03b0364f76
commit 77a3954bf7
14 changed files with 122 additions and 126 deletions

View File

@@ -31,16 +31,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
logger = logging.getLogger(__name__)
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
return torch.load(token_map_path)
class EAGLEWorker(TpModelWorker):
def __init__(
@@ -57,6 +47,7 @@ class EAGLEWorker(TpModelWorker):
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
# Load hot token ids
if server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = (
@@ -65,6 +56,7 @@ class EAGLEWorker(TpModelWorker):
else:
self.hot_token_id = None
# Init target worker
super().__init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
@@ -88,9 +80,7 @@ class EAGLEWorker(TpModelWorker):
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if self.hot_token_id is not None:
head = head.clone()
self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device
)
self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id]
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
@@ -369,3 +359,14 @@ class EAGLEWorker(TpModelWorker):
][:req_len]
self.model_runner.token_to_kv_pool.free(kv_indices)
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
def load_token_map(token_map_path: str) -> List[int]:
if not os.path.exists(token_map_path):
cache_dir = snapshot_download(
os.path.dirname(token_map_path),
ignore_patterns=["*.bin", "*.safetensors"],
)
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path)
return torch.tensor(hot_token_id, dtype=torch.int32)