example: add vlm to token in & out example (#3941)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
@@ -52,7 +52,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
|
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
|
||||||
* `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
|
* `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
|
||||||
* `revision`: Adjust if a specific version of the model should be used.
|
* `revision`: Adjust if a specific version of the model should be used.
|
||||||
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out_llm.py).
|
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out/).
|
||||||
* `json_model_override_args`: Override model config with the provided JSON.
|
* `json_model_override_args`: Override model config with the provided JSON.
|
||||||
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
|
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
|
||||||
|
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ SGLang provides a direct inference engine without the need for an HTTP server. T
|
|||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
### 1. [Offline Batch Inference](./offline_batch_inference.py)
|
### [Offline Batch Inference](./offline_batch_inference.py)
|
||||||
|
|
||||||
In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors.
|
In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors.
|
||||||
|
|
||||||
### 2. [Embedding Generation](./embedding.py)
|
### [Embedding Generation](./embedding.py)
|
||||||
|
|
||||||
In this example, we launch an SGLang engine and feed a batch of inputs for embedding generation.
|
In this example, we launch an SGLang engine and feed a batch of inputs for embedding generation.
|
||||||
|
|
||||||
### 3. [Custom Server](./custom_server.py)
|
### [Custom Server](./custom_server.py)
|
||||||
|
|
||||||
This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints.
|
This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints.
|
||||||
|
|
||||||
@@ -43,3 +43,7 @@ curl -X POST http://localhost:8000/generate_stream -H "Content-Type: applicatio
|
|||||||
```
|
```
|
||||||
|
|
||||||
This will send both non-streaming and streaming requests to the server.
|
This will send both non-streaming and streaming requests to the server.
|
||||||
|
|
||||||
|
### [Token-In-Token-Out for RLHF](./token_in_token_out)
|
||||||
|
|
||||||
|
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def main():
|
|||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
print("===============================")
|
print("===============================")
|
||||||
print(f"Prompt: {prompt}\nGenerated token ids: {output['token_ids']}")
|
print(f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}")
|
||||||
|
|
||||||
|
|
||||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
from sglang import Engine
|
||||||
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.test.test_utils import DEFAULT_IMAGE_URL
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_ids(
|
||||||
|
server_args: ServerArgs, model_config: ModelConfig
|
||||||
|
) -> Tuple[list[int], list]:
|
||||||
|
chat_template = get_chat_template_by_model_path(model_config.model_path)
|
||||||
|
text = f"{chat_template.image_token}What is in this picture?"
|
||||||
|
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
|
||||||
|
image_data = [DEFAULT_IMAGE_URL]
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
model_config.model_path, trust_remote_code=server_args.trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = processor(
|
||||||
|
text=[text],
|
||||||
|
images=images,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs.input_ids[0].tolist(), image_data
|
||||||
|
|
||||||
|
|
||||||
|
def token_in_out_example(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
):
|
||||||
|
input_ids, image_data = get_input_ids(
|
||||||
|
server_args,
|
||||||
|
ModelConfig(
|
||||||
|
server_args.model_path,
|
||||||
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
model_override_args=server_args.json_model_override_args,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
backend = Engine(**dataclasses.asdict(server_args))
|
||||||
|
|
||||||
|
output = backend.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
image_data=image_data,
|
||||||
|
sampling_params={
|
||||||
|
"temperature": 0.8,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("===============================")
|
||||||
|
print(f"Output token ids: ", output["output_ids"])
|
||||||
|
|
||||||
|
backend.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
ServerArgs.add_cli_args(parser)
|
||||||
|
args = [
|
||||||
|
"--model-path=Qwen/Qwen2-VL-2B",
|
||||||
|
]
|
||||||
|
args = parser.parse_args(args=args)
|
||||||
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
server_args.skip_tokenizer_init = True
|
||||||
|
token_in_out_example(server_args)
|
||||||
@@ -40,7 +40,7 @@ class ModelConfig:
|
|||||||
trust_remote_code: bool = True,
|
trust_remote_code: bool = True,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
model_override_args: Optional[dict] = None,
|
model_override_args: Optional[str] = None,
|
||||||
is_embedding: Optional[bool] = None,
|
is_embedding: Optional[bool] = None,
|
||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
quantization: Optional[str] = None,
|
quantization: Optional[str] = None,
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchMultimodalDecodeReq,
|
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
@@ -104,7 +103,6 @@ from sglang.srt.utils import (
|
|||||||
crash_on_warnings,
|
crash_on_warnings,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_itself_when_parent_died,
|
|
||||||
pyspy_dump_schedulers,
|
pyspy_dump_schedulers,
|
||||||
set_gpu_proc_affinity,
|
set_gpu_proc_affinity,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
@@ -1199,7 +1197,6 @@ class Scheduler:
|
|||||||
self.spec_num_total_forward_ct += batch.batch_size()
|
self.spec_num_total_forward_ct += batch.batch_size()
|
||||||
self.num_generated_tokens += num_accepted_tokens
|
self.num_generated_tokens += num_accepted_tokens
|
||||||
batch.output_ids = next_token_ids
|
batch.output_ids = next_token_ids
|
||||||
|
|
||||||
# These 2 values are needed for processing the output, but the values can be
|
# These 2 values are needed for processing the output, but the values can be
|
||||||
# modified by overlap schedule. So we have to copy them here so that
|
# modified by overlap schedule. So we have to copy them here so that
|
||||||
# we can use the correct values in output processing.
|
# we can use the correct values in output processing.
|
||||||
@@ -1480,7 +1477,6 @@ class Scheduler:
|
|||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
self.current_stream.synchronize()
|
self.current_stream.synchronize()
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
self.stream_output(batch.reqs, batch.return_logprob)
|
self.stream_output(batch.reqs, batch.return_logprob)
|
||||||
|
|
||||||
self.token_to_kv_pool.free_group_end()
|
self.token_to_kv_pool.free_group_end()
|
||||||
@@ -1580,11 +1576,11 @@ class Scheduler:
|
|||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.input_top_logprobs_val = [None]
|
req.input_top_logprobs_val = [None]
|
||||||
req.input_top_logprobs_idx = [None]
|
req.input_top_logprobs_idx = [None]
|
||||||
|
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
||||||
|
req.temp_input_token_ids_logprobs_idx
|
||||||
|
)
|
||||||
for val, idx in zip(
|
for val, idx in zip(
|
||||||
req.temp_input_top_logprobs_val,
|
req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
|
||||||
req.temp_input_top_logprobs_idx,
|
|
||||||
strict=True,
|
|
||||||
):
|
):
|
||||||
req.input_top_logprobs_val.extend(val)
|
req.input_top_logprobs_val.extend(val)
|
||||||
req.input_top_logprobs_idx.extend(idx)
|
req.input_top_logprobs_idx.extend(idx)
|
||||||
@@ -1779,7 +1775,6 @@ class Scheduler:
|
|||||||
if rids:
|
if rids:
|
||||||
if self.model_config.is_multimodal_gen:
|
if self.model_config.is_multimodal_gen:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
self.send_to_detokenizer.send_pyobj(
|
self.send_to_detokenizer.send_pyobj(
|
||||||
BatchTokenIDOut(
|
BatchTokenIDOut(
|
||||||
rids,
|
rids,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import gguf
|
import gguf
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
@@ -19,7 +19,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModelForCausalLM, PretrainedConfig
|
from transformers import AutoModelForCausalLM
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
from sglang.srt.configs.device_config import DeviceConfig
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
@@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
Returns the path to the downloaded model, or None if the model is not
|
Returns the path to the downloaded model, or None if the model is not
|
||||||
downloaded from ModelScope."""
|
downloaded from ModelScope."""
|
||||||
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
|
||||||
# download model from ModelScope hub,
|
# download model from ModelScope hub,
|
||||||
# lazy import so that modelscope is not required for normal use.
|
# lazy import so that modelscope is not required for normal use.
|
||||||
# pylint: disable=C.
|
# pylint: disable=C.
|
||||||
|
|||||||
@@ -43,10 +43,15 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Ins
|
|||||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B"
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
||||||
|
|
||||||
|
DEFAULT_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||||
|
DEFAULT_VIDEO_URL = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
|
||||||
|
|
||||||
|
|
||||||
def is_in_ci():
|
def is_in_ci():
|
||||||
"""Return whether it is in CI runner."""
|
"""Return whether it is in CI runner."""
|
||||||
|
|||||||
@@ -5,13 +5,18 @@ python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_st
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from transformers import AutoTokenizer
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
|
||||||
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_IMAGE_URL,
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_SMALL_VLM_MODEL_NAME,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
@@ -29,6 +34,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=["--skip-tokenizer-init", "--stream-output"],
|
other_args=["--skip-tokenizer-init", "--stream-output"],
|
||||||
)
|
)
|
||||||
|
cls.eos_token_id = [119690]
|
||||||
cls.tokenizer = AutoTokenizer.from_pretrained(
|
cls.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
|
||||||
)
|
)
|
||||||
@@ -45,9 +51,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
top_logprobs_num=0,
|
top_logprobs_num=0,
|
||||||
n=1,
|
n=1,
|
||||||
):
|
):
|
||||||
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
|
input_ids = self.get_input_ids(prompt_text)
|
||||||
0
|
|
||||||
].tolist()
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
@@ -104,7 +108,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
|
|
||||||
def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
|
def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||||
max_new_tokens = 32
|
max_new_tokens = 32
|
||||||
input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is
|
input_ids = self.get_input_ids("The capital of France is")
|
||||||
requests.post(self.base_url + "/flush_cache")
|
requests.post(self.base_url + "/flush_cache")
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
@@ -114,7 +118,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
"temperature": 0 if n == 1 else 0.5,
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
"n": n,
|
"n": n,
|
||||||
"stop_token_ids": [119690],
|
"stop_token_ids": self.eos_token_id,
|
||||||
},
|
},
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
@@ -125,6 +129,9 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
ret = response.json()
|
ret = response.json()
|
||||||
print(json.dumps(ret))
|
print(json.dumps(ret))
|
||||||
output_ids = ret["output_ids"]
|
output_ids = ret["output_ids"]
|
||||||
|
print("output from non-streaming request:")
|
||||||
|
print(output_ids)
|
||||||
|
print(self.tokenizer.decode(output_ids, skip_special_tokens=True))
|
||||||
|
|
||||||
requests.post(self.base_url + "/flush_cache")
|
requests.post(self.base_url + "/flush_cache")
|
||||||
response_stream = requests.post(
|
response_stream = requests.post(
|
||||||
@@ -135,7 +142,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
"temperature": 0 if n == 1 else 0.5,
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
"n": n,
|
"n": n,
|
||||||
"stop_token_ids": [119690],
|
"stop_token_ids": self.eos_token_id,
|
||||||
},
|
},
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
@@ -143,13 +150,10 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
"logprob_start_len": 0,
|
"logprob_start_len": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ret = response.json()
|
|
||||||
output_ids = ret["output_ids"]
|
|
||||||
print("output from non-streaming request:")
|
|
||||||
print(output_ids)
|
|
||||||
|
|
||||||
response_stream_json = []
|
response_stream_json = []
|
||||||
for line in response_stream.iter_lines():
|
for line in response_stream.iter_lines():
|
||||||
|
print(line)
|
||||||
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
||||||
response_stream_json.append(json.loads(line[6:]))
|
response_stream_json.append(json.loads(line[6:]))
|
||||||
out_stream_ids = []
|
out_stream_ids = []
|
||||||
@@ -157,6 +161,8 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
out_stream_ids += x["output_ids"]
|
out_stream_ids += x["output_ids"]
|
||||||
print("output from streaming request:")
|
print("output from streaming request:")
|
||||||
print(out_stream_ids)
|
print(out_stream_ids)
|
||||||
|
print(self.tokenizer.decode(out_stream_ids, skip_special_tokens=True))
|
||||||
|
|
||||||
assert output_ids == out_stream_ids
|
assert output_ids == out_stream_ids
|
||||||
|
|
||||||
def test_simple_decode(self):
|
def test_simple_decode(self):
|
||||||
@@ -175,6 +181,46 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
def test_simple_decode_stream(self):
|
def test_simple_decode_stream(self):
|
||||||
self.run_decode_stream()
|
self.run_decode_stream()
|
||||||
|
|
||||||
|
def get_input_ids(self, prompt_text) -> list[int]:
|
||||||
|
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
|
||||||
|
0
|
||||||
|
].tolist()
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.image_url = DEFAULT_IMAGE_URL
|
||||||
|
response = requests.get(cls.image_url)
|
||||||
|
cls.image = Image.open(BytesIO(response.content))
|
||||||
|
cls.model = DEFAULT_SMALL_VLM_MODEL_NAME
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model, use_fast=False)
|
||||||
|
cls.processor = AutoProcessor.from_pretrained(cls.model, trust_remote_code=True)
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--skip-tokenizer-init"],
|
||||||
|
)
|
||||||
|
cls.eos_token_id = [cls.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
def get_input_ids(self, _prompt_text) -> list[int]:
|
||||||
|
chat_template = get_chat_template_by_model_path(self.model)
|
||||||
|
text = f"{chat_template.image_token}What is in this picture?"
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text],
|
||||||
|
images=[self.image],
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs.input_ids[0].tolist()
|
||||||
|
|
||||||
|
def test_simple_decode_stream(self):
|
||||||
|
# TODO mick
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user