Urgent model support: support gemma-3-it (#4424)
This commit is contained in:
@@ -33,6 +33,7 @@ from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -331,6 +332,32 @@ class ForwardBatch:
|
||||
|
||||
return ret
|
||||
|
||||
def get_merged_image_inputs(self) -> Optional[ImageInputs]:
|
||||
"""
|
||||
Merge all image inputs in the batch into a single ImageInputs object.
|
||||
|
||||
Returns:
|
||||
if none, current batch contains no image input
|
||||
|
||||
"""
|
||||
if not self.image_inputs or all(x is None for x in self.image_inputs):
|
||||
return None
|
||||
|
||||
# Filter out None values
|
||||
valid_inputs = [x for x in self.image_inputs if x is not None]
|
||||
|
||||
# Start with the first valid image input
|
||||
merged = valid_inputs[0]
|
||||
|
||||
# Merge remaining inputs
|
||||
for img_input in valid_inputs[1:]:
|
||||
merged.merge(img_input)
|
||||
|
||||
if isinstance(merged.pixel_values, np.ndarray):
|
||||
merged.pixel_values = torch.from_numpy(merged.pixel_values)
|
||||
|
||||
return merged
|
||||
|
||||
def _compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user