Unify the memory pool api and tp worker API (#1724)
This commit is contained in:
@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
|
||||
It will be transformed from CPU scheduler to GPU model runner.
|
||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
@@ -522,12 +524,12 @@ class ScheduleBatch:
|
||||
assert seq_len - pre_len == req.extend_input_len
|
||||
|
||||
if pre_len > 0:
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
|
||||
req.prefix_indices
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||
)
|
||||
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
||||
out_cache_loc[pt : pt + req.extend_input_len]
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(pre_len, seq_len)),
|
||||
out_cache_loc[pt : pt + req.extend_input_len],
|
||||
)
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
@@ -765,9 +767,8 @@ class ScheduleBatch:
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
|
||||
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
||||
self.out_cache_loc
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
@@ -848,7 +849,6 @@ class ScheduleBatch:
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
image_inputs = [r.image_inputs for r in self.reqs]
|
||||
|
||||
lora_paths = [req.lora_path for req in self.reqs]
|
||||
if self.has_regex:
|
||||
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
||||
self.sampling_info.regex_fsm_states = [
|
||||
@@ -869,13 +869,14 @@ class ScheduleBatch:
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||
image_inputs=image_inputs,
|
||||
lora_paths=lora_paths,
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
mrope_positions_delta=mrope_positions_delta,
|
||||
)
|
||||
@@ -911,6 +912,9 @@ class ModelWorkerBatch:
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# The memory pool operation records
|
||||
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool
|
||||
top_logprobs_nums: Optional[List[int]]
|
||||
@@ -940,6 +944,7 @@ class ModelWorkerBatch:
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens.clone(),
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
req_to_token_pool_records=self.req_to_token_pool_records,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
extend_seq_lens=self.extend_seq_lens,
|
||||
@@ -950,3 +955,14 @@ class ModelWorkerBatch:
|
||||
sampling_info=self.sampling_info.copy(),
|
||||
mrope_positions_delta=self.mrope_positions_delta,
|
||||
)
|
||||
|
||||
def to(self, device: str):
|
||||
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
||||
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
|
||||
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
|
||||
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
|
||||
self.req_to_token_pool_records = [
|
||||
(x, y.to(device, non_blocking=True))
|
||||
for x, y in self.req_to_token_pool_records
|
||||
]
|
||||
self.sampling_info.to(device)
|
||||
|
||||
Reference in New Issue
Block a user