From 4474eaf5528aa073ce5ea6dc8c4136dc2b8f7449 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Wed, 4 Jun 2025 22:08:30 -0700 Subject: [PATCH] Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (#6861) --- python/sglang/srt/lora/lora.py | 9 ++++-- python/sglang/srt/lora/mem_pool.py | 28 ++++++++++++------ test/srt/test_vision_openai_server_b.py | 30 +++++++++++++++++--- test/srt/test_vision_openai_server_common.py | 15 +++++++++- 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index a6cbc7a28..c1ebe2dcd 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -165,14 +165,19 @@ class LoRAAdapter(nn.Module): self.base_hf_config.hidden_size // self.base_hf_config.num_attention_heads ) - weights[q_name], weights[kv_name] = torch.split( + weights[q_name], k_proj_weight, v_proj_weight = torch.split( weights[qkv_name], [ head_size * self.base_hf_config.num_attention_heads, - head_size * self.base_hf_config.num_key_value_heads * 2, + head_size * self.base_hf_config.num_key_value_heads, + head_size * self.base_hf_config.num_key_value_heads, ], dim=0, ) + weights[kv_name] = torch.stack( + [k_proj_weight, v_proj_weight], + dim=0, + ) def normalize_gate_up_proj( self, weight_names: List[str], weights: Dict[str, torch.Tensor] diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 6db4e14f3..8b8d21332 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -157,6 +157,10 @@ class LoRAMemoryPool: def load_lora_weight_to_buffer( self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None ): + def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): + assert ( + buffer_view.shape == weight.shape + ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}." if uid is None: for i in range(self.num_layer): @@ -208,21 +212,27 @@ class LoRAMemoryPool: for name, weights in temp_A_buffer.items(): c = get_stacked_multiply(name) - self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_( - weights - ) + buffer_view = self.A_buffer[name][layer_id][buffer_id][ + : lora_rank * c, : + ] + check_lora_weight_shape(buffer_view, weights) + buffer_view.copy_(weights) for name, weights in temp_B_buffer.items(): c = get_stacked_multiply(name) if c > 1: for stacked_id in range(c): - self.B_buffer[name][layer_id][stacked_id][buffer_id][ - :, :lora_rank - ].copy_(weights[stacked_id]) + buffer_view = self.B_buffer[name][layer_id][stacked_id][ + buffer_id + ][:, :lora_rank] + check_lora_weight_shape(buffer_view, weights[stacked_id]) + buffer_view.copy_(weights[stacked_id]) else: - self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_( - weights - ) + buffer_view = self.B_buffer[name][layer_id][0][buffer_id][ + :, :lora_rank + ] + check_lora_weight_shape(buffer_view, weights) + buffer_view.copy_(weights) def get_tensor( self, weight_name: str, layer_id: int, lora_type: LoRAType diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index 6043dd107..2d05b0688 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -177,9 +177,19 @@ class TestKimiVLServer(TestOpenAIVisionServer): class TestPhi4MMServer(TestOpenAIVisionServer): @classmethod def setUpClass(cls): + # Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default. + from huggingface_hub import constants, snapshot_download + + snapshot_download( + "microsoft/Phi-4-multimodal-instruct", + allow_patterns=["**/adapter_config.json"], + ) + cls.model = "microsoft/Phi-4-multimodal-instruct" cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" + + revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a" cls.process = popen_launch_server( cls.model, cls.base_url, @@ -188,15 +198,27 @@ class TestPhi4MMServer(TestOpenAIVisionServer): "--trust-remote-code", "--mem-fraction-static", "0.75", + "--disable-radix-cache", + "--max-loras-per-batch", + "1", + "--revision", + revision, + "--lora-paths", + f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora", ], ) cls.base_url += "/v1" - def test_video_chat_completion(self): - pass + def get_request_kwargs(self): + return { + "extra_body": { + "lora_path": "vision", + "top_k": 1, + "top_p": 1.0, + } + } - def test_multi_images_chat_completion(self): - # TODO (lifuhuang): support LoRA to enable Phi4MM multi-image understanding capability. + def test_video_chat_completion(self): pass diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index 637345e2d..3687d9381 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -1,4 +1,5 @@ import base64 +import copy import io import json import os @@ -47,6 +48,9 @@ class TestOpenAIVisionServer(CustomTestCase): def tearDownClass(cls): kill_process_tree(cls.process.pid) + def get_request_kwargs(self): + return {} + def test_single_image_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -68,6 +72,7 @@ class TestOpenAIVisionServer(CustomTestCase): }, ], temperature=0, + **(self.get_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -130,6 +135,7 @@ class TestOpenAIVisionServer(CustomTestCase): }, ], temperature=0, + **(self.get_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -172,6 +178,7 @@ class TestOpenAIVisionServer(CustomTestCase): }, ], temperature=0, + **(self.get_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -284,6 +291,7 @@ class TestOpenAIVisionServer(CustomTestCase): temperature=0, max_tokens=1024, stream=False, + **(self.get_request_kwargs()), ) video_response = response.choices[0].message.content @@ -324,6 +332,9 @@ class TestOpenAIVisionServer(CustomTestCase): + r"""\}""" ) + extra_kwargs = self.get_request_kwargs() + extra_kwargs.setdefault("extra_body", {})["regex"] = regex + response = client.chat.completions.create( model="default", messages=[ @@ -342,7 +353,7 @@ class TestOpenAIVisionServer(CustomTestCase): }, ], temperature=0, - extra_body={"regex": regex}, + **extra_kwargs, ) text = response.choices[0].message.content @@ -388,6 +399,7 @@ class TestOpenAIVisionServer(CustomTestCase): {"role": "user", "content": content}, ], temperature=0, + **(self.get_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -430,6 +442,7 @@ class TestOpenAIVisionServer(CustomTestCase): temperature=0, max_tokens=128, stream=False, + **(self.get_request_kwargs()), ) audio_response = response.choices[0].message.content