Support precomputed_embeddings for Llama 4 (#8156)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Xiang (Kevin) Li <lik@nvidia.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -39,7 +39,11 @@ repos:
|
|||||||
- id: codespell
|
- id: codespell
|
||||||
additional_dependencies: ['tomli']
|
additional_dependencies: ['tomli']
|
||||||
args: ['--toml', 'python/pyproject.toml', '-L', 'cann']
|
args: ['--toml', 'python/pyproject.toml', '-L', 'cann']
|
||||||
exclude: test/srt/test_reasoning_parser.py # Exclude the test file that is expected to fail
|
exclude: |
|
||||||
|
(?x)^(
|
||||||
|
test/srt/test_reasoning_parser\.py|
|
||||||
|
docs/backend/vlm_query\.ipynb
|
||||||
|
)$
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v18.1.8
|
rev: v18.1.8
|
||||||
hooks:
|
hooks:
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -62,6 +62,7 @@ The core features include:
|
|||||||
backend/quantization.md
|
backend/quantization.md
|
||||||
backend/lora.ipynb
|
backend/lora.ipynb
|
||||||
backend/pd_disaggregation.md
|
backend/pd_disaggregation.md
|
||||||
|
backend/vlm_query.ipynb
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|||||||
@@ -55,6 +55,9 @@ def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
|||||||
|
|
||||||
intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
|
intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
|
||||||
|
|
||||||
|
# Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||||
|
# Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
|
||||||
|
with torch.cuda.device(tensor.device):
|
||||||
hash_kernel[grid](
|
hash_kernel[grid](
|
||||||
tensor,
|
tensor,
|
||||||
intermediate_hashes,
|
intermediate_hashes,
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|||||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||||
self.vision_config = hf_config.vision_config
|
self.vision_config = hf_config.vision_config
|
||||||
self.text_config = hf_config.text_config
|
self.text_config = hf_config.text_config
|
||||||
self.boi_token_index = hf_config.boi_token_index
|
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||||
self.eoi_token_index = hf_config.eoi_token_index
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||||
self.image_token_index = hf_config.image_token_index
|
self.IM_TOKEN_ID = hf_config.image_token_index
|
||||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
self.mm_tokens = MultimodalSpecialTokens(
|
||||||
image_token=_processor.image_token,
|
image_token=_processor.image_token,
|
||||||
image_token_id=self.image_token_index,
|
image_token_id=self.IM_TOKEN_ID,
|
||||||
).build(_processor)
|
).build(_processor)
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
@@ -37,114 +37,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if isinstance(input_text, list):
|
base_output = self.load_mm_data(
|
||||||
assert len(input_text) and isinstance(input_text[0], int)
|
|
||||||
input_text = self._processor.tokenizer.decode(input_text)
|
|
||||||
|
|
||||||
# Process images and text using the base processor's load_mm_data method
|
|
||||||
processed_data = self.load_mm_data(
|
|
||||||
prompt=input_text,
|
prompt=input_text,
|
||||||
multimodal_tokens=self.multimodal_tokens,
|
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
return_text=True,
|
multimodal_tokens=self.mm_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process the images using the processor
|
|
||||||
processor = self._processor
|
|
||||||
|
|
||||||
# Process the prompt and images
|
# Process the prompt and images
|
||||||
processor_output = self.process_mm_data(
|
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||||
input_text=processed_data.input_text,
|
base_output, self.mm_tokens
|
||||||
images=processed_data.images,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle image resolutions and aspect ratios
|
return {
|
||||||
if "pixel_values" not in processor_output: # no image processed
|
"input_ids": input_ids.tolist(),
|
||||||
return None
|
"mm_items": mm_items,
|
||||||
|
"im_start_id": self.IM_START_TOKEN_ID,
|
||||||
image_processor = processor.image_processor
|
"im_end_id": self.IM_END_TOKEN_ID,
|
||||||
tokenizer = self._processor.tokenizer
|
"im_token_id": self.IM_TOKEN_ID,
|
||||||
|
}
|
||||||
# Calculate tile size and find supported resolutions
|
|
||||||
tile_size = self.vision_config.image_size
|
|
||||||
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
|
||||||
|
|
||||||
possible_resolutions = find_supported_resolutions(
|
|
||||||
max_num_chunks=max_num_tiles,
|
|
||||||
patch_size=SizeDict(height=tile_size, width=tile_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find best fit for each image
|
|
||||||
best_fit_sizes = [
|
|
||||||
get_best_fit(
|
|
||||||
(image.size[1], image.size[0]), # (height, width)
|
|
||||||
torch.tensor(possible_resolutions),
|
|
||||||
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
|
||||||
)
|
|
||||||
for image in processed_data.images
|
|
||||||
]
|
|
||||||
|
|
||||||
# Calculate aspect ratios and patches per image
|
|
||||||
aspect_ratios = [
|
|
||||||
(image_size[0] // tile_size, image_size[1] // tile_size)
|
|
||||||
for image_size in best_fit_sizes
|
|
||||||
]
|
|
||||||
|
|
||||||
patches_per_image = [
|
|
||||||
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add to image_inputs
|
|
||||||
processor_output["aspect_ratios"] = aspect_ratios
|
|
||||||
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
|
||||||
|
|
||||||
# Process embed_is_patch
|
|
||||||
vocab = tokenizer.get_vocab()
|
|
||||||
patch_id = vocab.get(processor.img_patch_token, -1)
|
|
||||||
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
|
||||||
|
|
||||||
if patch_id != -1 and image_end_id != -1:
|
|
||||||
input_ids = processor_output["input_ids"].view(-1)
|
|
||||||
|
|
||||||
# Remove BOS token if present
|
|
||||||
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
|
||||||
input_ids = input_ids[1:]
|
|
||||||
|
|
||||||
# Find image end indices and split input_ids
|
|
||||||
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
|
||||||
|
|
||||||
if image_end_indices.size(0) > 0:
|
|
||||||
# Split at image boundaries
|
|
||||||
split_indices = (image_end_indices + 1)[:-1]
|
|
||||||
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
|
||||||
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
|
||||||
|
|
||||||
# Create embed_is_patch for each image
|
|
||||||
embed_is_patch = []
|
|
||||||
for per_image_input_ids in split_input_ids:
|
|
||||||
embed_is_patch.append(per_image_input_ids == patch_id)
|
|
||||||
|
|
||||||
processor_output["embed_is_patch"] = embed_is_patch
|
|
||||||
|
|
||||||
# Convert to the format expected by SGLang
|
|
||||||
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|
|
||||||
|
|
||||||
processor_output["im_start_id"] = self.boi_token_index
|
|
||||||
processor_output["im_end_id"] = self.eoi_token_index
|
|
||||||
processor_output["im_token_id"] = self.image_token_index
|
|
||||||
|
|
||||||
image_offsets = self.get_mm_items_offset(
|
|
||||||
input_ids=torch.tensor(processor_output["input_ids"]),
|
|
||||||
mm_token_id=self.image_token_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add metadata for image processing
|
|
||||||
processor_output["mm_items"] = [
|
|
||||||
MultimodalDataItem(
|
|
||||||
feature=processor_output["pixel_values"],
|
|
||||||
modality=Modality.IMAGE,
|
|
||||||
offsets=image_offsets,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return processor_output
|
|
||||||
|
|||||||
@@ -216,5 +216,43 @@ class TestKimiVLImageUnderstandsImage(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# not for CI: too large
|
||||||
|
# class TestLlama4ImageUnderstandsImage(
|
||||||
|
# VLMInputTestBase, unittest.IsolatedAsyncioTestCase
|
||||||
|
# ):
|
||||||
|
# model_path = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||||
|
# chat_template = "llama_4_vision"
|
||||||
|
|
||||||
|
# def setUp(self):
|
||||||
|
# self.engine = Engine(
|
||||||
|
# model_path=self.model_path,
|
||||||
|
# trust_remote_code=True,
|
||||||
|
# chat_template=self.chat_template,
|
||||||
|
# enable_multimodal=True,
|
||||||
|
# mem_fraction_static=0.8,
|
||||||
|
# tp_size=4,
|
||||||
|
# attention_backend="fa3",
|
||||||
|
# context_length=65536,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# @classmethod
|
||||||
|
# def _init_visual(cls):
|
||||||
|
# model = AutoModel.from_pretrained(cls.model_path, trust_remote_code=True, torch_dtype="auto")
|
||||||
|
# cls.vision_tower = model.vision_model.eval().to(cls.device)
|
||||||
|
# cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
|
||||||
|
|
||||||
|
# cls.visual = lambda tokenizer_output: cls.mm_projector(
|
||||||
|
# cls.vision_tower(
|
||||||
|
# pixel_values=tokenizer_output["pixel_values"],
|
||||||
|
# ).last_hidden_state.flatten(0, -2)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# def _pixel_values_image_data(self, processor_output):
|
||||||
|
# return dict(
|
||||||
|
# modality="IMAGE",
|
||||||
|
# pixel_values=processor_output["pixel_values"],
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user