diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f7d55ed9b..c4daa8a07 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -145,15 +145,17 @@ class ImageInputs: # Use image hash as fake token_ids, which is then used for prefix matching ret = ImageInputs( pixel_values=obj["pixel_values"], - image_hashes=hash(tuple(obj["image_hashes"])), + image_hashes=obj["image_hashes"], ) - image_hash = ret.image_hashes - ret.pad_values = [ - (image_hash) % vocab_size, - (image_hash >> 16) % vocab_size, - (image_hash >> 32) % vocab_size, - (image_hash >> 64) % vocab_size, - ] + if not isinstance(ret.image_hashes, list): + ret.pad_values = [ + (ret.image_hashes) % vocab_size, + (ret.image_hashes >> 16) % vocab_size, + (ret.image_hashes >> 32) % vocab_size, + (ret.image_hashes >> 64) % vocab_size, + ] + else: + ret.pad_values = [x % vocab_size for x in ret.image_hashes] optional_args = [ "image_sizes", @@ -171,14 +173,18 @@ class ImageInputs: def merge(self, other, vocab_size): assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:] self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values]) - self.image_hashes += other.image_hashes - self.pad_values = [ - (self.image_hashes) % vocab_size, - (self.image_hashes >> 16) % vocab_size, - (self.image_hashes >> 32) % vocab_size, - (self.image_hashes >> 64) % vocab_size, - ] + if isinstance(self.image_hashes, list) and isinstance(other.image_hashes, list): + self.image_hashes += other.image_hashes + self.pad_values = [x % vocab_size for x in self.image_hashes] + else: + self.image_hashes = hash(tuple(self.image_hashes, other.image_hashes)) + self.pad_values = [ + (self.image_hashes) % vocab_size, + (self.image_hashes >> 16) % vocab_size, + (self.image_hashes >> 32) % vocab_size, + (self.image_hashes >> 64) % vocab_size, + ] optional_args = [ "image_sizes", diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index b07474ad9..514c7c1bd 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -57,7 +57,7 @@ class LlavaBaseForCausalLM(nn.Module): else: image_aspect_ratio = "anyres" offset_list = [] - for image_s in image_sizes: + for image_idx, image_s in enumerate(image_sizes): if len(image_sizes) > 16: # 2x2 pooling with stride 2 new_image_feature_len = ( @@ -92,10 +92,6 @@ class LlavaBaseForCausalLM(nn.Module): new_w = int(new_w // times) new_image_feature_len += new_h * (new_w + 1) - pad_ids = pad_values * ( - (new_image_feature_len + len(pad_values)) // len(pad_values) - ) - # print("calculated new_image_feature_len: ", new_image_feature_len) try: offset = input_ids.index(self.config.image_token_index) except ValueError: @@ -103,7 +99,7 @@ class LlavaBaseForCausalLM(nn.Module): # old_len + pad_len - 1, because we need to remove image_token_id input_ids = ( input_ids[:offset] - + pad_ids[:new_image_feature_len] + + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :] ) offset_list.append(offset) diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 3d3876243..0258ed332 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -500,7 +500,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): return num_image_tokens # Use grid_t * grid_w * grid_h to pad tokens for each image - # and replaced padding by unique image hash + # add replaced padding by unique image hash def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): image_grid_thws = image_inputs.image_grid_thws pad_values = image_inputs.pad_values diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py index 8558b4249..47169aeaa 100644 --- a/test/srt/test_session_control.py +++ b/test/srt/test_session_control.py @@ -301,6 +301,8 @@ class TestSessionControlVision(unittest.TestCase): assert response["meta_info"]["finish_reason"]["type"] == "abort" # 2. not use session control + requests.post(self.base_url + "/flush_cache") + input_ids_first_req = None input_ids = [] outputs_normal = []