Simplify the usage of device (#1734)

This commit is contained in:
Lianmin Zheng
2024-10-20 18:17:41 -07:00
committed by GitHub
parent 554fbf93cd
commit e12358dc91
3 changed files with 29 additions and 23 deletions

View File

@@ -103,6 +103,7 @@ class Scheduler:
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule
# Init inter-process communication
context = zmq.Context(2)
@@ -146,7 +147,7 @@ class Scheduler:
)
# Launch a tensor parallel worker
if self.server_args.enable_overlap_schedule:
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
@@ -670,7 +671,7 @@ class Scheduler:
# Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode()
self.running_batch.prepare_for_decode(self.enable_overlap)
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None
@@ -717,7 +718,7 @@ class Scheduler:
return
# Update batch tensors
batch.prepare_for_decode()
batch.prepare_for_decode(self.enable_overlap)
def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""