qwen2vl fix bug for #1971 #1897 (#1984)

This commit is contained in:
yizhang2077
2024-11-11 00:10:45 +08:00
committed by GitHub
parent 47ffe7af81
commit a8aad9357d
3 changed files with 8 additions and 14 deletions

View File

@@ -133,6 +133,7 @@ class ImageInputs:
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None
@staticmethod
def from_dict(obj, vocab_size):
@@ -251,9 +252,6 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object
# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None
@@ -983,8 +981,6 @@ class ScheduleBatch:
global bid
bid += 1
mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode,
@@ -1007,7 +1003,6 @@ class ScheduleBatch:
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta,
)
def copy(self):
@@ -1074,9 +1069,6 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo
# For Qwen2-VL
mrope_positions_delta: List[List[int]]
def copy(self):
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())