diff --git a/.github/workflows/export-gtcrn.yaml b/.github/workflows/export-gtcrn.yaml new file mode 100644 index 00000000..88d668f4 --- /dev/null +++ b/.github/workflows/export-gtcrn.yaml @@ -0,0 +1,103 @@ +name: export-gtcrn-to-onnx + +on: + push: + branches: + - export-gtcrn + + workflow_dispatch: + +concurrency: + group: export-gtcrn-to-onnx-${{ github.ref }} + cancel-in-progress: true + +jobs: + export-gtcrn-to-onnx: + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' + name: export gtcrn ${{ matrix.version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + + steps: + - uses: actions/checkout@v4 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + shell: bash + run: | + pip install "numpy<=1.26.4" onnx==1.16.0 onnxruntime==1.17.1 librosa soundfile torch==2.6.0+cpu -f https://download.pytorch.org/whl/torch "kaldi-native-fbank>=1.21.1" + + - name: Run + shell: bash + run: | + cd scripts/gtcrn + ./run.sh + ./test.py + ls -lh + + - name: Collect results + shell: bash + run: | + src=scripts/gtcrn + cp -v $src/*.onnx ./ + ls -lh *.onnx + + - name: Publish to huggingface 0.19 + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + rm -rf huggingface + export GIT_LFS_SKIP_SMUDGE=1 + export GIT_CLONE_PROTECTION_ACTIVE=false + + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/speech-enhancement-models huggingface + cd huggingface + git fetch + git pull + + cp -v ../gtcrn_simple.onnx ./ + + git lfs track "*.onnx" + git add . + + ls -lh + + git status + + git commit -m "add models" + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/speech-enhancement-models main || true + + - name: Release + if: github.repository_owner == 'csukuangfj' + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.onnx + overwrite: true + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: speech-enhancement-models + + - name: Release + if: github.repository_owner == 'k2-fsa' + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.onnx + overwrite: true + tag: speech-enhancement-models diff --git a/scripts/gtcrn/README.md b/scripts/gtcrn/README.md new file mode 100644 index 00000000..44c90376 --- /dev/null +++ b/scripts/gtcrn/README.md @@ -0,0 +1,4 @@ +# Introduction + +This folder contains scripts for adding metadata to models from +https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx diff --git a/scripts/gtcrn/add_meta_data.py b/scripts/gtcrn/add_meta_data.py new file mode 100755 index 00000000..c367d6e6 --- /dev/null +++ b/scripts/gtcrn/add_meta_data.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2]) +NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33]) +NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16]) +NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16]) +----- +NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2]) +NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33]) +NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16]) +NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16]) +""" + +import onnx +import onnxruntime as ort + + +def show(filename): + session_opts = ort.SessionOptions() + session_opts.log_severity_level = 3 + sess = ort.InferenceSession(filename, session_opts) + for i in sess.get_inputs(): + print(i) + + print("-----") + + for i in sess.get_outputs(): + print(i) + + +def main(): + filename = "./gtcrn_simple.onnx" + show(filename) + model = onnx.load(filename) + + meta_data = { + "model_type": "gtcrn", + "comment": "gtcrn_simple", + "version": 1, + "sample_rate": 16000, + "model_url": "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx", + "maintainer": "k2-fsa", + "comment2": "Please see also https://github.com/Xiaobin-Rong/gtcrn", + "conv_cache_shape": "2,1,16,16,33", + "tra_cache_shape": "2,3,1,1,16", + "inter_cache_shape": "2,1,33,16", + "n_fft": 512, + "hop_length": 256, + "window_length": 512, + "window_type": "hann_sqrt", + } + + print(model.metadata_props) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + print("--------------------") + + print(model.metadata_props) + + onnx.save(model, filename) + + +if __name__ == "__main__": + main() diff --git a/scripts/gtcrn/run.sh b/scripts/gtcrn/run.sh new file mode 100755 index 00000000..0e7b8d0e --- /dev/null +++ b/scripts/gtcrn/run.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# + +if [ ! -f gtcrn_simple.onnx ]; then + wget https://github.com/Xiaobin-Rong/gtcrn/raw/refs/heads/main/stream/onnx_models/gtcrn_simple.onnx +fi + +if [ ! -f ./inp_16k.wav ]; then + wget https://github.com/yuyun2000/SpeechDenoiser/raw/refs/heads/main/16k/inp_16k.wav +fi + +python3 ./add_meta_data.py diff --git a/scripts/gtcrn/test.py b/scripts/gtcrn/test.py new file mode 100755 index 00000000..350dc4a2 --- /dev/null +++ b/scripts/gtcrn/test.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +from typing import Tuple + +import kaldi_native_fbank as knf +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +class OnnxModel: + def __init__(self): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + "./gtcrn_simple.onnx", + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + meta = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(meta["sample_rate"]) + self.n_fft = int(meta["n_fft"]) + self.hop_length = int(meta["hop_length"]) + self.window_length = int(meta["window_length"]) + assert meta["window_type"] == "hann_sqrt", meta["window_type"] + + self.window = torch.hann_window(self.window_length).pow(0.5) + + def get_init_states(self): + meta = self.model.get_modelmeta().custom_metadata_map + conv_cache_shape = list(map(int, meta["conv_cache_shape"].split(","))) + tra_cache_shape = list(map(int, meta["tra_cache_shape"].split(","))) + inter_cache_shape = list(map(int, meta["inter_cache_shape"].split(","))) + + conv_cache_shape = np.zeros(conv_cache_shape, dtype=np.float32) + tra_cache = np.zeros(tra_cache_shape, dtype=np.float32) + inter_cache = np.zeros(inter_cache_shape, dtype=np.float32) + + return conv_cache_shape, tra_cache, inter_cache + + def __call__(self, x, states): + """ + Args: + x: (1, n_fft/2+1, 1, 2) + Returns: + o: (1, n_fft/2+1, 1, 2) + """ + out, next_conv_cache, next_tra_cache, next_inter_cache = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + self.model.get_outputs()[2].name, + self.model.get_outputs()[3].name, + ], + { + self.model.get_inputs()[0].name: x, + self.model.get_inputs()[1].name: states[0], + self.model.get_inputs()[2].name: states[1], + self.model.get_inputs()[3].name: states[2], + }, + ) + + return out, (next_conv_cache, next_tra_cache, next_inter_cache) + + +def main(): + model = OnnxModel() + + filename = "./inp_16k.wav" + wave, sample_rate = load_audio(filename) + if sample_rate != model.sample_rate: + import librosa + + wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=model.sample_rate) + sample_rate = model.sample_rate + + stft_config = knf.StftConfig( + n_fft=model.n_fft, + hop_length=model.hop_length, + win_length=model.window_length, + window=model.window.tolist(), + ) + stft = knf.Stft(stft_config) + stft_result = stft(wave) + num_frames = stft_result.num_frames + real = np.array(stft_result.real, dtype=np.float32).reshape(num_frames, -1) + imag = np.array(stft_result.imag, dtype=np.float32).reshape(num_frames, -1) + + states = model.get_init_states() + outputs = [] + for i in range(num_frames): + x_real = real[i : i + 1] + x_imag = imag[i : i + 1] + x = np.vstack([x_real, x_imag]).transpose() + x = np.expand_dims(x, axis=0) + x = np.expand_dims(x, axis=2) + + o, states = model(x, states) + outputs.append(o) + + outputs = np.concatenate(outputs, axis=2) + outputs = outputs.squeeze(0).transpose(1, 0, 2) + + enhanced_real = outputs[:, :, 0] + enhanced_imag = outputs[:, :, 1] + enhanced_stft_result = knf.StftResult( + real=enhanced_real.reshape(-1).tolist(), + imag=enhanced_imag.reshape(-1).tolist(), + num_frames=enhanced_real.shape[0], + ) + + istft = knf.IStft(stft_config) + enhanced = istft(enhanced_stft_result) + + sf.write("./enhanced_16k.wav", enhanced, model.sample_rate) + + +if __name__ == "__main__": + main()