[Feat/Fix] Refactoring Llava models into single file (#475)

This commit is contained in:
Li Bo
2024-05-27 03:29:51 +08:00
committed by GitHub
parent 947bda73fe
commit 2b605ab1d7
6 changed files with 118 additions and 688 deletions

View File

@@ -22,11 +22,7 @@ import aiohttp
import requests
from llava.conversation import (
default_conversation,
conv_templates,
SeparatorStyle,
conv_llava_llama_3,
conv_qwen,
)
@@ -43,7 +39,8 @@ async def test_concurrent(args):
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_llava_llama_3)
conv_template.append_message(role="user", message=prompt)
conv_template.append_message(role=conv_template.roles[0], message=prompt)
conv_template.append_message(role=conv_template.roles[1], message=None)
prompt_with_template = conv_template.get_prompt()
response = []
for i in range(1):
@@ -74,7 +71,8 @@ def test_streaming(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_llava_llama_3)
conv_template.append_message(role="user", message=prompt)
conv_template.append_message(role=conv_template.roles[0], message=prompt)
conv_template.append_message(role=conv_template.roles[1], message=None)
prompt_with_template = conv_template.get_prompt()
pload = {
"text": prompt_with_template,

View File

@@ -22,11 +22,7 @@ import aiohttp
import requests
from llava.conversation import (
default_conversation,
conv_templates,
SeparatorStyle,
conv_llava_llama_3,
conv_qwen,
conv_qwen
)
@@ -43,7 +39,8 @@ async def test_concurrent(args):
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_qwen)
conv_template.append_message(role="user", message=prompt)
conv_template.append_message(role=conv_template.roles[0], message=prompt)
conv_template.append_message(role=conv_template.roles[1], message=None)
prompt_with_template = conv_template.get_prompt()
response = []
for i in range(1):
@@ -74,7 +71,8 @@ def test_streaming(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_qwen)
conv_template.append_message(role="user", message=prompt)
conv_template.append_message(role=conv_template.roles[0], message=prompt)
conv_template.append_message(role=conv_template.roles[1], message=None)
prompt_with_template = conv_template.get_prompt()
pload = {
"text": prompt_with_template,
@@ -113,5 +111,5 @@ if __name__ == "__main__":
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
# asyncio.run(test_concurrent(args))
asyncio.run(test_concurrent(args))
test_streaming(args)