diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 03bb1cb..99a565b 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -25,21 +25,32 @@ on HuggingFace model repository. """ import os +import argparse + +from vllm.assets.audio import AudioAsset +try: + import librosa +except ImportError: + raise Exception("Can't import librosa, please ensure it's installed") from vllm import LLM, SamplingParams -from vllm.assets.audio import AudioAsset os.environ["VLLM_USE_MODELSCOPE"] = "True" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" -audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] -question_per_audio_count = { - 1: "What is recited in the audio?", - 2: "What sport and what nursery rhyme are referenced?" -} +def prepare_inputs(audio_count: int, audio_path1: str, audio_path2: str): + use_vllm_audio_assert = True if audio_path1 == "mary_had_lamb" and audio_path2 == "winning_call" else False + if use_vllm_audio_assert: + audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] + else: + audio_assets = [librosa.load(audio_path1, sr=None), librosa.load(audio_path2, sr=None)] + + question_per_audio_count = { + 1: "What is recited in the audio?", + 2: "What sport and what nursery rhyme are referenced?" + } -def prepare_inputs(audio_count: int): audio_in_prompt = "".join([ f"Audio {idx+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) @@ -52,7 +63,7 @@ def prepare_inputs(audio_count: int): mm_data = { "audio": - [asset.audio_and_sample_rate for asset in audio_assets[:audio_count]] + audio_assets if not use_vllm_audio_assert else [asset.audio_and_sample_rate for asset in audio_assets[:audio_count]] } # Merge text prompt and audio data into inputs @@ -60,7 +71,7 @@ def prepare_inputs(audio_count: int): return inputs -def main(audio_count: int): +def main(audio_count: int, audio_path1: str, audio_path2: str): # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on # lower-end GPUs. # Unless specified, these settings have been tested to work on a single L4. @@ -71,7 +82,7 @@ def main(audio_count: int): limit_mm_per_prompt={"audio": audio_count}, enforce_eager=True) - inputs = prepare_inputs(audio_count) + inputs = prepare_inputs(audio_count, audio_path1, audio_path2) sampling_params = SamplingParams(temperature=0.2, max_tokens=64, @@ -81,9 +92,14 @@ def main(audio_count: int): for o in outputs: generated_text = o.outputs[0].text - print(generated_text) + print("generated_text:", generated_text) if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Arguments of rank table generator", ) + parser.add_argument("--audio-path1", type=str, default="mary_had_lamb") + parser.add_argument("--audio-path2", type=str, default="winning_call") + args = parser.parse_args() + audio_count = 2 - main(audio_count) + main(audio_count, args.audio_path1, args.audio_path2)