support llava video (#426)
This commit is contained in:
@@ -259,6 +259,8 @@ def match_vicuna(model_path: str):
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava-v1.5" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava-next-video-7b" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
@@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str):
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
# import pdb;pdb.set_trace()
|
||||
model_path = model_path.lower()
|
||||
if "tinyllama" in model_path:
|
||||
return get_chat_template("chatml")
|
||||
if "qwen" in model_path and "chat" in model_path:
|
||||
return get_chat_template("chatml")
|
||||
if "llava-v1.6-34b" in model_path:
|
||||
if (
|
||||
"llava-v1.6-34b" in model_path
|
||||
or "llava-v1.6-yi-34b" in model_path
|
||||
or "llava-next-video-34b" in model_path
|
||||
):
|
||||
return get_chat_template("chatml-llava")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_yi(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "yi" in model_path:
|
||||
if "yi" in model_path and "llava" not in model_path:
|
||||
return get_chat_template("yi")
|
||||
|
||||
|
||||
|
||||
@@ -28,8 +28,9 @@ from sglang.lang.ir import (
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
SglVideo,
|
||||
)
|
||||
from sglang.utils import encode_image_base64, get_exception_traceback
|
||||
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
|
||||
|
||||
|
||||
def run_internal(state, program, func_args, func_kwargs, sync):
|
||||
@@ -361,6 +362,8 @@ class StreamExecutor:
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglImage):
|
||||
self._execute_image(other)
|
||||
elif isinstance(other, SglVideo):
|
||||
self._execute_video(other)
|
||||
elif isinstance(other, SglVariable):
|
||||
self._execute_variable(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
@@ -397,6 +400,16 @@ class StreamExecutor:
|
||||
self.cur_images.append((path, base64_data))
|
||||
self.text_ += self.chat_template.image_token
|
||||
|
||||
def _execute_video(self, expr: SglVideo):
|
||||
path = expr.path
|
||||
num_frames = expr.num_frames
|
||||
|
||||
base64_data = encode_video_base64(path, num_frames)
|
||||
|
||||
self.images_.append((path, base64_data))
|
||||
self.cur_images.append((path, base64_data))
|
||||
self.text_ += self.chat_template.image_token
|
||||
|
||||
# if global_config.eager_fill_image:
|
||||
# self.backend.fill_image(self)
|
||||
|
||||
|
||||
@@ -330,6 +330,15 @@ class SglImage(SglExpr):
|
||||
return f"SglImage({self.path})"
|
||||
|
||||
|
||||
class SglVideo(SglExpr):
|
||||
def __init__(self, path, num_frames):
|
||||
self.path = path
|
||||
self.num_frames = num_frames
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SglVideo({self.path}, {self.num_frames})"
|
||||
|
||||
|
||||
class SglGen(SglExpr):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
|
||||
##################################
|
||||
|
||||
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||
assert (size >= 1)
|
||||
assert size >= 1
|
||||
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
|
||||
Reference in New Issue
Block a user