Fix mixed batch for multi modal models (#1702)
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -76,7 +76,7 @@ jobs:
|
|||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 5 --range-end 16
|
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
|
||||||
|
|
||||||
unit-test-backend-part-3:
|
unit-test-backend-part-3:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -96,7 +96,7 @@ jobs:
|
|||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 16
|
python3 run_suite.py --suite minimal --range-begin 17
|
||||||
|
|
||||||
performance-test-1-gpu-part-1:
|
performance-test-1-gpu-part-1:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
|
|||||||
@@ -160,9 +160,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
image_sizes = [
|
image_sizes = [
|
||||||
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
|
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
|
||||||
]
|
]
|
||||||
image_offsets = [
|
|
||||||
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
|
|
||||||
]
|
|
||||||
|
|
||||||
########## Encode Image ########
|
########## Encode Image ########
|
||||||
|
|
||||||
@@ -358,7 +355,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
prefix_len = prefix_lens_cpu[i]
|
prefix_len = prefix_lens_cpu[i]
|
||||||
|
|
||||||
# Multiple images
|
# Multiple images
|
||||||
for j, image_offset in enumerate(image_offsets[i]):
|
for j, image_offset in enumerate(image_inputs[i].image_offsets):
|
||||||
if image_offset < prefix_len:
|
if image_offset < prefix_len:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch
|
||||||
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
@@ -288,6 +294,55 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
|||||||
assert isinstance(js_obj["color"], str)
|
assert isinstance(js_obj["color"], str)
|
||||||
assert isinstance(js_obj["number_of_cars"], int)
|
assert isinstance(js_obj["number_of_cars"], int)
|
||||||
|
|
||||||
|
def run_decode_with_image(self, image_id):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
content = []
|
||||||
|
if image_id == 0:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif image_id == 1:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this image in a very short sentence.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": content},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
|
||||||
|
def test_mixed_batch(self):
|
||||||
|
image_ids = [0, 1, 2] * 4
|
||||||
|
with ThreadPoolExecutor(4) as executor:
|
||||||
|
list(executor.map(self.run_decode_with_image, image_ids))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user