Unify forward mode (#1360)
This commit is contained in:
@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
image_sizes: Optional[List[List[int]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
bs = input_metadata.batch_size
|
||||
# Got List[List[str]] extend it to List[str]
|
||||
# The length of the List should be equal to batch size
|
||||
@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
elif input_metadata.forward_mode.is_decode():
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
Reference in New Issue
Block a user