Simplify the usage of device (#1734)

This commit is contained in:
Lianmin Zheng
2024-10-20 18:17:41 -07:00
committed by GitHub
parent 554fbf93cd
commit e12358dc91
3 changed files with 29 additions and 23 deletions

View File

@@ -425,7 +425,6 @@ class ScheduleBatch:
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
out_cache_loc: torch.Tensor = None
output_ids: torch.Tensor = None
# For processing logprobs
@@ -442,27 +441,23 @@ class ScheduleBatch:
# Stream
has_stream: bool = False
# device
device: str = "cuda"
# Has regex
has_regex: bool = False
# device
device: str = "cuda"
@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_logprob = any(req.return_logprob for req in reqs)
has_stream = any(req.stream for req in reqs)
has_regex = any(req.regex_fsm for req in reqs)
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
return_logprob=return_logprob,
has_stream=has_stream,
return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs),
device=req_to_token_pool.device,
has_regex=has_regex,
)
def batch_size(self):
@@ -754,7 +749,7 @@ class ScheduleBatch:
return jump_forward_reqs
def prepare_for_decode(self):
def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE
self.input_ids = self.output_ids
@@ -767,10 +762,19 @@ class ScheduleBatch:
# Alloc mem
bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens.add_(1)
if enable_overlap:
# Do not use in-place operations in the overlap mode
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens = self.seq_lens + 1
else:
# A faster in-place version
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens.add_(1)
def filter_batch(
self,
@@ -882,6 +886,7 @@ class ScheduleBatch:
)
def copy(self):
# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
forward_mode=self.forward_mode,
@@ -940,9 +945,9 @@ class ModelWorkerBatch:
return ModelWorkerBatch(
bid=self.bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids.clone(),
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens.clone(),
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob,