Organize sampling batch info better (#1562)

This commit is contained in:
Lianmin Zheng
2024-10-03 18:29:49 -07:00
committed by GitHub
parent e0b5dbcec1
commit 32eb6e96f2
8 changed files with 43 additions and 35 deletions

View File

@@ -96,7 +96,9 @@ class Scheduler:
if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.scheduler_port}")
self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
@@ -141,9 +143,6 @@ class Scheduler:
nccl_port=port_args.nccl_ports[0],
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Get token and memory info from the tp worker
(
@@ -154,6 +153,9 @@ class Scheduler:
self.random_seed,
) = self.tp_worker.get_token_and_memory_info()
set_random_seed(self.random_seed)
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Print debug info
logger.info(