Simplify the usage of device (#1734)
This commit is contained in:
@@ -425,7 +425,6 @@ class ScheduleBatch:
|
||||
req_pool_indices: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
|
||||
output_ids: torch.Tensor = None
|
||||
|
||||
# For processing logprobs
|
||||
@@ -442,27 +441,23 @@ class ScheduleBatch:
|
||||
# Stream
|
||||
has_stream: bool = False
|
||||
|
||||
# device
|
||||
device: str = "cuda"
|
||||
|
||||
# Has regex
|
||||
has_regex: bool = False
|
||||
|
||||
# device
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
has_stream = any(req.stream for req in reqs)
|
||||
has_regex = any(req.regex_fsm for req in reqs)
|
||||
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
tree_cache=tree_cache,
|
||||
return_logprob=return_logprob,
|
||||
has_stream=has_stream,
|
||||
return_logprob=any(req.return_logprob for req in reqs),
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_regex=any(req.regex_fsm for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
has_regex=has_regex,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -754,7 +749,7 @@ class ScheduleBatch:
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_for_decode(self):
|
||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
self.input_ids = self.output_ids
|
||||
@@ -767,10 +762,19 @@ class ScheduleBatch:
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
if enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
@@ -882,6 +886,7 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
# Only contain fields that will be used by process_batch_result
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
forward_mode=self.forward_mode,
|
||||
@@ -940,9 +945,9 @@ class ModelWorkerBatch:
|
||||
return ModelWorkerBatch(
|
||||
bid=self.bid,
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids.clone(),
|
||||
input_ids=self.input_ids,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens.clone(),
|
||||
seq_lens=self.seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
req_to_token_pool_records=self.req_to_token_pool_records,
|
||||
return_logprob=self.return_logprob,
|
||||
|
||||
Reference in New Issue
Block a user