Adjust InputeMetadata and ScheduleBatch (#981)
This commit is contained in:
@@ -350,33 +350,18 @@ class ModelRunner:
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
position_ids_offsets=batch.position_ids_offsets,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self, batch, ForwardMode.DECODE
|
||||
)
|
||||
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
position_ids_offsets=batch.position_ids_offsets,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self, batch, forward_mode=ForwardMode.EXTEND
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -384,24 +369,16 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
position_ids_offsets=batch.position_ids_offsets,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self, batch, forward_mode=ForwardMode.EXTEND
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
batch.pixel_values,
|
||||
batch.image_sizes,
|
||||
batch.image_offsets,
|
||||
input_metadata.pixel_values,
|
||||
input_metadata.image_sizes,
|
||||
input_metadata.image_offsets,
|
||||
)
|
||||
|
||||
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
||||
|
||||
Reference in New Issue
Block a user