Clean up batch data structures: Introducing ModelWorkerBatch (#1544)

This commit is contained in:
Lianmin Zheng
2024-09-30 06:41:49 -07:00
committed by GitHub
parent 36d5acfca5
commit 63ba2f8d7b
9 changed files with 274 additions and 155 deletions

View File

@@ -141,6 +141,9 @@ class Scheduler:
nccl_port=port_args.nccl_ports[0],
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Get token and memory info from the tp worker
(
@@ -292,7 +295,7 @@ class Scheduler:
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
self.running_batch.merge_batch(new_batch)
else:
# Run a decode batch
if self.running_batch is not None:
@@ -370,7 +373,7 @@ class Scheduler:
req.image_inputs = ImageInputs.from_dict(
recv_req.image_inputs, self.model_config.vocab_size
)
req.origin_input_ids = self.tp_worker.model_runner.model.pad_input_ids(
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
)
@@ -575,9 +578,9 @@ class Scheduler:
if self.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
forward_batch = batch.get_forward_batch()
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
forward_batch, batch
model_worker_batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
@@ -641,8 +644,8 @@ class Scheduler:
)
else:
assert batch.extend_num_tokens != 0
forward_batch = batch.get_forward_batch()
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
# Check finish conditions
for i, req in enumerate(batch.reqs):
@@ -759,9 +762,7 @@ class Scheduler:
# Check for jump-forward
if not self.disable_regex_jump_forward:
jump_forward_reqs = batch.check_for_jump_forward(
self.tp_worker.model_runner
)
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
return
@@ -771,9 +772,9 @@ class Scheduler:
batch.prepare_for_decode()
# Forward and sample the next tokens
forward_batch = batch.get_forward_batch()
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
forward_batch, batch
model_worker_batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids