support llava video (#426)

This commit is contained in:
Yuanhan Zhang
2024-05-14 07:57:00 +08:00
committed by GitHub
parent 5dc55a5f02
commit 0992d85f92
37 changed files with 1139 additions and 222 deletions

View File

@@ -259,6 +259,8 @@ def match_vicuna(model_path: str):
return get_chat_template("vicuna_v1.1")
if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava-next-video-7b" in model_path.lower():
return get_chat_template("vicuna_v1.1")
@register_chat_template_matching_function
@@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str):
@register_chat_template_matching_function
def match_chat_ml(model_path: str):
# import pdb;pdb.set_trace()
model_path = model_path.lower()
if "tinyllama" in model_path:
return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")
if "llava-v1.6-34b" in model_path:
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
):
return get_chat_template("chatml-llava")
@register_chat_template_matching_function
def match_chat_yi(model_path: str):
model_path = model_path.lower()
if "yi" in model_path:
if "yi" in model_path and "llava" not in model_path:
return get_chat_template("yi")

View File

@@ -28,8 +28,9 @@ from sglang.lang.ir import (
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
SglVideo,
)
from sglang.utils import encode_image_base64, get_exception_traceback
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
def run_internal(state, program, func_args, func_kwargs, sync):
@@ -361,6 +362,8 @@ class StreamExecutor:
self._execute_role_end(other)
elif isinstance(other, SglImage):
self._execute_image(other)
elif isinstance(other, SglVideo):
self._execute_video(other)
elif isinstance(other, SglVariable):
self._execute_variable(other)
elif isinstance(other, SglVarScopeBegin):
@@ -397,6 +400,16 @@ class StreamExecutor:
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
def _execute_video(self, expr: SglVideo):
path = expr.path
num_frames = expr.num_frames
base64_data = encode_video_base64(path, num_frames)
self.images_.append((path, base64_data))
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
# if global_config.eager_fill_image:
# self.backend.fill_image(self)

View File

@@ -330,6 +330,15 @@ class SglImage(SglExpr):
return f"SglImage({self.path})"
class SglVideo(SglExpr):
def __init__(self, path, num_frames):
self.path = path
self.num_frames = num_frames
def __repr__(self) -> str:
return f"SglVideo({self.path}, {self.num_frames})"
class SglGen(SglExpr):
def __init__(
self,

View File

@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
##################################
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
assert (size >= 1)
assert size >= 1
if self.only_trace_prefix:
raise StopTracing()