Simplify eagle tests and TP sync in grammar backend (#4066)
This commit is contained in:
@@ -1886,33 +1886,22 @@ class Scheduler:
|
||||
break
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
if self.attn_tp_size > 1:
|
||||
# Sync across attn TP ranks to make sure they have the same number of ready requests
|
||||
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
||||
torch.distributed.all_reduce(
|
||||
tensor,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=self.attn_tp_cpu_group,
|
||||
)
|
||||
num_ready_reqs_max = tensor.item()
|
||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||
self.grammar_queue[i].grammar = self.grammar_queue[
|
||||
i
|
||||
].grammar.result()
|
||||
num_ready_reqs = num_ready_reqs_max
|
||||
tp_size = self.attn_tp_size
|
||||
tp_group = self.attn_tp_cpu_group
|
||||
else:
|
||||
if self.tp_size > 1:
|
||||
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
|
||||
)
|
||||
num_ready_reqs_max = tensor.item()
|
||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||
self.grammar_queue[i].grammar = self.grammar_queue[
|
||||
i
|
||||
].grammar.result()
|
||||
num_ready_reqs = num_ready_reqs_max
|
||||
tp_size = self.tp_size
|
||||
tp_group = self.tp_cpu_group
|
||||
|
||||
if tp_size > 1:
|
||||
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
||||
)
|
||||
num_ready_reqs_max = tensor.item()
|
||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
||||
num_ready_reqs = num_ready_reqs_max
|
||||
|
||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
@@ -31,16 +31,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_token_map(token_map_path: str) -> List[int]:
|
||||
if not os.path.exists(token_map_path):
|
||||
cache_dir = snapshot_download(
|
||||
os.path.dirname(token_map_path),
|
||||
ignore_patterns=["*.bin", "*.safetensors"],
|
||||
)
|
||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||
return torch.load(token_map_path)
|
||||
|
||||
|
||||
class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def __init__(
|
||||
@@ -57,6 +47,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
# Load hot token ids
|
||||
if server_args.speculative_token_map is not None:
|
||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||
server_args.json_model_override_args = (
|
||||
@@ -65,6 +56,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
else:
|
||||
self.hot_token_id = None
|
||||
|
||||
# Init target worker
|
||||
super().__init__(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
@@ -88,9 +80,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
if self.hot_token_id is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = torch.tensor(
|
||||
self.hot_token_id, dtype=torch.int32, device=head.device
|
||||
)
|
||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
@@ -369,3 +359,14 @@ class EAGLEWorker(TpModelWorker):
|
||||
][:req_len]
|
||||
self.model_runner.token_to_kv_pool.free(kv_indices)
|
||||
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
|
||||
def load_token_map(token_map_path: str) -> List[int]:
|
||||
if not os.path.exists(token_map_path):
|
||||
cache_dir = snapshot_download(
|
||||
os.path.dirname(token_map_path),
|
||||
ignore_patterns=["*.bin", "*.safetensors"],
|
||||
)
|
||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||
hot_token_id = torch.load(token_map_path)
|
||||
return torch.tensor(hot_token_id, dtype=torch.int32)
|
||||
|
||||
@@ -501,6 +501,7 @@ def get_benchmark_args(
|
||||
request_rate=float("inf"),
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
seed: int = 0,
|
||||
pd_seperated: bool = False,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
@@ -524,7 +525,7 @@ def get_benchmark_args(
|
||||
disable_tqdm=False,
|
||||
disable_stream=disable_stream,
|
||||
return_logprob=False,
|
||||
seed=0,
|
||||
seed=seed,
|
||||
disable_ignore_eos=disable_ignore_eos,
|
||||
extra_request_body=None,
|
||||
apply_chat_template=False,
|
||||
@@ -549,6 +550,7 @@ def run_bench_serving(
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
need_warmup=False,
|
||||
seed: int = 0,
|
||||
):
|
||||
# Launch the server
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
@@ -572,6 +574,7 @@ def run_bench_serving(
|
||||
request_rate=request_rate,
|
||||
disable_stream=disable_stream,
|
||||
disable_ignore_eos=disable_ignore_eos,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user