[feat] Enable chunked prefill for llava-onevision (#2412)

This commit is contained in:
Ying Sheng
2024-12-09 09:52:38 -08:00
committed by GitHub
parent 641b7d0ae0
commit 8586b72da0
5 changed files with 222 additions and 20 deletions

View File

@@ -129,6 +129,7 @@ class ImageInputs:
image_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
image_pad_len: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None
num_image_tokens: Optional[int] = None
@@ -181,6 +182,7 @@ class ImageInputs:
optional_args = [
"image_sizes",
"image_offsets",
"image_pad_len",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",

View File

@@ -111,17 +111,20 @@ class ModelRunner:
)
if self.is_multimodal:
server_args.chunked_prefill_size = -1
self.mem_fraction_static *= 0.95
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"and turn off chunked prefill "
f"because this is a multimodal model."
)
if self.model_config.hf_config.architectures == [
"MllamaForConditionalGeneration"
]:
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
server_args.chunked_prefill_size = -1
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
]:
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
# Global vars

View File

@@ -57,6 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
else:
image_aspect_ratio = "anyres"
offset_list = []
image_inputs.image_pad_len = []
for image_idx, image_s in enumerate(image_sizes):
if len(image_sizes) > 16:
# 2x2 pooling with stride 2
@@ -103,6 +104,7 @@ class LlavaBaseForCausalLM(nn.Module):
+ input_ids[offset + 1 :]
)
offset_list.append(offset)
image_inputs.image_pad_len.append(new_image_feature_len)
image_inputs.image_offsets = offset_list
return input_ids
@@ -134,6 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
image_inputs = forward_batch.image_inputs
if forward_batch.forward_mode.is_extend():
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
@@ -142,18 +152,12 @@ class LlavaBaseForCausalLM(nn.Module):
if im and im.modalities is not None:
modalities_list.extend(im.modalities)
if im and im.image_offsets:
max_image_offset.append(max(im.image_offsets))
max_image_offset.append(
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
)
else:
max_image_offset.append(-1)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset)
@@ -350,6 +354,7 @@ class LlavaBaseForCausalLM(nn.Module):
# Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0
for i in range(bs):
@@ -357,18 +362,36 @@ class LlavaBaseForCausalLM(nn.Module):
continue
start_idx = extend_start_loc_cpu[i]
seq_len = extend_seq_lens[i]
prefix_len = prefix_lens_cpu[i]
# Multiple images
for j, image_offset in enumerate(image_inputs[i].image_offsets):
if image_offset < prefix_len:
for image_idx, image_offset in enumerate(
image_inputs[i].image_offsets
):
if (
image_offset + image_inputs[i].image_pad_len[image_idx]
<= prefix_len
):
continue
if image_offset >= prefix_len + seq_len:
break
tmp_image_feature = image_features[pt][j]
tmp_image_feature = image_features[pt][image_idx]
pad_len = tmp_image_feature.shape[0]
left_idx = start_idx + (image_offset - prefix_len)
right_idx = start_idx + (image_offset - prefix_len) + pad_len
input_offset = image_offset - prefix_len
left_idx = start_idx + input_offset
right_idx = left_idx + pad_len
assert right_idx > start_idx
if input_offset < 0:
left_idx = start_idx
tmp_image_feature = tmp_image_feature[-input_offset:]
if right_idx > start_idx + seq_len:
tmp_image_feature = tmp_image_feature[
: start_idx + seq_len - right_idx
]
right_idx = start_idx + seq_len
try:
input_embeds[left_idx:right_idx] = tmp_image_feature
except RuntimeError as e: