Fix hash collision for multi modal models (#2256)
This commit is contained in:
@@ -147,6 +147,11 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
else:
|
||||
max_image_offset.append(-1)
|
||||
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
|
||||
@@ -597,13 +597,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
||||
`None` if no images are passed.
|
||||
"""
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
image_inputs = None
|
||||
if forward_batch.image_inputs is not None:
|
||||
image_inputs = [
|
||||
img for img in forward_batch.image_inputs if img is not None
|
||||
]
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or image_inputs is None
|
||||
@@ -617,6 +619,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
)
|
||||
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
|
||||
Reference in New Issue
Block a user