Simplify the usage of device (#1734)
This commit is contained in:
@@ -103,6 +103,7 @@ class Scheduler:
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
self.lora_paths = server_args.lora_paths
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
self.enable_overlap = server_args.enable_overlap_schedule
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
@@ -146,7 +147,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
if self.enable_overlap:
|
||||
TpWorkerClass = TpModelWorkerClient
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
@@ -670,7 +671,7 @@ class Scheduler:
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
self.running_batch.prepare_for_decode()
|
||||
self.running_batch.prepare_for_decode(self.enable_overlap)
|
||||
new_batch.mix_with_running(self.running_batch)
|
||||
new_batch.decoding_reqs = self.running_batch.reqs
|
||||
self.running_batch = None
|
||||
@@ -717,7 +718,7 @@ class Scheduler:
|
||||
return
|
||||
|
||||
# Update batch tensors
|
||||
batch.prepare_for_decode()
|
||||
batch.prepare_for_decode(self.enable_overlap)
|
||||
|
||||
def run_batch(self, batch: ScheduleBatch):
|
||||
"""Run a batch."""
|
||||
|
||||
Reference in New Issue
Block a user