Unify forward mode (#1360)
This commit is contained in:
@@ -530,11 +530,7 @@ class ModelRunner:
|
||||
):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self,
|
||||
batch,
|
||||
ForwardMode.DECODE,
|
||||
)
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -542,11 +538,7 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
)
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -562,11 +554,7 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
)
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
input_metadata.positions,
|
||||
@@ -577,16 +565,18 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||
assert batch.forward_mode is not None
|
||||
|
||||
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
||||
return self.forward_extend_multi_modal(batch)
|
||||
elif forward_mode == ForwardMode.DECODE:
|
||||
elif batch.forward_mode.is_decode():
|
||||
return self.forward_decode(batch)
|
||||
elif forward_mode == ForwardMode.EXTEND:
|
||||
elif batch.forward_mode.is_extend():
|
||||
return self.forward_extend(batch)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
||||
Reference in New Issue
Block a user