This commit is contained in:
2025-08-20 17:05:36 +08:00
parent 6ae2fb5868
commit 82140eff2c
259 changed files with 505375 additions and 251 deletions

View File

@@ -1,10 +1,13 @@
FROM corex:3.2.1
WORKDIR /workspace
COPY GPT-SoVITS constraints_gsv.txt gsv_server.py launch_gsv.sh /workspace/
RUN pip install -r GPT-SOVITS/extra-req.txt --no-deps \
&& pip install -r GPT-SoVITS/requirements.txt -c constraints_gsv.txt \
COPY constraints_gsv.txt launch_gsv.sh /workspace/
COPY GPT-SoVITS /workspace/GPT-SoVITS
RUN pip install -r GPT-SoVITS/extra-req.txt --no-deps \
&& pip install -r GPT-SoVITS/requirements.txt -c constraints_gsv.txt \
&& apt update \
&& apt install -y ffmpeg libsox-dev
&& apt install -y ffmpeg libsox-dev redis
RUN pip install redis
COPY wav /workspace/wav
ENTRYPOINT ["/bin/bash", "launch_gsv.sh"]

View File

@@ -0,0 +1,198 @@
GPT_SoVITS/pretrained_models/*
tools/asr/models/*
tools/uvr5/uvr5_weights/*
.git
.DS_Store
.vscode
*.pyc
env
runtime
.idea
output
logs
SoVITS_weights*/
GPT_weights*/
TEMP
weight.json
ffmpeg*
ffprobe*
cfg.json
speakers.json
ref_audios
# Byte-compiled / optimized / DLL files
__pycache__/
**/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc

195
bi_v100-gpt-sovits/GPT-SoVITS/.gitignore vendored Normal file
View File

@@ -0,0 +1,195 @@
.DS_Store
.vscode
__pycache__
*.pyc
env
runtime
.idea
output
logs
SoVITS_weights*/
GPT_weights*/
TEMP
weight.json
ffmpeg*
ffprobe*
cfg.json
speakers.json
ref_audios
tools/AP_BWE_main/24kto48k/*
!tools/AP_BWE_main/24kto48k/readme.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc

View File

@@ -0,0 +1,15 @@
ci:
autoupdate_schedule: monthly
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
hooks:
# Run the linter.
- id: ruff
types_or: [ python, pyi ]
args: [ --fix , "--exit-zero" ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
args: [ --line-length, "120", --target-version, "py310" ]

View File

@@ -0,0 +1,191 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-Inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GPT-SoVITS Infer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Env Setup (Run Once Only)\n",
"## 环境配置, 只需运行一次"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e9b7iFV3dm1f"
},
"outputs": [],
"source": [
"%%writefile /content/setup.sh\n",
"set -e\n",
"\n",
"cd /content\n",
"\n",
"git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
"\n",
"cd GPT-SoVITS\n",
"\n",
"mkdir -p GPT_weights\n",
"\n",
"mkdir -p SoVITS_weights\n",
"\n",
"if conda env list | awk '{print $1}' | grep -Fxq \"GPTSoVITS\"; then\n",
" :\n",
"else\n",
" conda create -n GPTSoVITS python=3.10 -y\n",
"fi\n",
"\n",
"source activate GPTSoVITS\n",
"\n",
"pip install ipykernel\n",
"\n",
"bash install.sh --device CU126 --source HF"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "0NgxXg5sjv7z"
},
"outputs": [],
"source": [
"%pip install -q condacolab\n",
"import condacolab\n",
"condacolab.install_from_url(\"https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh\")\n",
"!cd /content && bash setup.sh"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Download Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download From HuggingFace"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vbZY-LnM0tzq"
},
"outputs": [],
"source": [
"# Modify These\n",
"USER_ID = \"AkitoP\"\n",
"REPO_NAME = \"GPT-SoVITS-v2-aegi\"\n",
"BRANCH = \"main\"\n",
"GPT_PATH = \"new_aegigoe-e100.ckpt\"\n",
"SOVITS_PATH = \"new_aegigoe_e60_s32220.pth\"\n",
"\n",
"# Do Not Modify\n",
"HF_BASE = \"https://huggingface.co\"\n",
"REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
"GPT_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{GPT_PATH}\"\n",
"SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{SOVITS_PATH}\"\n",
"\n",
"!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
"!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download From ModelScope"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Modify These\n",
"USER_ID = \"aihobbyist\"\n",
"REPO_NAME = \"GPT-SoVits-V2-models\"\n",
"BRANCH = \"master\"\n",
"GPT_PATH = \"Genshin_Impact/EN/GPT_GenshinImpact_EN_5.1.ckpt\"\n",
"SOVITS_PATH = \"Wuthering_Waves/CN/SV_WutheringWaves_CN_1.3.pth\"\n",
"\n",
"# Do Not Modify\n",
"HF_BASE = \"https://www.modelscope.cn/models\"\n",
"REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
"GPT_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{GPT_PATH}\"\n",
"SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{SOVITS_PATH}\"\n",
"\n",
"!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
"!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Launch WebUI\n",
"# 启动 WebUI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "4oRGUzkrk8C7"
},
"outputs": [],
"source": [
"!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -0,0 +1,117 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GPT-SoVITS WebUI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_o6a8GS2lWQM"
},
"source": [
"## Env Setup (Run Once Only)\n",
"## 环境配置, 只需运行一次"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile /content/setup.sh\n",
"set -e\n",
"\n",
"cd /content\n",
"\n",
"git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
"\n",
"cd GPT-SoVITS\n",
"\n",
"if conda env list | awk '{print $1}' | grep -Fxq \"GPTSoVITS\"; then\n",
" :\n",
"else\n",
" conda create -n GPTSoVITS python=3.10 -y\n",
"fi\n",
"\n",
"source activate GPTSoVITS\n",
"\n",
"pip install ipykernel\n",
"\n",
"bash install.sh --device CU126 --source HF --download-uvr5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -q condacolab\n",
"import condacolab\n",
"condacolab.install_from_url(\"https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh\")\n",
"!cd /content && bash setup.sh"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Launch WebUI\n",
"## 启动 WebUI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4oRGUzkrk8C7"
},
"outputs": [],
"source": [
"!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -0,0 +1,33 @@
#!/bin/bash
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
cd "$SCRIPT_DIR" || exit 1
cd .. || exit 1
set -e
source "$HOME/miniconda3/etc/profile.d/conda.sh"
mkdir -p GPT_SoVITS
mkdir -p GPT_SoVITS/text
ln -s /workspace/models/pretrained_models /workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
ln -s /workspace/models/G2PWModel /workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
bash install.sh --device "CU${CUDA_VERSION//./}" --source HF
pip cache purge
pip show torch
rm -rf /tmp/* /var/tmp/*
rm -rf "$HOME/miniconda3/pkgs"
mkdir -p "$HOME/miniconda3/pkgs"
rm -rf /root/.conda /root/.cache

View File

@@ -0,0 +1,70 @@
#!/bin/bash
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
cd "$SCRIPT_DIR" || exit 1
cd .. || exit 1
if [ -d "$HOME/miniconda3" ]; then
exit 0
fi
WORKFLOW=${WORKFLOW:-"false"}
TARGETPLATFORM=${TARGETPLATFORM:-"linux/amd64"}
if [ "$WORKFLOW" = "true" ]; then
WGET_CMD=(wget -nv --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404)
else
WGET_CMD=(wget --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404)
fi
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-x86_64.sh
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-aarch64.sh
else
exit 1
fi
LOG_PATH="/tmp/miniconda-install.log"
bash miniconda.sh -b -p "$HOME/miniconda3" >"$LOG_PATH" 2>&1
if [ $? -eq 0 ]; then
echo "== Miniconda Installed =="
else
echo "Failed to Install miniconda"
tail -n 50 "$LOG_PATH"
exit 1
fi
rm miniconda.sh
source "$HOME/miniconda3/etc/profile.d/conda.sh"
"$HOME/miniconda3/bin/conda" config --add channels conda-forge
"$HOME/miniconda3/bin/conda" update -q --all -y 1>/dev/null
"$HOME/miniconda3/bin/conda" install python=3.11 -q -y
"$HOME/miniconda3/bin/conda" install gcc=14 gxx ffmpeg cmake make unzip -q -y
if [ "$CUDA_VERSION" = "12.8" ]; then
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
elif [ "$CUDA_VERSION" = "12.6" ]; then
"$HOME/miniconda3/bin/pip" install torch==2.6 torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
fi
"$HOME/miniconda3/bin/pip" cache purge
rm $LOG_PATH
rm -rf "$HOME/miniconda3/pkgs"
mkdir -p "$HOME/miniconda3/pkgs"
rm -rf "$HOME/.conda" "$HOME/.cache"

View File

@@ -0,0 +1,62 @@
ARG CUDA_VERSION=12.6
ARG TORCH_BASE=full
FROM xxxxrt666/torch-base:cu${CUDA_VERSION}-${TORCH_BASE}
LABEL maintainer="XXXXRT"
LABEL version="V4"
LABEL description="Docker image for GPT-SoVITS"
ARG CUDA_VERSION=12.6
ENV CUDA_VERSION=${CUDA_VERSION}
SHELL ["/bin/bash", "-c"]
WORKDIR /workspace/GPT-SoVITS
COPY Docker /workspace/GPT-SoVITS/Docker/
ARG LITE=false
ENV LITE=${LITE}
ARG WORKFLOW=false
ENV WORKFLOW=${WORKFLOW}
ARG TARGETPLATFORM
ENV TARGETPLATFORM=${TARGETPLATFORM}
RUN bash Docker/miniconda_install.sh
COPY extra-req.txt /workspace/GPT-SoVITS/
COPY requirements.txt /workspace/GPT-SoVITS/
COPY install.sh /workspace/GPT-SoVITS/
RUN bash Docker/install_wrapper.sh
EXPOSE 9871 9872 9873 9874 9880
ENV PYTHONPATH="/workspace/GPT-SoVITS"
RUN conda init bash && echo "conda activate base" >> ~/.bashrc
WORKDIR /workspace
RUN rm -rf /workspace/GPT-SoVITS
WORKDIR /workspace/GPT-SoVITS
COPY . /workspace/GPT-SoVITS
CMD ["/bin/bash", "-c", "\
rm -rf /workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models && \
rm -rf /workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel && \
rm -rf /workspace/GPT-SoVITS/tools/asr/models && \
rm -rf /workspace/GPT-SoVITS/tools/uvr5/uvr5_weights && \
ln -s /workspace/models/pretrained_models /workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models && \
ln -s /workspace/models/G2PWModel /workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel && \
ln -s /workspace/models/asr_models /workspace/GPT-SoVITS/tools/asr/models && \
ln -s /workspace/models/uvr5_weights /workspace/GPT-SoVITS/tools/uvr5/uvr5_weights && \
exec bash"]

View File

@@ -0,0 +1,149 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py
# reference: https://github.com/lifeiteng/vall-e
import itertools
import math
import random
from random import shuffle
from typing import Iterator, Optional, TypeVar
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, Sampler
__all__ = [
"DistributedBucketSampler",
]
T_co = TypeVar("T_co", covariant=True)
class DistributedBucketSampler(Sampler[T_co]):
r"""
sort the dataset wrt. input length
divide samples into buckets
sort within buckets
divide buckets into batches
sort batches
"""
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
batch_size: int = 32,
) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank() if torch.cuda.is_available() else 0
if torch.cuda.is_available():
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas,
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
self.batch_size = batch_size
self.id_with_length = self._get_sample_lengths()
self.id_buckets = self.make_buckets(bucket_width=2.0)
def _get_sample_lengths(self):
id_with_lengths = []
for i in range(len(self.dataset)):
id_with_lengths.append((i, self.dataset.get_sample_length(i)))
id_with_lengths.sort(key=lambda x: x[1])
return id_with_lengths
def make_buckets(self, bucket_width: float = 2.0):
buckets = []
cur = []
max_sec = bucket_width
for id, sec in self.id_with_length:
if sec < max_sec:
cur.append(id)
else:
buckets.append(cur)
cur = [id]
max_sec += bucket_width
if len(cur) > 0:
buckets.append(cur)
return buckets
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
random.seed(self.epoch + self.seed)
shuffled_bucket = []
for buc in self.id_buckets:
buc_copy = buc.copy()
shuffle(buc_copy)
shuffled_bucket.append(buc_copy)
grouped_batch_size = self.batch_size * self.num_replicas
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
shuffle(batches)
indices = list(itertools.chain(*batches))
else:
# type: ignore[arg-type]
indices = list(range(len(self.dataset)))
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch

View File

@@ -0,0 +1,81 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
# reference: https://github.com/lifeiteng/vall-e
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset
class Text2SemanticDataModule(LightningDataModule):
def __init__(
self,
config,
train_semantic_path,
train_phoneme_path,
dev_semantic_path=None,
dev_phoneme_path=None,
):
super().__init__()
self.config = config
self.train_semantic_path = train_semantic_path
self.train_phoneme_path = train_phoneme_path
self.dev_semantic_path = dev_semantic_path
self.dev_phoneme_path = dev_phoneme_path
self.num_workers = self.config["data"]["num_workers"]
def prepare_data(self):
pass
def setup(self, stage=None, output_logs=False):
self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path,
max_sec=self.config["data"]["max_sec"],
pad_val=self.config["data"]["pad_val"],
)
self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset(
# phoneme_path=self.dev_phoneme_path,
# semantic_path=self.dev_semantic_path,
# max_sample=self.config['data']['max_eval_sample'],
# max_sec=self.config['data']['max_sec'],
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size = (
self.config["train"]["batch_size"] // 2
if self.config["train"].get("if_dpo", False) is True
else self.config["train"]["batch_size"]
)
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=self._train_dataset.collate,
num_workers=self.num_workers,
persistent_workers=True,
prefetch_factor=16,
)
def val_dataloader(self):
return DataLoader(
self._dev_dataset,
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate,
num_workers=max(self.num_workers, 12),
persistent_workers=True,
prefetch_factor=16,
)
# 这个会使用到嘛?
def test_dataloader(self):
return DataLoader(
self._dev_dataset,
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate,
)

View File

@@ -0,0 +1,320 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
# reference: https://github.com/lifeiteng/vall-e
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import os
import traceback
from typing import Dict, List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
version = os.environ.get("version", None)
from text import cleaned_text_to_sequence
# from config import exp_dir
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
seq = sequences[0]
ndim = seq.ndim
if axis < 0:
axis += ndim
dtype = seq.dtype
pad_value = dtype.type(pad_value)
seq_lengths = [seq.shape[axis] for seq in sequences]
max_length = np.max(seq_lengths)
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
return batch
class Text2SemanticDataset(Dataset):
"""dataset class for text tokens to semantic model training."""
def __init__(
self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25,
) -> None:
super().__init__()
self.semantic_data = pd.read_csv(
semantic_path,
delimiter="\t",
encoding="utf-8",
)
# get dict
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.dirname(
phoneme_path,
)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
assert os.path.exists(self.path6)
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
# pad for semantic tokens
self.PAD: int = pad_val
# self.hz = 25
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
# self.hz=int(data[:-2])#
self.hz = int(os.environ.get("hz", "25hz")[:-2])
# max seconds of semantic token
self.max_sec = max_sec
self.min_ps_ratio = min_ps_ratio
self.max_ps_ratio = max_ps_ratio
if max_sample is not None:
self.semantic_data = self.semantic_data[:max_sample]
# {idx: (semantic, phoneme)}
# semantic list, phoneme list
self.semantic_phoneme = []
self.item_names = []
self.inited = False
if not self.inited:
# 调用初始化函数
self.init_batch()
self.inited = True
del self.semantic_data
del self.phoneme_data
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
def init_batch(self):
semantic_data_len = len(self.semantic_data)
phoneme_data_len = len(self.phoneme_data.keys())
print("semantic_data_len:", semantic_data_len)
print("phoneme_data_len:", phoneme_data_len)
print(self.semantic_data)
idx = 0
num_not_in = 0
num_deleted_bigger = 0
num_deleted_ps = 0
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data.iloc[i, 0]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
except Exception:
traceback.print_exc()
# print(f"{item_name} not in self.phoneme_data !")
num_not_in += 1
continue
semantic_str = self.semantic_data.iloc[i, 1]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
# 过滤掉太长的样本
if (
len(semantic_ids) > self.max_sec * self.hz
): #########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
num_deleted_bigger += 1
continue
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
phoneme = phoneme.split(" ")
try:
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except:
traceback.print_exc()
# print(f"{item_name} not in self.phoneme_data !")
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
# num_deleted_bigger += 1
# continue
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
self.semantic_phoneme.append((semantic_ids, phoneme_ids))
idx += 1
self.item_names.append(item_name)
min_num = 100 # 20直接不补#30补了也不存ckpt
leng = len(self.semantic_phoneme)
if leng < min_num:
tmp1 = self.semantic_phoneme
tmp2 = self.item_names
self.semantic_phoneme = []
self.item_names = []
for _ in range(max(2, int(min_num / leng))):
self.semantic_phoneme += tmp1
self.item_names += tmp2
if num_not_in > 0:
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
print(
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
)
if num_deleted_ps > 0:
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
)
"""
there are 31 semantic datas not in phoneme datas
deleted 34 audios who's duration are bigger than 54 seconds
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
dataset.__len__(): 366463
"""
# 345410 for LibriTTS
print("dataset.__len__():", self.__len__())
def __get_item_names__(self) -> List[str]:
return self.item_names
def __len__(self) -> int:
return len(self.semantic_phoneme)
def __getitem__(self, idx: int) -> Dict:
semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
item_name = self.item_names[idx]
phoneme_ids_len = len(phoneme_ids)
# semantic tokens target
semantic_ids_len = len(semantic_ids)
flag = 0
path_bert = "%s/%s.pt" % (self.path3, item_name)
if os.path.exists(path_bert) == True:
bert_feature = torch.load(path_bert, map_location="cpu")
else:
flag = 1
if flag == 1:
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature = None
else:
assert bert_feature.shape[-1] == len(phoneme_ids)
return {
"idx": idx,
"phoneme_ids": phoneme_ids,
"phoneme_ids_len": phoneme_ids_len,
"semantic_ids": semantic_ids,
"semantic_ids_len": semantic_ids_len,
"bert_feature": bert_feature,
}
def get_sample_length(self, idx: int):
semantic_ids = self.semantic_phoneme[idx][0]
sec = 1.0 * len(semantic_ids) / self.hz
return sec
def collate(self, examples: List[Dict]) -> Dict:
sample_index: List[int] = []
phoneme_ids: List[torch.Tensor] = []
phoneme_ids_lens: List[int] = []
semantic_ids: List[torch.Tensor] = []
semantic_ids_lens: List[int] = []
# return
for item in examples:
sample_index.append(item["idx"])
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
phoneme_ids_lens.append(item["phoneme_ids_len"])
semantic_ids_lens.append(item["semantic_ids_len"])
# pad 0
phoneme_ids = batch_sequences(phoneme_ids)
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
# # convert each batch to torch.tensor
phoneme_ids = torch.tensor(phoneme_ids)
semantic_ids = torch.tensor(semantic_ids)
phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
semantic_ids_lens = torch.tensor(semantic_ids_lens)
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
bert_padded.zero_()
for idx, item in enumerate(examples):
bert = item["bert_feature"]
if bert != None:
bert_padded[idx, :, : bert.shape[-1]] = bert
return {
# List[int]
"ids": sample_index,
# torch.Tensor (B, max_phoneme_length)
"phoneme_ids": phoneme_ids,
# torch.Tensor (B)
"phoneme_ids_len": phoneme_ids_lens,
# torch.Tensor (B, max_semantic_ids_length)
"semantic_ids": semantic_ids,
# torch.Tensor (B)
"semantic_ids_len": semantic_ids_lens,
# torch.Tensor (B, 1024, max_phoneme_length)
"bert_feature": bert_padded,
}
if __name__ == "__main__":
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
dataset = Text2SemanticDataset(
phoneme_path=root_dir + "phoneme_train.npy",
semantic_path=root_dir + "semantic_train.tsv",
)
batch_size = 12
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False,
)
for i, batch in enumerate(dataloader):
if i % 1000 == 0:
print(i)
# if i == 0:
# print('batch["ids"]:', batch["ids"])
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape)

View File

@@ -0,0 +1,146 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
self.config = config
self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
pretrained_s1 = config.get("pretrained_s1")
if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(
pretrained_s1,
map_location="cpu",
weights_only=False,
)["weight"],
)
)
if is_train:
self.automatic_optimization = False
self.save_hyperparameters()
self.eval_dir = output_dir / "eval"
self.eval_dir.mkdir(parents=True, exist_ok=True)
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
loss, acc = forward(
batch["phoneme_ids"],
batch["phoneme_ids_len"],
batch["semantic_ids"],
batch["semantic_ids_len"],
batch["bert_feature"],
)
self.manual_backward(loss)
if batch_idx > 0 and batch_idx % 4 == 0:
opt.step()
opt.zero_grad()
scheduler.step()
self.log(
"total_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
self.log(
"lr",
scheduler.get_last_lr()[0],
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
self.log(
f"top_{self.top_k}_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
def validation_step(self, batch: Dict, batch_idx: int):
return
# # get loss
# loss, acc = self.model.forward(
# batch['phoneme_ids'], batch['phoneme_ids_len'],
# batch['semantic_ids'], batch['semantic_ids_len'],
# batch['bert_feature']
# )
#
# self.log(
# "val_total_loss",
# loss,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
# self.log(
# f"val_top_{self.top_k}_acc",
# acc,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
#
# # get infer output
# semantic_len = batch['semantic_ids'].size(1)
# prompt_len = min(int(semantic_len * 0.5), 150)
# prompt = batch['semantic_ids'][:, :prompt_len]
# pred_semantic = self.model.infer(batch['phoneme_ids'],
# batch['phoneme_ids_len'], prompt,
# batch['bert_feature']
# )
# save_name = f'semantic_toks_{batch_idx}.pt'
# save_path = os.path.join(self.eval_dir, save_name)
# torch.save(pred_semantic.detach().cpu(), save_path)
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,
betas=(0.9, 0.95),
clipping_scale=2.0,
parameters_names=parameters_names,
show_dominant_parameters=False,
clipping_update_period=1000,
)
return {
"optimizer": lm_opt,
"lr_scheduler": {
"scheduler": WarmupCosineLRSchedule(
lm_opt,
init_lr=self.config["optimizer"]["lr_init"],
peak_lr=self.config["optimizer"]["lr"],
end_lr=self.config["optimizer"]["lr_end"],
warmup_steps=self.config["optimizer"]["warmup_steps"],
total_steps=self.config["optimizer"]["decay_steps"],
)
},
}

View File

@@ -0,0 +1,110 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model_onnx import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
self.config = config
self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
pretrained_s1 = config.get("pretrained_s1")
if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(
pretrained_s1,
map_location="cpu",
)["weight"],
),
)
if is_train:
self.automatic_optimization = False
self.save_hyperparameters()
self.eval_dir = output_dir / "eval"
self.eval_dir.mkdir(parents=True, exist_ok=True)
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
loss, acc = self.model.forward(
batch["phoneme_ids"],
batch["phoneme_ids_len"],
batch["semantic_ids"],
batch["semantic_ids_len"],
batch["bert_feature"],
)
self.manual_backward(loss)
if batch_idx > 0 and batch_idx % 4 == 0:
opt.step()
opt.zero_grad()
scheduler.step()
self.log(
"total_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
self.log(
"lr",
scheduler.get_last_lr()[0],
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
self.log(
f"top_{self.top_k}_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
def validation_step(self, batch: Dict, batch_idx: int):
return
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,
betas=(0.9, 0.95),
clipping_scale=2.0,
parameters_names=parameters_names,
show_dominant_parameters=False,
clipping_update_period=1000,
)
return {
"optimizer": lm_opt,
"lr_scheduler": {
"scheduler": WarmupCosineLRSchedule(
lm_opt,
init_lr=self.config["optimizer"]["lr_init"],
peak_lr=self.config["optimizer"]["lr"],
end_lr=self.config["optimizer"]["lr_end"],
warmup_steps=self.config["optimizer"]["warmup_steps"],
total_steps=self.config["optimizer"]["decay_steps"],
)
},
}

View File

@@ -0,0 +1,935 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import math
from typing import List, Optional
import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from AR.models.utils import (
dpo_loss,
get_batch_logps,
make_pad_mask,
make_pad_mask_left,
make_reject_y,
sample,
topk_sampling,
)
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
"num_head": 8,
"num_layers": 12,
"num_codebook": 8,
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512,
"EOS": 1024,
}
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else:
scale_factor = scale
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask, float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_weight.masked_fill_(attn_mask, 0)
else:
attn_mask[attn_mask != float("-inf")] = 0
attn_mask[attn_mask == float("-inf")] = 1
attn_weight.masked_fill_(attn_mask, 0)
return attn_weight @ value
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
self.w1 = w1
self.b1 = b1
self.w2 = w2
self.b2 = b2
def forward(self, x):
x = F.relu(F.linear(x, self.w1, self.b1))
x = F.linear(x, self.w2, self.b2)
return x
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
):
self.num_heads = num_heads
self.mlp = mlp
self.hidden_dim: int = hidden_dim
self.qkv_w = qkv_w
self.qkv_b = qkv_b
self.out_w = out_w
self.out_b = out_b
self.norm_w1 = norm_w1
self.norm_b1 = norm_b1
self.norm_eps1 = norm_eps1
self.norm_w2 = norm_w2
self.norm_b2 = norm_b2
self.norm_eps2 = norm_eps2
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor],
):
if padding_mask is None:
return x
if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0)
else:
return x * padding_mask
def process_prompt(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
q = self.to_mask(q, padding_mask)
k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask)
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
else:
attn = scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(
self,
x: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else:
attn = scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w1,
self.norm_b1,
self.norm_eps1,
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
self.num_blocks: int = num_blocks
self.blocks = blocks
def process_prompt(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
k_cache.append(k_cache_)
v_cache.append(v_cache_)
return x, k_cache, v_cache
def decode_next_token(
self,
x: torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
)
return x, k_cache, v_cache
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config["model"]["head"]
self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first
self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = config["model"]["dropout"]
self.EOS = config["model"]["EOS"]
self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1
# should be same as num of kmeans bin
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim,
self.phoneme_vocab_size,
self.p_dropout,
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim,
self.vocab_size,
self.p_dropout,
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.h = TransformerEncoder(
TransformerEncoderLayer(
d_model=self.model_dim,
nhead=self.num_head,
dim_feedforward=self.model_dim * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
),
num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size,
top_k=top_k,
average="micro",
multidim_average="global",
ignore_index=self.EOS,
)
blocks = []
for i in range(self.num_layers):
layer = self.h.layers[i]
t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias,
)
block = T2SBlock(
self.num_head,
self.model_dim,
t2smlp,
layer.self_attn.in_proj_weight,
layer.self_attn.in_proj_bias,
layer.self_attn.out_proj.weight,
layer.self_attn.out_proj.bias,
layer.norm1.weight,
layer.norm1.bias,
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps,
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
x_mask = make_pad_mask_left(x_lens)
y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64)
codes = y.type(torch.int64) * (1 - y_mask_int)
# Training
# AR Decoder
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
x_len = x_lens.max()
y_len = y_lens.max()
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
ar_xy_padding_mask = xy_padding_mask
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
# x_attn_mask[:, x_len]=False
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1,
),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
# x 和完整的 y 一次性输入模型
xy_pos = torch.concat([x, y_pos], dim=1)
return xy_pos, xy_attn_mask, targets
def forward(self, x, x_lens, y, y_lens, bert_feature):
"""
x: phoneme_ids
y: semantic_ids
"""
reject_y, reject_y_lens = make_reject_y(y, y_lens)
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
x_len = x_lens.max()
logits = self.ar_predict_layer(xy_dec[:, x_len-1:])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
x, x_lens, reject_y, reject_y_lens, bert_feature
)
reject_xy_dec, _ = self.h(
(reject_xy_pos, None),
mask=reject_xy_attn_mask,
)
x_len = x_lens.max()
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len-1:])
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
loss = loss_1 + loss_2
return loss, acc
def forward_old(self, x, x_lens, y, y_lens, bert_feature):
"""
x: phoneme_ids
y: semantic_ids
"""
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
x_mask = make_pad_mask_left(x_lens)
y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64)
codes = y.type(torch.int64) * (1 - y_mask_int)
# Training
# AR Decoder
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
x_len = x_lens.max()
y_len = y_lens.max()
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
ar_xy_padding_mask = xy_padding_mask
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1,
),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
# x 和完整的 y 一次性输入模型
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, x_len-1:]).permute(0, 2, 1)
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
for _ in tqdm(range(1500)):
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型
xy_pos = torch.concat([x, y_pos], dim=1)
y_len = y.shape[1]
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
# import os
# os._exit(2333)
y = torch.concat([y, samples], dim=1)
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
# 错位
return targets[:, :-1], targets
def infer_panel_batch_infer(
self,
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs,
):
if prompts is None:
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(
x,
x_lens,
prompts,
bert_feature,
top_k=top_k,
top_p=top_p,
early_stop_num=early_stop_num,
temperature=temperature,
**kwargs,
)
max_len = kwargs.get("max_len", x_lens.max())
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0)
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
x_item = (
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
) ### padding left
x_list.append(x_item)
x: torch.Tensor = torch.stack(x_list, dim=0)
# AR Decoder
y = prompts
x_len = x.shape[1]
stop = False
k_cache = None
v_cache = None
################### first step ##########################
assert y is not None, "Error: Prompt free is not supported batch_infer!"
ref_free = False
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
# (bsz, x_len + y_len)
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0),
value=False,
)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
### 上面是错误的会导致padding的token被"看见"
# 正确的padding_mask应该是
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉但是为了防止计算attention时不出现nan还是保留了不影响结果
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
# 正确的attn_mask应该是这样的
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉但是为了防止计算attention时不出现nan还是保留了不影响结果
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
###### decode #####
y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None] * y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
logits = logits[:, :-1]
else:
attn_mask = F.pad(attn_mask, (0, 1), value=False)
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
tokens = torch.argmax(logits, dim=-1)
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0] == self.EOS
l2 = tokens == self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None:
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
print("use early stop num:", early_stop_num)
stop = True
for i, batch_index in enumerate(batch_idx_map):
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if None not in idx_list:
stop = True
if stop:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if None in idx_list:
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0] * x.shape[0]
# print(idx_list)
return y_list, idx_list
def infer_panel_naive_batched(
self,
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs,
):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_naive(
x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs,
)
y_list.append(y[0])
idx_list.append(idx)
return y_list, idx_list
def infer_panel_naive(
self,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
k_cache = None
v_cache = None
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
bsz = x.shape[0]
src_len = x_len + y_len
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = (
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.unsqueeze(0)
.expand(bsz * self.num_head, -1, -1)
.view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
xy_attn_mask = None
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx
def infer_panel(
self,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs,
):
return self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
)

View File

@@ -0,0 +1,394 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
"num_head": 8,
"num_layers": 12,
"num_codebook": 8,
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512,
"EOS": 1024,
}
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
def logits_to_probs(
logits,
previous_tokens=None,
temperature: float = 1.0,
top_k=None,
top_p=None,
repetition_penalty: float = 1.0,
):
previous_tokens = previous_tokens.squeeze()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(
sorted_logits,
dim=-1,
),
dim=-1,
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, inf_tensor_value, logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def sample(
logits,
previous_tokens,
**sampling_kwargs,
):
probs = logits_to_probs(
logits=logits,
previous_tokens=previous_tokens,
**sampling_kwargs,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
class OnnxEncoder(nn.Module):
def __init__(self, ar_text_embedding, bert_proj, ar_text_position):
super().__init__()
self.ar_text_embedding = ar_text_embedding
self.bert_proj = bert_proj
self.ar_text_position = ar_text_position
def forward(self, x, bert_feature):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
return self.ar_text_position(x)
class T2SFirstStageDecoder(nn.Module):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
self.h = h
self.ar_predict_layer = ar_predict_layer
self.loss_fct = loss_fct
self.ar_accuracy_metric = ar_accuracy_metric
self.top_k = top_k
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, x, prompt):
y = prompt
x_example = x[:, :, 0] * 0.0
# N, 1, 512
cache = {
"all_stage": self.num_layers,
"k": None,
"v": None,
"y_emb": None,
"first_infer": 1,
"stage": 0,
}
y_emb = self.ar_audio_embedding(y)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
y_example = y_pos[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones_like(
y_example.transpose(0, 1),
dtype=torch.int64,
),
dim=0,
)
y_attn_mask = y_attn_mask > 0
x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
cache["k"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
cache["v"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)
return y, cache["k"], cache["v"], cache["y_emb"], x_example
class T2SStageDecoder(nn.Module):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
self.h = h
self.ar_predict_layer = ar_predict_layer
self.loss_fct = loss_fct
self.ar_accuracy_metric = ar_accuracy_metric
self.top_k = top_k
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, y, k, v, y_emb, x_example):
cache = {
"all_stage": self.num_layers,
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
"v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
"y_emb": y_emb,
"first_infer": 0,
"stage": 0,
}
y_emb = torch.cat(
[
cache["y_emb"],
self.ar_audio_embedding(y[:, -1:]),
],
1,
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
y_example = y_pos[:, :, 0] * 0.0
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)
return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config["model"]["head"]
self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first
self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = float(config["model"]["dropout"])
self.EOS = config["model"]["EOS"]
self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.h = TransformerEncoder(
TransformerEncoderLayer(
d_model=self.model_dim,
nhead=self.num_head,
dim_feedforward=self.model_dim * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
),
num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size,
top_k=top_k,
average="micro",
multidim_average="global",
ignore_index=self.EOS,
)
self.top_k = torch.LongTensor([1])
self.early_stop_num = torch.LongTensor([-1])
def init_onnx(self):
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
self.first_stage_decoder = T2SFirstStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
self.stage_decoder = T2SStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
def forward(self, x, prompts, bert_feature):
early_stop_num = self.early_stop_num
prefix_len = prompts.shape[1]
x = self.onnx_encoder(x, bert_feature)
y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
stop = False
for idx in range(1, 1500):
enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
y, k, v, y_emb, stage, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
break
y[0, -1] = 0
return y, idx
def infer(self, x, prompts, bert_feature):
top_k = self.top_k
early_stop_num = self.early_stop_num
x = self.onnx_encoder(x, bert_feature)
y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1]
x_example = x[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
stop = False
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers,
"v": [None] * self.num_layers,
"y_emb": None,
"first_infer": 1,
"stage": 0,
}
for idx in range(1500):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
if cache["first_infer"] == 1:
xy_pos = torch.concat([x, y_pos], dim=1)
else:
xy_pos = y_pos[:, -1:]
y_len = y_pos.shape[1]
if cache["first_infer"] == 1:
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
else:
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
break
y = torch.concat([y, samples], dim=1)
cache["first_infer"] = 0
return y, idx

View File

@@ -0,0 +1,282 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
# reference: https://github.com/lifeiteng/vall-e
from typing import Tuple
import torch
import torch.nn.functional as F
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
#>>> lengths = torch.tensor([1, 3, 2, 5])
#>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
return expaned_lengths >= lengths.unsqueeze(-1)
def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
#>>> lengths = torch.tensor([1, 3, 2, 5])
#>>> make_pad_mask(lengths)
tensor(
[
[True, True, False],
[True, False, False],
[True, True, False],
...
]
)
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
return expaned_lengths < 0
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits,
top_k=0,
top_p=1.0,
filter_value=-float("Inf"),
min_tokens_to_keep=1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
# temperature: (`optional`) float
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
# top_k: (`optional`) int
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
# top_p: (`optional`) float
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
return token
from typing import Optional
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(
logits,
previous_tokens: Optional[torch.Tensor] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
# if previous_tokens is not None:
# previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(
logits,
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def dpo_loss(
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
if reference_free:
ref_logratios = 0
logits = pi_logratios - ref_logratios
losses = -F.logsigmoid(beta * logits)
chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
return losses.mean(), chosen_rewards, rejected_rewards
def get_batch_logps(
logits_target: torch.FloatTensor,
logits_reject: torch.FloatTensor,
labels_target: torch.LongTensor,
labels_reject: torch.LongTensor,
average_log_prob: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# dummy token; we'll ignore the losses on these tokens later
per_token_logps_target = torch.gather(
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
).squeeze(2)
per_token_logps_reject = torch.gather(
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
).squeeze(2)
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
def make_reject_y(y_o, y_lens):
def repeat_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, range_text, range_text, shf])
return new_y
def lost_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, shf])
return new_y
bs = len(y_lens)
reject_y = []
reject_y_lens = []
for b in range(bs):
process_item_idx = torch.randint(0, 1, size=(1,))[0]
if process_item_idx == 0:
new_y = repeat_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
elif process_item_idx == 1:
new_y = lost_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
max_length = max(reject_y_lens)
for b in range(bs):
pad_length = max_length - reject_y_lens[b]
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
reject_y = torch.stack(reject_y, dim=0)
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
return reject_y, reject_y_lens

View File

@@ -0,0 +1,413 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
Multi-Head Attention is defined as:
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
``forward()`` will use a special optimized implementation if all of the following
conditions are met:
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
restriction will be loosened in the future.)
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
- training is disabled (using ``.eval()``)
- dropout is 0
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``batch_first`` is ``True`` and the input is batched
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
nor ``attn_mask`` is passed
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
``query``/``key``/``value`` to represent padding more efficiently than using a
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
will be returned, and an additional speedup proportional to the fraction of the input
that is padding can be expected.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
Default: ``False``.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Examples::
>>> # xdoctest: +SKIP
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs),
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs),
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs),
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
self._reset_parameters()
else:
if not self._qkv_same_embed_dim:
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
)
self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = self.in_proj_linear.bias
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
self.add_zero_attn = add_zero_attn
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.0)
constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask,
):
raise AssertionError("only bool and floating types of key_padding_mask are supported")
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
)
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
)
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.dropout:
why_not_fast_path = f"dropout was {self.dropout}, required zero"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
cache=cache,
)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
cache=cache,
)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights

View File

@@ -0,0 +1,188 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
class MultiheadAttention(Module):
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty(
(embed_dim, embed_dim),
**factory_kwargs,
)
)
self.k_proj_weight = Parameter(
torch.empty(
(embed_dim, self.kdim),
**factory_kwargs,
)
)
self.v_proj_weight = Parameter(
torch.empty(
(embed_dim, self.vdim),
**factory_kwargs,
)
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty(
(3 * embed_dim, embed_dim),
**factory_kwargs,
)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs),
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self._reset_parameters()
else:
if not self._qkv_same_embed_dim:
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
)
self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = self.in_proj_linear.bias
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
self.add_zero_attn = add_zero_attn
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.0)
constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
any_nested = query.is_nested or key.is_nested or value.is_nested
query = key = value = query.transpose(1, 0)
attn_output = multi_head_attention_forward_patched(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
cache=cache,
)
return attn_output.transpose(1, 0)

View File

@@ -0,0 +1,78 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
import math
import torch
from torch import nn
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.dropout = torch.nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self) -> torch.Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor):
x = self.word_embeddings(x)
x = self.dropout(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.dropout = torch.nn.Dropout(p=dropout)
self.reverse = False
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse:
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output)

View File

@@ -0,0 +1,63 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
import math
import torch
from torch import nn
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.dropout = torch.nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self) -> torch.Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor):
x = self.word_embeddings(x)
x = self.dropout(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.dropout = torch.nn.Dropout(p=dropout)
self.reverse = False
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
def extend_pe(self, x):
position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
scpe = (position * self.div_term).unsqueeze(0)
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
pe = pe.contiguous().view(1, -1, self.embedding_dim)
return pe
def forward(self, x: torch.Tensor) -> torch.Tensor:
pe = self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * pe
return self.dropout(output)

View File

@@ -0,0 +1,85 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
# reference: https://github.com/lifeiteng/vall-e
import math
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import Adam
class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
"""
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
"""
def __init__(
self,
optimizer,
init_lr,
peak_lr,
end_lr,
warmup_steps=10000,
total_steps=400000,
current_step=0,
):
self.init_lr = init_lr
self.peak_lr = peak_lr
self.end_lr = end_lr
self.optimizer = optimizer
self._warmup_rate = (peak_lr - init_lr) / warmup_steps
self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
self._current_step = current_step
self.lr = init_lr
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self._last_lr = [self.lr]
def set_lr(self, lr):
self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
for g in self.optimizer.param_groups:
# g['lr'] = lr
g["lr"] = self.end_lr ###锁定用线性
def step(self):
if self._current_step < self.warmup_steps:
lr = self.init_lr + self._warmup_rate * self._current_step
elif self._current_step > self.total_steps:
lr = self.end_lr
else:
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
self.set_lr(lr)
self.lr = lr
self._current_step += 1
return self.lr
if __name__ == "__main__":
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule(
opt,
1e-6,
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0,
)
lrs = []
for i in range(25000):
s.step()
lrs.append(s.lr)
print(s.lr)
plt.plot(lrs)
plt.plot(range(0, 25000), lrs)
plt.show()

View File

@@ -0,0 +1,593 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
from collections import defaultdict
from typing import List, Tuple
import torch
from torch import Tensor
from torch.optim import Optimizer
class BatchedOptimizer(Optimizer):
"""
This class adds to class Optimizer the capability to optimize parameters in batches:
it will stack the parameters and their grads for you so the optimizer can work
on tensors with an extra leading dimension. This is intended for speed with GPUs,
as it reduces the number of kernels launched in the optimizer.
Args:
params:
"""
def __init__(self, params, defaults):
super(BatchedOptimizer, self).__init__(params, defaults)
@contextlib.contextmanager
def batched_params(self, param_group, group_params_names):
"""
This function returns (technically, yields) a list of
of tuples (p, state), where
p is a `fake` parameter that is stacked (over axis 0) from real parameters
that share the same shape, and its gradient is also stacked;
`state` is the state corresponding to this batch of parameters
(it will be physically located in the "state" for one of the real
parameters, the last one that has any particular shape and dtype).
This function is decorated as a context manager so that it can
write parameters back to their "real" locations.
The idea is, instead of doing:
<code>
for p in group["params"]:
state = self.state[p]
...
</code>
you can do:
<code>
with self.batched_params(group["params"]) as batches:
for p, state, p_names in batches:
...
</code>
Args:
group: a parameter group, which is a list of parameters; should be
one of self.param_groups.
group_params_names: name for each parameter in group,
which is List[str].
"""
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
assert len(param_group) == len(group_params_names)
for p, named_p in zip(param_group, group_params_names):
key = (str(p.dtype), *p.shape)
batches[key].append(p)
batches_names[key].append(named_p)
batches_names_keys = list(batches_names.keys())
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict()
# turn batches into a list, in deterministic order.
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
# one for each batch in `batches`.
tuples = []
for batch, batch_names in zip(batches, batches_names):
p = batch[0]
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
tuples.append((p_stacked, state, batch_names))
yield tuples # <-- calling code will do the actual optimization here!
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
class ScaledAdam(BatchedOptimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter,
in log space, subject to upper and lower limits (as if we had factored each parameter as
param = underlying_param * log_scale.exp())
Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
clipping_scale: (e.g. 2.0)
A scale for gradient-clipping: if specified, the normalized gradients
over the whole model will be clipped to have 2-norm equal to
`clipping_scale` times the median 2-norm over the most recent period
of `clipping_update_period` minibatches. By "normalized gradients",
we mean after multiplying by the rms parameter value for this tensor
[for non-scalars]; this is appropriate because our update is scaled
by this quantity.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
Must satisfy 0 < beta <= beta2 < 1.
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor and scalar parameters of the mode..
If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
would be a the scaling factor on the learning rate of p_scale.
eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be <= this value)
scalar_max: Maximum absolute value for scalar parameters (applicable if your
model has any parameters with numel() == 1).
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time
in the update.
clipping_update_period: if clipping_scale is specified, this is the period
"""
def __init__(
self,
params,
lr=3e-02,
clipping_scale=None,
betas=(0.9, 0.98),
scalar_lr_scale=0.1,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=3.0,
scalar_max=10.0,
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True,
):
assert parameters_names is not None, (
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
)
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
betas=betas,
scalar_lr_scale=scalar_lr_scale,
eps=eps,
param_min_rms=param_min_rms,
param_max_rms=param_max_rms,
scalar_max=scalar_max,
size_update_period=size_update_period,
clipping_update_period=clipping_update_period,
)
super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names)
self.parameters_names = parameters_names
self.show_dominant_parameters = show_dominant_parameters
def __setstate__(self, state):
super(ScaledAdam, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
batch = True
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
clipping_scale = 1
else:
clipping_scale = self._get_clipping_scale(group, batches)
for p, state, _ in batches:
# Perform optimization step.
# grad is not going to be None, we handled that when creating the batches.
grad = p.grad
if grad.is_sparse:
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
self._step_one_batch(group, p, state, clipping_scale)
return loss
def _init_state(self, group: dict, p: Tensor, state: dict):
"""
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
is actually the batch dimension, corresponding to batched-together
parameters of a given shape.
Args:
group: Dict to look up configuration values.
p: The parameter that we are initializing the state for
state: Dict from string to whatever state we are initializing
"""
size_update_period = group["size_update_period"]
state["step"] = 0
kwargs = {"device": p.device, "dtype": p.dtype}
# 'delta' implements conventional momentum. There are
# several different kinds of update going on, so rather than
# compute "exp_avg" like in Adam, we store and decay a
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
batch_size = p.shape[0]
numel = p.numel() // batch_size
numel = p.numel()
if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
"""
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update.
Args:
group: the parameter group, an item in self.param_groups
tuples: a list of tuples of (param, state, param_names)
where param is a batched set of parameters,
with a .grad (1st dim is batch dim)
and state is the state-dict where optimization parameters are kept.
param_names is a List[str] while each str is name for a parameter
in batched set of parameters "param".
"""
assert len(tuples) >= 1
clipping_scale = group["clipping_scale"]
(first_p, first_state, _) = tuples[0]
step = first_state["step"]
if clipping_scale is None or step == 0:
# no clipping. return early on step == 0 because the other
# parameters' state won't have been initialized yet.
return 1.0
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else:
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
tot_norm = tot_sumsq.sqrt()
if "model_norms" not in first_state:
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
# Print some stats.
# We don't reach here if step == 0 because we would have returned
# above.
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
quartiles = []
for n in range(0, 5):
index = min(
clipping_update_period - 1,
(clipping_update_period // 4) * n,
)
quartiles.append(sorted_norms[index].item())
median = quartiles[2]
threshold = clipping_scale * median
first_state["model_norm_threshold"] = threshold
percent_clipped = (
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
)
first_state["num_clipped"] = 0
quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)
if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to.
else:
try:
model_norm_threshold = first_state["model_norm_threshold"]
except KeyError:
logging.info(
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
)
return 1.0
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warning(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
return ans
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
"""
Show information of parameter which dominating tot_sumsq.
Args:
tuples: a list of tuples of (param, state, param_names)
where param is a batched set of parameters,
with a .grad (1st dim is batch dim)
and state is the state-dict where optimization parameters are kept.
param_names is a List[str] while each str is name for a parameter
in batched set of parameters "param".
tot_sumsq: sumsq of all parameters. Though it's could be calculated
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
batch_sumsq_orig = batch_grad**2
# Dummpy values used by following `zip` statement.
batch_rms_orig = torch.ones(p.shape[0])
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
for name, sumsq_orig, rms, grad in zip(
batch_param_names,
batch_sumsq_orig,
batch_rms_orig,
batch_grad,
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0),
)
sorted_by_proportion = {
k: v
for k, v in sorted(
all_sumsq_orig.items(),
key=lambda item: item[1][0],
reverse=True,
)
}
dominant_param_name = next(iter(sorted_by_proportion))
(
dominant_proportion,
dominant_sumsq,
dominant_rms,
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
f"Parameter Dominating tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
Args:
group: dict to look up configuration values
p: parameter to update (actually multiple parameters stacked together
as a batch)
state: state-dict for p, to look up the optimizer state
"""
lr = group["lr"]
size_update_period = group["size_update_period"]
beta1 = group["betas"][0]
grad = p.grad
if clipping_scale != 1.0:
grad = grad * clipping_scale
step = state["step"]
delta = state["delta"]
delta.mul_(beta1)
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
self._size_update(group, scale_grads, p, state)
if numel == 1:
# For parameters with 1 element we just use regular Adam.
# Updates delta.
self._step_scalar(group, p, state)
else:
self._step(group, p, state)
state["step"] = step + 1
def _size_update(
self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict,
) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
gradient descent on underlying param and on scale, this function does the update
on `scale`.
Args:
group: dict to look up configuration values
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
grads w.r.t. the scales.
p: The parameter to update
state: The state-dict of p
"""
param_rms = state["param_rms"]
beta1, beta2 = group["betas"]
size_lr = group["lr"] * group["scalar_lr_scale"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
eps = group["eps"]
step = state["step"]
batch_size = p.shape[0]
size_update_period = scale_grads.shape[0]
# correct beta2 for the size update period: we will have
# faster decay at this level.
beta2_corr = beta2**size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
bias_correction2 = 1 - beta2_corr**size_step
# we don't bother with bias_correction1; this will help prevent divergence
# at the start of training.
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
is_too_small = param_rms < param_min_rms
is_too_large = param_rms > param_max_rms
# when the param gets too small, just don't shrink it any further.
scale_step.masked_fill_(is_too_small, 0.0)
# when it gets too large, stop it from getting any larger.
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
delta.add_(p * scale_step, alpha=(1 - beta1))
def _step(self, group: dict, p: Tensor, state: dict):
"""
This function does the core update of self.step(), in the case where the members of
the batch have more than 1 element.
Args:
group: A dict which will be used to look up configuration values
p: The parameter to be updated
grad: The grad of p
state: The state-dict corresponding to parameter p
This function modifies p.
"""
grad = p.grad
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
param_min_rms = group["param_min_rms"]
step = state["step"]
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
bias_correction2 = 1 - beta2 ** (this_step + 1)
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
denom = exp_avg_sq.sqrt()
denom += eps
grad = grad / denom
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
delta = state["delta"]
delta.add_(grad * alpha)
p.add_(delta)
def _step_scalar(self, group: dict, p: Tensor, state: dict):
"""
A simplified form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms.
"""
beta1, beta2 = group["betas"]
scalar_max = group["scalar_max"]
eps = group["eps"]
lr = group["lr"] * group["scalar_lr_scale"]
grad = p.grad
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
delta = state["delta"]
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
p.clamp_(min=-scalar_max, max=scalar_max)
p.add_(delta)

View File

@@ -0,0 +1,428 @@
from torch.nn.functional import *
from torch.nn.functional import (
_mha_shape_check,
_canonical_mask,
_none_or_dtype,
_in_projection_packed,
)
import torch
# Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union
def multi_head_attention_forward_patched(
query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p: float,
out_proj_weight,
out_proj_bias,
training=True,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
use_separate_proj_weight=False,
q_proj_weight=None,
k_proj_weight=None,
v_proj_weight=None,
static_k=None,
static_v=None,
average_attn_weights=True,
is_causal=False,
cache=None,
):
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
Default: `True`
Note: `needs_weight` defaults to `True`, but should be set to `False`
For best performance when attention weights are not nedeeded.
*Setting needs_weights to `True`
leads to a significant performance degradation.*
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
is_causal: If specified, applies a causal mask as attention mask, and ignores
attn_mask for computing scaled dot product attention.
Default: ``False``.
.. warning::
is_causal is provides a hint that the attn_mask is the
causal mask.Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in different forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
when ``need_weights=True.``. Default: True
Shape:
Inputs:
- query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a FloatTensor is provided, it will be directly added to the value.
If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
Outputs:
- attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
"""
tens_ops = (
query,
key,
value,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
out_proj_weight,
out_proj_bias,
)
if has_torch_function(tens_ops):
return handle_torch_function(
multi_head_attention_forward,
tens_ops,
query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias,
training=training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
is_causal=is_causal,
use_separate_proj_weight=use_separate_proj_weight,
q_proj_weight=q_proj_weight,
k_proj_weight=k_proj_weight,
v_proj_weight=v_proj_weight,
static_k=static_k,
static_v=static_v,
average_attn_weights=average_attn_weights,
cache=cache,
)
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
# batch dimension so that the output doesn't carry this temporary batch dimension.
if not is_batched:
# unsqueeze if the input is unbatched
query = query.unsqueeze(1)
key = key.unsqueeze(1)
value = value.unsqueeze(1)
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.unsqueeze(0)
# set up shape vars
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
key_padding_mask = _canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype,
)
if is_causal and attn_mask is None:
raise RuntimeError(
"Need attn_mask if specifying the is_causal hint. "
"You may use the Transformer module method "
"`generate_square_subsequent_mask` to create this mask."
)
if is_causal and key_padding_mask is None and not need_weights:
# when we have a kpm or need weights, we need attn_mask
# Otherwise, we use the is_causal hint go as is_causal
# indicator to SDPA.
attn_mask = None
else:
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
if key_padding_mask is not None:
# We have the attn_mask, and use that to merge kpm into it.
# Turn off use of is_causal hint, as the merged mask is no
# longer causal.
is_causal = False
assert embed_dim == embed_dim_to_check, (
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
)
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], (
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
)
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(
query,
key,
value,
q_proj_weight,
k_proj_weight,
v_proj_weight,
b_q,
b_k,
b_v,
)
if cache != None:
if cache["first_infer"] == 1:
cache["k"][cache["stage"]] = k
# print(0,cache["k"].shape)
cache["v"][cache["stage"]] = v
else: ###12个layer每个都要留自己的cache_kv
# print(1,cache["k"].shape)
cache["k"][cache["stage"]] = torch.cat(
[cache["k"][cache["stage"]], k], 0
) ##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
# print(2, cache["k"].shape)
src_len = cache["k"][cache["stage"]].shape[0]
k = cache["k"][cache["stage"]]
v = cache["v"][cache["stage"]]
# if attn_mask is not None:
# attn_mask=attn_mask[-1:,]
# print(attn_mask.shape,attn_mask)
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
# print(2333,cache)
# prep attention mask
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=q.dtype,
check_other=False,
)
if attn_mask is not None:
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
)
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
assert bias_v is None
#
# reshape q, k, v for multihead attention and make em batch first
#
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is None:
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, (
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
)
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * num_heads, (
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
)
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
src_len = k.size(1)
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (
bsz,
src_len,
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None:
attn_mask = key_padding_mask
else:
attn_mask = attn_mask + key_padding_mask
# adjust dropout probability
if not training:
dropout_p = 0.0
#
# (deep breath) calculate attention and out projection
#
if need_weights:
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
if dropout_p > 0.0:
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
# optionally average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
attn_output_weights = attn_output_weights.squeeze(0)
return attn_output, attn_output_weights
else:
# attn_mask can be either (L,S) or (N*num_heads, L, S)
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
# in order to match the input for SDPA of (N, num_heads, L, S)
if attn_mask is not None:
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
return attn_output, None

View File

@@ -0,0 +1,85 @@
from torch.nn.functional import *
from torch.nn.functional import (
_canonical_mask,
)
def multi_head_attention_forward_patched(
query,
key,
value,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight,
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars
_, _, embed_dim = query.shape
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
head_dim = embed_dim // num_heads
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
if cache["first_infer"] == 1:
cache["k"][cache["stage"]] = k
cache["v"][cache["stage"]] = v
else:
cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
k = cache["k"][cache["stage"]]
v = cache["v"][cache["stage"]]
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=q.dtype,
check_other=False,
)
attn_mask = attn_mask.unsqueeze(0)
q = q.view(-1, num_heads, head_dim).transpose(0, 1)
k = k.view(-1, num_heads, head_dim).transpose(0, 1)
v = v.view(-1, num_heads, head_dim).transpose(0, 1)
dropout_p = 0.0
attn_mask = attn_mask.unsqueeze(0)
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(-1, 1, attn_output.size(1))
return attn_output

View File

@@ -0,0 +1,320 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Optional
from typing import Tuple
import torch
import torch.nn as nn
from torch import Tensor
class DoubleSwishFunction(torch.autograd.Function):
"""
double_swish(x) = x * torch.sigmoid(x-1)
This is a definition, originally motivated by its close numerical
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
Memory-efficient derivative computation:
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
Now, s'(x) = s(x) * (1-s(x)).
double_swish'(x) = x * s'(x) + s(x).
= x * s(x) * (1-s(x)) + s(x).
= double_swish(x) * (1-s(x)) + s(x)
... so we just need to remember s(x) but not x itself.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
x_dtype = x.dtype
if x.dtype == torch.float16:
x = x.to(torch.float32)
s = torch.sigmoid(x - 1.0)
y = x * s
if requires_grad:
deriv = y * (1 - s) + s
# notes on derivative of x * sigmoid(x - 1):
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
# floors), should be expectation-preserving.
floor = -0.043637
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
assert d_scaled.max() < 256.0
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
return y
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
(d,) = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.043637
ceil = 1.2
d = d * ((ceil - floor) / 255.0) + floor
return y_grad * d
class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x)
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
xgt0 = x > 0
if sign_factor is None:
ctx.save_for_backward(xgt0, scale_factor)
else:
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
if len(ctx.saved_tensors) == 3:
xgt0, scale_factor, sign_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
scale_factor = scale_factor.unsqueeze(-1)
sign_factor = sign_factor.unsqueeze(-1)
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
else:
xgt0, scale_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor
return (
x_grad - neg_delta_grad,
None,
None,
None,
)
def _compute_scale_factor(
x: Tensor,
channel_dim: int,
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
if min_abs == 0.0:
below_threshold = 0.0
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
return below_threshold - above_threshold
def _compute_sign_factor(
x: Tensor,
channel_dim: int,
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
if min_positive == 0.0:
factor1 = 0.0
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
min=0, max=max_factor
)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
return sign_factor
class ActivationBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
each channel, that it is positive at least a proportion `threshold` of the
time. It does this by multiplying negative derivative values by up to
(1+max_factor), and positive derivative values by up to (1-max_factor),
interpolated from 1 at the threshold to those extremal values when none
of the inputs are positive.
Args:
num_channels: the number of channels
channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
min_positive: the minimum, per channel, of the proportion of the time
that (x > 0), below which we start to modify the derivatives.
max_positive: the maximum, per channel, of the proportion of the time
that (x > 0), above which we start to modify the derivatives.
max_factor: the maximum factor by which we modify the derivatives for
either the sign constraint or the magnitude constraint;
e.g. with max_factor=0.02, the the derivatives would be multiplied by
values in the range [0.98..1.02].
sign_gain_factor: determines the 'gain' with which we increase the
change in gradient once the constraints on min_positive and max_positive
are violated.
scale_gain_factor: determines the 'gain' with which we increase the
change in gradient once the constraints on min_abs and max_abs
are violated.
min_abs: the minimum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify
the derivatives to prevent this.
min_prob: determines the minimum probability with which we modify the
gradients for the {min,max}_positive and {min,max}_abs constraints,
on each forward(). This is done randomly to prevent all layers
from doing it at the same time. Early in training we may use
higher probabilities than this; it will decay to this value.
"""
def __init__(
self,
num_channels: int,
channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.04,
sign_gain_factor: float = 0.01,
scale_gain_factor: float = 0.02,
min_abs: float = 0.2,
max_abs: float = 100.0,
min_prob: float = 0.1,
):
super(ActivationBalancer, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.min_positive = min_positive
self.max_positive = max_positive
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
self.min_prob = min_prob
self.sign_gain_factor = sign_gain_factor
self.scale_gain_factor = scale_gain_factor
# count measures how many times the forward() function has been called.
# We occasionally sync this to a tensor called `count`, that exists to
# make sure it is synced to disk when we load and save the model.
self.cpu_count = 0
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
return _no_op(x)
count = self.cpu_count
self.cpu_count += 1
if random.random() < 0.01:
# Occasionally sync self.cpu_count with self.count.
# count affects the decay of 'prob'. don't do this on every iter,
# because syncing with the GPU is slow.
self.cpu_count = max(self.cpu_count, self.count.item())
self.count.fill_(self.cpu_count)
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
if random.random() < prob:
sign_gain_factor = 0.5
if self.min_positive != 0.0 or self.max_positive != 1.0:
sign_factor = _compute_sign_factor(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor,
)
else:
sign_factor = None
scale_factor = _compute_scale_factor(
x.detach(),
self.channel_dim,
min_abs=self.min_abs,
max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor,
)
return ActivationBalancerFunction.apply(
x,
scale_factor,
sign_factor,
self.channel_dim,
)
else:
return _no_op(x)
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
"""
ActivationBalancer -> DoubleSwish
"""
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
return nn.Sequential(
balancer,
DoubleSwish(),
)

View File

@@ -0,0 +1,362 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
import copy
import numbers
from functools import partial
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from AR.modules.activation import MultiheadAttention
from AR.modules.scaling import BalancedDoubleSwish
from torch import nn
from torch import Tensor
from torch.nn import functional as F
_shape_t = Union[int, List[int], torch.Size]
class LayerNorm(nn.Module):
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
def extra_repr(self) -> str:
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
assert embedding is None
return input
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
return_layer_states: return layers' state (optional).
Shape:
see the docs in Transformer class.
"""
if return_layer_states:
layer_states = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
layer_states.append(output[0])
if self.norm is not None:
output = self.norm(output)
return layer_states, output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
# print(233333333333,d_model,nhead)
# import os
# os._exit(2333333)
self.self_attn = MultiheadAttention(
d_model, # 512 16
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
activation = activation(d_model)
elif activation == BalancedDoubleSwish:
activation = BalancedDoubleSwish(d_model)
# # We can't test self.activation in forward() in TorchScript,
# # so stash some information about it instead.
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
# self.activation_relu_or_gelu = 1
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
# self.activation_relu_or_gelu = 2
# else:
# self.activation_relu_or_gelu = 0
self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
else:
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
else:
self.norm1 = norm1
self.norm2 = norm2
def __setstate__(self, state):
super(TransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, "activation"):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
x, stage_embedding = src, None
is_src_tuple = False
if isinstance(src, tuple):
x, stage_embedding = src
is_src_tuple = True
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
raise AssertionError("only bool and floating types of key_padding_mask are supported")
if self.norm_first:
x = x + self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,
cache=cache,
)
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
return (x, stage_embedding)
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
cache=None,
) -> Tensor:
# print(x.shape,attn_mask.shape,key_padding_mask)
# torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# import os
# os._exit(23333)
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
cache=cache,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class AdaptiveLayerNorm(nn.Module):
r"""Adaptive Layer Normalization"""
def __init__(self, d_model, norm) -> None:
super(AdaptiveLayerNorm, self).__init__()
self.project_layer = nn.Linear(d_model, 2 * d_model)
self.norm = norm
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return weight * self.norm(input) + bias
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

View File

@@ -0,0 +1,281 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
import copy
import numbers
from functools import partial
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from AR.modules.activation_onnx import MultiheadAttention
from AR.modules.scaling import BalancedDoubleSwish
from torch import nn
from torch import Tensor
from torch.nn import functional as F
_shape_t = Union[int, List[int], torch.Size]
class LayerNorm(nn.Module):
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
def extra_repr(self) -> str:
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
assert embedding is None
return input
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
cache=None,
) -> Tensor:
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model, # 512 16
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
if isinstance(activation, str):
activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
activation = activation(d_model)
elif activation == BalancedDoubleSwish:
activation = BalancedDoubleSwish(d_model)
self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
else:
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
else:
self.norm1 = norm1
self.norm2 = norm2
def __setstate__(self, state):
super(TransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, "activation"):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
cache=None,
) -> Tensor:
x = src
stage_embedding = None
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
return x
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
cache=None,
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
cache=cache,
)
return self.dropout1(x)
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class AdaptiveLayerNorm(nn.Module):
r"""Adaptive Layer Normalization"""
def __init__(self, d_model, norm) -> None:
super(AdaptiveLayerNorm, self).__init__()
self.project_layer = nn.Linear(d_model, 2 * d_model)
self.norm = norm
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return weight * self.norm(input) + bias
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

View File

@@ -0,0 +1,72 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
# reference: https://github.com/lifeiteng/vall-e
import itertools
import re
from typing import Dict
from typing import List
import regex
from gruut import sentences
from gruut.const import Sentence
from gruut.const import Word
from AR.text_processing.symbols import SYMBOL_TO_ID
class GruutPhonemizer:
def __init__(self, language: str):
self._phonemizer = sentences
self.lang = language
self.symbol_to_id = SYMBOL_TO_ID
self._special_cases_dict: Dict[str] = {
r"\.\.\.": "... ",
";": "; ",
":": ": ",
",": ", ",
r"\.": ". ",
"!": "! ",
r"\?": "? ",
"": "",
"": "",
"«": "«",
"»": "»",
}
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(r"\pZ+", r" ", text)
return text.strip()
def _convert_punctuation(self, word: Word) -> str:
if not word.phonemes:
return ""
if word.phonemes[0] in ["", "|"]:
return word.text.strip()
phonemes = "".join(word.phonemes)
# remove modifier characters ˈˌː with regex
phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
return phonemes.strip()
def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
return " ".join(words)
def transform(self, phonemes):
# convert phonemes to ids
# dictionary is in symbols.py
return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
if __name__ == "__main__":
phonemizer = GruutPhonemizer("en-us")
# text -> IPA
phonemes = phonemizer.phonemize("Hello, wor-ld ?")
print("phonemes:", phonemes)
print("len(phonemes):", len(phonemes))
phoneme_ids = phonemizer.transform(phonemes)
print("phoneme_ids:", phoneme_ids)
print("len(phoneme_ids):", len(phoneme_ids))

View File

@@ -0,0 +1,12 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
# reference: https://github.com/lifeiteng/vall-e
PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = (
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ")
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}

View File

@@ -0,0 +1,36 @@
import re
def str2bool(str):
return True if str.lower() == "true" else False
def get_newest_ckpt(string_list):
# 定义一个正则表达式模式,用于匹配字符串中的数字
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
extracted_info = []
for string in string_list:
match = re.match(pattern, string)
if match:
epoch = int(match.group(1))
step = int(match.group(2))
extracted_info.append((epoch, step, string))
# 按照 epoch 后面的数字和 step 后面的数字进行排序
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
# 获取最新的 ckpt 文件名
newest_ckpt = sorted_info[0][2]
return newest_ckpt
# 文本存在且不为空时 return True
def check_txt_file(file_path):
try:
with open(file_path, "r") as file:
text = file.readline().strip()
assert text.strip() != ""
return text
except Exception:
return False
return False

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
"""Initialize modules for espnet2 neural networks."""
import torch
from typeguard import check_argument_types
def initialize(model: torch.nn.Module, init: str):
"""Initialize weights of a neural network module.
Parameters are initialized using the given method or distribution.
Custom initialization routines can be implemented into submodules
as function `espnet_initialization_fn` within the custom module.
Args:
model: Target.
init: Method of initialization.
"""
assert check_argument_types()
print("init with", init)
# weight init
for p in model.parameters():
if p.dim() > 1:
if init == "xavier_uniform":
torch.nn.init.xavier_uniform_(p.data)
elif init == "xavier_normal":
torch.nn.init.xavier_normal_(p.data)
elif init == "kaiming_uniform":
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
elif init == "kaiming_normal":
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
else:
raise ValueError("Unknown initialization: " + init)
# bias init
for name, p in model.named_parameters():
if ".bias" in name and p.dim() == 1:
p.data.zero_()

View File

@@ -0,0 +1,30 @@
import sys
import torch
import yaml
def load_yaml_config(path):
with open(path) as f:
config = yaml.full_load(f)
return config
def save_config_to_yaml(config, path):
assert path.endswith(".yaml")
with open(path, "w") as f:
f.write(yaml.dump(config))
f.close()
def write_args(args, path):
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv))
args_file.write("\n==> args:\n")
for k, v in sorted(args_dict.items()):
args_file.write(" %s: %s\n" % (str(k), str(v)))
args_file.close()

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 NVIDIA CORPORATION.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,266 @@
## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
[[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bigvgan-a-universal-neural-vocoder-with-large/speech-synthesis-on-libritts)](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
## News
- **Sep 2024 (v2.4):**
- We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints.
- **Jul 2024 (v2.3):**
- General refactor and code improvements for improved readability.
- Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
- **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
- **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
- Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
- Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
- Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
- We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
## Installation
The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
```shell
conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
conda activate bigvgan
```
Clone the repository and install dependencies:
```shell
git clone https://github.com/NVIDIA/BigVGAN
cd BigVGAN
pip install -r requirements.txt
```
## Inference Quickstart using 🤗 Hugging Face Hub
Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
```python
device = 'cuda'
import torch
import bigvgan
import librosa
from meldataset import get_mel_spectrogram
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
# remove weight norm in the model and set to eval mode
model.remove_weight_norm()
model = model.eval().to(device)
# load wav file and compute mel spectrogram
wav_path = '/path/to/your/audio.wav'
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
# compute mel spectrogram from the ground truth audio
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
# generate waveform from mel
with torch.inference_mode():
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
# you can convert the generated waveform to 16 bit linear PCM
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
```
## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
You can run a local gradio demo using below command:
```python
pip install -r demo/requirements.txt
python demo/app.py
```
## Training
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
```shell
cd filelists/LibriTTS && \
ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
ln -s /path/to/your/LibriTTS/dev-other dev-other && \
ln -s /path/to/your/LibriTTS/test-clean test-clean && \
ln -s /path/to/your/LibriTTS/test-other test-other && \
cd ../..
```
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
```shell
python train.py \
--config configs/bigvgan_v2_24khz_100band_256x.json \
--input_wavs_dir filelists/LibriTTS \
--input_training_file filelists/LibriTTS/train-full.txt \
--input_validation_file filelists/LibriTTS/val-full.txt \
--list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
--list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
--checkpoint_path exp/bigvgan_v2_24khz_100band_256x
```
## Synthesis
Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
```shell
python inference.py \
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
--input_wavs_dir /path/to/your/input_wav \
--output_dir /path/to/your/output_wav
```
`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
```shell
python inference_e2e.py \
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
--input_mels_dir /path/to/your/input_mel \
--output_dir /path/to/your/output_wav
```
## Using Custom CUDA Kernel for Synthesis
You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
```python
generator = BigVGAN(h, use_cuda_kernel=True)
```
You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
```python
python tests/test_cuda_vs_torch_model.py \
--checkpoint_file /path/to/your/bigvgan_generator.pt
```
```shell
loading plain Pytorch BigVGAN
...
loading CUDA kernel BigVGAN with auto-build
Detected CUDA files, patching ldflags
Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
Building extension module anti_alias_activation_cuda...
...
Loading extension module anti_alias_activation_cuda...
...
Loading '/path/to/your/bigvgan_generator.pt'
...
[Success] test CUDA fused vs. plain torch BigVGAN inference
> mean_difference=0.0007238413265440613
...
```
If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
## Pretrained Models
We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
| Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
|:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
| [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No |
| [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No |
| [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No |
| [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No |
| [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No |
| [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
| [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
| [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
| [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
You can fine-tune the models by:
1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
## Training Details of BigVGAN-v2
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
## Evaluation Results of BigVGAN-v2
Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
| Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
|:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:|
| BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
| BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
| BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 |
| BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** |
## Speed Benchmark
Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
| GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
|:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
| NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
| | | True | 3916.5 | 163.2x | 1.3 |
| | 2048 | False | 1899.6 | 79.2x | 1.7 |
| | | True | 5330.1 | 222.1x | 1.7 |
| | 16384 | False | 1973.8 | 82.2x | 5.0 |
| | | True | 5761.7 | 240.1x | 4.4 |
| NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
| | | True | 1598.1 | 66.6x | 1.3 |
| | 2048 | False | 929.9 | 38.7x | 1.7 |
| | | True | 1971.3 | 82.1x | 1.6 |
| | 16384 | False | 943.4 | 39.3x | 5.0 |
| | | True | 2026.5 | 84.4x | 3.9 |
| NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
| | | True | 811.3 | 33.8x | 1.3 |
| | 2048 | False | 576.5 | 24.0x | 1.7 |
| | | True | 1023.0 | 42.6x | 1.5 |
| | 16384 | False | 589.4 | 24.6x | 5.0 |
| | | True | 1068.1 | 44.5x | 3.2 |
## Acknowledgements
We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
## References
- [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
- [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
- [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
- [Julius](https://github.com/adefossez/julius) (for low-pass filter)
- [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
- [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
- [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)

View File

@@ -0,0 +1,122 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
class Snake(nn.Module):
"""
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
"""
super(Snake, self).__init__()
self.in_features = in_features
# Initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # Log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # Linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super(SnakeBeta, self).__init__()
self.in_features = in_features
# Initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # Log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # Linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x

View File

@@ -0,0 +1,69 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import torch
import torch.nn as nn
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
from alias_free_activation.cuda import load
anti_alias_activation_cuda = load.load()
class FusedAntiAliasActivation(torch.autograd.Function):
"""
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
The hyperparameters are hard-coded in the kernel to maximize speed.
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
"""
@staticmethod
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
return activation_results
@staticmethod
def backward(ctx, output_grads):
raise NotImplementedError
return output_grads, None, None
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
fused: bool = True,
):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
self.fused = fused # Whether to use fused CUDA kernel or not
def forward(self, x):
if not self.fused:
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
else:
if self.act.__class__.__name__ == "Snake":
beta = self.act.alpha.data # Snake uses same params for alpha and beta
else:
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
alpha = self.act.alpha.data
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
alpha = torch.log(alpha)
beta = torch.log(beta)
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
return x

View File

@@ -0,0 +1,23 @@
/* coding=utf-8
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
}

View File

@@ -0,0 +1,246 @@
/* coding=utf-8
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "type_shim.h"
#include <assert.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace
{
// Hard-coded hyperparameters
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
constexpr int BUFFER_SIZE = 32;
constexpr int FILTER_SIZE = 12;
constexpr int HALF_FILTER_SIZE = 6;
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
template <typename input_t, typename output_t, typename acc_t>
__global__ void anti_alias_activation_forward(
output_t *dst,
const input_t *src,
const input_t *up_ftr,
const input_t *down_ftr,
const input_t *alpha,
const input_t *beta,
int batch_size,
int channels,
int seq_len)
{
// Up and downsample filters
input_t up_filter[FILTER_SIZE];
input_t down_filter[FILTER_SIZE];
// Load data from global memory including extra indices reserved for replication paddings
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
// Output stores downsampled output before writing to dst
output_t output[BUFFER_SIZE];
// blockDim/threadIdx = (128, 1, 1)
// gridDim/blockIdx = (seq_blocks, channels, batches)
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
int local_offset = threadIdx.x * BUFFER_SIZE;
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
// intermediate have double the seq_len
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
// Get values needed for replication padding before moving pointer
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
input_t seq_left_most_value = right_most_pntr[0];
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
// Move src and dst pointers
src += block_offset + local_offset;
dst += block_offset + local_offset;
// Alpha and beta values for snake activatons. Applies exp by default
alpha = alpha + blockIdx.y;
input_t alpha_val = expf(alpha[0]);
beta = beta + blockIdx.y;
input_t beta_val = expf(beta[0]);
#pragma unroll
for (int it = 0; it < FILTER_SIZE; it += 1)
{
up_filter[it] = up_ftr[it];
down_filter[it] = down_ftr[it];
}
// Apply replication padding for upsampling, matching torch impl
#pragma unroll
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
{
int element_index = seq_offset + it; // index for element
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
{
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
}
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
{
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
}
if ((element_index >= 0) && (element_index < seq_len))
{
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
}
}
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
#pragma unroll
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
{
input_t acc = 0.0;
int element_index = intermediate_seq_offset + it; // index for intermediate
#pragma unroll
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
{
if ((element_index + f_idx) >= 0)
{
acc += up_filter[f_idx] * elements[it + f_idx];
}
}
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
}
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
double no_div_by_zero = 0.000000001;
#pragma unroll
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
{
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
}
// Apply replication padding before downsampling conv from intermediates
#pragma unroll
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
{
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
}
#pragma unroll
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
{
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
}
// Apply downsample strided convolution (assuming stride=2) from intermediates
#pragma unroll
for (int it = 0; it < BUFFER_SIZE; it += 1)
{
input_t acc = 0.0;
#pragma unroll
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
{
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
}
output[it] = acc;
}
// Write output to dst
#pragma unroll
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
{
int element_index = seq_offset + it;
if (element_index < seq_len)
{
dst[it] = output[it];
}
}
}
template <typename input_t, typename output_t, typename acc_t>
void dispatch_anti_alias_activation_forward(
output_t *dst,
const input_t *src,
const input_t *up_ftr,
const input_t *down_ftr,
const input_t *alpha,
const input_t *beta,
int batch_size,
int channels,
int seq_len)
{
if (seq_len == 0)
{
return;
}
else
{
// Use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
constexpr int seq_len_per_block = 4096;
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
dim3 blocks(blocks_per_seq_len, channels, batch_size);
dim3 threads(threads_per_block, 1, 1);
anti_alias_activation_forward<input_t, output_t, acc_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
}
}
}
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
{
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
const int batches = input.size(0);
const int channels = input.size(1);
const int seq_len = input.size(2);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor anti_alias_activation_results =
torch::empty({batches, channels, seq_len}, act_options);
void *input_ptr = static_cast<void *>(input.data_ptr());
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
void *beta_ptr = static_cast<void *>(beta.data_ptr());
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch anti alias activation_forward",
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
reinterpret_cast<const scalar_t *>(input_ptr),
reinterpret_cast<const scalar_t *>(up_filter_ptr),
reinterpret_cast<const scalar_t *>(down_filter_ptr),
reinterpret_cast<const scalar_t *>(alpha_ptr),
reinterpret_cast<const scalar_t *>(beta_ptr),
batches,
channels,
seq_len););
return anti_alias_activation_results;
}

View File

@@ -0,0 +1,29 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif

View File

@@ -0,0 +1,82 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import pathlib
import subprocess
from torch.utils import cpp_extension
"""
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
"""
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def load():
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / "build"
_create_build_dir(buildpath)
# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=[
"-O3",
],
extra_cuda_cflags=[
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"--use_fast_math",
]
+ extra_cuda_flags
+ cc_flag,
verbose=True,
)
extra_cuda_flags = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]
sources = [
srcpath / "anti_alias_activation.cpp",
srcpath / "anti_alias_activation_cuda.cu",
]
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
return anti_alias_activation_cuda
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")

View File

@@ -0,0 +1,92 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch (TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}

View File

@@ -0,0 +1,6 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
from .filter import *
from .resample import *
from .act import *

View File

@@ -0,0 +1,30 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from .resample import UpSample1d, DownSample1d
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x

View File

@@ -0,0 +1,99 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
if "sinc" in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = kernel_size % 2 == 0
half_size = kernel_size // 2
# For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
"""
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
"""
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = "replicate",
kernel_size: int = 12,
):
"""
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
"""
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
# Input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
return out

View File

@@ -0,0 +1,48 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x):
xx = self.lowpass(x)
return xx

View File

@@ -0,0 +1,461 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import os
import json
from pathlib import Path
from typing import Optional, Union, Dict
import torch
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
from . import activations
from .utils0 import init_weights, get_padding
from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
from .env import AttrDict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
def load_hparams_from_json(path) -> AttrDict:
with open(path) as f:
data = f.read()
return AttrDict(json.loads(data))
class AMPBlock1(torch.nn.Module):
"""
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
Args:
h (AttrDict): Hyperparameters.
channels (int): Number of convolution channels.
kernel_size (int): Size of the convolution kernel. Default is 3.
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
"""
def __init__(
self,
h: AttrDict,
channels: int,
kernel_size: int = 3,
dilation: tuple = (1, 3, 5),
activation: str = None,
):
super().__init__()
self.h = h
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=get_padding(kernel_size, d),
)
)
for d in dilation
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
)
for _ in range(len(dilation))
]
)
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
# Activation functions
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class AMPBlock2(torch.nn.Module):
"""
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
Args:
h (AttrDict): Hyperparameters.
channels (int): Number of convolution channels.
kernel_size (int): Size of the convolution kernel. Default is 3.
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
"""
def __init__(
self,
h: AttrDict,
channels: int,
kernel_size: int = 3,
dilation: tuple = (1, 3, 5),
activation: str = None,
):
super().__init__()
self.h = h
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=get_padding(kernel_size, d),
)
)
for d in dilation
]
)
self.convs.apply(init_weights)
self.num_layers = len(self.convs) # Total number of conv layers
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
# Activation functions
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class BigVGAN(
torch.nn.Module,
PyTorchModelHubMixin,
# library_name="bigvgan",
# repo_url="https://github.com/NVIDIA/BigVGAN",
# docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
# pipeline_tag="audio-to-audio",
# license="mit",
# tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
):
"""
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
Args:
h (AttrDict): Hyperparameters.
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
Note:
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
"""
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
super().__init__()
self.h = h
self.h["use_cuda_kernel"] = use_cuda_kernel
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# Pre-conv
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
if h.resblock == "1":
resblock_class = AMPBlock1
elif h.resblock == "2":
resblock_class = AMPBlock2
else:
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
# Transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList(
[
weight_norm(
ConvTranspose1d(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
]
)
)
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
# Post-conv
activation_post = (
activations.Snake(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snake"
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
)
if activation_post is None:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
self.activation_post = Activation1d(activation=activation_post)
# Whether to use bias for the final conv_post. Default to True for backward compatibility
self.use_bias_at_final = h.get("use_bias_at_final", True)
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
# Weight initialization
for i in range(len(self.ups)):
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)
# Final tanh activation. Defaults to True for backward compatibility
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
def forward(self, x):
# Pre-conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# Upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# Post-conv
x = self.activation_post(x)
x = self.conv_post(x)
# Final tanh activation
if self.use_tanh_at_final:
x = torch.tanh(x)
else:
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
return x
def remove_weight_norm(self):
try:
# print("Removing weight norm...")
for l in self.ups:
for l_i in l:
remove_weight_norm(l_i)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
except ValueError:
print("[INFO] Model already removed weight norm. Skipping!")
pass
# Additional methods for huggingface_hub support
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights and config.json from a Pytorch model to a local directory."""
model_path = save_directory / "bigvgan_generator.pt"
torch.save({"generator": self.state_dict()}, model_path)
config_path = save_directory / "config.json"
with open(config_path, "w") as config_file:
json.dump(self.h, config_file, indent=4)
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: str,
cache_dir: str,
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu", # Additional argument
strict: bool = False, # Additional argument
use_cuda_kernel: bool = False,
**model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""
# Download and load hyperparameters (h) used by BigVGAN
if os.path.isdir(model_id):
# print("Loading config.json from local directory")
config_file = os.path.join(model_id, "config.json")
else:
config_file = hf_hub_download(
repo_id=model_id,
filename="config.json",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
h = load_hparams_from_json(config_file)
# instantiate BigVGAN using h
if use_cuda_kernel:
print(
"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
)
print(
"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
)
print(
"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
)
model = cls(h, use_cuda_kernel=use_cuda_kernel)
# Download and load pretrained generator weight
if os.path.isdir(model_id):
# print("Loading weights from local directory")
model_file = os.path.join(model_id, "bigvgan_generator.pt")
else:
# print(f"Loading weights from {model_id}")
model_file = hf_hub_download(
repo_id=model_id,
filename="bigvgan_generator.pt",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
checkpoint_dict = torch.load(model_file, map_location=map_location)
try:
model.load_state_dict(checkpoint_dict["generator"])
except RuntimeError:
print(
"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
)
model.remove_weight_norm()
model.load_state_dict(checkpoint_dict["generator"])
return model

View File

@@ -0,0 +1,45 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"activation": "snakebeta",
"snake_logscale": true,
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,45 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"activation": "snakebeta",
"snake_logscale": true,
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"segment_size": 8192,
"num_mels": 100,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 24000,
"fmin": 0,
"fmax": 12000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,45 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [8,8,2,2],
"upsample_kernel_sizes": [16,16,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"activation": "snakebeta",
"snake_logscale": true,
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,45 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [8,8,2,2],
"upsample_kernel_sizes": [16,16,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"activation": "snakebeta",
"snake_logscale": true,
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"segment_size": 8192,
"num_mels": 100,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 24000,
"fmin": 0,
"fmax": 12000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,61 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 4,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"use_tanh_at_final": false,
"use_bias_at_final": false,
"activation": "snakebeta",
"snake_logscale": true,
"use_cqtd_instead_of_mrd": true,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": true,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": null,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,61 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 4,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"use_tanh_at_final": false,
"use_bias_at_final": false,
"activation": "snakebeta",
"snake_logscale": true,
"use_cqtd_instead_of_mrd": true,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": true,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,61 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 4,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"use_tanh_at_final": false,
"use_bias_at_final": false,
"activation": "snakebeta",
"snake_logscale": true,
"use_cqtd_instead_of_mrd": true,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": true,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 100,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 24000,
"fmin": 0,
"fmax": null,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,61 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 4,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"use_tanh_at_final": false,
"use_bias_at_final": false,
"activation": "snakebeta",
"snake_logscale": true,
"use_cqtd_instead_of_mrd": true,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": true,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 128,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 44100,
"fmin": 0,
"fmax": null,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,61 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 4,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [8,4,2,2,2,2],
"upsample_kernel_sizes": [16,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"use_tanh_at_final": false,
"use_bias_at_final": false,
"activation": "snakebeta",
"snake_logscale": true,
"use_cqtd_instead_of_mrd": true,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": true,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 128,
"num_freq": 2049,
"n_fft": 2048,
"hop_size": 512,
"win_size": 2048,
"sampling_rate": 44100,
"fmin": 0,
"fmax": null,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,625 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv2d
from torch.nn.utils import weight_norm, spectral_norm
from torchaudio.transforms import Spectrogram, Resample
from env import AttrDict
from utils import get_padding
import typing
from typing import List, Tuple
class DiscriminatorP(torch.nn.Module):
def __init__(
self,
h: AttrDict,
period: List[int],
kernel_size: int = 5,
stride: int = 3,
use_spectral_norm: bool = False,
):
super().__init__()
self.period = period
self.d_mult = h.discriminator_channel_mult
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(
Conv2d(
1,
int(32 * self.d_mult),
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
int(32 * self.d_mult),
int(128 * self.d_mult),
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
int(128 * self.d_mult),
int(512 * self.d_mult),
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
int(512 * self.d_mult),
int(1024 * self.d_mult),
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
int(1024 * self.d_mult),
int(1024 * self.d_mult),
(kernel_size, 1),
1,
padding=(2, 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, h: AttrDict):
super().__init__()
self.mpd_reshapes = h.mpd_reshapes
print(f"mpd_reshapes: {self.mpd_reshapes}")
self.discriminators = nn.ModuleList(
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(self, cfg: AttrDict, resolution: List[List[int]]):
super().__init__()
self.resolution = resolution
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
self.lrelu_slope = 0.1
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
if hasattr(cfg, "mrd_use_spectral_norm"):
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
self.d_mult = cfg.discriminator_channel_mult
if hasattr(cfg, "mrd_channel_mult"):
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
self.d_mult = cfg.mrd_channel_mult
self.convs = nn.ModuleList(
[
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
norm_f(
nn.Conv2d(
int(32 * self.d_mult),
int(32 * self.d_mult),
(3, 9),
stride=(1, 2),
padding=(1, 4),
)
),
norm_f(
nn.Conv2d(
int(32 * self.d_mult),
int(32 * self.d_mult),
(3, 9),
stride=(1, 2),
padding=(1, 4),
)
),
norm_f(
nn.Conv2d(
int(32 * self.d_mult),
int(32 * self.d_mult),
(3, 9),
stride=(1, 2),
padding=(1, 4),
)
),
norm_f(
nn.Conv2d(
int(32 * self.d_mult),
int(32 * self.d_mult),
(3, 3),
padding=(1, 1),
)
),
]
)
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.lrelu_slope)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
n_fft, hop_length, win_length = self.resolution
x = F.pad(
x,
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
x = x.squeeze(1)
x = torch.stft(
x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=False,
return_complex=True,
)
x = torch.view_as_real(x) # [B, F, TT, 2]
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
return mag
class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False):
super().__init__()
self.resolutions = cfg.resolutions
assert len(self.resolutions) == 3, (
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
)
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
# LICENSE is in incl_licenses directory.
class DiscriminatorB(nn.Module):
def __init__(
self,
window_length: int,
channels: int = 32,
hop_factor: float = 0.25,
bands: Tuple[Tuple[float, float], ...] = (
(0.0, 0.1),
(0.1, 0.25),
(0.25, 0.5),
(0.5, 0.75),
(0.75, 1.0),
),
):
super().__init__()
self.window_length = window_length
self.hop_factor = hop_factor
self.spec_fn = Spectrogram(
n_fft=window_length,
hop_length=int(window_length * hop_factor),
win_length=window_length,
power=None,
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.spec_fn(x)
x = torch.view_as_real(x)
x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
# Split into bands
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
return x_bands
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x_bands = self.spectrogram(x.squeeze(1))
fmap = []
x = []
for band, stack in zip(x_bands, self.band_convs):
for i, layer in enumerate(stack):
band = layer(band)
band = torch.nn.functional.leaky_relu(band, 0.1)
if i > 0:
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
x = self.conv_post(x)
fmap.append(x)
return x, fmap
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
# LICENSE is in incl_licenses directory.
class MultiBandDiscriminator(nn.Module):
def __init__(
self,
h,
):
"""
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
and the modified code adapted from https://github.com/gemelo-ai/vocos.
"""
super().__init__()
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
# LICENSE is in incl_licenses directory.
class DiscriminatorCQT(nn.Module):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
super().__init__()
self.cfg = cfg
self.filters = cfg["cqtd_filters"]
self.max_filters = cfg["cqtd_max_filters"]
self.filters_scale = cfg["cqtd_filters_scale"]
self.kernel_size = (3, 9)
self.dilations = cfg["cqtd_dilations"]
self.stride = (1, 2)
self.in_channels = cfg["cqtd_in_channels"]
self.out_channels = cfg["cqtd_out_channels"]
self.fs = cfg["sampling_rate"]
self.hop_length = hop_length
self.n_octaves = n_octaves
self.bins_per_octave = bins_per_octave
# Lazy-load
from nnAudio import features
self.cqt_transform = features.cqt.CQT2010v2(
sr=self.fs * 2,
hop_length=self.hop_length,
n_bins=self.bins_per_octave * self.n_octaves,
bins_per_octave=self.bins_per_octave,
output_format="Complex",
pad_mode="constant",
)
self.conv_pres = nn.ModuleList()
for _ in range(self.n_octaves):
self.conv_pres.append(
nn.Conv2d(
self.in_channels * 2,
self.in_channels * 2,
kernel_size=self.kernel_size,
padding=self.get_2d_padding(self.kernel_size),
)
)
self.convs = nn.ModuleList()
self.convs.append(
nn.Conv2d(
self.in_channels * 2,
self.filters,
kernel_size=self.kernel_size,
padding=self.get_2d_padding(self.kernel_size),
)
)
in_chs = min(self.filters_scale * self.filters, self.max_filters)
for i, dilation in enumerate(self.dilations):
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
self.convs.append(
weight_norm(
nn.Conv2d(
in_chs,
out_chs,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=(dilation, 1),
padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
)
)
)
in_chs = out_chs
out_chs = min(
(self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
self.max_filters,
)
self.convs.append(
weight_norm(
nn.Conv2d(
in_chs,
out_chs,
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
)
)
)
self.conv_post = weight_norm(
nn.Conv2d(
out_chs,
self.out_channels,
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
)
)
self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2)
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
if self.cqtd_normalize_volume:
print(
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
)
def get_2d_padding(
self,
kernel_size: typing.Tuple[int, int],
dilation: typing.Tuple[int, int] = (1, 1),
):
return (
((kernel_size[0] - 1) * dilation[0]) // 2,
((kernel_size[1] - 1) * dilation[1]) // 2,
)
def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
if self.cqtd_normalize_volume:
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.resample(x)
z = self.cqt_transform(x)
z_amplitude = z[:, :, :, 0].unsqueeze(1)
z_phase = z[:, :, :, 1].unsqueeze(1)
z = torch.cat([z_amplitude, z_phase], dim=1)
z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
latent_z = []
for i in range(self.n_octaves):
latent_z.append(
self.conv_pres[i](
z[
:,
:,
:,
i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
]
)
)
latent_z = torch.cat(latent_z, dim=-1)
for i, l in enumerate(self.convs):
latent_z = l(latent_z)
latent_z = self.activation(latent_z)
fmap.append(latent_z)
latent_z = self.conv_post(latent_z)
return latent_z, fmap
class MultiScaleSubbandCQTDiscriminator(nn.Module):
def __init__(self, cfg: AttrDict):
super().__init__()
self.cfg = cfg
# Using get with defaults
self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32)
self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024)
self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1)
self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4])
self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1)
self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1)
# Multi-scale params to loop over
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
self.discriminators = nn.ModuleList(
[
DiscriminatorCQT(
self.cfg,
hop_length=self.cfg["cqtd_hop_lengths"][i],
n_octaves=self.cfg["cqtd_n_octaves"][i],
bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i],
)
for i in range(len(self.cfg["cqtd_hop_lengths"]))
]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for disc in self.discriminators:
y_d_r, fmap_r = disc(y)
y_d_g, fmap_g = disc(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class CombinedDiscriminator(nn.Module):
"""
Wrapper of chaining multiple discrimiantor architectures.
Example: combine mbd and cqtd as a single class
"""
def __init__(self, list_discriminator: List[nn.Module]):
super().__init__()
self.discrimiantor = nn.ModuleList(list_discriminator)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for disc in self.discrimiantor:
y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat)
y_d_rs.extend(y_d_r)
fmap_rs.extend(fmap_r)
y_d_gs.extend(y_d_g)
fmap_gs.extend(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

View File

@@ -0,0 +1,18 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import os
import shutil
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def build_env(config, config_name, path):
t_path = os.path.join(path, config_name)
if config != t_path:
os.makedirs(path, exist_ok=True)
shutil.copyfile(config, os.path.join(path, config_name))

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020 Jungil Kong
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020 Edward Dixon
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2019, Seungwon Park 박승원
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,16 @@
Copyright 2020 Alexandre Défossez
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or
substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023-present, Descript
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Charactr Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Amphion
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,85 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import argparse
import json
import torch
import librosa
from utils import load_checkpoint
from meldataset import get_mel_spectrogram
from scipy.io.wavfile import write
from env import AttrDict
from meldataset import MAX_WAV_VALUE
from bigvgan import BigVGAN as Generator
h = None
device = None
torch.backends.cudnn.benchmark = False
def inference(a, h):
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
state_dict_g = load_checkpoint(a.checkpoint_file, device)
generator.load_state_dict(state_dict_g["generator"])
filelist = os.listdir(a.input_wavs_dir)
os.makedirs(a.output_dir, exist_ok=True)
generator.eval()
generator.remove_weight_norm()
with torch.no_grad():
for i, filname in enumerate(filelist):
# Load the ground truth audio and resample if necessary
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
wav = torch.FloatTensor(wav).to(device)
# Compute mel spectrogram from the ground truth audio
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
y_g_hat = generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)
def main():
print("Initializing Inference Process..")
parser = argparse.ArgumentParser()
parser.add_argument("--input_wavs_dir", default="test_files")
parser.add_argument("--output_dir", default="generated_files")
parser.add_argument("--checkpoint_file", required=True)
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
a = parser.parse_args()
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
with open(config_file) as f:
data = f.read()
global h
json_config = json.loads(data)
h = AttrDict(json_config)
torch.manual_seed(h.seed)
global device
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
device = torch.device("cuda")
else:
device = torch.device("cpu")
inference(a, h)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,100 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
from __future__ import absolute_import, division, print_function, unicode_literals
import glob
import os
import numpy as np
import argparse
import json
import torch
from scipy.io.wavfile import write
from env import AttrDict
from meldataset import MAX_WAV_VALUE
from bigvgan import BigVGAN as Generator
h = None
device = None
torch.backends.cudnn.benchmark = False
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + "*")
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return ""
return sorted(cp_list)[-1]
def inference(a, h):
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
state_dict_g = load_checkpoint(a.checkpoint_file, device)
generator.load_state_dict(state_dict_g["generator"])
filelist = os.listdir(a.input_mels_dir)
os.makedirs(a.output_dir, exist_ok=True)
generator.eval()
generator.remove_weight_norm()
with torch.no_grad():
for i, filname in enumerate(filelist):
# Load the mel spectrogram in .npy format
x = np.load(os.path.join(a.input_mels_dir, filname))
x = torch.FloatTensor(x).to(device)
if len(x.shape) == 2:
x = x.unsqueeze(0)
y_g_hat = generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)
def main():
print("Initializing Inference Process..")
parser = argparse.ArgumentParser()
parser.add_argument("--input_mels_dir", default="test_mel_files")
parser.add_argument("--output_dir", default="generated_files_from_mel")
parser.add_argument("--checkpoint_file", required=True)
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
a = parser.parse_args()
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
with open(config_file) as f:
data = f.read()
global h
json_config = json.loads(data)
h = AttrDict(json_config)
torch.manual_seed(h.seed)
global device
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
device = torch.device("cuda")
else:
device = torch.device("cpu")
inference(a, h)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,238 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from librosa.filters import mel as librosa_mel_fn
from scipy import signal
import typing
from typing import List, Tuple
from collections import namedtuple
import math
import functools
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
# LICENSE is in incl_licenses directory.
class MultiScaleMelSpectrogramLoss(nn.Module):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
window_lengths : List[int], optional
Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 1.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
def __init__(
self,
sampling_rate: int,
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 0.0,
log_weight: float = 1.0,
pow: float = 1.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
mel_fmax: List[float] = [None, None, None, None, None, None, None],
window_type: str = "hann",
):
super().__init__()
self.sampling_rate = sampling_rate
STFTParams = namedtuple(
"STFTParams",
["window_length", "hop_length", "window_type", "match_stride"],
)
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow
@staticmethod
@functools.lru_cache(None)
def get_window(
window_type,
window_length,
):
return signal.get_window(window_type, window_length)
@staticmethod
@functools.lru_cache(None)
def get_mel_filters(sr, n_fft, n_mels, fmin, fmax):
return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
def mel_spectrogram(
self,
wav,
n_mels,
fmin,
fmax,
window_length,
hop_length,
match_stride,
window_type,
):
"""
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
B, C, T = wav.shape
if match_stride:
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(T / hop_length) * hop_length - T
pad = (window_length - hop_length) // 2
else:
right_pad = 0
pad = 0
wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect")
window = self.get_window(window_type, window_length)
window = torch.from_numpy(window).to(wav.device).float()
stft = torch.stft(
wav.reshape(-1, T),
n_fft=window_length,
hop_length=hop_length,
window=window,
return_complex=True,
center=True,
)
_, nf, nt = stft.shape
stft = stft.reshape(B, C, nf, nt)
if match_stride:
"""
Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples.
"""
stft = stft[..., 2:-2]
magnitude = torch.abs(stft)
nf = magnitude.shape[2]
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
return mel_spectrogram
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes mel loss between an estimate and a reference
signal.
Parameters
----------
x : torch.Tensor
Estimate signal
y : torch.Tensor
Reference signal
Returns
-------
torch.Tensor
Mel loss.
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
kwargs = {
"n_mels": n_mels,
"fmin": fmin,
"fmax": fmax,
"window_length": s.window_length,
"hop_length": s.hop_length,
"match_stride": s.match_stride,
"window_type": s.window_type,
}
x_mels = self.mel_spectrogram(x, **kwargs)
y_mels = self.mel_spectrogram(y, **kwargs)
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
return loss
# Loss functions
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2 # This equates to lambda=2.0 for the feature matching loss
def discriminator_loss(
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(
disc_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses

View File

@@ -0,0 +1,370 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import math
import os
import random
import torch
import torch.utils.data
import numpy as np
import librosa
from librosa.filters import mel as librosa_mel_fn
import pathlib
from tqdm import tqdm
from typing import List, Tuple, Optional
from .env import AttrDict
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
return dynamic_range_compression_torch(magnitudes)
def spectral_de_normalize_torch(magnitudes):
return dynamic_range_decompression_torch(magnitudes)
mel_basis_cache = {}
hann_window_cache = {}
def mel_spectrogram(
y: torch.Tensor,
n_fft: int,
num_mels: int,
sampling_rate: int,
hop_size: int,
win_size: int,
fmin: int,
fmax: int = None,
center: bool = False,
) -> torch.Tensor:
"""
Calculate the mel spectrogram of an input signal.
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
Args:
y (torch.Tensor): Input signal.
n_fft (int): FFT size.
num_mels (int): Number of mel bins.
sampling_rate (int): Sampling rate of the input signal.
hop_size (int): Hop size for STFT.
win_size (int): Window size for STFT.
fmin (int): Minimum frequency for mel filterbank.
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
center (bool): Whether to pad the input to center the frames. Default is False.
Returns:
torch.Tensor: Mel spectrogram.
"""
if torch.min(y) < -1.0:
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
if torch.max(y) > 1.0:
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
device = y.device
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
hann_window_cache[key] = torch.hann_window(win_size).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = spectral_normalize_torch(mel_spec)
return mel_spec
def get_mel_spectrogram(wav, h):
"""
Generate mel spectrogram from a waveform using given hyperparameters.
Args:
wav (torch.Tensor): Input waveform.
h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
Returns:
torch.Tensor: Mel spectrogram.
"""
return mel_spectrogram(
wav,
h.n_fft,
h.num_mels,
h.sampling_rate,
h.hop_size,
h.win_size,
h.fmin,
h.fmax,
)
def get_dataset_filelist(a):
training_files = []
validation_files = []
list_unseen_validation_files = []
with open(a.input_training_file, "r", encoding="utf-8") as fi:
training_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first training file: {training_files[0]}")
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
validation_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first validation file: {validation_files[0]}")
for i in range(len(a.list_input_unseen_validation_file)):
with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
unseen_validation_files = [
os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
]
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
list_unseen_validation_files.append(unseen_validation_files)
return training_files, validation_files, list_unseen_validation_files
class MelDataset(torch.utils.data.Dataset):
def __init__(
self,
training_files: List[str],
hparams: AttrDict,
segment_size: int,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
sampling_rate: int,
fmin: int,
fmax: Optional[int],
split: bool = True,
shuffle: bool = True,
device: str = None,
fmax_loss: Optional[int] = None,
fine_tuning: bool = False,
base_mels_path: str = None,
is_seen: bool = True,
):
self.audio_files = training_files
random.seed(1234)
if shuffle:
random.shuffle(self.audio_files)
self.hparams = hparams
self.is_seen = is_seen
if self.is_seen:
self.name = pathlib.Path(self.audio_files[0]).parts[0]
else:
self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
self.segment_size = segment_size
self.sampling_rate = sampling_rate
self.split = split
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.fmax_loss = fmax_loss
self.device = device
self.fine_tuning = fine_tuning
self.base_mels_path = base_mels_path
print("[INFO] checking dataset integrity...")
for i in tqdm(range(len(self.audio_files))):
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
try:
filename = self.audio_files[index]
# Use librosa.load that ensures loading waveform into mono with [-1, 1] float values
# Audio is ndarray with shape [T_time]. Disable auto-resampling here to minimize overhead
# The on-the-fly resampling during training will be done only for the obtained random chunk
audio, source_sampling_rate = librosa.load(filename, sr=None, mono=True)
# Main logic that uses <mel, audio> pair for training BigVGAN
if not self.fine_tuning:
if self.split: # Training step
# Obtain randomized audio chunk
if source_sampling_rate != self.sampling_rate:
# Adjust segment size to crop if the source sr is different
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
else:
target_segment_size = self.segment_size
# Compute upper bound index for the random chunk
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
# Crop or pad audio to obtain random chunk with target_segment_size
if audio.shape[0] >= target_segment_size:
audio_start = random.randint(0, random_chunk_upper_bound)
audio = audio[audio_start : audio_start + target_segment_size]
else:
audio = np.pad(
audio,
(0, target_segment_size - audio.shape[0]),
mode="constant",
)
# Resample audio chunk to self.sampling rate
if source_sampling_rate != self.sampling_rate:
audio = librosa.resample(
audio,
orig_sr=source_sampling_rate,
target_sr=self.sampling_rate,
)
if audio.shape[0] > self.segment_size:
# trim last elements to match self.segment_size (e.g., 16385 for 44khz downsampled to 24khz -> 16384)
audio = audio[: self.segment_size]
else: # Validation step
# Resample full audio clip to target sampling rate
if source_sampling_rate != self.sampling_rate:
audio = librosa.resample(
audio,
orig_sr=source_sampling_rate,
target_sr=self.sampling_rate,
)
# Trim last elements to match audio length to self.hop_size * n for evaluation
if (audio.shape[0] % self.hop_size) != 0:
audio = audio[: -(audio.shape[0] % self.hop_size)]
# BigVGAN is trained using volume-normalized waveform
audio = librosa.util.normalize(audio) * 0.95
# Cast ndarray to torch tensor
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0) # [B(1), self.segment_size]
# Compute mel spectrogram corresponding to audio
mel = mel_spectrogram(
audio,
self.n_fft,
self.num_mels,
self.sampling_rate,
self.hop_size,
self.win_size,
self.fmin,
self.fmax,
center=False,
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
# Fine-tuning logic that uses pre-computed mel. Example: Using TTS model-generated mel as input
else:
# For fine-tuning, assert that the waveform is in the defined sampling_rate
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
assert source_sampling_rate == self.sampling_rate, (
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
)
# Cast ndarray to torch tensor
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0) # [B(1), T_time]
# Load pre-computed mel from disk
mel = np.load(
os.path.join(
self.base_mels_path,
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
)
)
mel = torch.from_numpy(mel)
if len(mel.shape) < 3:
mel = mel.unsqueeze(0) # ensure [B, C, T]
if self.split:
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
if audio.size(1) >= self.segment_size:
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
audio = audio[
:,
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
]
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
mel_loss = mel_spectrogram(
audio,
self.n_fft,
self.num_mels,
self.sampling_rate,
self.hop_size,
self.win_size,
self.fmin,
self.fmax_loss,
center=False,
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
# Shape sanity checks
assert (
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
), (
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
# If it encounters error during loading the data, skip this sample and load random other sample to the batch
except Exception as e:
if self.fine_tuning:
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
else:
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
return self[random.randrange(len(self))]
def __len__(self):
return len(self.audio_files)

View File

@@ -0,0 +1,4 @@
| Field | Response |
| :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- |
| Participation considerations from adversely impacted groups protected classes in model design and testing: | None |
| Measures taken to mitigate against unwanted bias: | No measures taken to mitigate against unwanted bias. |

View File

@@ -0,0 +1,13 @@
| Field | Response |
| :---------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Intended Application & Domain: | Generating waveform from mel spectrogram. |
| Model Type: | Convolutional Neural Network (CNN) |
| Intended Users: | This model is intended for developers to synthesize and generate waveforms from the AI-generated mel spectrograms. |
| Output: | Audio Waveform |
| Describe how the model works: | Model generates audio waveform corresponding to the input mel spectrogram. |
| Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable |
| Technical Limitations: | This may not perform well on synthetically-generated mel spectrograms that deviate significantly from the profile of mel spectrograms on which this was trained. |
| Verified to have met prescribed NVIDIA quality standards: | Yes |
| Performance Metrics: | Perceptual Evaluation of Speech Quality (PESQ), Virtual Speech Quality Objective Listener (VISQOL), Multi-resolution STFT (MRSTFT), Mel cepstral distortion (MCD), Periodicity RMSE, Voice/Unvoiced F1 Score (V/UV F1) |
| Potential Known Risks: | This model may generate low-quality or distorted soundwaves. |
| Licensing: | https://github.com/NVIDIA/BigVGAN/blob/main/LICENSE |

View File

@@ -0,0 +1,126 @@
# Model Overview
## Description:
BigVGAN is a generative AI model specialized in synthesizing audio waveforms using Mel spectrogram as inputs.
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
BigVGAN is a fully convolutional architecture with several upsampling blocks using transposed convolution followed by multiple residual dilated convolution layers.
BigVGAN consists of a novel module, called anti-aliased multi-periodicity composition (AMP), which is specifically designed for generating waveforms. AMP is specialized in synthesizing high-frequency and periodic soundwaves drawing inspiration from audio signal processing principles.
It applies a periodic activation function, called Snake, which provides an inductive bias to the architecture in generating periodic soundwaves. It also applies anti-aliasing filters to reduce undesired artifacts in the generated waveforms. <br>
This model is ready for commercial use.<br>
## References(s):
- [BigVGAN: A Universal Neural Vocoder with Large-Scale Training](https://arxiv.org/abs/2206.04658) <br>
- [Project Page](https://research.nvidia.com/labs/adlr/projects/bigvgan/) <br>
- [Audio Demo](https://bigvgan-demo.github.io/) <br>
## Model Architecture:
**Architecture Type:** Convolution Neural Network (CNN) <br>
**Network Architecture:** You can see the details of this model on this link: https://github.com/NVIDIA/BigVGAN and the related paper can be found here: https://arxiv.org/abs/2206.04658<br>
**Model Version:** 2.0 <br>
## Input:
**Input Type:** Audio <br>
**Input Format:** Mel Spectrogram <br>
**Input Parameters:** None <br>
**Other Properties Related to Input:** The input mel spectrogram has shape `[batch, channels, frames]`, where `channels` refers to the number of mel bands defined by the model and `frames` refers to the temporal length. The model supports arbitrary long `frames` that fits into the GPU memory.
## Output:
**Input Type:** Audio <br>
**Output Format:** Audio Waveform <br>
**Output Parameters:** None <br>
**Other Properties Related to Output:** The output audio waveform has shape `[batch, 1, time]`, where `1` refers to the mono audio channels and `time` refers to the temporal length. `time` is defined as a fixed integer multiple of input `frames`, which is an upsampling ratio of the model (`time = upsampling ratio * frames`). The output audio waveform consitutes float values with a range of `[-1, 1]`.
## Software Integration:
**Runtime Engine(s):** PyTorch
**Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper, NVIDIA Lovelace, NVIDIA Turing, NVIDIA Volta <br>
## Preferred/Supported Operating System(s):
Linux
## Model Version(s):
v2.0
## Training, Testing, and Evaluation Datasets:
### Training Dataset:
The dataset contains diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
**Links:**
- [AAM: Artificial Audio Multitracks Dataset](https://zenodo.org/records/5794629)
- [AudioCaps](https://audiocaps.github.io/)
- [AudioSet](https://research.google.com/audioset/index.html)
- [common-accent](https://huggingface.co/datasets/DTU54DL/common-accent)
- [Crowd Sourced Emotional Multimodal Actors Dataset (CREMA-D)](https://ieeexplore.ieee.org/document/6849440)
- [DCASE2017 Challenge, Task 4: Large-scale weakly supervised sound event detection for smart cars](https://dcase.community/challenge2017/task-large-scale-sound-event-detection)
- [FSDnoisy18k](https://zenodo.org/records/2529934)
- [Free Universal Sound Separation Dataset](https://zenodo.org/records/3694384)
- [Greatest Hits dataset](https://andrewowens.com/vis/)
- [GTZAN](https://ieeexplore.ieee.org/document/1021072)
- [JL corpus](https://www.kaggle.com/datasets/tli725/jl-corpus)
- [Medley-solos-DB: a cross-collection dataset for musical instrument recognition](https://zenodo.org/records/3464194)
- [MUSAN: A Music, Speech, and Noise Corpus](https://www.openslr.org/17/)
- [MusicBench](https://huggingface.co/datasets/amaai-lab/MusicBench)
- [MusicCaps](https://www.kaggle.com/datasets/googleai/musiccaps)
- [MusicNet](https://www.kaggle.com/datasets/imsparsh/musicnet-dataset)
- [NSynth](https://magenta.tensorflow.org/datasets/nsynth)
- [OnAir-Music-Dataset](https://github.com/sevagh/OnAir-Music-Dataset)
- [Audio Piano Triads Dataset](https://zenodo.org/records/4740877)
- [Pitch Audio Dataset (Surge synthesizer)](https://zenodo.org/records/4677097)
- [SONYC Urban Sound Tagging (SONYC-UST): a multilabel dataset from an urban acoustic sensor network](https://zenodo.org/records/3966543)
- [VocalSound: A Dataset for Improving Human Vocal Sounds Recognition](https://arxiv.org/abs/2205.03433)
- [WavText5K](https://github.com/microsoft/WavText5K)
- [CSS10: A Collection of Single Speaker Speech Datasets for 10 Languages](https://github.com/Kyubyong/css10)
- [Hi-Fi Multi-Speaker English TTS Dataset (Hi-Fi TTS)](https://www.openslr.org/109/)
- [IIIT-H Indic Speech Databases](http://festvox.org/databases/iiit_voices/)
- [Libri-Light: A Benchmark for ASR with Limited or No Supervision](https://arxiv.org/abs/1912.07875)
- [LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech](https://www.openslr.org/60)
- [LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus](https://www.openslr.org/141/)
- [The SIWIS French Speech Synthesis Database](https://datashare.ed.ac.uk/handle/10283/2353)
- [Crowdsourced high-quality Colombian Spanish speech data set](https://openslr.org/72/)
- [TTS-Portuguese Corpus](https://github.com/Edresson/TTS-Portuguese-Corpus)
- [CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit](https://datashare.ed.ac.uk/handle/10283/3443)
\*\* Data Collection Method by dataset <br>
- Human <br>
\*\* Labeling Method by dataset (for those with labels) <br>
- Hybrid: Automated, Human, Unknown <br>
### Evaluating Dataset:
Properties: The audio generation quality of BigVGAN is evaluated using `dev` splits of the [LibriTTS dataset](https://www.openslr.org/60/) and [Hi-Fi TTS dataset](https://www.openslr.org/109/). The datasets include speech in English language with equal balance of genders.
\*\* Data Collection Method by dataset <br>
- Human <br>
\*\* Labeling Method by dataset <br>
- Automated <br>
## Inference:
**Engine:** PyTorch <br>
**Test Hardware:** NVIDIA A100 GPU <br>
## Ethical Considerations:
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the Model Card++ Explainability, Bias, Safety & Security, and Privacy Subcards. Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).

View File

@@ -0,0 +1,14 @@
| Field | Response |
| :------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------- |
| Generatable or reverse engineerable personal information? | None |
| Protected class data used to create this model? | None |
| Was consent obtained for any personal data used? | Not Applicable (No Personal Data) |
| How often is dataset reviewed? | Before Release |
| Is a mechanism in place to honor data subject right of access or deletion of personal data? | Not Applicable |
| If personal collected for the development of the model, was it collected directly by NVIDIA? | Not Applicable |
| If personal collected for the development of the model by NVIDIA, do you maintain or have access to disclosures made to data subjects? | Not Applicable |
| If personal collected for the development of this AI model, was it minimized to only what was required? | Not Applicable |
| Is data in dataset traceable? | Yes |
| Is there provenance for all datasets used in training? | Yes |
| Does data labeling (annotation, metadata) comply with privacy laws? | Yes |
| Is data compliant with data subject requests for data correction or removal, if such a request was made? | No, not possible with externally-sourced data. |

View File

@@ -0,0 +1,6 @@
| Field | Response |
| :---------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Model Application(s): | Synethic Audio Generation |
| Describe the life critical impact (if present). | Not Applicable |
| Use Case Restrictions: | None |
| Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to. |

View File

@@ -0,0 +1,13 @@
torch
numpy
librosa>=0.8.1
scipy
tensorboard
soundfile
matplotlib
pesq
auraloss
tqdm
nnAudio
ninja
huggingface_hub>=0.23.4

View File

@@ -0,0 +1,62 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
import torch
from alias_free_activation.cuda import activation1d
from activations import Snake
def test_load_fused_kernels():
try:
print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e
def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations.Snake cuda vs. torch
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_anti_alias_activation"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}"
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_anti_alias_activation"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, "
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
)
if __name__ == "__main__":
from alias_free_activation.cuda import load
load.load()
test_load_fused_kernels()
test_anti_alias_activation()

View File

@@ -0,0 +1,62 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
import torch
from alias_free_activation.cuda import activation1d
from activations import SnakeBeta
def test_load_fused_kernels():
try:
print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e
def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations, Snake CUDA vs. Torch
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_anti_alias_activation"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}"
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_anti_alias_activation"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, "
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
)
if __name__ == "__main__":
from alias_free_activation.cuda import load
load.load()
test_load_fused_kernels()
test_anti_alias_activation()

View File

@@ -0,0 +1,215 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
import torch
import json
from env import AttrDict
from bigvgan import BigVGAN
from time import time
from tqdm import tqdm
from meldataset import mel_spectrogram, MAX_WAV_VALUE
from scipy.io.wavfile import write
import numpy as np
import argparse
torch.backends.cudnn.benchmark = True
# For easier debugging
torch.set_printoptions(linewidth=200, threshold=10_000)
def generate_soundwave(duration=5.0, sr=24000):
t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32)
modulation = np.sin(2 * np.pi * t / duration)
min_freq = 220
max_freq = 1760
frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2
soundwave = np.sin(2 * np.pi * frequencies * t)
soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95
return soundwave, sr
def get_mel(x, h):
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
parser.add_argument(
"--checkpoint_file",
type=str,
required=True,
help="Path to the checkpoint file. Assumes config.json exists in the directory.",
)
args = parser.parse_args()
config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json")
with open(config_file) as f:
config = f.read()
json_config = json.loads(config)
h = AttrDict({**json_config})
print("loading plain Pytorch BigVGAN")
generator_original = BigVGAN(h).to("cuda")
print("loading CUDA kernel BigVGAN with auto-build")
generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda")
state_dict_g = load_checkpoint(args.checkpoint_file, "cuda")
generator_original.load_state_dict(state_dict_g["generator"])
generator_cuda_kernel.load_state_dict(state_dict_g["generator"])
generator_original.remove_weight_norm()
generator_original.eval()
generator_cuda_kernel.remove_weight_norm()
generator_cuda_kernel.eval()
# define number of samples and length of mel frame to benchmark
num_sample = 10
num_mel_frame = 16384
# CUDA kernel correctness check
diff = 0.0
for i in tqdm(range(num_sample)):
# Random mel
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
with torch.inference_mode():
audio_original = generator_original(data)
with torch.inference_mode():
audio_cuda_kernel = generator_cuda_kernel(data)
# Both outputs should be (almost) the same
test_result = (audio_original - audio_cuda_kernel).abs()
diff += test_result.mean(dim=-1).item()
diff /= num_sample
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
print(
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}"
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}"
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
)
else:
print(
f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}"
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
)
del data, audio_original, audio_cuda_kernel
# Variables for tracking total time and VRAM usage
toc_total_original = 0
toc_total_cuda_kernel = 0
vram_used_original_total = 0
vram_used_cuda_kernel_total = 0
audio_length_total = 0
# Measure Original inference in isolation
for i in tqdm(range(num_sample)):
torch.cuda.reset_peak_memory_stats(device="cuda")
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
torch.cuda.synchronize()
tic = time()
with torch.inference_mode():
audio_original = generator_original(data)
torch.cuda.synchronize()
toc = time() - tic
toc_total_original += toc
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_original
torch.cuda.empty_cache()
# Measure CUDA kernel inference in isolation
for i in tqdm(range(num_sample)):
torch.cuda.reset_peak_memory_stats(device="cuda")
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
torch.cuda.synchronize()
tic = time()
with torch.inference_mode():
audio_cuda_kernel = generator_cuda_kernel(data)
torch.cuda.synchronize()
toc = time() - tic
toc_total_cuda_kernel += toc
audio_length_total += audio_cuda_kernel.shape[-1]
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_cuda_kernel
torch.cuda.empty_cache()
# Calculate metrics
audio_second = audio_length_total / h.sampling_rate
khz_original = audio_length_total / toc_total_original / 1000
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
# Print results
print(
f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB"
)
print(
f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB"
)
print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}")
print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}")
# Use artificial sine waves for inference test
audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate)
audio_real = torch.tensor(audio_real).to("cuda")
# Compute mel spectrogram from the ground truth audio
x = get_mel(audio_real.unsqueeze(0), h)
with torch.inference_mode():
y_g_hat_original = generator_original(x)
y_g_hat_cuda_kernel = generator_cuda_kernel(x)
audio_real = audio_real.squeeze()
audio_real = audio_real * MAX_WAV_VALUE
audio_real = audio_real.cpu().numpy().astype("int16")
audio_original = y_g_hat_original.squeeze()
audio_original = audio_original * MAX_WAV_VALUE
audio_original = audio_original.cpu().numpy().astype("int16")
audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze()
audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE
audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16")
os.makedirs("tmp", exist_ok=True)
output_file_real = os.path.join("tmp", "audio_real.wav")
output_file_original = os.path.join("tmp", "audio_generated_original.wav")
output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav")
write(output_file_real, h.sampling_rate, audio_real)
write(output_file_original, h.sampling_rate, audio_original)
write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel)
print("Example generated audios of original vs. fused CUDA kernel written to tmp!")
print("Done")

View File

@@ -0,0 +1,716 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
import itertools
import os
import time
import argparse
import json
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DistributedSampler, DataLoader
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from env import AttrDict, build_env
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE
from bigvgan import BigVGAN
from discriminators import (
MultiPeriodDiscriminator,
MultiResolutionDiscriminator,
MultiBandDiscriminator,
MultiScaleSubbandCQTDiscriminator,
)
from loss import (
feature_loss,
generator_loss,
discriminator_loss,
MultiScaleMelSpectrogramLoss,
)
from utils import (
plot_spectrogram,
plot_spectrogram_clipped,
scan_checkpoint,
load_checkpoint,
save_checkpoint,
save_audio,
)
import torchaudio as ta
from pesq import pesq
from tqdm import tqdm
import auraloss
torch.backends.cudnn.benchmark = False
def train(rank, a, h):
if h.num_gpus > 1:
# initialize distributed
init_process_group(
backend=h.dist_config["dist_backend"],
init_method=h.dist_config["dist_url"],
world_size=h.dist_config["world_size"] * h.num_gpus,
rank=rank,
)
# Set seed and device
torch.cuda.manual_seed(h.seed)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank:d}")
# Define BigVGAN generator
generator = BigVGAN(h).to(device)
# Define discriminators. MPD is used by default
mpd = MultiPeriodDiscriminator(h).to(device)
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
# Variable name is kept as "mrd" for backward compatibility & minimal code change
mrd = MultiBandDiscriminator(h).to(device)
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
else: # Fallback to original MRD in BigVGAN-v1
mrd = MultiResolutionDiscriminator(h).to(device)
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
if h.get("use_multiscale_melloss", False):
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
sampling_rate=h.sampling_rate
) # NOTE: accepts waveform as input
else:
fn_mel_loss_singlescale = F.l1_loss
# Print the model & number of parameters, and create or scan the latest checkpoint from checkpoints directory
if rank == 0:
print(generator)
print(mpd)
print(mrd)
print(f"Generator params: {sum(p.numel() for p in generator.parameters())}")
print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters())}")
print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters())}")
os.makedirs(a.checkpoint_path, exist_ok=True)
print(f"Checkpoints directory: {a.checkpoint_path}")
if os.path.isdir(a.checkpoint_path):
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
cp_do = scan_checkpoint(
a.checkpoint_path,
prefix="do_",
renamed_file="bigvgan_discriminator_optimizer.pt",
)
# Load the latest checkpoint if exists
steps = 0
if cp_g is None or cp_do is None:
state_dict_do = None
last_epoch = -1
else:
state_dict_g = load_checkpoint(cp_g, device)
state_dict_do = load_checkpoint(cp_do, device)
generator.load_state_dict(state_dict_g["generator"])
mpd.load_state_dict(state_dict_do["mpd"])
mrd.load_state_dict(state_dict_do["mrd"])
steps = state_dict_do["steps"] + 1
last_epoch = state_dict_do["epoch"]
# Initialize DDP, optimizers, and schedulers
if h.num_gpus > 1:
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(
itertools.chain(mrd.parameters(), mpd.parameters()),
h.learning_rate,
betas=[h.adam_b1, h.adam_b2],
)
if state_dict_do is not None:
optim_g.load_state_dict(state_dict_do["optim_g"])
optim_d.load_state_dict(state_dict_do["optim_d"])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
# Define training and validation datasets
"""
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
Example: trained on LibriTTS, validate on VCTK
"""
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
trainset = MelDataset(
training_filelist,
h,
h.segment_size,
h.n_fft,
h.num_mels,
h.hop_size,
h.win_size,
h.sampling_rate,
h.fmin,
h.fmax,
shuffle=False if h.num_gpus > 1 else True,
fmax_loss=h.fmax_for_loss,
device=device,
fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir,
is_seen=True,
)
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
train_loader = DataLoader(
trainset,
num_workers=h.num_workers,
shuffle=False,
sampler=train_sampler,
batch_size=h.batch_size,
pin_memory=True,
drop_last=True,
)
if rank == 0:
validset = MelDataset(
validation_filelist,
h,
h.segment_size,
h.n_fft,
h.num_mels,
h.hop_size,
h.win_size,
h.sampling_rate,
h.fmin,
h.fmax,
False,
False,
fmax_loss=h.fmax_for_loss,
device=device,
fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir,
is_seen=True,
)
validation_loader = DataLoader(
validset,
num_workers=1,
shuffle=False,
sampler=None,
batch_size=1,
pin_memory=True,
drop_last=True,
)
list_unseen_validset = []
list_unseen_validation_loader = []
for i in range(len(list_unseen_validation_filelist)):
unseen_validset = MelDataset(
list_unseen_validation_filelist[i],
h,
h.segment_size,
h.n_fft,
h.num_mels,
h.hop_size,
h.win_size,
h.sampling_rate,
h.fmin,
h.fmax,
False,
False,
fmax_loss=h.fmax_for_loss,
device=device,
fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir,
is_seen=False,
)
unseen_validation_loader = DataLoader(
unseen_validset,
num_workers=1,
shuffle=False,
sampler=None,
batch_size=1,
pin_memory=True,
drop_last=True,
)
list_unseen_validset.append(unseen_validset)
list_unseen_validation_loader.append(unseen_validation_loader)
# Tensorboard logger
sw = SummaryWriter(os.path.join(a.checkpoint_path, "logs"))
if a.save_audio: # Also save audio to disk if --save_audio is set to True
os.makedirs(os.path.join(a.checkpoint_path, "samples"), exist_ok=True)
"""
Validation loop, "mode" parameter is automatically defined as (seen or unseen)_(name of the dataset).
If the name of the dataset contains "nonspeech", it skips PESQ calculation to prevent errors
"""
def validate(rank, a, h, loader, mode="seen"):
assert rank == 0, "validate should only run on rank=0"
generator.eval()
torch.cuda.empty_cache()
val_err_tot = 0
val_pesq_tot = 0
val_mrstft_tot = 0
# Modules for evaluation metrics
pesq_resampler = ta.transforms.Resample(h.sampling_rate, 16000).cuda()
loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda")
if a.save_audio: # Also save audio to disk if --save_audio is set to True
os.makedirs(
os.path.join(a.checkpoint_path, "samples", f"gt_{mode}"),
exist_ok=True,
)
os.makedirs(
os.path.join(a.checkpoint_path, "samples", f"{mode}_{steps:08d}"),
exist_ok=True,
)
with torch.no_grad():
print(f"step {steps} {mode} speaker validation...")
# Loop over validation set and compute metrics
for j, batch in enumerate(tqdm(loader)):
x, y, _, y_mel = batch
y = y.to(device)
if hasattr(generator, "module"):
y_g_hat = generator.module(x.to(device))
else:
y_g_hat = generator(x.to(device))
y_mel = y_mel.to(device, non_blocking=True)
y_g_hat_mel = mel_spectrogram(
y_g_hat.squeeze(1),
h.n_fft,
h.num_mels,
h.sampling_rate,
h.hop_size,
h.win_size,
h.fmin,
h.fmax_for_loss,
)
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
# Resample to 16000 for pesq
y_16k = pesq_resampler(y)
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
# MRSTFT calculation
min_t = min(y.size(-1), y_g_hat.size(-1))
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
# Log audio and figures to Tensorboard
if j % a.eval_subsample == 0: # Subsample every nth from validation set
if steps >= 0:
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y[0],
os.path.join(
a.checkpoint_path,
"samples",
f"gt_{mode}",
f"{j:04d}.wav",
),
h.sampling_rate,
)
sw.add_figure(
f"gt_{mode}/y_spec_{j}",
plot_spectrogram(x[0]),
steps,
)
sw.add_audio(
f"generated_{mode}/y_hat_{j}",
y_g_hat[0],
steps,
h.sampling_rate,
)
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y_g_hat[0, 0],
os.path.join(
a.checkpoint_path,
"samples",
f"{mode}_{steps:08d}",
f"{j:04d}.wav",
),
h.sampling_rate,
)
# Spectrogram of synthesized audio
y_hat_spec = mel_spectrogram(
y_g_hat.squeeze(1),
h.n_fft,
h.num_mels,
h.sampling_rate,
h.hop_size,
h.win_size,
h.fmin,
h.fmax,
)
sw.add_figure(
f"generated_{mode}/y_hat_spec_{j}",
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()),
steps,
)
"""
Visualization of spectrogram difference between GT and synthesized audio, difference higher than 1 is clipped for better visualization.
"""
spec_delta = torch.clamp(
torch.abs(x[0] - y_hat_spec.squeeze(0).cpu()),
min=1e-6,
max=1.0,
)
sw.add_figure(
f"delta_dclip1_{mode}/spec_{j}",
plot_spectrogram_clipped(spec_delta.numpy(), clip_max=1.0),
steps,
)
val_err = val_err_tot / (j + 1)
val_pesq = val_pesq_tot / (j + 1)
val_mrstft = val_mrstft_tot / (j + 1)
# Log evaluation metrics to Tensorboard
sw.add_scalar(f"validation_{mode}/mel_spec_error", val_err, steps)
sw.add_scalar(f"validation_{mode}/pesq", val_pesq, steps)
sw.add_scalar(f"validation_{mode}/mrstft", val_mrstft, steps)
generator.train()
# If the checkpoint is loaded, start with validation loop
if steps != 0 and rank == 0 and not a.debug:
if not a.skip_seen:
validate(
rank,
a,
h,
validation_loader,
mode=f"seen_{train_loader.dataset.name}",
)
for i in range(len(list_unseen_validation_loader)):
validate(
rank,
a,
h,
list_unseen_validation_loader[i],
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
)
# Exit the script if --evaluate is set to True
if a.evaluate:
exit()
# Main training loop
generator.train()
mpd.train()
mrd.train()
for epoch in range(max(0, last_epoch), a.training_epochs):
if rank == 0:
start = time.time()
print(f"Epoch: {epoch + 1}")
if h.num_gpus > 1:
train_sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
if rank == 0:
start_b = time.time()
x, y, _, y_mel = batch
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
y_mel = y_mel.to(device, non_blocking=True)
y = y.unsqueeze(1)
y_g_hat = generator(x)
y_g_hat_mel = mel_spectrogram(
y_g_hat.squeeze(1),
h.n_fft,
h.num_mels,
h.sampling_rate,
h.hop_size,
h.win_size,
h.fmin,
h.fmax_for_loss,
)
optim_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MRD
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
# Set clip_grad_norm value
clip_grad_norm = h.get("clip_grad_norm", 1000.0) # Default to 1000
# Whether to freeze D for initial training steps
if steps >= a.freeze_step:
loss_disc_all.backward()
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
optim_d.step()
else:
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
grad_norm_mpd = 0.0
grad_norm_mrd = 0.0
# Generator
optim_g.zero_grad()
# L1 Mel-Spectrogram Loss
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
else: # Uses mel <y_mel, y_g_hat_mel> for loss
loss_mel = fn_mel_loss_singlescale(y_mel, y_g_hat_mel) * lambda_melloss
# MPD loss
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
# MRD loss
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(y, y_g_hat)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
if steps >= a.freeze_step:
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
else:
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
loss_gen_all = loss_mel
loss_gen_all.backward()
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
print(
f"Steps: {steps:d}, "
f"Gen Loss Total: {loss_gen_all:4.3f}, "
f"Mel Error: {mel_error:4.3f}, "
f"s/b: {time.time() - start_b:4.3f} "
f"lr: {optim_g.param_groups[0]['lr']:4.7f} "
f"grad_norm_g: {grad_norm_g:4.3f}"
)
# Checkpointing
if steps % a.checkpoint_interval == 0 and steps != 0:
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
save_checkpoint(
checkpoint_path,
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
)
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
save_checkpoint(
checkpoint_path,
{
"mpd": (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
"mrd": (mrd.module if h.num_gpus > 1 else mrd).state_dict(),
"optim_g": optim_g.state_dict(),
"optim_d": optim_d.state_dict(),
"steps": steps,
"epoch": epoch,
},
)
# Tensorboard summary logging
if steps % a.summary_interval == 0:
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
sw.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
sw.add_scalar("training/grad_norm_mpd", grad_norm_mpd, steps)
sw.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
sw.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
sw.add_scalar("training/epoch", epoch + 1, steps)
# Validation
if steps % a.validation_interval == 0:
# Plot training input x so far used
for i_x in range(x.shape[0]):
sw.add_figure(
f"training_input/x_{i_x}",
plot_spectrogram(x[i_x].cpu()),
steps,
)
sw.add_audio(
f"training_input/y_{i_x}",
y[i_x][0],
steps,
h.sampling_rate,
)
# Seen and unseen speakers validation loops
if not a.debug and steps != 0:
validate(
rank,
a,
h,
validation_loader,
mode=f"seen_{train_loader.dataset.name}",
)
for i in range(len(list_unseen_validation_loader)):
validate(
rank,
a,
h,
list_unseen_validation_loader[i],
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
)
steps += 1
# BigVGAN-v2 learning rate scheduler is changed from epoch-level to step-level
scheduler_g.step()
scheduler_d.step()
if rank == 0:
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
def main():
print("Initializing Training Process..")
parser = argparse.ArgumentParser()
parser.add_argument("--group_name", default=None)
parser.add_argument("--input_wavs_dir", default="LibriTTS")
parser.add_argument("--input_mels_dir", default="ft_dataset")
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
parser.add_argument(
"--list_input_unseen_wavs_dir",
nargs="+",
default=["tests/LibriTTS", "tests/LibriTTS"],
)
parser.add_argument(
"--list_input_unseen_validation_file",
nargs="+",
default=["tests/LibriTTS/dev-clean.txt", "tests/LibriTTS/dev-other.txt"],
)
parser.add_argument("--checkpoint_path", default="exp/bigvgan")
parser.add_argument("--config", default="")
parser.add_argument("--training_epochs", default=100000, type=int)
parser.add_argument("--stdout_interval", default=5, type=int)
parser.add_argument("--checkpoint_interval", default=50000, type=int)
parser.add_argument("--summary_interval", default=100, type=int)
parser.add_argument("--validation_interval", default=50000, type=int)
parser.add_argument(
"--freeze_step",
default=0,
type=int,
help="freeze D for the first specified steps. G only uses regression loss for these steps.",
)
parser.add_argument("--fine_tuning", default=False, type=bool)
parser.add_argument(
"--debug",
default=False,
type=bool,
help="debug mode. skips validation loop throughout training",
)
parser.add_argument(
"--evaluate",
default=False,
type=bool,
help="only run evaluation from checkpoint and exit",
)
parser.add_argument(
"--eval_subsample",
default=5,
type=int,
help="subsampling during evaluation loop",
)
parser.add_argument(
"--skip_seen",
default=False,
type=bool,
help="skip seen dataset. useful for test set inference",
)
parser.add_argument(
"--save_audio",
default=False,
type=bool,
help="save audio of test set inference to disk",
)
a = parser.parse_args()
with open(a.config) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
build_env(a.config, "config.json", a.checkpoint_path)
torch.manual_seed(h.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print(f"Batch size per GPU: {h.batch_size}")
else:
pass
if h.num_gpus > 1:
mp.spawn(
train,
nprocs=h.num_gpus,
args=(
a,
h,
),
)
else:
train(0, a, h)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,99 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import glob
import os
import matplotlib
import torch
from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt
from .meldataset import MAX_WAV_VALUE
from scipy.io.wavfile import write
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(
spectrogram,
aspect="auto",
origin="lower",
interpolation="none",
vmin=1e-6,
vmax=clip_max,
)
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def apply_weight_norm(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
weight_norm(m)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def save_checkpoint(filepath, obj):
print(f"Saving checkpoint to {filepath}")
torch.save(obj, filepath)
print("Complete.")
def scan_checkpoint(cp_dir, prefix, renamed_file=None):
# Fallback to original scanning logic first
pattern = os.path.join(cp_dir, prefix + "????????")
cp_list = glob.glob(pattern)
if len(cp_list) > 0:
last_checkpoint_path = sorted(cp_list)[-1]
print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
return last_checkpoint_path
# If no pattern-based checkpoints are found, check for renamed file
if renamed_file:
renamed_path = os.path.join(cp_dir, renamed_file)
if os.path.isfile(renamed_path):
print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
return renamed_path
return None
def save_audio(audio, path, sr):
# wav: torch with 1d shape
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
write(path, sr, audio)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,239 @@
import os
import sys
import threading
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
import re
import torch
from text.LangSegmenter import LangSegmenter
from text import chinese
from typing import Dict, List, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from tools.i18n.i18n import I18nAuto, scan_language_list
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
punctuation = set(["!", "?", "", ",", ".", "-"])
def get_first(text: str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def merge_short_text_in_array(texts: str, threshold: int) -> list:
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if len(text) > 0:
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
class TextPreprocessor:
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.bert_lock = threading.RLock()
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(f"############ {i18n('提取文本Bert特征')} ############")
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text == "":
continue
res = {
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
result.append(res)
return result
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n")
if len(text) == 0:
return []
if text[0] not in splits and len(get_first(text)) < 4:
text = "" + text if lang != "en" else "." + text
print(i18n("实际输入的目标文本:"))
print(text)
seg_method = get_seg_method(text_split_method)
text = seg_method(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
_texts = text.split("\n")
_texts = self.filter_text(_texts)
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if len(text.strip()) == 0:
continue
if not re.sub("\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if text[-1] not in splits:
text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
if len(text) > 510:
texts.extend(split_big_text(text))
else:
texts.append(text)
print(i18n("实际输入的目标文本(切句后):"))
print(texts)
return texts
def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1"
) -> Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version)
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
with self.bert_lock:
text = re.sub(r' {2,}', ' ', text)
textlist = []
langlist = []
if language == "all_zh":
for tmp in LangSegmenter.getTexts(text,"zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True)
return phones, bert, norm_text
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
language = language.replace("all_", "")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
else:
feature = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
return feature
def filter_text(self, texts):
_text = []
if all(text in [None, " ", "\n", ""] for text in texts):
raise ValueError(i18n("请输入有效文本"))
for text in texts:
if text in [None, " ", ""]:
pass
else:
_text.append(text)
return _text
def replace_consecutive_punctuation(self, text):
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result

View File

@@ -0,0 +1 @@
from . import TTS, text_segmentation_method

View File

@@ -0,0 +1,190 @@
import re
from typing import Callable
punctuation = set(["!", "?", "", ",", ".", "-", " "])
METHODS = dict()
def get_method(name: str) -> Callable:
method = METHODS.get(name, None)
if method is None:
raise ValueError(f"Method {name} not found")
return method
def get_method_names() -> list:
return list(METHODS.keys())
def register_method(name):
def decorator(func):
METHODS[name] = func
return func
return decorator
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
"",
}
def split_big_text(text, max_len=510):
# 定义全角和半角标点符号
punctuation = "".join(splits)
# 切割文本
segments = re.split("([" + punctuation + "])", text)
# 初始化结果列表和当前片段
result = []
current_segment = ""
for segment in segments:
# 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段
if len(current_segment + segment) > max_len:
result.append(current_segment)
current_segment = segment
else:
current_segment += segment
# 将最后一个片段加入结果列表
if current_segment:
result.append(current_segment)
return result
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
# 不切
@register_method("cut0")
def cut0(inp):
if not set(inp).issubset(punctuation):
return inp
else:
return "/n"
# 凑四句一切
@register_method("cut1")
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else:
opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 凑50字一切
@register_method("cut2")
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按中文句号。切
@register_method("cut3")
def cut3(inp):
inp = inp.strip("\n")
opts = ["%s" % item for item in inp.strip("").split("")]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按英文句号.切
@register_method("cut4")
def cut4(inp):
inp = inp.strip("\n")
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按标点符号切
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@register_method("cut5")
def cut5(inp):
inp = inp.strip("\n")
punds = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
mergeitems = []
items = []
for i, char in enumerate(inp):
if char in punds:
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
items.append(char)
else:
items.append(char)
mergeitems.append("".join(items))
items = []
else:
items.append(char)
if items:
mergeitems.append("".join(items))
opt = [item for item in mergeitems if not set(item).issubset(punds)]
return "\n".join(opt)
if __name__ == "__main__":
method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))

View File

@@ -0,0 +1 @@
*.yaml

View File

@@ -0,0 +1,91 @@
{
"train": {
"log_interval": 100,
"eval_interval": 500,
"seed": 1234,
"epochs": 100,
"learning_rate": 0.0001,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 32,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 20480,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0,
"text_low_lr_rate": 0.4,
"grad_ckpt": false
},
"data": {
"max_wav_value": 32768.0,
"sampling_rate": 32000,
"filter_length": 2048,
"hop_length": 640,
"win_length": 2048,
"n_mel_channels": 128,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 300,
"cleaned_text": true
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
10,
8,
2,
2,
2
],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [
16,
16,
8,
2,
2
],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 512,
"semantic_frame_rate": "25hz",
"freeze_quantizer": true
},
"s2_ckpt_dir": "logs/s2/big2k1",
"content_module": "cnhubert"
}

View File

@@ -0,0 +1,91 @@
{
"train": {
"log_interval": 100,
"eval_interval": 500,
"seed": 1234,
"epochs": 100,
"learning_rate": 0.0001,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 32,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 20480,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0,
"text_low_lr_rate": 0.4,
"grad_ckpt": false
},
"data": {
"max_wav_value": 32768.0,
"sampling_rate": 32000,
"filter_length": 2048,
"hop_length": 640,
"win_length": 2048,
"n_mel_channels": 128,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 300,
"cleaned_text": true
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.0,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
10,
8,
2,
2,
2
],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [
16,
16,
8,
2,
2
],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 1024,
"semantic_frame_rate": "25hz",
"freeze_quantizer": true
},
"s2_ckpt_dir": "logs/s2/big2k1",
"content_module": "cnhubert"
}

View File

@@ -0,0 +1,91 @@
{
"train": {
"log_interval": 100,
"eval_interval": 500,
"seed": 1234,
"epochs": 100,
"learning_rate": 0.0001,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 32,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 20480,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0,
"text_low_lr_rate": 0.4,
"grad_ckpt": false
},
"data": {
"max_wav_value": 32768.0,
"sampling_rate": 32000,
"filter_length": 2048,
"hop_length": 640,
"win_length": 2048,
"n_mel_channels": 128,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 300,
"cleaned_text": true
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.0,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
10,
8,
2,
2,
2
],
"upsample_initial_channel": 768,
"upsample_kernel_sizes": [
20,
16,
8,
2,
2
],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 1024,
"semantic_frame_rate": "25hz",
"freeze_quantizer": true
},
"s2_ckpt_dir": "logs/s2/big2k1",
"content_module": "cnhubert"
}

View File

@@ -0,0 +1,13 @@
import os
import sys
now_dir = os.getcwd()
sys.path.insert(0, now_dir)
from text.g2pw import G2PWPinyin
g2pw = G2PWPinyin(
model_dir="GPT_SoVITS/text/G2PWModel",
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
v_to_u=False,
neutral_tone_with_five=True,
)

View File

@@ -0,0 +1,264 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
"""
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = "inplace" if self.inplace else ""
return self.__class__.__name__ + " (" + inplace_str + ")"
class BasicBlockERes2Net(nn.Module):
expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
super(BasicBlockERes2Net, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
convs = []
bns = []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
super(BasicBlockERes2Net_diff_AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
convs = []
fuse_models = []
bns = []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
for j in range(self.nums - 1):
fuse_models.append(AFF(channels=width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = self.fuse_models[i - 1](sp, spx[i])
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class ERes2Net(nn.Module):
def __init__(
self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=32,
feat_dim=80,
embedding_size=192,
pooling_func="TSTP",
two_emb_layer=False,
):
super(ERes2Net, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embedding_size = embedding_size
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
# Downsampling module for each layer
self.layer1_downsample = nn.Conv2d(
m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
)
self.layer2_downsample = nn.Conv2d(
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
)
self.layer3_downsample = nn.Conv2d(
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
)
# Bottom-up fusion module
self.fuse_mode12 = AFF(channels=m_channels * 4)
self.fuse_mode123 = AFF(channels=m_channels * 8)
self.fuse_mode1234 = AFF(channels=m_channels * 16)
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
stats = self.pool(fuse_out1234)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_b
else:
return embed_a
def forward3(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
return fuse_out1234
if __name__ == "__main__":
x = torch.zeros(10, 300, 80)
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func="TSTP")
model.eval()
out = model(x)
print(out.shape) # torch.Size([10, 192])
num_params = sum(param.numel() for param in model.parameters())
print("{} M".format(num_params / 1e6)) # 6.61M

View File

@@ -0,0 +1,272 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
within each stage. However, this modification also increases the number of model parameters and computational complexity.
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
both the model parameters and its computational cost.
"""
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = "inplace" if self.inplace else ""
return self.__class__.__name__ + " (" + inplace_str + ")"
class BasicBlockERes2NetV2(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
self.expansion = expansion
convs = []
bns = []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class BasicBlockERes2NetV2AFF(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
self.expansion = expansion
convs = []
fuse_models = []
bns = []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
for j in range(self.nums - 1):
fuse_models.append(AFF(channels=width, r=4))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = self.fuse_models[i - 1](sp, spx[i])
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class ERes2NetV2(nn.Module):
def __init__(
self,
block=BasicBlockERes2NetV2,
block_fuse=BasicBlockERes2NetV2AFF,
num_blocks=[3, 4, 6, 3],
m_channels=64,
feat_dim=80,
embedding_size=192,
baseWidth=26,
scale=2,
expansion=2,
pooling_func="TSTP",
two_emb_layer=False,
):
super(ERes2NetV2, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embedding_size = embedding_size
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.baseWidth = baseWidth
self.scale = scale
self.expansion = expansion
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
# Downsampling module
self.layer3_ds = nn.Conv2d(
m_channels * 4 * self.expansion,
m_channels * 8 * self.expansion,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
# Bottom-up fusion module
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion)
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(
block(
self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion
)
)
self.in_planes = planes * self.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out3_ds = self.layer3_ds(out3)
fuse_out34 = self.fuse34(out4, out3_ds)
stats = self.pool(fuse_out34)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_b
else:
return embed_a
def forward3(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out3_ds = self.layer3_ds(out3)
fuse_out34 = self.fuse34(out4, out3_ds)
# print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
# stats = self.pool(fuse_out34)
#
# embed_a = self.seg_1(stats)
# if self.two_emb_layer:
# out = F.relu(embed_a)
# out = self.seg_bn_1(out)
# embed_b = self.seg_2(out)
# return embed_b
# else:
# return embed_a
if __name__ == "__main__":
x = torch.randn(1, 300, 80)
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
model.eval()
y = model(x)
print(y.size())
macs, num_params = profile(model, inputs=(x,))
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
print("MACs: {} G".format(macs / 1e9)) # 12.69 G

View File

@@ -0,0 +1,289 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
"""
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = "inplace" if self.inplace else ""
return self.__class__.__name__ + " (" + inplace_str + ")"
class BasicBlockERes2Net(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
super(BasicBlockERes2Net, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
convs = []
bns = []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
super(BasicBlockERes2Net_diff_AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
convs = []
fuse_models = []
bns = []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
for j in range(self.nums - 1):
fuse_models.append(AFF(channels=width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = self.fuse_models[i - 1](sp, spx[i])
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class ERes2Net(nn.Module):
def __init__(
self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=64,
feat_dim=80,
embedding_size=192,
pooling_func="TSTP",
two_emb_layer=False,
):
super(ERes2Net, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embedding_size = embedding_size
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
self.layer1_downsample = nn.Conv2d(
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
)
self.layer2_downsample = nn.Conv2d(
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
)
self.layer3_downsample = nn.Conv2d(
m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False
)
self.fuse_mode12 = AFF(channels=m_channels * 8)
self.fuse_mode123 = AFF(channels=m_channels * 16)
self.fuse_mode1234 = AFF(channels=m_channels * 32)
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
stats = self.pool(fuse_out1234)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_b
else:
return embed_a
def forward2(self, x, if_mean):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
if if_mean == False:
mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
else:
mean = fuse_out1234.mean(2) # bs,20480
mean_std = torch.cat([mean, torch.zeros_like(mean)], 1)
return self.seg_1(mean_std) # (T,192)
# stats = self.pool(fuse_out1234)
# if self.two_emb_layer:
# out = F.relu(embed_a)
# out = self.seg_bn_1(out)
# embed_b = self.seg_2(out)
# return embed_b
# else:
# return embed_a
def forward3(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
return fuse_out1234
# print(fuse_out1234.shape)
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
# pdb.set_trace()

Some files were not shown because too many files have changed in this diff Show More