Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (#6861)
This commit is contained in:
@@ -165,14 +165,19 @@ class LoRAAdapter(nn.Module):
|
||||
self.base_hf_config.hidden_size
|
||||
// self.base_hf_config.num_attention_heads
|
||||
)
|
||||
weights[q_name], weights[kv_name] = torch.split(
|
||||
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
|
||||
weights[qkv_name],
|
||||
[
|
||||
head_size * self.base_hf_config.num_attention_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads * 2,
|
||||
head_size * self.base_hf_config.num_key_value_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
weights[kv_name] = torch.stack(
|
||||
[k_proj_weight, v_proj_weight],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
def normalize_gate_up_proj(
|
||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||
|
||||
@@ -157,6 +157,10 @@ class LoRAMemoryPool:
|
||||
def load_lora_weight_to_buffer(
|
||||
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
|
||||
):
|
||||
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
|
||||
assert (
|
||||
buffer_view.shape == weight.shape
|
||||
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
|
||||
|
||||
if uid is None:
|
||||
for i in range(self.num_layer):
|
||||
@@ -208,21 +212,27 @@ class LoRAMemoryPool:
|
||||
|
||||
for name, weights in temp_A_buffer.items():
|
||||
c = get_stacked_multiply(name)
|
||||
self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
|
||||
weights
|
||||
)
|
||||
buffer_view = self.A_buffer[name][layer_id][buffer_id][
|
||||
: lora_rank * c, :
|
||||
]
|
||||
check_lora_weight_shape(buffer_view, weights)
|
||||
buffer_view.copy_(weights)
|
||||
|
||||
for name, weights in temp_B_buffer.items():
|
||||
c = get_stacked_multiply(name)
|
||||
if c > 1:
|
||||
for stacked_id in range(c):
|
||||
self.B_buffer[name][layer_id][stacked_id][buffer_id][
|
||||
:, :lora_rank
|
||||
].copy_(weights[stacked_id])
|
||||
buffer_view = self.B_buffer[name][layer_id][stacked_id][
|
||||
buffer_id
|
||||
][:, :lora_rank]
|
||||
check_lora_weight_shape(buffer_view, weights[stacked_id])
|
||||
buffer_view.copy_(weights[stacked_id])
|
||||
else:
|
||||
self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
|
||||
weights
|
||||
)
|
||||
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
|
||||
:, :lora_rank
|
||||
]
|
||||
check_lora_weight_shape(buffer_view, weights)
|
||||
buffer_view.copy_(weights)
|
||||
|
||||
def get_tensor(
|
||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||
|
||||
@@ -177,9 +177,19 @@ class TestKimiVLServer(TestOpenAIVisionServer):
|
||||
class TestPhi4MMServer(TestOpenAIVisionServer):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default.
|
||||
from huggingface_hub import constants, snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct",
|
||||
allow_patterns=["**/adapter_config.json"],
|
||||
)
|
||||
|
||||
cls.model = "microsoft/Phi-4-multimodal-instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
@@ -188,15 +198,27 @@ class TestPhi4MMServer(TestOpenAIVisionServer):
|
||||
"--trust-remote-code",
|
||||
"--mem-fraction-static",
|
||||
"0.75",
|
||||
"--disable-radix-cache",
|
||||
"--max-loras-per-batch",
|
||||
"1",
|
||||
"--revision",
|
||||
revision,
|
||||
"--lora-paths",
|
||||
f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
pass
|
||||
def get_request_kwargs(self):
|
||||
return {
|
||||
"extra_body": {
|
||||
"lora_path": "vision",
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
def test_multi_images_chat_completion(self):
|
||||
# TODO (lifuhuang): support LoRA to enable Phi4MM multi-image understanding capability.
|
||||
def test_video_chat_completion(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@@ -47,6 +48,9 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def get_request_kwargs(self):
|
||||
return {}
|
||||
|
||||
def test_single_image_chat_completion(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
@@ -68,6 +72,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
**(self.get_request_kwargs()),
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
@@ -130,6 +135,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
**(self.get_request_kwargs()),
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
@@ -172,6 +178,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
**(self.get_request_kwargs()),
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
@@ -284,6 +291,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
stream=False,
|
||||
**(self.get_request_kwargs()),
|
||||
)
|
||||
|
||||
video_response = response.choices[0].message.content
|
||||
@@ -324,6 +332,9 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
extra_kwargs = self.get_request_kwargs()
|
||||
extra_kwargs.setdefault("extra_body", {})["regex"] = regex
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
@@ -342,7 +353,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
extra_body={"regex": regex},
|
||||
**extra_kwargs,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
@@ -388,6 +399,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
temperature=0,
|
||||
**(self.get_request_kwargs()),
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
@@ -430,6 +442,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
stream=False,
|
||||
**(self.get_request_kwargs()),
|
||||
)
|
||||
|
||||
audio_response = response.choices[0].message.content
|
||||
|
||||
Reference in New Issue
Block a user