Unify forward mode (#1360)

This commit is contained in:
Liangsheng Yin
2024-09-09 13:49:29 -07:00
committed by GitHub
parent 689ff588ec
commit 69b3bb9ae1
9 changed files with 54 additions and 58 deletions

View File

@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module):
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
elif input_metadata.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):