Organize sampling batch info better (#1562)
This commit is contained in:
@@ -96,7 +96,9 @@ class Scheduler:
|
||||
|
||||
if self.tp_rank == 0:
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.scheduler_port}")
|
||||
self.recv_from_tokenizer.bind(
|
||||
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
|
||||
)
|
||||
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(
|
||||
@@ -141,9 +143,6 @@ class Scheduler:
|
||||
nccl_port=port_args.nccl_ports[0],
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
self.pad_input_ids_func = getattr(
|
||||
self.tp_worker.model_runner.model, "pad_input_ids", None
|
||||
)
|
||||
|
||||
# Get token and memory info from the tp worker
|
||||
(
|
||||
@@ -154,6 +153,9 @@ class Scheduler:
|
||||
self.random_seed,
|
||||
) = self.tp_worker.get_token_and_memory_info()
|
||||
set_random_seed(self.random_seed)
|
||||
self.pad_input_ids_func = getattr(
|
||||
self.tp_worker.model_runner.model, "pad_input_ids", None
|
||||
)
|
||||
|
||||
# Print debug info
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user