init ascend tts
This commit is contained in:
177
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/README.md
Normal file
177
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/README.md
Normal file
@@ -0,0 +1,177 @@
|
||||
# Inference
|
||||
|
||||
The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts.
|
||||
|
||||
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
|
||||
|
||||
Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.
|
||||
|
||||
To avoid possible inference failures, make sure you have seen through the following instructions.
|
||||
|
||||
- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
|
||||
- <ins>Uppercased letters</ins> (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words.
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
|
||||
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
|
||||
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), <ins>check for FFmpeg installation</ins>.
|
||||
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
|
||||
|
||||
|
||||
## Gradio App
|
||||
|
||||
Currently supported features:
|
||||
|
||||
- Basic TTS with Chunk Inference
|
||||
- Multi-Style / Multi-Speaker Generation
|
||||
- Voice Chat powered by Qwen2.5-3B-Instruct
|
||||
- [Custom inference with more language support](SHARED.md)
|
||||
|
||||
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
|
||||
|
||||
The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
|
||||
|
||||
More flags options:
|
||||
|
||||
```bash
|
||||
# Automatically launch the interface in the default web browser
|
||||
f5-tts_infer-gradio --inbrowser
|
||||
|
||||
# Set the root path of the application, if it's not served from the root ("/") of the domain
|
||||
# For example, if the application is served at "https://example.com/myapp"
|
||||
f5-tts_infer-gradio --root_path "/myapp"
|
||||
```
|
||||
|
||||
Could also be used as a component for larger application:
|
||||
```python
|
||||
import gradio as gr
|
||||
from f5_tts.infer.infer_gradio import app
|
||||
|
||||
with gr.Blocks() as main_app:
|
||||
gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
|
||||
|
||||
# ... other Gradio components
|
||||
|
||||
app.render()
|
||||
|
||||
main_app.launch()
|
||||
```
|
||||
|
||||
|
||||
## CLI Inference
|
||||
|
||||
The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference.
|
||||
|
||||
The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`.
|
||||
|
||||
For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file.
|
||||
|
||||
Basically you can inference with flags:
|
||||
```bash
|
||||
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
||||
f5-tts_infer-cli \
|
||||
--model F5TTS_v1_Base \
|
||||
--ref_audio "ref_audio.wav" \
|
||||
--ref_text "The content, subtitle or transcription of reference audio." \
|
||||
--gen_text "Some text you want TTS model generate for you."
|
||||
|
||||
# Use BigVGAN as vocoder. Currently only support F5TTS_Base.
|
||||
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
|
||||
|
||||
# Use custom path checkpoint, e.g.
|
||||
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors
|
||||
|
||||
# More instructions
|
||||
f5-tts_infer-cli --help
|
||||
```
|
||||
|
||||
And a `.toml` file would help with more flexible usage.
|
||||
|
||||
```bash
|
||||
f5-tts_infer-cli -c custom.toml
|
||||
```
|
||||
|
||||
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
|
||||
|
||||
```toml
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = "Some call me nature, others call me mother nature."
|
||||
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
||||
# File with text to generate. Ignores the text above.
|
||||
gen_file = ""
|
||||
remove_silence = false
|
||||
output_dir = "tests"
|
||||
```
|
||||
|
||||
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
|
||||
|
||||
```toml
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/multi/main.flac"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = ""
|
||||
gen_text = ""
|
||||
# File with text to generate. Ignores the text above.
|
||||
gen_file = "infer/examples/multi/story.txt"
|
||||
remove_silence = true
|
||||
output_dir = "tests"
|
||||
|
||||
[voices.town]
|
||||
ref_audio = "infer/examples/multi/town.flac"
|
||||
ref_text = ""
|
||||
|
||||
[voices.country]
|
||||
ref_audio = "infer/examples/multi/country.flac"
|
||||
ref_text = ""
|
||||
```
|
||||
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
||||
|
||||
## API Usage
|
||||
|
||||
```python
|
||||
from importlib.resources import files
|
||||
from f5_tts.api import F5TTS
|
||||
|
||||
f5tts = F5TTS()
|
||||
wav, sr, spec = f5tts.infer(
|
||||
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
||||
ref_text="some call me nature, others call me mother nature.",
|
||||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
||||
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
||||
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
||||
seed=None,
|
||||
)
|
||||
```
|
||||
Check [api.py](../api.py) for more details.
|
||||
|
||||
## TensorRT-LLM Deployment
|
||||
|
||||
See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.
|
||||
|
||||
## Socket Real-time Service
|
||||
|
||||
Real-time voice output with chunk stream:
|
||||
|
||||
```bash
|
||||
# Start socket server
|
||||
python src/f5_tts/socket_server.py
|
||||
|
||||
# If PyAudio not installed
|
||||
sudo apt-get install portaudio19-dev
|
||||
pip install pyaudio
|
||||
|
||||
# Communicate with socket client
|
||||
python src/f5_tts/socket_client.py
|
||||
```
|
||||
|
||||
## Speech Editing
|
||||
|
||||
To test speech editing capabilities, use the following command:
|
||||
|
||||
```bash
|
||||
python src/f5_tts/infer/speech_edit.py
|
||||
```
|
||||
|
||||
193
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/SHARED.md
Normal file
193
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/SHARED.md
Normal file
@@ -0,0 +1,193 @@
|
||||
<!-- omit in toc -->
|
||||
# Shared Model Cards
|
||||
|
||||
<!-- omit in toc -->
|
||||
### **Prerequisites of using**
|
||||
- This document is serving as a quick lookup table for the community training/finetuning result, with various language support.
|
||||
- The models in this repository are open source and are based on voluntary contributions from contributors.
|
||||
- The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts.
|
||||
|
||||
<!-- omit in toc -->
|
||||
### **Welcome to share here**
|
||||
- Have a pretrained/finetuned result: model checkpoint (pruned best to facilitate inference, i.e. leave only `ema_model_state_dict`) and corresponding vocab file (for tokenization).
|
||||
- Host a public [huggingface model repository](https://huggingface.co/new) and upload the model related files.
|
||||
- Make a pull request adding a model card to the current page, i.e. `src\f5_tts\infer\SHARED.md`.
|
||||
|
||||
<!-- omit in toc -->
|
||||
### Supported Languages
|
||||
- [Multilingual](#multilingual)
|
||||
- [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
|
||||
- [English](#english)
|
||||
- [Finnish](#finnish)
|
||||
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
||||
- [French](#french)
|
||||
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
|
||||
- [German](#german)
|
||||
- [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
|
||||
- [Hindi](#hindi)
|
||||
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
|
||||
- [Italian](#italian)
|
||||
- [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79)
|
||||
- [Japanese](#japanese)
|
||||
- [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica)
|
||||
- [Mandarin](#mandarin)
|
||||
- [Russian](#russian)
|
||||
- [F5-TTS Base @ ru @ HotDro4illa](#f5-tts-base--ru--hotdro4illa)
|
||||
- [Spanish](#spanish)
|
||||
- [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar)
|
||||
|
||||
|
||||
## Multilingual
|
||||
|
||||
#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
|
||||
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
|
||||
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
```
|
||||
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
|
||||
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
||||
|
||||
|
||||
## English
|
||||
|
||||
|
||||
## Finnish
|
||||
|
||||
#### F5-TTS Base @ fi @ AsmoKoskinen
|
||||
|Model|🤗Hugging Face|Data|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
|
||||
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
|
||||
## French
|
||||
|
||||
#### F5-TTS Base @ fr @ RASPIAUDIO
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
|
||||
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
||||
- [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
|
||||
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
|
||||
|
||||
|
||||
## German
|
||||
|
||||
#### F5-TTS Base @ de @ hvoss-techfak
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
|
||||
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)
|
||||
|
||||
|
||||
## Hindi
|
||||
|
||||
#### F5-TTS Small @ hi @ SPRINGLab
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
||||
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
||||
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Authors: SPRING Lab, Indian Institute of Technology, Madras
|
||||
- Website: https://asr.iitm.ac.in/
|
||||
|
||||
|
||||
## Italian
|
||||
|
||||
#### F5-TTS Base @ it @ alien79
|
||||
|Model|🤗Hugging Face|Data|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
|
||||
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
||||
- Model details on [hf project home](https://huggingface.co/alien79/F5-TTS-italian)
|
||||
- Open to collaborations to further improve the model
|
||||
|
||||
|
||||
## Japanese
|
||||
|
||||
#### F5-TTS Base @ ja @ Jmica
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_21999120)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://Jmica/F5TTS/JA_21999120/model_21999120.pt
|
||||
Vocab: hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
|
||||
## Mandarin
|
||||
|
||||
|
||||
## Russian
|
||||
|
||||
#### F5-TTS Base @ ru @ HotDro4illa
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hotstone228/F5-TTS-Russian)|[Common voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0)|cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
|
||||
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
|
||||
- Any improvements are welcome
|
||||
|
||||
|
||||
## Spanish
|
||||
|
||||
#### F5-TTS Base @ es @ jpgallegoar
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
|
||||
|
||||
- @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
|
||||
@@ -0,0 +1,11 @@
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = "Some call me nature, others call me mother nature."
|
||||
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
||||
# File with text to generate. Ignores the text above.
|
||||
gen_file = ""
|
||||
remove_silence = false
|
||||
output_dir = "tests"
|
||||
output_file = "infer_cli_basic.wav"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,20 @@
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/multi/main.flac"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = ""
|
||||
gen_text = ""
|
||||
# File with text to generate. Ignores the text above.
|
||||
gen_file = "infer/examples/multi/story.txt"
|
||||
remove_silence = true
|
||||
output_dir = "tests"
|
||||
output_file = "infer_cli_story.wav"
|
||||
|
||||
[voices.town]
|
||||
ref_audio = "infer/examples/multi/town.flac"
|
||||
ref_text = ""
|
||||
|
||||
[voices.country]
|
||||
ref_audio = "infer/examples/multi/country.flac"
|
||||
ref_text = ""
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”
|
||||
Binary file not shown.
2545
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/examples/vocab.txt
Normal file
2545
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/examples/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
368
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/infer_cli.py
Normal file
368
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/infer_cli.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import argparse
|
||||
import codecs
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import tomli
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
cfg_strength,
|
||||
cross_fade_duration,
|
||||
device,
|
||||
fix_duration,
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
mel_spec_type,
|
||||
nfe_step,
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
speed,
|
||||
sway_sampling_coef,
|
||||
target_rms,
|
||||
)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="python3 infer-cli.py",
|
||||
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
|
||||
epilog="Specify options above to override one or more settings from config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
type=str,
|
||||
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
|
||||
help="The configuration file, default see infer/examples/basic/basic.toml",
|
||||
)
|
||||
|
||||
|
||||
# Note. Not to provide default value here in order to read default from config file
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
type=str,
|
||||
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-mc",
|
||||
"--model_cfg",
|
||||
type=str,
|
||||
help="The path to F5-TTS model config file .yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--ckpt_file",
|
||||
type=str,
|
||||
help="The path to model checkpoint .pt, leave blank to use default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--vocab_file",
|
||||
type=str,
|
||||
help="The path to vocab file .txt, leave blank to use default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--ref_audio",
|
||||
type=str,
|
||||
help="The reference audio file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--ref_text",
|
||||
type=str,
|
||||
help="The transcript/subtitle for the reference audio",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--gen_text",
|
||||
type=str,
|
||||
help="The text to make model synthesize a speech",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--gen_file",
|
||||
type=str,
|
||||
help="The file with text to generate, will ignore --gen_text",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
type=str,
|
||||
help="The path to output folder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--output_file",
|
||||
type=str,
|
||||
help="The name of output file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_chunk",
|
||||
action="store_true",
|
||||
help="To save each audio chunks during inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_silence",
|
||||
action="store_true",
|
||||
help="To remove long silence found in ouput",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_vocoder_from_local",
|
||||
action="store_true",
|
||||
help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocoder_name",
|
||||
type=str,
|
||||
choices=["vocos", "bigvgan"],
|
||||
help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_rms",
|
||||
type=float,
|
||||
help=f"Target output speech loudness normalization value, default {target_rms}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cross_fade_duration",
|
||||
type=float,
|
||||
help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nfe_step",
|
||||
type=int,
|
||||
help=f"The number of function evaluation (denoising steps), default {nfe_step}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cfg_strength",
|
||||
type=float,
|
||||
help=f"Classifier-free guidance strength, default {cfg_strength}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sway_sampling_coef",
|
||||
type=float,
|
||||
help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speed",
|
||||
type=float,
|
||||
help=f"The speed of the generated audio, default {speed}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fix_duration",
|
||||
type=float,
|
||||
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
help="Specify the device to run on",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# config file
|
||||
|
||||
config = tomli.load(open(args.config, "rb"))
|
||||
|
||||
|
||||
# command-line interface parameters
|
||||
|
||||
model = args.model or config.get("model", "F5TTS_v1_Base")
|
||||
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
||||
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
||||
|
||||
ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
|
||||
ref_text = (
|
||||
args.ref_text
|
||||
if args.ref_text is not None
|
||||
else config.get("ref_text", "Some call me nature, others call me mother nature.")
|
||||
)
|
||||
gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
|
||||
gen_file = args.gen_file or config.get("gen_file", "")
|
||||
|
||||
output_dir = args.output_dir or config.get("output_dir", "tests")
|
||||
output_file = args.output_file or config.get(
|
||||
"output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
|
||||
)
|
||||
|
||||
save_chunk = args.save_chunk or config.get("save_chunk", False)
|
||||
remove_silence = args.remove_silence or config.get("remove_silence", False)
|
||||
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
|
||||
|
||||
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
|
||||
target_rms = args.target_rms or config.get("target_rms", target_rms)
|
||||
cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
|
||||
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
|
||||
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
||||
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
||||
speed = args.speed or config.get("speed", speed)
|
||||
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
||||
device = args.device or config.get("device", device)
|
||||
|
||||
|
||||
# patches for pip pkg user
|
||||
if "infer/examples/" in ref_audio:
|
||||
ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
|
||||
if "infer/examples/" in gen_file:
|
||||
gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
|
||||
if "voices" in config:
|
||||
for voice in config["voices"]:
|
||||
voice_ref_audio = config["voices"][voice]["ref_audio"]
|
||||
if "infer/examples/" in voice_ref_audio:
|
||||
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
|
||||
|
||||
|
||||
# ignore gen_text if gen_file provided
|
||||
|
||||
if gen_file:
|
||||
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
||||
|
||||
|
||||
# output path
|
||||
|
||||
wave_path = Path(output_dir) / output_file
|
||||
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
||||
if save_chunk:
|
||||
output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
|
||||
if not os.path.exists(output_chunk_dir):
|
||||
os.makedirs(output_chunk_dir)
|
||||
|
||||
|
||||
# load vocoder
|
||||
|
||||
if vocoder_name == "vocos":
|
||||
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
||||
elif vocoder_name == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
|
||||
)
|
||||
|
||||
|
||||
# load TTS model
|
||||
|
||||
model_cfg = OmegaConf.load(
|
||||
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
)
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
||||
|
||||
if model != "F5TTS_Base":
|
||||
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
|
||||
|
||||
# override for previous models
|
||||
if model == "F5TTS_Base":
|
||||
if vocoder_name == "vocos":
|
||||
ckpt_step = 1200000
|
||||
elif vocoder_name == "bigvgan":
|
||||
model = "F5TTS_Base_bigvgan"
|
||||
ckpt_type = "pt"
|
||||
elif model == "E2TTS_Base":
|
||||
repo_name = "E2-TTS"
|
||||
ckpt_step = 1200000
|
||||
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
||||
|
||||
print(f"Using {model}...")
|
||||
ema_model = load_model(
|
||||
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
||||
)
|
||||
|
||||
|
||||
# inference process
|
||||
|
||||
|
||||
def main():
|
||||
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
||||
if "voices" not in config:
|
||||
voices = {"main": main_voice}
|
||||
else:
|
||||
voices = config["voices"]
|
||||
voices["main"] = main_voice
|
||||
for voice in voices:
|
||||
print("Voice:", voice)
|
||||
print("ref_audio ", voices[voice]["ref_audio"])
|
||||
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
||||
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
||||
)
|
||||
print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
|
||||
|
||||
generated_audio_segments = []
|
||||
reg1 = r"(?=\[\w+\])"
|
||||
chunks = re.split(reg1, gen_text)
|
||||
reg2 = r"\[(\w+)\]"
|
||||
for text in chunks:
|
||||
if not text.strip():
|
||||
continue
|
||||
match = re.match(reg2, text)
|
||||
if match:
|
||||
voice = match[1]
|
||||
else:
|
||||
print("No voice tag found, using main.")
|
||||
voice = "main"
|
||||
if voice not in voices:
|
||||
print(f"Voice {voice} not found, using main.")
|
||||
voice = "main"
|
||||
text = re.sub(reg2, "", text)
|
||||
ref_audio_ = voices[voice]["ref_audio"]
|
||||
ref_text_ = voices[voice]["ref_text"]
|
||||
gen_text_ = text.strip()
|
||||
print(f"Voice: {voice}")
|
||||
audio_segment, final_sample_rate, spectrogram = infer_process(
|
||||
ref_audio_,
|
||||
ref_text_,
|
||||
gen_text_,
|
||||
ema_model,
|
||||
vocoder,
|
||||
mel_spec_type=vocoder_name,
|
||||
target_rms=target_rms,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
nfe_step=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
)
|
||||
generated_audio_segments.append(audio_segment)
|
||||
|
||||
if save_chunk:
|
||||
if len(gen_text_) > 200:
|
||||
gen_text_ = gen_text_[:200] + " ... "
|
||||
sf.write(
|
||||
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
|
||||
audio_segment,
|
||||
final_sample_rate,
|
||||
)
|
||||
|
||||
if generated_audio_segments:
|
||||
final_wave = np.concatenate(generated_audio_segments)
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
with open(wave_path, "wb") as f:
|
||||
sf.write(f.name, final_wave, final_sample_rate)
|
||||
# Remove silence
|
||||
if remove_silence:
|
||||
remove_silence_for_generated_wav(f.name)
|
||||
print(f.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1121
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/infer_gradio.py
Normal file
1121
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/infer_gradio.py
Normal file
File diff suppressed because it is too large
Load Diff
205
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/speech_edit.py
Normal file
205
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/speech_edit.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import os
|
||||
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "xpu"
|
||||
if torch.xpu.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------- infer setting ---------------------- #
|
||||
|
||||
seed = None # int | None
|
||||
|
||||
exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
|
||||
ckpt_step = 1250000
|
||||
|
||||
nfe_step = 32 # 16, 32
|
||||
cfg_strength = 2.0
|
||||
ode_method = "euler" # euler | midpoint
|
||||
sway_sampling_coef = -1.0
|
||||
speed = 1.0
|
||||
target_rms = 0.1
|
||||
|
||||
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
dataset_name = model_cfg.datasets.name
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
||||
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
|
||||
hop_length = model_cfg.model.mel_spec.hop_length
|
||||
win_length = model_cfg.model.mel_spec.win_length
|
||||
n_fft = model_cfg.model.mel_spec.n_fft
|
||||
|
||||
|
||||
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
output_dir = "tests"
|
||||
|
||||
|
||||
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
||||
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
||||
# [write the origin_text into a file, e.g. tests/test_edit.txt]
|
||||
# ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
|
||||
# [result will be saved at same path of audio file]
|
||||
# [--language "zho" for Chinese, "eng" for English]
|
||||
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
|
||||
|
||||
audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
|
||||
origin_text = "Some call me nature, others call me mother nature."
|
||||
target_text = "Some call me optimist, others call me realist."
|
||||
parts_to_edit = [
|
||||
[1.42, 2.44],
|
||||
[4.04, 4.9],
|
||||
] # stard_ends of "nature" & "mother nature", in seconds
|
||||
fix_duration = [
|
||||
1.2,
|
||||
1,
|
||||
] # fix duration for "optimist" & "realist", in seconds
|
||||
|
||||
# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
|
||||
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
|
||||
# target_text = "对,那就是你,万人敬仰的太白金星。"
|
||||
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
|
||||
# fix_duration = None # use origin text duration
|
||||
|
||||
|
||||
# -------------------------------------------------#
|
||||
|
||||
use_ema = True
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Vocoder model
|
||||
local = False
|
||||
if mel_spec_type == "vocos":
|
||||
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
elif mel_spec_type == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
|
||||
|
||||
# Tokenizer
|
||||
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||
|
||||
# Audio
|
||||
audio, sr = torchaudio.load(audio_to_edit)
|
||||
if audio.shape[0] > 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||
if rms < target_rms:
|
||||
audio = audio * target_rms / rms
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
offset = 0
|
||||
audio_ = torch.zeros(1, 0)
|
||||
edit_mask = torch.zeros(1, 0, dtype=torch.bool)
|
||||
for part in parts_to_edit:
|
||||
start, end = part
|
||||
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
||||
part_dur = part_dur * target_sample_rate
|
||||
start = start * target_sample_rate
|
||||
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
|
||||
edit_mask = torch.cat(
|
||||
(
|
||||
edit_mask,
|
||||
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
|
||||
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
offset = end * target_sample_rate
|
||||
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
|
||||
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
|
||||
audio = audio.to(device)
|
||||
edit_mask = edit_mask.to(device)
|
||||
|
||||
# Text
|
||||
text_list = [target_text]
|
||||
if tokenizer == "pinyin":
|
||||
final_text_list = convert_char_to_pinyin(text_list)
|
||||
else:
|
||||
final_text_list = [text_list]
|
||||
print(f"text : {text_list}")
|
||||
print(f"pinyin: {final_text_list}")
|
||||
|
||||
# Duration
|
||||
ref_audio_len = 0
|
||||
duration = audio.shape[-1] // hop_length
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, trajectory = model.sample(
|
||||
cond=audio,
|
||||
text=final_text_list,
|
||||
duration=duration,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
seed=seed,
|
||||
edit_mask=edit_mask,
|
||||
)
|
||||
print(f"Generated mel: {generated.shape}")
|
||||
|
||||
# Final result
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
gen_mel_spec = generated.permute(0, 2, 1)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
||||
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
|
||||
print(f"Generated wav: {generated_wave.shape}")
|
||||
610
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/utils_infer.py
Normal file
610
ascend_910-f5-tts/F5-TTS/src/f5_tts/infer/utils_infer.py
Normal file
@@ -0,0 +1,610 @@
|
||||
# A unified script for inference process
|
||||
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import tempfile
|
||||
from importlib.resources import files
|
||||
|
||||
import matplotlib
|
||||
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import tqdm
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pydub import AudioSegment, silence
|
||||
from transformers import pipeline
|
||||
from vocos import Vocos
|
||||
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
|
||||
_ref_audio_cache = {}
|
||||
_ref_text_cache = {}
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "xpu"
|
||||
if torch.xpu.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
|
||||
|
||||
# -----------------------------------------
|
||||
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
mel_spec_type = "vocos"
|
||||
target_rms = 0.1
|
||||
cross_fade_duration = 0.15
|
||||
ode_method = "euler"
|
||||
nfe_step = 32 # 16, 32
|
||||
cfg_strength = 2.0
|
||||
sway_sampling_coef = -1.0
|
||||
speed = 1.0
|
||||
fix_duration = None
|
||||
|
||||
# -----------------------------------------
|
||||
|
||||
|
||||
# chunk text into smaller pieces
|
||||
|
||||
|
||||
def chunk_text(text, max_chars=135):
|
||||
"""
|
||||
Splits the input text into chunks, each with a maximum number of characters.
|
||||
|
||||
Args:
|
||||
text (str): The text to be split.
|
||||
max_chars (int): The maximum number of characters per chunk.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of text chunks.
|
||||
"""
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
# Split the text into sentences based on punctuation followed by whitespace
|
||||
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
||||
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
||||
else:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
# load vocoder
|
||||
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
|
||||
if vocoder_name == "vocos":
|
||||
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
||||
if is_local:
|
||||
print(f"Load vocos from local path {local_path}")
|
||||
config_path = f"{local_path}/config.yaml"
|
||||
model_path = f"{local_path}/pytorch_model.bin"
|
||||
else:
|
||||
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
||||
repo_id = "charactr/vocos-mel-24khz"
|
||||
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
||||
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
||||
vocoder = Vocos.from_hparams(config_path)
|
||||
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
||||
from vocos.feature_extractors import EncodecFeatures
|
||||
|
||||
if isinstance(vocoder.feature_extractor, EncodecFeatures):
|
||||
encodec_parameters = {
|
||||
"feature_extractor.encodec." + key: value
|
||||
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
|
||||
}
|
||||
state_dict.update(encodec_parameters)
|
||||
vocoder.load_state_dict(state_dict)
|
||||
vocoder = vocoder.eval().to(device)
|
||||
elif vocoder_name == "bigvgan":
|
||||
try:
|
||||
from third_party.BigVGAN import bigvgan
|
||||
except ImportError:
|
||||
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
||||
if is_local:
|
||||
# download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
|
||||
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
||||
else:
|
||||
vocoder = bigvgan.BigVGAN.from_pretrained(
|
||||
"nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir
|
||||
)
|
||||
|
||||
vocoder.remove_weight_norm()
|
||||
vocoder = vocoder.eval().to(device)
|
||||
return vocoder
|
||||
|
||||
|
||||
# load asr pipeline
|
||||
|
||||
asr_pipe = None
|
||||
|
||||
|
||||
def initialize_asr_pipeline(device: str = device, dtype=None):
|
||||
if dtype is None:
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
global asr_pipe
|
||||
asr_pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
torch_dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
# transcribe
|
||||
|
||||
|
||||
def transcribe(ref_audio, language=None):
|
||||
global asr_pipe
|
||||
if asr_pipe is None:
|
||||
initialize_asr_pipeline(device=device)
|
||||
return asr_pipe(
|
||||
ref_audio,
|
||||
chunk_length_s=30,
|
||||
batch_size=128,
|
||||
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
|
||||
return_timestamps=False,
|
||||
)["text"].strip()
|
||||
|
||||
|
||||
# load model checkpoint for inference
|
||||
|
||||
|
||||
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
||||
if dtype is None:
|
||||
print(f'device: {device}', flush=True)
|
||||
try:
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
except Exception as e:
|
||||
# print(f"Error determining dtype: {e}", flush=True)
|
||||
dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
model = model.to(dtype)
|
||||
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
checkpoint = load_file(ckpt_path, device=device)
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
|
||||
if use_ema:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {"ema_model_state_dict": checkpoint}
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
}
|
||||
|
||||
# patch for backward compatibility, 305e3ea
|
||||
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
||||
if key in checkpoint["model_state_dict"]:
|
||||
del checkpoint["model_state_dict"][key]
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
else:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {"model_state_dict": checkpoint}
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
del checkpoint
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model.to(device)
|
||||
|
||||
|
||||
# load model for inference
|
||||
|
||||
|
||||
def load_model(
|
||||
model_cls,
|
||||
model_cfg,
|
||||
ckpt_path,
|
||||
mel_spec_type=mel_spec_type,
|
||||
vocab_file="",
|
||||
ode_method=ode_method,
|
||||
use_ema=True,
|
||||
device=device,
|
||||
):
|
||||
if vocab_file == "":
|
||||
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
|
||||
tokenizer = "custom"
|
||||
|
||||
print("\nvocab : ", vocab_file)
|
||||
print("token : ", tokenizer)
|
||||
print("model : ", ckpt_path, "\n")
|
||||
|
||||
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def remove_silence_edges(audio, silence_threshold=-42):
|
||||
# Remove silence from the start
|
||||
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
|
||||
audio = audio[non_silent_start_idx:]
|
||||
|
||||
# Remove silence from the end
|
||||
non_silent_end_duration = audio.duration_seconds
|
||||
for ms in reversed(audio):
|
||||
if ms.dBFS > silence_threshold:
|
||||
break
|
||||
non_silent_end_duration -= 0.001
|
||||
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
|
||||
|
||||
return trimmed_audio
|
||||
|
||||
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
||||
show_info("Converting audio...")
|
||||
|
||||
# Compute a hash of the reference audio file
|
||||
with open(ref_audio_orig, "rb") as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
audio_hash = hashlib.md5(audio_data).hexdigest()
|
||||
|
||||
global _ref_audio_cache
|
||||
|
||||
if audio_hash in _ref_audio_cache:
|
||||
show_info("Using cached preprocessed reference audio...")
|
||||
ref_audio = _ref_audio_cache[audio_hash]
|
||||
|
||||
else: # first pass, do preprocess
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
|
||||
temp_path = f.name
|
||||
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
|
||||
# 1. try to find long silence for clipping
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
# 2. try to find short silence for clipping if 1. failed
|
||||
if len(non_silent_wave) > 12000:
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
aseg = non_silent_wave
|
||||
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
|
||||
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
||||
aseg.export(temp_path, format="wav")
|
||||
ref_audio = temp_path
|
||||
|
||||
# Cache the processed reference audio
|
||||
_ref_audio_cache[audio_hash] = ref_audio
|
||||
|
||||
if not ref_text.strip():
|
||||
global _ref_text_cache
|
||||
if audio_hash in _ref_text_cache:
|
||||
# Use cached asr transcription
|
||||
show_info("Using cached reference text...")
|
||||
ref_text = _ref_text_cache[audio_hash]
|
||||
else:
|
||||
show_info("No reference text provided, transcribing reference audio...")
|
||||
ref_text = transcribe(ref_audio)
|
||||
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
|
||||
_ref_text_cache[audio_hash] = ref_text
|
||||
else:
|
||||
show_info("Using custom reference text...")
|
||||
|
||||
# Ensure ref_text ends with a proper sentence-ending punctuation
|
||||
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
||||
if ref_text.endswith("."):
|
||||
ref_text += " "
|
||||
else:
|
||||
ref_text += ". "
|
||||
|
||||
print("\nref_text ", ref_text)
|
||||
|
||||
return ref_audio, ref_text
|
||||
|
||||
|
||||
# infer process: chunk text -> infer batches [i.e. infer_batch_process()]
|
||||
|
||||
|
||||
def infer_process(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
gen_text,
|
||||
model_obj,
|
||||
vocoder,
|
||||
mel_spec_type=mel_spec_type,
|
||||
show_info=print,
|
||||
progress=tqdm,
|
||||
target_rms=target_rms,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
nfe_step=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
):
|
||||
# Split the input text into batches
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f"gen_text {i}", gen_text)
|
||||
print("\n")
|
||||
|
||||
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
||||
return next(
|
||||
infer_batch_process(
|
||||
(audio, sr),
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
mel_spec_type=mel_spec_type,
|
||||
progress=progress,
|
||||
target_rms=target_rms,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
nfe_step=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# infer batches
|
||||
|
||||
|
||||
def infer_batch_process(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
mel_spec_type="vocos",
|
||||
progress=tqdm,
|
||||
target_rms=0.1,
|
||||
cross_fade_duration=0.15,
|
||||
nfe_step=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
speed=1,
|
||||
fix_duration=None,
|
||||
device=None,
|
||||
streaming=False,
|
||||
chunk_size=2048,
|
||||
):
|
||||
audio, sr = ref_audio
|
||||
if audio.shape[0] > 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
|
||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||
if rms < target_rms:
|
||||
audio = audio * target_rms / rms
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
audio = audio.to(device)
|
||||
|
||||
generated_waves = []
|
||||
spectrograms = []
|
||||
|
||||
if len(ref_text[-1].encode("utf-8")) == 1:
|
||||
ref_text = ref_text + " "
|
||||
|
||||
def process_batch(gen_text):
|
||||
local_speed = speed
|
||||
if len(gen_text.encode("utf-8")) < 10:
|
||||
local_speed = 0.3
|
||||
|
||||
# Prepare the text
|
||||
text_list = [ref_text + gen_text]
|
||||
final_text_list = convert_char_to_pinyin(text_list)
|
||||
|
||||
ref_audio_len = audio.shape[-1] // hop_length
|
||||
if fix_duration is not None:
|
||||
duration = int(fix_duration * target_sample_rate / hop_length)
|
||||
else:
|
||||
# Calculate duration
|
||||
ref_text_len = len(ref_text.encode("utf-8"))
|
||||
gen_text_len = len(gen_text.encode("utf-8"))
|
||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
|
||||
|
||||
# inference
|
||||
with torch.inference_mode():
|
||||
generated, _ = model_obj.sample(
|
||||
cond=audio,
|
||||
text=final_text_list,
|
||||
duration=duration,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
)
|
||||
del _
|
||||
|
||||
generated = generated.to(torch.float32) # generated mel spectrogram
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated = generated.permute(0, 2, 1)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(generated)
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(generated)
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
# wav -> numpy
|
||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
||||
|
||||
if streaming:
|
||||
for j in range(0, len(generated_wave), chunk_size):
|
||||
yield generated_wave[j : j + chunk_size], target_sample_rate
|
||||
else:
|
||||
generated_cpu = generated[0].cpu().numpy()
|
||||
del generated
|
||||
yield generated_wave, generated_cpu
|
||||
|
||||
if streaming:
|
||||
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
||||
for chunk in process_batch(gen_text):
|
||||
yield chunk
|
||||
else:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
|
||||
for future in progress.tqdm(futures) if progress is not None else futures:
|
||||
result = future.result()
|
||||
if result:
|
||||
generated_wave, generated_mel_spec = next(result)
|
||||
generated_waves.append(generated_wave)
|
||||
spectrograms.append(generated_mel_spec)
|
||||
|
||||
if generated_waves:
|
||||
if cross_fade_duration <= 0:
|
||||
# Simply concatenate
|
||||
final_wave = np.concatenate(generated_waves)
|
||||
else:
|
||||
# Combine all generated waves with cross-fading
|
||||
final_wave = generated_waves[0]
|
||||
for i in range(1, len(generated_waves)):
|
||||
prev_wave = final_wave
|
||||
next_wave = generated_waves[i]
|
||||
|
||||
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
||||
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
||||
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
||||
|
||||
if cross_fade_samples <= 0:
|
||||
# No overlap possible, concatenate
|
||||
final_wave = np.concatenate([prev_wave, next_wave])
|
||||
continue
|
||||
|
||||
# Overlapping parts
|
||||
prev_overlap = prev_wave[-cross_fade_samples:]
|
||||
next_overlap = next_wave[:cross_fade_samples]
|
||||
|
||||
# Fade out and fade in
|
||||
fade_out = np.linspace(1, 0, cross_fade_samples)
|
||||
fade_in = np.linspace(0, 1, cross_fade_samples)
|
||||
|
||||
# Cross-faded overlap
|
||||
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
||||
|
||||
# Combine
|
||||
new_wave = np.concatenate(
|
||||
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
||||
)
|
||||
|
||||
final_wave = new_wave
|
||||
|
||||
# Create a combined spectrogram
|
||||
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
||||
|
||||
yield final_wave, target_sample_rate, combined_spectrogram
|
||||
|
||||
else:
|
||||
yield None, target_sample_rate, None
|
||||
|
||||
|
||||
# remove silence from generated wav
|
||||
|
||||
|
||||
def remove_silence_for_generated_wav(filename):
|
||||
aseg = AudioSegment.from_file(filename)
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
non_silent_wave += non_silent_seg
|
||||
aseg = non_silent_wave
|
||||
aseg.export(filename, format="wav")
|
||||
|
||||
|
||||
# save spectrogram
|
||||
|
||||
|
||||
def save_spectrogram(spectrogram, path):
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.imshow(spectrogram, origin="lower", aspect="auto")
|
||||
plt.colorbar()
|
||||
plt.savefig(path)
|
||||
plt.close()
|
||||
Reference in New Issue
Block a user