refactor: multimodal data (#4754)
This commit is contained in:
@@ -72,17 +72,38 @@ def eval_mmmu(args):
|
||||
if suffix:
|
||||
contents += [{"type": "text", "text": suffix}]
|
||||
messages = [{"role": "user", "content": contents}]
|
||||
model_inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
input_len = model_inputs["input_ids"].shape[-1]
|
||||
generation = model.generate(**model_inputs, generation_config=generation_config)
|
||||
generation = generation[0][input_len:]
|
||||
response = processor.decode(generation, skip_special_tokens=True)
|
||||
try:
|
||||
model_inputs = processor.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
input_len = model_inputs["input_ids"].shape[-1]
|
||||
generation = model.generate(
|
||||
**model_inputs, generation_config=generation_config
|
||||
)
|
||||
generation = generation[0][input_len:]
|
||||
response = processor.decode(generation, skip_special_tokens=True)
|
||||
except:
|
||||
contents = []
|
||||
if prefix:
|
||||
contents += [prefix]
|
||||
image = PIL.Image.open(sample["image_path"])
|
||||
contents += [image]
|
||||
if suffix:
|
||||
contents += [suffix]
|
||||
messages = [{"role": "user", "content": contents}]
|
||||
response = model.chat(
|
||||
msgs=messages,
|
||||
tokenizer=processor.tokenizer,
|
||||
sampling=False,
|
||||
max_new_tokens=sampling_params["max_new_tokens"],
|
||||
use_tts_template=False,
|
||||
generate_audio=False,
|
||||
temperature=0.0,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
process_result(response, sample, answer_dict, out_samples)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user