Adjust InputeMetadata and ScheduleBatch (#981)

This commit is contained in:
Liangsheng Yin
2024-08-08 01:11:22 -07:00
committed by GitHub
parent 20a4f927dc
commit 1ac304eeb4
4 changed files with 203 additions and 192 deletions

View File

@@ -350,33 +350,18 @@ class ModelRunner:
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.DECODE,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
input_metadata = InputMetadata.from_schedule_batch(
self, batch, ForwardMode.DECODE
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
@torch.inference_mode()
def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
input_metadata = InputMetadata.from_schedule_batch(
self, batch, forward_mode=ForwardMode.EXTEND
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
@@ -384,24 +369,16 @@ class ModelRunner:
@torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
input_metadata = InputMetadata.from_schedule_batch(
self, batch, forward_mode=ForwardMode.EXTEND
)
return self.model.forward(
batch.input_ids,
input_metadata.positions,
input_metadata,
batch.pixel_values,
batch.image_sizes,
batch.image_offsets,
input_metadata.pixel_values,
input_metadata.image_sizes,
input_metadata.image_offsets,
)
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):