Rename InputMetadata -> ForwardBatch (#1543)
This commit is contained in:
@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image,
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
@@ -130,12 +130,12 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
image_inputs = input_metadata.image_inputs
|
||||
image_inputs = forward_batch.image_inputs
|
||||
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
bs = input_metadata.batch_size
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
bs = forward_batch.batch_size
|
||||
# Got List[List[str]] extend it to List[str]
|
||||
# The length of the List should be equal to batch size
|
||||
modalities_list = []
|
||||
@@ -151,7 +151,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
|
||||
need_vision = start_positions <= np.array(max_image_offset)
|
||||
|
||||
if need_vision.any():
|
||||
@@ -348,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
image_features = new_image_features
|
||||
|
||||
# Fill in the placeholder for the image
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
@@ -379,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
input_ids, positions, forward_batch, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode.is_decode():
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
elif forward_batch.forward_mode.is_decode():
|
||||
return self.language_model(input_ids, positions, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# Load clip vision model by cfg['mm_vision_tower']:
|
||||
|
||||
Reference in New Issue
Block a user