Clean up batch data structures: Introducing ModelWorkerBatch (#1544)
This commit is contained in:
@@ -141,6 +141,9 @@ class Scheduler:
|
||||
nccl_port=port_args.nccl_ports[0],
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
self.pad_input_ids_func = getattr(
|
||||
self.tp_worker.model_runner.model, "pad_input_ids", None
|
||||
)
|
||||
|
||||
# Get token and memory info from the tp worker
|
||||
(
|
||||
@@ -292,7 +295,7 @@ class Scheduler:
|
||||
if self.running_batch is None:
|
||||
self.running_batch = new_batch
|
||||
else:
|
||||
self.running_batch.merge(new_batch)
|
||||
self.running_batch.merge_batch(new_batch)
|
||||
else:
|
||||
# Run a decode batch
|
||||
if self.running_batch is not None:
|
||||
@@ -370,7 +373,7 @@ class Scheduler:
|
||||
req.image_inputs = ImageInputs.from_dict(
|
||||
recv_req.image_inputs, self.model_config.vocab_size
|
||||
)
|
||||
req.origin_input_ids = self.tp_worker.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids = self.pad_input_ids_func(
|
||||
req.origin_input_ids_unpadded, req.image_inputs
|
||||
)
|
||||
|
||||
@@ -575,9 +578,9 @@ class Scheduler:
|
||||
if self.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
forward_batch = batch.get_forward_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
forward_batch, batch
|
||||
model_worker_batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
@@ -641,8 +644,8 @@ class Scheduler:
|
||||
)
|
||||
else:
|
||||
assert batch.extend_num_tokens != 0
|
||||
forward_batch = batch.get_forward_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
@@ -759,9 +762,7 @@ class Scheduler:
|
||||
|
||||
# Check for jump-forward
|
||||
if not self.disable_regex_jump_forward:
|
||||
jump_forward_reqs = batch.check_for_jump_forward(
|
||||
self.tp_worker.model_runner
|
||||
)
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||
self.waiting_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return
|
||||
@@ -771,9 +772,9 @@ class Scheduler:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward and sample the next tokens
|
||||
forward_batch = batch.get_forward_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
forward_batch, batch
|
||||
model_worker_batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
|
||||
Reference in New Issue
Block a user