Unify the memory pool api and tp worker API (#1724)

This commit is contained in:
Lianmin Zheng
2024-10-19 23:19:26 -07:00
committed by GitHub
parent 95946271af
commit 59cbf47626
8 changed files with 87 additions and 25 deletions

View File

@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
@@ -522,12 +524,12 @@ class ScheduleBatch:
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
req.prefix_indices
self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
self.req_to_token_pool.write(
(req.req_pool_idx, slice(pre_len, seq_len)),
out_cache_loc[pt : pt + req.extend_input_len],
)
# Compute the relative logprob_start_len in an extend batch
@@ -765,9 +767,8 @@ class ScheduleBatch:
# Alloc mem
bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
self.out_cache_loc
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens.add_(1)
@@ -848,7 +849,6 @@ class ScheduleBatch:
extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]
lora_paths = [req.lora_path for req in self.reqs]
if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
self.sampling_info.regex_fsm_states = [
@@ -869,13 +869,14 @@ class ScheduleBatch:
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs,
lora_paths=lora_paths,
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta,
)
@@ -911,6 +912,9 @@ class ModelWorkerBatch:
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# The memory pool operation records
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
@@ -940,6 +944,7 @@ class ModelWorkerBatch:
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens.clone(),
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens,
@@ -950,3 +955,14 @@ class ModelWorkerBatch:
sampling_info=self.sampling_info.copy(),
mrope_positions_delta=self.mrope_positions_delta,
)
def to(self, device: str):
self.input_ids = self.input_ids.to(device, non_blocking=True)
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
self.req_to_token_pool_records = [
(x, y.to(device, non_blocking=True))
for x, y in self.req_to_token_pool_records
]
self.sampling_info.to(device)