Input_embeds support (#2052)
This commit is contained in:
@@ -11,21 +11,23 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
|
|||||||
class GenerateReqInput:
|
class GenerateReqInput:
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
text: Optional[Union[List[str], str]] = None
|
text: Optional[Union[List[str], str]] = None
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
# The token ids for text; one can specify either text or input_ids
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
|
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
||||||
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||||
# See also python/sglang/srt/utils.py:load_image.
|
# See also python/sglang/srt/utils.py:load_image.
|
||||||
image_data: Optional[Union[List[str], str]] = None
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
# The sampling_params. See descriptions below.
|
# The sampling_params. See descriptions below.
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
# The request id.
|
# The request id.
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Whether to return logprobs.
|
# Whether to return logprobs.
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
# The start location of the prompt for return_logprob.
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||||
# By default, this value is "-1", which means it will only return logprobs for output tokens.
|
# By default, this value is "-1", which means it will only return logprobs for output tokens.
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
# The number of top logprobs to return.
|
# If return logprobs, the number of top logprobs to return at each position.
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
# Whether to detokenize tokens in text in the returned logprobs.
|
# Whether to detokenize tokens in text in the returned logprobs.
|
||||||
return_text_in_logprobs: bool = False
|
return_text_in_logprobs: bool = False
|
||||||
|
|||||||
@@ -29,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
|||||||
class GenerateReqInput:
|
class GenerateReqInput:
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
text: Optional[Union[List[str], str]] = None
|
text: Optional[Union[List[str], str]] = None
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
# The token ids for text; one can specify either text or input_ids
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
|
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
||||||
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||||
# See also python/sglang/srt/utils.py:load_image.
|
# See also python/sglang/srt/utils.py:load_image.
|
||||||
image_data: Optional[Union[List[str], str]] = None
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
@@ -60,10 +62,16 @@ class GenerateReqInput:
|
|||||||
] = None
|
] = None
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
if (
|
||||||
self.text is not None and self.input_ids is not None
|
self.text is None and self.input_ids is None and self.input_embeds is None
|
||||||
|
) or (
|
||||||
|
self.text is not None
|
||||||
|
and self.input_ids is not None
|
||||||
|
and self.input_embeds is not None
|
||||||
):
|
):
|
||||||
raise ValueError("Either text or input_ids should be provided.")
|
raise ValueError(
|
||||||
|
"Either text, input_ids or input_embeds should be provided."
|
||||||
|
)
|
||||||
|
|
||||||
# Derive the batch size
|
# Derive the batch size
|
||||||
if self.text is not None:
|
if self.text is not None:
|
||||||
@@ -73,13 +81,21 @@ class GenerateReqInput:
|
|||||||
else:
|
else:
|
||||||
self.is_single = False
|
self.is_single = False
|
||||||
self.batch_size = len(self.text)
|
self.batch_size = len(self.text)
|
||||||
else:
|
self.input_embeds = None
|
||||||
|
elif self.input_ids is not None:
|
||||||
if isinstance(self.input_ids[0], int):
|
if isinstance(self.input_ids[0], int):
|
||||||
self.is_single = True
|
self.is_single = True
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
else:
|
else:
|
||||||
self.is_single = False
|
self.is_single = False
|
||||||
self.batch_size = len(self.input_ids)
|
self.batch_size = len(self.input_ids)
|
||||||
|
self.input_embeds = None
|
||||||
|
else:
|
||||||
|
if isinstance(self.input_embeds[0][0], float):
|
||||||
|
self.is_single = True
|
||||||
|
self.batch_size = 1
|
||||||
|
else:
|
||||||
|
self.batch_size = len(self.input_embeds)
|
||||||
|
|
||||||
# Handle parallel sampling
|
# Handle parallel sampling
|
||||||
# When parallel sampling is used, we always treat the input as a batch.
|
# When parallel sampling is used, we always treat the input as a batch.
|
||||||
@@ -202,6 +218,8 @@ class TokenizedGenerateReqInput:
|
|||||||
|
|
||||||
# LoRA related
|
# LoRA related
|
||||||
lora_path: Optional[str] = None # None means just use the base model
|
lora_path: Optional[str] = None # None means just use the base model
|
||||||
|
# The input embeds
|
||||||
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||||
|
|
||||||
# Session id info for continual prompting
|
# Session id info for continual prompting
|
||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
@@ -218,6 +236,8 @@ class EmbeddingReqInput:
|
|||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
|
# Dummy input embeds for compatibility
|
||||||
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
if (self.text is None and self.input_ids is None) or (
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ class Req:
|
|||||||
origin_input_ids: Tuple[int],
|
origin_input_ids: Tuple[int],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
lora_path: Optional[str] = None,
|
lora_path: Optional[str] = None,
|
||||||
|
input_embeds: Optional[List[List[float]]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
# Input and output info
|
# Input and output info
|
||||||
@@ -191,6 +192,7 @@ class Req:
|
|||||||
|
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.lora_path = lora_path
|
self.lora_path = lora_path
|
||||||
|
self.input_embeds = input_embeds
|
||||||
|
|
||||||
# Memory pool info
|
# Memory pool info
|
||||||
self.req_pool_idx = None
|
self.req_pool_idx = None
|
||||||
@@ -448,6 +450,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
# Batched arguments to model runner
|
# Batched arguments to model runner
|
||||||
input_ids: torch.Tensor = None
|
input_ids: torch.Tensor = None
|
||||||
|
input_embeds: torch.Tensor = None
|
||||||
req_pool_indices: torch.Tensor = None
|
req_pool_indices: torch.Tensor = None
|
||||||
seq_lens: torch.Tensor = None
|
seq_lens: torch.Tensor = None
|
||||||
# The output locations of the KV cache
|
# The output locations of the KV cache
|
||||||
@@ -631,6 +634,9 @@ class ScheduleBatch:
|
|||||||
req_pool_indices = self.alloc_req_slots(bs)
|
req_pool_indices = self.alloc_req_slots(bs)
|
||||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||||
|
|
||||||
|
input_embeds = []
|
||||||
|
|
||||||
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
already_computed = (
|
already_computed = (
|
||||||
req.extend_logprob_start_len + 1 + req.cached_tokens
|
req.extend_logprob_start_len + 1 + req.cached_tokens
|
||||||
@@ -649,6 +655,11 @@ class ScheduleBatch:
|
|||||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If input_embeds are available, store them
|
||||||
|
if req.input_embeds is not None:
|
||||||
|
# If req.input_embeds is already a list, append its content directly
|
||||||
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||||
|
|
||||||
# Compute the relative logprob_start_len in an extend batch
|
# Compute the relative logprob_start_len in an extend batch
|
||||||
if req.logprob_start_len >= pre_len:
|
if req.logprob_start_len >= pre_len:
|
||||||
extend_logprob_start_len = min(
|
extend_logprob_start_len = min(
|
||||||
@@ -671,6 +682,12 @@ class ScheduleBatch:
|
|||||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
|
self.input_embeds = (
|
||||||
|
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
||||||
|
if input_embeds
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
|
|
||||||
self.seq_lens_sum = sum(seq_lens)
|
self.seq_lens_sum = sum(seq_lens)
|
||||||
@@ -1053,6 +1070,7 @@ class ScheduleBatch:
|
|||||||
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
||||||
lora_paths=[req.lora_path for req in self.reqs],
|
lora_paths=[req.lora_path for req in self.reqs],
|
||||||
sampling_info=self.sampling_info,
|
sampling_info=self.sampling_info,
|
||||||
|
input_embeds=self.input_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@@ -1123,6 +1141,9 @@ class ModelWorkerBatch:
|
|||||||
# Sampling info
|
# Sampling info
|
||||||
sampling_info: SamplingBatchInfo
|
sampling_info: SamplingBatchInfo
|
||||||
|
|
||||||
|
# The input Embeds
|
||||||
|
input_embeds: Optional[torch.tensor] = None
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def write_req_to_token_pool_triton(
|
def write_req_to_token_pool_triton(
|
||||||
|
|||||||
@@ -526,12 +526,20 @@ class Scheduler:
|
|||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: TokenizedGenerateReqInput,
|
||||||
):
|
):
|
||||||
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
||||||
|
# Check if input_embeds is present and create dummy input_ids
|
||||||
|
if recv_req.input_embeds is not None:
|
||||||
|
# Generate fake input_ids based on the length of input_embeds
|
||||||
|
seq_length = len(recv_req.input_embeds)
|
||||||
|
fake_input_ids = [1] * seq_length
|
||||||
|
recv_req.input_ids = fake_input_ids
|
||||||
|
|
||||||
req = Req(
|
req = Req(
|
||||||
recv_req.rid,
|
recv_req.rid,
|
||||||
recv_req.input_text,
|
recv_req.input_text,
|
||||||
recv_req.input_ids,
|
recv_req.input_ids,
|
||||||
recv_req.sampling_params,
|
recv_req.sampling_params,
|
||||||
lora_path=recv_req.lora_path,
|
lora_path=recv_req.lora_path,
|
||||||
|
input_embeds=recv_req.input_embeds,
|
||||||
)
|
)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
if recv_req.session_id is not None:
|
if recv_req.session_id is not None:
|
||||||
|
|||||||
@@ -201,8 +201,18 @@ class TokenizerManager:
|
|||||||
):
|
):
|
||||||
"""Tokenize one request."""
|
"""Tokenize one request."""
|
||||||
# Tokenize
|
# Tokenize
|
||||||
|
input_embeds = None
|
||||||
input_text = obj.text
|
input_text = obj.text
|
||||||
if obj.input_ids is None:
|
if obj.input_embeds is not None:
|
||||||
|
if not self.server_args.disable_radix_cache:
|
||||||
|
raise ValueError(
|
||||||
|
"input_embeds is provided while disable_radix_cache is False. "
|
||||||
|
"Please add `--disable-radix-cach` when you launch the server "
|
||||||
|
"if you want to use input_embeds as inputs."
|
||||||
|
)
|
||||||
|
input_embeds = obj.input_embeds
|
||||||
|
input_ids = obj.input_ids
|
||||||
|
elif obj.input_ids is None:
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
else:
|
else:
|
||||||
input_ids = obj.input_ids
|
input_ids = obj.input_ids
|
||||||
@@ -219,7 +229,7 @@ class TokenizerManager:
|
|||||||
session_id = obj.session[0] if obj.session else None
|
session_id = obj.session[0] if obj.session else None
|
||||||
session_rid = obj.session[1] if obj.session else None
|
session_rid = obj.session[1] if obj.session else None
|
||||||
|
|
||||||
if len(input_ids) >= self.context_len:
|
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||||
f"model's context length ({self.context_len} tokens)."
|
f"model's context length ({self.context_len} tokens)."
|
||||||
@@ -242,7 +252,8 @@ class TokenizerManager:
|
|||||||
logprob_start_len,
|
logprob_start_len,
|
||||||
top_logprobs_num,
|
top_logprobs_num,
|
||||||
obj.stream,
|
obj.stream,
|
||||||
obj.lora_path,
|
lora_path=obj.lora_path,
|
||||||
|
input_embeds=input_embeds,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_rid=session_rid,
|
session_rid=session_rid,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -130,6 +130,9 @@ class ForwardBatch:
|
|||||||
# For LoRA
|
# For LoRA
|
||||||
lora_paths: Optional[List[str]] = None
|
lora_paths: Optional[List[str]] = None
|
||||||
|
|
||||||
|
# For input embeddings
|
||||||
|
input_embeds: Optional[torch.tensor] = None
|
||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
sampling_info: SamplingBatchInfo = None
|
sampling_info: SamplingBatchInfo = None
|
||||||
|
|
||||||
@@ -231,6 +234,7 @@ class ForwardBatch:
|
|||||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||||
lora_paths=batch.lora_paths,
|
lora_paths=batch.lora_paths,
|
||||||
sampling_info=batch.sampling_info,
|
sampling_info=batch.sampling_info,
|
||||||
|
input_embeds=batch.input_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ret.global_num_tokens is not None:
|
if ret.global_num_tokens is not None:
|
||||||
|
|||||||
@@ -606,9 +606,17 @@ class ModelRunner:
|
|||||||
def forward_extend(self, forward_batch: ForwardBatch):
|
def forward_extend(self, forward_batch: ForwardBatch):
|
||||||
self.attn_backend.init_forward_metadata(forward_batch)
|
self.attn_backend.init_forward_metadata(forward_batch)
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
return self.model.forward(
|
if forward_batch.input_embeds is None:
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
return self.model.forward(
|
||||||
)
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.model.forward(
|
||||||
|
forward_batch.input_ids,
|
||||||
|
forward_batch.positions,
|
||||||
|
forward_batch,
|
||||||
|
input_embeds=forward_batch.input_embeds.bfloat16(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Only embedding models have get_embedding parameter
|
# Only embedding models have get_embedding parameter
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ suites = {
|
|||||||
"test_double_sparsity.py",
|
"test_double_sparsity.py",
|
||||||
"test_embedding_openai_server.py",
|
"test_embedding_openai_server.py",
|
||||||
"test_eval_accuracy_mini.py",
|
"test_eval_accuracy_mini.py",
|
||||||
|
"test_input_embeddings.py",
|
||||||
"test_json_constrained.py",
|
"test_json_constrained.py",
|
||||||
"test_large_max_new_tokens.py",
|
"test_large_max_new_tokens.py",
|
||||||
"test_metrics.py",
|
"test_metrics.py",
|
||||||
|
|||||||
114
test/srt/test_input_embeddings.py
Normal file
114
test/srt/test_input_embeddings.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInputEmbeds(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
|
||||||
|
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model)
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--disable-radix"],
|
||||||
|
)
|
||||||
|
cls.texts = [
|
||||||
|
"The capital of France is",
|
||||||
|
"What is the best time of year to visit Japan for cherry blossoms?",
|
||||||
|
]
|
||||||
|
|
||||||
|
def generate_input_embeddings(self, text):
|
||||||
|
"""Generate input embeddings for a given text."""
|
||||||
|
input_ids = self.tokenizer(text, return_tensors="pt")["input_ids"]
|
||||||
|
embeddings = self.ref_model.get_input_embeddings()(input_ids)
|
||||||
|
return embeddings.squeeze().tolist() # Convert tensor to a list for API use
|
||||||
|
|
||||||
|
def send_request(self, payload):
|
||||||
|
"""Send a POST request to the API and return the response."""
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json=payload,
|
||||||
|
timeout=30, # Set a reasonable timeout for the API request
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
return {
|
||||||
|
"error": f"Request failed with status {response.status_code}: {response.text}"
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_text_based_response(self):
|
||||||
|
"""Print API response using text-based input."""
|
||||||
|
for text in self.texts:
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"text": text,
|
||||||
|
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
|
||||||
|
}
|
||||||
|
response = self.send_request(payload)
|
||||||
|
print(
|
||||||
|
f"Text Input: {text}\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_embedding_based_response(self):
|
||||||
|
"""Print API response using input embeddings."""
|
||||||
|
for text in self.texts:
|
||||||
|
embeddings = self.generate_input_embeddings(text)
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input_embeds": embeddings,
|
||||||
|
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
|
||||||
|
}
|
||||||
|
response = self.send_request(payload)
|
||||||
|
print(
|
||||||
|
f"Embeddings Input (for text '{text}'):\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_compare_text_vs_embedding(self):
|
||||||
|
"""Print responses for both text-based and embedding-based inputs."""
|
||||||
|
for text in self.texts:
|
||||||
|
# Text-based payload
|
||||||
|
text_payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"text": text,
|
||||||
|
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
|
||||||
|
}
|
||||||
|
# Embedding-based payload
|
||||||
|
embeddings = self.generate_input_embeddings(text)
|
||||||
|
embed_payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input_embeds": embeddings,
|
||||||
|
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
|
||||||
|
}
|
||||||
|
# Get responses
|
||||||
|
text_response = self.send_request(text_payload)
|
||||||
|
embed_response = self.send_request(embed_payload)
|
||||||
|
# Print responses
|
||||||
|
print(
|
||||||
|
f"Text Input: {text}\nText-Based Response: {json.dumps(text_response, indent=2)}\n"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Embeddings Input (for text '{text}'):\nEmbedding-Based Response: {json.dumps(embed_response, indent=2)}\n{'-' * 80}"
|
||||||
|
)
|
||||||
|
self.assertEqual(text_response["text"], embed_response["text"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user