Simplify the usage of device (#1734)
This commit is contained in:
@@ -425,7 +425,6 @@ class ScheduleBatch:
|
||||
req_pool_indices: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
|
||||
output_ids: torch.Tensor = None
|
||||
|
||||
# For processing logprobs
|
||||
@@ -442,27 +441,23 @@ class ScheduleBatch:
|
||||
# Stream
|
||||
has_stream: bool = False
|
||||
|
||||
# device
|
||||
device: str = "cuda"
|
||||
|
||||
# Has regex
|
||||
has_regex: bool = False
|
||||
|
||||
# device
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
has_stream = any(req.stream for req in reqs)
|
||||
has_regex = any(req.regex_fsm for req in reqs)
|
||||
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
tree_cache=tree_cache,
|
||||
return_logprob=return_logprob,
|
||||
has_stream=has_stream,
|
||||
return_logprob=any(req.return_logprob for req in reqs),
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_regex=any(req.regex_fsm for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
has_regex=has_regex,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -754,7 +749,7 @@ class ScheduleBatch:
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_for_decode(self):
|
||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
self.input_ids = self.output_ids
|
||||
@@ -767,10 +762,19 @@ class ScheduleBatch:
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
if enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
@@ -882,6 +886,7 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
# Only contain fields that will be used by process_batch_result
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
forward_mode=self.forward_mode,
|
||||
@@ -940,9 +945,9 @@ class ModelWorkerBatch:
|
||||
return ModelWorkerBatch(
|
||||
bid=self.bid,
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids.clone(),
|
||||
input_ids=self.input_ids,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens.clone(),
|
||||
seq_lens=self.seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
req_to_token_pool_records=self.req_to_token_pool_records,
|
||||
return_logprob=self.return_logprob,
|
||||
|
||||
@@ -103,6 +103,7 @@ class Scheduler:
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
self.lora_paths = server_args.lora_paths
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
self.enable_overlap = server_args.enable_overlap_schedule
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
@@ -146,7 +147,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
if self.enable_overlap:
|
||||
TpWorkerClass = TpModelWorkerClient
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
@@ -670,7 +671,7 @@ class Scheduler:
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
self.running_batch.prepare_for_decode()
|
||||
self.running_batch.prepare_for_decode(self.enable_overlap)
|
||||
new_batch.mix_with_running(self.running_batch)
|
||||
new_batch.decoding_reqs = self.running_batch.reqs
|
||||
self.running_batch = None
|
||||
@@ -717,7 +718,7 @@ class Scheduler:
|
||||
return
|
||||
|
||||
# Update batch tensors
|
||||
batch.prepare_for_decode()
|
||||
batch.prepare_for_decode(self.enable_overlap)
|
||||
|
||||
def run_batch(self, batch: ScheduleBatch):
|
||||
"""Run a batch."""
|
||||
|
||||
@@ -51,7 +51,7 @@ class SamplingBatchInfo:
|
||||
disable_penalizer: bool,
|
||||
):
|
||||
reqs = batch.reqs
|
||||
device = batch.input_ids.device
|
||||
device = batch.device
|
||||
temperatures = (
|
||||
torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
@@ -95,7 +95,7 @@ class SamplingBatchInfo:
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=batch.input_ids.device,
|
||||
device=batch.device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
|
||||
Reference in New Issue
Block a user