Unify forward mode (#1360)

This commit is contained in:
Liangsheng Yin
2024-09-09 13:49:29 -07:00
committed by GitHub
parent 689ff588ec
commit 69b3bb9ae1
9 changed files with 54 additions and 58 deletions

View File

@@ -530,11 +530,7 @@ class ModelRunner:
):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(
self,
batch,
ForwardMode.DECODE,
)
input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
@@ -542,11 +538,7 @@ class ModelRunner:
@torch.inference_mode()
def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(
self,
batch,
forward_mode=ForwardMode.EXTEND,
)
input_metadata = InputMetadata.from_schedule_batch(self, batch)
if self.is_generation:
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
@@ -562,11 +554,7 @@ class ModelRunner:
@torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(
self,
batch,
forward_mode=ForwardMode.EXTEND,
)
input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward(
batch.input_ids,
input_metadata.positions,
@@ -577,16 +565,18 @@ class ModelRunner:
)
def forward(
self, batch: ScheduleBatch, forward_mode: ForwardMode
self, batch: ScheduleBatch
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
assert batch.forward_mode is not None
if self.is_multimodal_model and batch.forward_mode.is_extend():
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE:
elif batch.forward_mode.is_decode():
return self.forward_decode(batch)
elif forward_mode == ForwardMode.EXTEND:
elif batch.forward_mode.is_extend():
return self.forward_extend(batch)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
@lru_cache()