Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)

Co-authored-by: Yury Sulsky <ysulsky@tesla.com>
This commit is contained in:
Yury Sulsky
2025-05-16 12:26:15 -07:00
committed by GitHub
parent c23a7072b6
commit f19a9204cd
14 changed files with 592 additions and 125 deletions

View File

@@ -54,21 +54,17 @@ class TestSkipTokenizerInit(CustomTestCase):
):
input_ids = self.get_input_ids(prompt_text)
request = self.get_request_json(
input_ids=input_ids,
return_logprob=return_logprob,
top_logprobs_num=top_logprobs_num,
max_new_tokens=max_new_tokens,
stream=False,
n=n,
)
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [self.tokenizer.eos_token_id],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
json=request,
)
ret = response.json()
print(json.dumps(ret, indent=2))
@@ -87,9 +83,12 @@ class TestSkipTokenizerInit(CustomTestCase):
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
if return_logprob:
num_input_logprobs = len(input_ids) - request["logprob_start_len"]
if num_input_logprobs > len(input_ids):
num_input_logprobs -= len(input_ids)
self.assertEqual(
len(item["meta_info"]["input_token_logprobs"]),
len(input_ids),
num_input_logprobs,
f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
)
self.assertEqual(
@@ -113,19 +112,14 @@ class TestSkipTokenizerInit(CustomTestCase):
requests.post(self.base_url + "/flush_cache")
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": self.eos_token_id,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
json=self.get_request_json(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
return_logprob=return_logprob,
top_logprobs_num=top_logprobs_num,
stream=False,
n=n,
),
)
ret = response.json()
print(json.dumps(ret))
@@ -137,19 +131,13 @@ class TestSkipTokenizerInit(CustomTestCase):
requests.post(self.base_url + "/flush_cache")
response_stream = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": self.eos_token_id,
},
"stream": True,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
json=self.get_request_json(
input_ids=input_ids,
return_logprob=return_logprob,
top_logprobs_num=top_logprobs_num,
stream=True,
n=n,
),
)
response_stream_json = []
@@ -188,6 +176,29 @@ class TestSkipTokenizerInit(CustomTestCase):
].tolist()
return input_ids
def get_request_json(
self,
input_ids,
max_new_tokens=32,
return_logprob=False,
top_logprobs_num=0,
stream=False,
n=1,
):
return {
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": self.eos_token_id,
},
"stream": stream,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
}
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
@classmethod
@@ -218,6 +229,14 @@ class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
return inputs.input_ids[0].tolist()
def get_request_json(self, *args, **kwargs):
ret = super().get_request_json(*args, **kwargs)
ret["image_data"] = [self.image_url]
ret["logprob_start_len"] = (
-1
) # Do not try to calculate logprobs of image embeddings.
return ret
def test_simple_decode_stream(self):
# TODO mick
pass