Simplify the usage of device (#1734)
This commit is contained in:
@@ -425,7 +425,6 @@ class ScheduleBatch:
|
|||||||
req_pool_indices: torch.Tensor = None
|
req_pool_indices: torch.Tensor = None
|
||||||
seq_lens: torch.Tensor = None
|
seq_lens: torch.Tensor = None
|
||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None
|
||||||
|
|
||||||
output_ids: torch.Tensor = None
|
output_ids: torch.Tensor = None
|
||||||
|
|
||||||
# For processing logprobs
|
# For processing logprobs
|
||||||
@@ -442,27 +441,23 @@ class ScheduleBatch:
|
|||||||
# Stream
|
# Stream
|
||||||
has_stream: bool = False
|
has_stream: bool = False
|
||||||
|
|
||||||
# device
|
|
||||||
device: str = "cuda"
|
|
||||||
|
|
||||||
# Has regex
|
# Has regex
|
||||||
has_regex: bool = False
|
has_regex: bool = False
|
||||||
|
|
||||||
|
# device
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||||
return_logprob = any(req.return_logprob for req in reqs)
|
|
||||||
has_stream = any(req.stream for req in reqs)
|
|
||||||
has_regex = any(req.regex_fsm for req in reqs)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
req_to_token_pool=req_to_token_pool,
|
req_to_token_pool=req_to_token_pool,
|
||||||
token_to_kv_pool=token_to_kv_pool,
|
token_to_kv_pool=token_to_kv_pool,
|
||||||
tree_cache=tree_cache,
|
tree_cache=tree_cache,
|
||||||
return_logprob=return_logprob,
|
return_logprob=any(req.return_logprob for req in reqs),
|
||||||
has_stream=has_stream,
|
has_stream=any(req.stream for req in reqs),
|
||||||
|
has_regex=any(req.regex_fsm for req in reqs),
|
||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
has_regex=has_regex,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -754,7 +749,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
return jump_forward_reqs
|
return jump_forward_reqs
|
||||||
|
|
||||||
def prepare_for_decode(self):
|
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||||
self.forward_mode = ForwardMode.DECODE
|
self.forward_mode = ForwardMode.DECODE
|
||||||
|
|
||||||
self.input_ids = self.output_ids
|
self.input_ids = self.output_ids
|
||||||
@@ -767,10 +762,19 @@ class ScheduleBatch:
|
|||||||
# Alloc mem
|
# Alloc mem
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||||
self.req_to_token_pool.write(
|
|
||||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
if enable_overlap:
|
||||||
)
|
# Do not use in-place operations in the overlap mode
|
||||||
self.seq_lens.add_(1)
|
self.req_to_token_pool.write(
|
||||||
|
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||||
|
)
|
||||||
|
self.seq_lens = self.seq_lens + 1
|
||||||
|
else:
|
||||||
|
# A faster in-place version
|
||||||
|
self.req_to_token_pool.write(
|
||||||
|
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||||
|
)
|
||||||
|
self.seq_lens.add_(1)
|
||||||
|
|
||||||
def filter_batch(
|
def filter_batch(
|
||||||
self,
|
self,
|
||||||
@@ -882,6 +886,7 @@ class ScheduleBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
# Only contain fields that will be used by process_batch_result
|
||||||
return ScheduleBatch(
|
return ScheduleBatch(
|
||||||
reqs=self.reqs,
|
reqs=self.reqs,
|
||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
@@ -940,9 +945,9 @@ class ModelWorkerBatch:
|
|||||||
return ModelWorkerBatch(
|
return ModelWorkerBatch(
|
||||||
bid=self.bid,
|
bid=self.bid,
|
||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
input_ids=self.input_ids.clone(),
|
input_ids=self.input_ids,
|
||||||
req_pool_indices=self.req_pool_indices,
|
req_pool_indices=self.req_pool_indices,
|
||||||
seq_lens=self.seq_lens.clone(),
|
seq_lens=self.seq_lens,
|
||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
req_to_token_pool_records=self.req_to_token_pool_records,
|
req_to_token_pool_records=self.req_to_token_pool_records,
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class Scheduler:
|
|||||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||||
self.lora_paths = server_args.lora_paths
|
self.lora_paths = server_args.lora_paths
|
||||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||||
|
self.enable_overlap = server_args.enable_overlap_schedule
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
@@ -146,7 +147,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Launch a tensor parallel worker
|
# Launch a tensor parallel worker
|
||||||
if self.server_args.enable_overlap_schedule:
|
if self.enable_overlap:
|
||||||
TpWorkerClass = TpModelWorkerClient
|
TpWorkerClass = TpModelWorkerClient
|
||||||
self.resolve_next_token_ids = (
|
self.resolve_next_token_ids = (
|
||||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||||
@@ -670,7 +671,7 @@ class Scheduler:
|
|||||||
|
|
||||||
# Mixed-style chunked prefill
|
# Mixed-style chunked prefill
|
||||||
if self.is_mixed_chunk and self.running_batch is not None:
|
if self.is_mixed_chunk and self.running_batch is not None:
|
||||||
self.running_batch.prepare_for_decode()
|
self.running_batch.prepare_for_decode(self.enable_overlap)
|
||||||
new_batch.mix_with_running(self.running_batch)
|
new_batch.mix_with_running(self.running_batch)
|
||||||
new_batch.decoding_reqs = self.running_batch.reqs
|
new_batch.decoding_reqs = self.running_batch.reqs
|
||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
@@ -717,7 +718,7 @@ class Scheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Update batch tensors
|
# Update batch tensors
|
||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode(self.enable_overlap)
|
||||||
|
|
||||||
def run_batch(self, batch: ScheduleBatch):
|
def run_batch(self, batch: ScheduleBatch):
|
||||||
"""Run a batch."""
|
"""Run a batch."""
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class SamplingBatchInfo:
|
|||||||
disable_penalizer: bool,
|
disable_penalizer: bool,
|
||||||
):
|
):
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
device = batch.input_ids.device
|
device = batch.device
|
||||||
temperatures = (
|
temperatures = (
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[r.sampling_params.temperature for r in reqs],
|
[r.sampling_params.temperature for r in reqs],
|
||||||
@@ -95,7 +95,7 @@ class SamplingBatchInfo:
|
|||||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
device=batch.input_ids.device,
|
device=batch.device,
|
||||||
Penalizers={
|
Penalizers={
|
||||||
penaltylib.BatchedFrequencyPenalizer,
|
penaltylib.BatchedFrequencyPenalizer,
|
||||||
penaltylib.BatchedMinNewTokensPenalizer,
|
penaltylib.BatchedMinNewTokensPenalizer,
|
||||||
|
|||||||
Reference in New Issue
Block a user