Support whisper models (#238)
This commit is contained in:
63
.github/workflows/export-whisper-to-onnx.yaml
vendored
Normal file
63
.github/workflows/export-whisper-to-onnx.yaml
vendored
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
name: export-whisper-to-onnx
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: release-whisper-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release-whisper-models:
|
||||||
|
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
|
||||||
|
name: ${{ matrix.model }}
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos-latest]
|
||||||
|
model: ["tiny.en", "base.en", "small.en", "medium.en"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python3 -m pip install openai-whisper torch onnxruntime onnx
|
||||||
|
|
||||||
|
- name: export ${{ matrix.model }}
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd scripts/whisper
|
||||||
|
python3 ./export-onnx.py --model ${{ matrix.model }}
|
||||||
|
python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
|
||||||
|
|
||||||
|
ls -lh
|
||||||
|
|
||||||
|
ls -lh ~/.cache/whisper
|
||||||
|
|
||||||
|
- name: Publish ${{ matrix.model }} to huggingface
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
run: |
|
||||||
|
cd scripts/whisper
|
||||||
|
|
||||||
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
|
git config --global user.name "Fangjun Kuang"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
|
||||||
|
|
||||||
|
cp *.onnx ./huggingface
|
||||||
|
cp *.ort ./huggingface
|
||||||
|
cp *tokens.txt ./huggingface
|
||||||
|
|
||||||
|
cd huggingface
|
||||||
|
git status
|
||||||
|
ls -lh
|
||||||
|
git lfs track "*.onnx"
|
||||||
|
git lfs track "*.ort"
|
||||||
|
git add .
|
||||||
|
git commit -m "upload ${{ matrix.model }}"
|
||||||
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
|
||||||
4
.github/workflows/run-java-test.yaml
vendored
4
.github/workflows/run-java-test.yaml
vendored
@@ -23,14 +23,14 @@ on:
|
|||||||
- 'sherpa-onnx/jni/*'
|
- 'sherpa-onnx/jni/*'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: jni-${{ github.ref }}
|
group: run-java-test-${{ github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
jni:
|
run_java_test:
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.5.5")
|
set(SHERPA_ONNX_VERSION "1.6.0")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
function(download_kaldi_native_fbank)
|
function(download_kaldi_native_fbank)
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.17.tar.gz")
|
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.1.tar.gz")
|
||||||
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.17.tar.gz")
|
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.1.tar.gz")
|
||||||
set(kaldi_native_fbank_HASH "SHA256=300dc282d51d738e70f194ef13a50bf4cf8d54a3b2686d75f7fc2fb821f8c1e6")
|
set(kaldi_native_fbank_HASH "SHA256=c7676f319fa97e8c8bca6018792de120895dcfe122fa9b4bff00f8f9165348e7")
|
||||||
|
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||||
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
|
|||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download kaldi-native-fbank
|
# please pre-download kaldi-native-fbank
|
||||||
set(possible_file_locations
|
set(possible_file_locations
|
||||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.17.tar.gz
|
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.1.tar.gz
|
||||||
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.17.tar.gz
|
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.1.tar.gz
|
||||||
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.17.tar.gz
|
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.1.tar.gz
|
||||||
/tmp/kaldi-native-fbank-1.17.tar.gz
|
/tmp/kaldi-native-fbank-1.18.1.tar.gz
|
||||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.17.tar.gz
|
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.1.tar.gz
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright (c) 2023 by manyeyes
|
# Copyright (c) 2023 by manyeyes
|
||||||
|
# Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||||
@@ -34,6 +35,27 @@ file(s) with a non-streaming model.
|
|||||||
|
|
||||||
(3) For CTC models from NeMo
|
(3) For CTC models from NeMo
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
|
||||||
|
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(4) For Whisper models
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||||
|
--num-threads=1 \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||||
|
|
||||||
Please refer to
|
Please refer to
|
||||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||||
to install sherpa-onnx and to download the pre-trained models
|
to install sherpa-onnx and to download the pre-trained models
|
||||||
@@ -144,6 +166,20 @@ def get_args():
|
|||||||
help="Number of threads for neural network computation",
|
help="Number of threads for neural network computation",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -247,6 +283,8 @@ def main():
|
|||||||
if args.encoder:
|
if args.encoder:
|
||||||
assert len(args.paraformer) == 0, args.paraformer
|
assert len(args.paraformer) == 0, args.paraformer
|
||||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
|
||||||
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
|
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
|
||||||
if contexts:
|
if contexts:
|
||||||
@@ -271,6 +309,9 @@ def main():
|
|||||||
)
|
)
|
||||||
elif args.paraformer:
|
elif args.paraformer:
|
||||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
|
||||||
assert_file_exists(args.paraformer)
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||||
@@ -283,6 +324,11 @@ def main():
|
|||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
)
|
)
|
||||||
elif args.nemo_ctc:
|
elif args.nemo_ctc:
|
||||||
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
|
||||||
|
assert_file_exists(args.nemo_ctc)
|
||||||
|
|
||||||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||||
model=args.nemo_ctc,
|
model=args.nemo_ctc,
|
||||||
tokens=args.tokens,
|
tokens=args.tokens,
|
||||||
@@ -292,6 +338,18 @@ def main():
|
|||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
)
|
)
|
||||||
|
elif args.whisper_encoder:
|
||||||
|
assert_file_exists(args.whisper_encoder)
|
||||||
|
assert_file_exists(args.whisper_decoder)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
||||||
|
encoder=args.whisper_encoder,
|
||||||
|
decoder=args.whisper_decoder,
|
||||||
|
tokens=args.tokens,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("Please specify at least one model")
|
print("Please specify at least one model")
|
||||||
return
|
return
|
||||||
|
|||||||
4
scripts/whisper/.gitignore
vendored
Normal file
4
scripts/whisper/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
*.onnx
|
||||||
|
*.config
|
||||||
|
*.ort
|
||||||
|
*-tokens.txt
|
||||||
9
scripts/whisper/README.md
Normal file
9
scripts/whisper/README.md
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# Introduction
|
||||||
|
|
||||||
|
This folder contains code showing how to convert [Whisper][whisper] to onnx
|
||||||
|
and use onnxruntime to replace PyTorch for speech recognition.
|
||||||
|
|
||||||
|
You can use [sherpa-onnx][sherpa-onnx] to run the converted model.
|
||||||
|
|
||||||
|
[whisper]: https://github.com/openai/whisper
|
||||||
|
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
|
||||||
439
scripts/whisper/export-onnx.py
Executable file
439
scripts/whisper/export-onnx.py
Executable file
@@ -0,0 +1,439 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
"""
|
||||||
|
Note: Code in this file is modified from
|
||||||
|
https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
|
||||||
|
|
||||||
|
Thanks to https://github.com/TadaoYamaoka
|
||||||
|
for making the onnx export script public.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
import whisper
|
||||||
|
from whisper.model import (
|
||||||
|
AudioEncoder,
|
||||||
|
MultiHeadAttention,
|
||||||
|
ResidualAttentionBlock,
|
||||||
|
TextDecoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
# fmt: off
|
||||||
|
choices=[
|
||||||
|
"tiny", "tiny.en", "base", "base.en",
|
||||||
|
"small", "small.en", "medium", "medium.en",
|
||||||
|
"large", "large-v1", "large-v2"],
|
||||||
|
# fmt: on
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = str(value)
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEncoderTensorCache(nn.Module):
|
||||||
|
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
|
||||||
|
super().__init__()
|
||||||
|
self.audioEncoder = inAudioEncoder
|
||||||
|
self.textDecoder = inTextDecoder
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
audio_features = self.audioEncoder(x)
|
||||||
|
|
||||||
|
n_layer_cross_k_list = []
|
||||||
|
n_layer_cross_v_list = []
|
||||||
|
for block in self.textDecoder.blocks:
|
||||||
|
n_layer_cross_k_list.append(block.cross_attn.key(audio_features))
|
||||||
|
n_layer_cross_v_list.append(block.cross_attn.value(audio_features))
|
||||||
|
|
||||||
|
return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttentionCross(nn.Module):
|
||||||
|
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
|
||||||
|
super().__init__()
|
||||||
|
self.multiHeadAttention = inMultiHeadAttention
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
k: Tensor,
|
||||||
|
v: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
q = self.multiHeadAttention.query(x)
|
||||||
|
wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask)
|
||||||
|
return self.multiHeadAttention.out(wv)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttentionSelf(nn.Module):
|
||||||
|
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
|
||||||
|
super().__init__()
|
||||||
|
self.multiHeadAttention = inMultiHeadAttention
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor, # (b, n_ctx , n_state)
|
||||||
|
k_cache: Tensor, # (b, n_ctx_cache, n_state)
|
||||||
|
v_cache: Tensor, # (b, n_ctx_cache, n_state)
|
||||||
|
mask: Tensor,
|
||||||
|
):
|
||||||
|
q = self.multiHeadAttention.query(x) # (b, n_ctx, n_state)
|
||||||
|
k = self.multiHeadAttention.key(x) # (b, n_ctx, n_state)
|
||||||
|
v = self.multiHeadAttention.value(x) # (b, n_ctx, n_state)
|
||||||
|
|
||||||
|
k_cache[:, -k.shape[1] :, :] = k # (b, n_ctx_cache + n_ctx, n_state)
|
||||||
|
v_cache[:, -v.shape[1] :, :] = v # (b, n_ctx_cache + n_ctx, n_state)
|
||||||
|
|
||||||
|
wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache, mask)
|
||||||
|
return self.multiHeadAttention.out(wv), k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlockTensorCache(nn.Module):
|
||||||
|
def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock):
|
||||||
|
super().__init__()
|
||||||
|
self.originalBlock = inResidualAttentionBlock
|
||||||
|
self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn)
|
||||||
|
self.cross_attn = (
|
||||||
|
MultiHeadAttentionCross(inResidualAttentionBlock.cross_attn)
|
||||||
|
if inResidualAttentionBlock.cross_attn
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
self_k_cache: Tensor,
|
||||||
|
self_v_cache: Tensor,
|
||||||
|
cross_k: Tensor,
|
||||||
|
cross_v: Tensor,
|
||||||
|
mask: Tensor,
|
||||||
|
):
|
||||||
|
self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn(
|
||||||
|
self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask
|
||||||
|
)
|
||||||
|
x = x + self_attn_x
|
||||||
|
|
||||||
|
if self.cross_attn:
|
||||||
|
x = x + self.cross_attn(
|
||||||
|
self.originalBlock.cross_attn_ln(x), cross_k, cross_v
|
||||||
|
)
|
||||||
|
|
||||||
|
x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x))
|
||||||
|
return x, self_k_cache_updated, self_v_cache_updated
|
||||||
|
|
||||||
|
|
||||||
|
class TextDecoderTensorCache(nn.Module):
|
||||||
|
def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.textDecoder = inTextDecoder
|
||||||
|
self.n_ctx = in_n_ctx
|
||||||
|
|
||||||
|
self.blocks = []
|
||||||
|
for orginal_block in self.textDecoder.blocks:
|
||||||
|
self.blocks.append(ResidualAttentionBlockTensorCache(orginal_block))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokens: Tensor,
|
||||||
|
n_layer_self_k_cache: Tensor,
|
||||||
|
n_layer_self_v_cache: Tensor,
|
||||||
|
n_layer_cross_k: Tensor,
|
||||||
|
n_layer_cross_v: Tensor,
|
||||||
|
offset: Tensor,
|
||||||
|
):
|
||||||
|
x = (
|
||||||
|
self.textDecoder.token_embedding(tokens)
|
||||||
|
+ self.textDecoder.positional_embedding[
|
||||||
|
offset[0] : offset[0] + tokens.shape[-1]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
x = x.to(n_layer_cross_k[0].dtype)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for block in self.blocks:
|
||||||
|
self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :]
|
||||||
|
self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :]
|
||||||
|
x, self_k_cache, self_v_cache = block(
|
||||||
|
x,
|
||||||
|
self_k_cache=self_k_cache,
|
||||||
|
self_v_cache=self_v_cache,
|
||||||
|
cross_k=n_layer_cross_k[i],
|
||||||
|
cross_v=n_layer_cross_v[i],
|
||||||
|
mask=self.textDecoder.mask,
|
||||||
|
)
|
||||||
|
n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache
|
||||||
|
n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
x = self.textDecoder.ln(x)
|
||||||
|
|
||||||
|
logits = (
|
||||||
|
x
|
||||||
|
@ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1)
|
||||||
|
).float()
|
||||||
|
|
||||||
|
return logits, n_layer_self_k_cache, n_layer_self_v_cache
|
||||||
|
|
||||||
|
|
||||||
|
# ref: https://github.com/ggerganov/whisper.cpp/blob/master/models/convert-pt-to-ggml.py#L232
|
||||||
|
def convert_tokens(name, model):
|
||||||
|
whisper_dir = Path(whisper.__file__).parent
|
||||||
|
multilingual = model.is_multilingual
|
||||||
|
tokenizer = (
|
||||||
|
whisper_dir
|
||||||
|
/ "assets"
|
||||||
|
/ (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
|
||||||
|
)
|
||||||
|
if not tokenizer.is_file():
|
||||||
|
raise ValueError(f"Cannot find {tokenizer}")
|
||||||
|
|
||||||
|
# import base64
|
||||||
|
|
||||||
|
with open(tokenizer, "r") as f:
|
||||||
|
contents = f.read()
|
||||||
|
# tokens = {
|
||||||
|
# base64.b64decode(token): int(rank)
|
||||||
|
# for token, rank in (line.split() for line in contents.splitlines() if line)
|
||||||
|
# }
|
||||||
|
tokens = {
|
||||||
|
token: int(rank)
|
||||||
|
for token, rank in (line.split() for line in contents.splitlines() if line)
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(f"{name}-tokens.txt", "w") as f:
|
||||||
|
for t, i in tokens.items():
|
||||||
|
f.write(f"{t} {i}\n")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
name = args.model
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
model = whisper.load_model(name)
|
||||||
|
convert_tokens(name=name, model=model)
|
||||||
|
|
||||||
|
# write tokens
|
||||||
|
|
||||||
|
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
||||||
|
model.eval()
|
||||||
|
print(model.dims)
|
||||||
|
audio = torch.rand(16000 * 2)
|
||||||
|
audio = whisper.pad_or_trim(audio)
|
||||||
|
assert audio.shape == (16000 * 30,), audio.shape
|
||||||
|
|
||||||
|
# make log-Mel spectrogram and move to the same device as the model
|
||||||
|
mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
|
||||||
|
batch_size = 1
|
||||||
|
assert mel.shape == (batch_size, 80, 30 * 100)
|
||||||
|
|
||||||
|
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
||||||
|
n_layer_cross_k, n_layer_cross_v = encoder(mel)
|
||||||
|
assert n_layer_cross_k.shape == (
|
||||||
|
model.dims.n_text_layer,
|
||||||
|
batch_size,
|
||||||
|
model.dims.n_audio_ctx,
|
||||||
|
model.dims.n_text_state,
|
||||||
|
), n_layer_cross_k.shape
|
||||||
|
assert n_layer_cross_v.shape == (
|
||||||
|
model.dims.n_text_layer,
|
||||||
|
batch_size,
|
||||||
|
model.dims.n_audio_ctx,
|
||||||
|
model.dims.n_text_state,
|
||||||
|
), n_layer_cross_v.shape
|
||||||
|
|
||||||
|
encoder_filename = f"{name}-encoder.onnx"
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder,
|
||||||
|
mel,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["mel"],
|
||||||
|
output_names=["n_layer_cross_k", "n_layer_cross_v"],
|
||||||
|
dynamic_axes={
|
||||||
|
"mel": {0: "n_audio"}, # n_audio is also known as batch_size
|
||||||
|
"n_layer_cross_k": {1: "n_audio"},
|
||||||
|
"n_layer_cross_v": {1: "n_audio"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_meta_data = {
|
||||||
|
"model_type": f"whisper-{name}",
|
||||||
|
"version": "1",
|
||||||
|
"maintainer": "k2-fsa",
|
||||||
|
"n_mels": model.dims.n_mels,
|
||||||
|
"n_audio_ctx": model.dims.n_audio_ctx,
|
||||||
|
"n_audio_state": model.dims.n_audio_state,
|
||||||
|
"n_audio_head": model.dims.n_audio_head,
|
||||||
|
"n_audio_layer": model.dims.n_audio_layer,
|
||||||
|
"n_vocab": model.dims.n_vocab,
|
||||||
|
"n_text_ctx": model.dims.n_text_ctx,
|
||||||
|
"n_text_state": model.dims.n_text_state,
|
||||||
|
"n_text_head": model.dims.n_text_head,
|
||||||
|
"n_text_layer": model.dims.n_text_layer,
|
||||||
|
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
|
||||||
|
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
|
||||||
|
"all_language_codes": ",".join(tokenizer.all_language_codes),
|
||||||
|
"sot": tokenizer.sot,
|
||||||
|
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
|
||||||
|
"eot": tokenizer.eot,
|
||||||
|
"blank_id": tokenizer.encode(" ")[0],
|
||||||
|
"is_multilingual": int(model.is_multilingual),
|
||||||
|
"no_speech": tokenizer.no_speech,
|
||||||
|
"non_speech_tokens": ",".join(list(map(str, tokenizer.non_speech_tokens))),
|
||||||
|
"transcribe": tokenizer.transcribe,
|
||||||
|
"translate": tokenizer.translate,
|
||||||
|
"sot_prev": tokenizer.sot_prev,
|
||||||
|
"sot_lm": tokenizer.sot_lm,
|
||||||
|
"no_timestamps": tokenizer.no_timestamps,
|
||||||
|
}
|
||||||
|
print(f"encoder_meta_data: {encoder_meta_data}")
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data)
|
||||||
|
|
||||||
|
n_audio = mel.shape[0]
|
||||||
|
tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] * n_audio).to(
|
||||||
|
mel.device
|
||||||
|
) # [n_audio, 3]
|
||||||
|
decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx)
|
||||||
|
n_layer_self_k_cache = torch.zeros(
|
||||||
|
(
|
||||||
|
len(model.decoder.blocks),
|
||||||
|
n_audio,
|
||||||
|
model.dims.n_text_ctx,
|
||||||
|
model.dims.n_text_state,
|
||||||
|
),
|
||||||
|
device=mel.device,
|
||||||
|
)
|
||||||
|
n_layer_self_v_cache = torch.zeros(
|
||||||
|
(
|
||||||
|
len(model.decoder.blocks),
|
||||||
|
n_audio,
|
||||||
|
model.dims.n_text_ctx,
|
||||||
|
model.dims.n_text_state,
|
||||||
|
),
|
||||||
|
device=mel.device,
|
||||||
|
)
|
||||||
|
offset = torch.zeros(1, dtype=torch.int64).to(mel.device)
|
||||||
|
logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder(
|
||||||
|
tokens,
|
||||||
|
n_layer_self_k_cache,
|
||||||
|
n_layer_self_v_cache,
|
||||||
|
n_layer_cross_k,
|
||||||
|
n_layer_cross_v,
|
||||||
|
offset,
|
||||||
|
)
|
||||||
|
assert logits.shape == (n_audio, tokens.shape[1], model.dims.n_vocab)
|
||||||
|
assert n_layer_self_k_cache.shape == (
|
||||||
|
model.dims.n_text_layer,
|
||||||
|
n_audio,
|
||||||
|
model.dims.n_text_ctx,
|
||||||
|
model.dims.n_text_state,
|
||||||
|
)
|
||||||
|
assert n_layer_self_v_cache.shape == (
|
||||||
|
model.dims.n_text_layer,
|
||||||
|
n_audio,
|
||||||
|
model.dims.n_text_ctx,
|
||||||
|
model.dims.n_text_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device)
|
||||||
|
tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||||
|
|
||||||
|
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder(
|
||||||
|
tokens,
|
||||||
|
n_layer_self_k_cache,
|
||||||
|
n_layer_self_v_cache,
|
||||||
|
n_layer_cross_k,
|
||||||
|
n_layer_cross_v,
|
||||||
|
offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename = f"{name}-decoder.onnx"
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder,
|
||||||
|
(
|
||||||
|
tokens,
|
||||||
|
n_layer_self_k_cache,
|
||||||
|
n_layer_self_v_cache,
|
||||||
|
n_layer_cross_k,
|
||||||
|
n_layer_cross_v,
|
||||||
|
offset,
|
||||||
|
),
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"tokens",
|
||||||
|
"in_n_layer_self_k_cache",
|
||||||
|
"in_n_layer_self_v_cache",
|
||||||
|
"n_layer_cross_k",
|
||||||
|
"n_layer_cross_v",
|
||||||
|
"offset",
|
||||||
|
],
|
||||||
|
output_names=["logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"],
|
||||||
|
dynamic_axes={
|
||||||
|
"tokens": {0: "n_audio", 1: "n_tokens"},
|
||||||
|
"in_n_layer_self_k_cache": {1: "n_audio"},
|
||||||
|
"in_n_layer_self_v_cache": {1: "n_audio"},
|
||||||
|
"n_layer_cross_k": {1: "n_audio"},
|
||||||
|
"n_layer_cross_v": {1: "n_audio"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
print("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = f"{name}-encoder.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = f"{name}-decoder.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
scripts/whisper/requirements.txt
Normal file
1
scripts/whisper/requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
openai-whisper
|
||||||
241
scripts/whisper/test.py
Executable file
241
scripts/whisper/test.py
Executable file
@@ -0,0 +1,241 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
"""
|
||||||
|
Please first run ./export-onnx.py
|
||||||
|
before you run this script
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import kaldi_native_fbank as knf
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import whisper
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
# fmt: off
|
||||||
|
choices=[
|
||||||
|
"tiny", "tiny.en", "base", "base.en",
|
||||||
|
"small", "small.en", "medium", "medium.en",
|
||||||
|
"large", "large-v1", "large-v2"],
|
||||||
|
# fmt: on
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: str,
|
||||||
|
decoder: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 4
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.init_encoder(encoder)
|
||||||
|
self.init_decoder(decoder)
|
||||||
|
|
||||||
|
def init_encoder(self, encoder: str):
|
||||||
|
self.encoder = ort.InferenceSession(
|
||||||
|
encoder,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||||
|
self.n_text_layer = int(meta["n_text_layer"])
|
||||||
|
self.n_text_ctx = int(meta["n_text_ctx"])
|
||||||
|
self.n_text_state = int(meta["n_text_state"])
|
||||||
|
self.sot = int(meta["sot"])
|
||||||
|
self.eot = int(meta["eot"])
|
||||||
|
self.translate = int(meta["translate"])
|
||||||
|
self.no_timestamps = int(meta["no_timestamps"])
|
||||||
|
self.no_speech = int(meta["no_speech"])
|
||||||
|
self.blank = int(meta["blank_id"])
|
||||||
|
|
||||||
|
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
||||||
|
|
||||||
|
self.is_multilingual = int(meta["is_multilingual"]) == 1
|
||||||
|
|
||||||
|
def init_decoder(self, decoder: str):
|
||||||
|
self.decoder = ort.InferenceSession(
|
||||||
|
decoder,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_encoder(
|
||||||
|
self,
|
||||||
|
mel: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
n_layer_cross_k, n_layer_cross_v = self.encoder.run(
|
||||||
|
[
|
||||||
|
self.encoder.get_outputs()[0].name,
|
||||||
|
self.encoder.get_outputs()[1].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.encoder.get_inputs()[0].name: mel.numpy(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return torch.from_numpy(n_layer_cross_k), torch.from_numpy(n_layer_cross_v)
|
||||||
|
|
||||||
|
def run_decoder(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
n_layer_self_k_cache: torch.Tensor,
|
||||||
|
n_layer_self_v_cache: torch.Tensor,
|
||||||
|
n_layer_cross_k: torch.Tensor,
|
||||||
|
n_layer_cross_v: torch.Tensor,
|
||||||
|
offset: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
|
||||||
|
[
|
||||||
|
self.decoder.get_outputs()[0].name,
|
||||||
|
self.decoder.get_outputs()[1].name,
|
||||||
|
self.decoder.get_outputs()[2].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.decoder.get_inputs()[0].name: tokens.numpy(),
|
||||||
|
self.decoder.get_inputs()[1].name: n_layer_self_k_cache.numpy(),
|
||||||
|
self.decoder.get_inputs()[2].name: n_layer_self_v_cache.numpy(),
|
||||||
|
self.decoder.get_inputs()[3].name: n_layer_cross_k.numpy(),
|
||||||
|
self.decoder.get_inputs()[4].name: n_layer_cross_v.numpy(),
|
||||||
|
self.decoder.get_inputs()[5].name: offset.numpy(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
torch.from_numpy(logits),
|
||||||
|
torch.from_numpy(out_n_layer_self_k_cache),
|
||||||
|
torch.from_numpy(out_n_layer_self_v_cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
batch_size = 1
|
||||||
|
n_layer_self_k_cache = torch.zeros(
|
||||||
|
self.n_text_layer,
|
||||||
|
batch_size,
|
||||||
|
self.n_text_ctx,
|
||||||
|
self.n_text_state,
|
||||||
|
)
|
||||||
|
n_layer_self_v_cache = torch.zeros(
|
||||||
|
self.n_text_layer,
|
||||||
|
batch_size,
|
||||||
|
self.n_text_ctx,
|
||||||
|
self.n_text_state,
|
||||||
|
)
|
||||||
|
return n_layer_self_k_cache, n_layer_self_v_cache
|
||||||
|
|
||||||
|
def suppress_tokens(self, logits, is_initial: bool) -> None:
|
||||||
|
# suppress blank
|
||||||
|
if is_initial:
|
||||||
|
logits[self.eot] = float("-inf")
|
||||||
|
logits[self.blank] = float("-inf")
|
||||||
|
|
||||||
|
# suppress <|notimestamps|>
|
||||||
|
logits[self.no_timestamps] = float("-inf")
|
||||||
|
|
||||||
|
logits[self.sot] = float("-inf")
|
||||||
|
logits[self.no_speech] = float("-inf")
|
||||||
|
|
||||||
|
# logits is changed in-place
|
||||||
|
logits[self.translate] = float("-inf")
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokens(filename):
|
||||||
|
tokens = dict()
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
t, i = line.split()
|
||||||
|
tokens[int(i)] = t
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
name = args.model
|
||||||
|
|
||||||
|
encoder = f"./{name}-encoder.onnx"
|
||||||
|
decoder = f"./{name}-decoder.onnx"
|
||||||
|
audio = whisper.load_audio("0.wav")
|
||||||
|
|
||||||
|
features = []
|
||||||
|
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
||||||
|
online_whisper_fbank.accept_waveform(16000, audio)
|
||||||
|
online_whisper_fbank.input_finished()
|
||||||
|
for i in range(online_whisper_fbank.num_frames_ready):
|
||||||
|
f = online_whisper_fbank.get_frame(i)
|
||||||
|
f = torch.from_numpy(f)
|
||||||
|
features.append(f)
|
||||||
|
|
||||||
|
features = torch.stack(features)
|
||||||
|
|
||||||
|
log_spec = torch.clamp(features, min=1e-10).log10()
|
||||||
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
mel = (log_spec + 4.0) / 4.0
|
||||||
|
target = 3000
|
||||||
|
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||||
|
mel = mel.t().unsqueeze(0)
|
||||||
|
|
||||||
|
model = OnnxModel(encoder, decoder)
|
||||||
|
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
||||||
|
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
||||||
|
|
||||||
|
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
||||||
|
offset = torch.zeros(1, dtype=torch.int64)
|
||||||
|
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||||
|
tokens=tokens,
|
||||||
|
n_layer_self_k_cache=n_layer_self_k_cache,
|
||||||
|
n_layer_self_v_cache=n_layer_self_v_cache,
|
||||||
|
n_layer_cross_k=n_layer_cross_k,
|
||||||
|
n_layer_cross_v=n_layer_cross_v,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
# logits.shape (batch_size, tokens.shape[1], vocab_size)
|
||||||
|
logits = logits[0, -1]
|
||||||
|
model.suppress_tokens(logits, is_initial=True)
|
||||||
|
# logits = logits.softmax(dim=-1)
|
||||||
|
# for greedy search, we don't need to compute softmax or log_softmax
|
||||||
|
max_token_id = logits.argmax(dim=-1)
|
||||||
|
results = []
|
||||||
|
for i in range(model.n_text_ctx):
|
||||||
|
if max_token_id == model.eot:
|
||||||
|
break
|
||||||
|
results.append(max_token_id.item())
|
||||||
|
tokens = torch.tensor([[results[-1]]])
|
||||||
|
offset += 1
|
||||||
|
|
||||||
|
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||||
|
tokens=tokens,
|
||||||
|
n_layer_self_k_cache=n_layer_self_k_cache,
|
||||||
|
n_layer_self_v_cache=n_layer_self_v_cache,
|
||||||
|
n_layer_cross_k=n_layer_cross_k,
|
||||||
|
n_layer_cross_v=n_layer_cross_v,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
logits = logits[0, -1]
|
||||||
|
model.suppress_tokens(logits, is_initial=False)
|
||||||
|
max_token_id = logits.argmax(dim=-1)
|
||||||
|
token_table = load_tokens(f"./{name}-tokens.txt")
|
||||||
|
s = b""
|
||||||
|
for i in results:
|
||||||
|
if i in token_table:
|
||||||
|
s += base64.b64decode(token_table[i])
|
||||||
|
else:
|
||||||
|
print("oov", i)
|
||||||
|
|
||||||
|
print(s.decode().strip())
|
||||||
|
print(results)
|
||||||
|
print(model.sot_sequence)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -11,6 +11,7 @@ if(SHERPA_ONNX_ENABLE_PYTHON)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(sources
|
set(sources
|
||||||
|
base64-decode.cc
|
||||||
cat.cc
|
cat.cc
|
||||||
context-graph.cc
|
context-graph.cc
|
||||||
endpoint.cc
|
endpoint.cc
|
||||||
@@ -35,6 +36,9 @@ set(sources
|
|||||||
offline-transducer-model-config.cc
|
offline-transducer-model-config.cc
|
||||||
offline-transducer-model.cc
|
offline-transducer-model.cc
|
||||||
offline-transducer-modified-beam-search-decoder.cc
|
offline-transducer-modified-beam-search-decoder.cc
|
||||||
|
offline-whisper-greedy-search-decoder.cc
|
||||||
|
offline-whisper-model-config.cc
|
||||||
|
offline-whisper-model.cc
|
||||||
online-conformer-transducer-model.cc
|
online-conformer-transducer-model.cc
|
||||||
online-lm-config.cc
|
online-lm-config.cc
|
||||||
online-lm.cc
|
online-lm.cc
|
||||||
@@ -50,12 +54,12 @@ set(sources
|
|||||||
online-zipformer-transducer-model.cc
|
online-zipformer-transducer-model.cc
|
||||||
online-zipformer2-transducer-model.cc
|
online-zipformer2-transducer-model.cc
|
||||||
onnx-utils.cc
|
onnx-utils.cc
|
||||||
session.cc
|
|
||||||
packed-sequence.cc
|
packed-sequence.cc
|
||||||
pad-sequence.cc
|
pad-sequence.cc
|
||||||
parse-options.cc
|
parse-options.cc
|
||||||
provider.cc
|
provider.cc
|
||||||
resample.cc
|
resample.cc
|
||||||
|
session.cc
|
||||||
slice.cc
|
slice.cc
|
||||||
stack.cc
|
stack.cc
|
||||||
symbol-table.cc
|
symbol-table.cc
|
||||||
|
|||||||
67
sherpa-onnx/csrc/base64-decode.cc
Normal file
67
sherpa-onnx/csrc/base64-decode.cc
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// sherpa-onnx/csrc/base64-decode.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/base64-decode.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static int32_t Ord(char c) {
|
||||||
|
if (c >= 'A' && c <= 'Z') {
|
||||||
|
return c - 'A';
|
||||||
|
} else if (c >= 'a' && c <= 'z') {
|
||||||
|
return c - 'a' + ('Z' - 'A') + 1;
|
||||||
|
} else if (c >= '0' && c <= '9') {
|
||||||
|
return c - '0' + ('Z' - 'A') + ('z' - 'a') + 2;
|
||||||
|
} else if (c == '+') {
|
||||||
|
return 62;
|
||||||
|
} else if (c == '/') {
|
||||||
|
return 63;
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("Unknown character %d, %c\n", c, c);
|
||||||
|
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// see
|
||||||
|
// https://github.com/ReneNyffenegger/cpp-base64/blob/master/base64.cpp#L243
|
||||||
|
std::string Base64Decode(const std::string &s) {
|
||||||
|
if (s.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Empty string!");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t n = s.size() / 4 * 3;
|
||||||
|
|
||||||
|
std::string ans;
|
||||||
|
ans.reserve(n);
|
||||||
|
|
||||||
|
int32_t i = 0;
|
||||||
|
while (i < static_cast<int32_t>(s.size())) {
|
||||||
|
if (s[i] == '=') {
|
||||||
|
return " ";
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4);
|
||||||
|
ans.push_back(first);
|
||||||
|
|
||||||
|
if (i + 2 < static_cast<int32_t>(s.size()) && s[i + 2] != '=') {
|
||||||
|
int32_t second =
|
||||||
|
((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2);
|
||||||
|
ans.push_back(second);
|
||||||
|
|
||||||
|
if (i + 3 < static_cast<int32_t>(s.size()) && s[i + 3] != '=') {
|
||||||
|
int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]);
|
||||||
|
ans.push_back(third);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
19
sherpa-onnx/csrc/base64-decode.h
Normal file
19
sherpa-onnx/csrc/base64-decode.h
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
// sherpa-onnx/csrc/base64-decode.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_BASE64_DECODE_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_BASE64_DECODE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/** @param s A base64 encoded string.
|
||||||
|
* @return Return the decoded string.
|
||||||
|
*/
|
||||||
|
std::string Base64Decode(const std::string &s);
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_BASE64_DECODE_H_
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
// sherpa-onnx/csrc/macros.h
|
// sherpa-onnx/csrc/macros.h
|
||||||
//
|
//
|
||||||
// Copyright 2023 Xiaomi Corporation
|
// Copyright 2023 Xiaomi Corporation
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
|||||||
transducer.Register(po);
|
transducer.Register(po);
|
||||||
paraformer.Register(po);
|
paraformer.Register(po);
|
||||||
nemo_ctc.Register(po);
|
nemo_ctc.Register(po);
|
||||||
|
whisper.Register(po);
|
||||||
|
|
||||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
|||||||
|
|
||||||
po->Register("model-type", &model_type,
|
po->Register("model-type", &model_type,
|
||||||
"Specify it to reduce model initialization time. "
|
"Specify it to reduce model initialization time. "
|
||||||
"Valid values are: transducer, paraformer, nemo_ctc. "
|
"Valid values are: transducer, paraformer, nemo_ctc, whisper."
|
||||||
"All other values lead to loading the model twice.");
|
"All other values lead to loading the model twice.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,6 +52,10 @@ bool OfflineModelConfig::Validate() const {
|
|||||||
return nemo_ctc.Validate();
|
return nemo_ctc.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!whisper.encoder.empty()) {
|
||||||
|
return whisper.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
return transducer.Validate();
|
return transducer.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,6 +66,7 @@ std::string OfflineModelConfig::ToString() const {
|
|||||||
os << "transducer=" << transducer.ToString() << ", ";
|
os << "transducer=" << transducer.ToString() << ", ";
|
||||||
os << "paraformer=" << paraformer.ToString() << ", ";
|
os << "paraformer=" << paraformer.ToString() << ", ";
|
||||||
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
||||||
|
os << "whisper=" << whisper.ToString() << ", ";
|
||||||
os << "tokens=\"" << tokens << "\", ";
|
os << "tokens=\"" << tokens << "\", ";
|
||||||
os << "num_threads=" << num_threads << ", ";
|
os << "num_threads=" << num_threads << ", ";
|
||||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ struct OfflineModelConfig {
|
|||||||
OfflineTransducerModelConfig transducer;
|
OfflineTransducerModelConfig transducer;
|
||||||
OfflineParaformerModelConfig paraformer;
|
OfflineParaformerModelConfig paraformer;
|
||||||
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
||||||
|
OfflineWhisperModelConfig whisper;
|
||||||
|
|
||||||
std::string tokens;
|
std::string tokens;
|
||||||
int32_t num_threads = 2;
|
int32_t num_threads = 2;
|
||||||
@@ -37,11 +39,13 @@ struct OfflineModelConfig {
|
|||||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||||
const OfflineParaformerModelConfig ¶former,
|
const OfflineParaformerModelConfig ¶former,
|
||||||
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||||
|
const OfflineWhisperModelConfig &whisper,
|
||||||
const std::string &tokens, int32_t num_threads, bool debug,
|
const std::string &tokens, int32_t num_threads, bool debug,
|
||||||
const std::string &provider, const std::string &model_type)
|
const std::string &provider, const std::string &model_type)
|
||||||
: transducer(transducer),
|
: transducer(transducer),
|
||||||
paraformer(paraformer),
|
paraformer(paraformer),
|
||||||
nemo_ctc(nemo_ctc),
|
nemo_ctc(nemo_ctc),
|
||||||
|
whisper(whisper),
|
||||||
tokens(tokens),
|
tokens(tokens),
|
||||||
num_threads(num_threads),
|
num_threads(num_threads),
|
||||||
debug(debug),
|
debug(debug),
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
|
|||||||
|
|
||||||
bool OfflineNemoEncDecCtcModelConfig::Validate() const {
|
bool OfflineNemoEncDecCtcModelConfig::Validate() const {
|
||||||
if (!FileExists(model)) {
|
if (!FileExists(model)) {
|
||||||
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
SHERPA_ONNX_LOGE("NeMo model: %s does not exist", model.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) {
|
|||||||
|
|
||||||
bool OfflineParaformerModelConfig::Validate() const {
|
bool OfflineParaformerModelConfig::Validate() const {
|
||||||
if (!FileExists(model)) {
|
if (!FileExists(model)) {
|
||||||
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
SHERPA_ONNX_LOGE("Paraformer model %s does not exist", model.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
#include "sherpa-onnx/csrc/text-utils.h"
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
@@ -26,6 +27,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||||
} else if (model_type == "nemo_ctc") {
|
} else if (model_type == "nemo_ctc") {
|
||||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
|
} else if (model_type == "whisper") {
|
||||||
|
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||||
@@ -43,6 +46,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
model_filename = config.model_config.paraformer.model;
|
model_filename = config.model_config.paraformer.model;
|
||||||
} else if (!config.model_config.nemo_ctc.model.empty()) {
|
} else if (!config.model_config.nemo_ctc.model.empty()) {
|
||||||
model_filename = config.model_config.nemo_ctc.model;
|
model_filename = config.model_config.nemo_ctc.model;
|
||||||
|
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||||
|
model_filename = config.model_config.whisper.encoder;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Please provide a model");
|
SHERPA_ONNX_LOGE("Please provide a model");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
@@ -77,6 +82,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
"\n "
|
"\n "
|
||||||
"https://huggingface.co/csukuangfj/"
|
"https://huggingface.co/csukuangfj/"
|
||||||
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
|
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
|
||||||
|
"\n "
|
||||||
|
"(3) Whisper"
|
||||||
"\n");
|
"\n");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
@@ -95,12 +102,17 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (strncmp(model_type.c_str(), "whisper", 7) == 0) {
|
||||||
|
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"\nUnsupported model_type: %s\n"
|
"\nUnsupported model_type: %s\n"
|
||||||
"We support only the following model types at present: \n"
|
"We support only the following model types at present: \n"
|
||||||
" - Non-streaming transducer models from icefall\n"
|
" - Non-streaming transducer models from icefall\n"
|
||||||
" - Non-streaming Paraformer models from FunASR\n"
|
" - Non-streaming Paraformer models from FunASR\n"
|
||||||
" - EncDecCTCModelBPE models from NeMo\n",
|
" - EncDecCTCModelBPE models from NeMo\n"
|
||||||
|
" - Whisper models\n",
|
||||||
model_type.c_str());
|
model_type.c_str());
|
||||||
|
|
||||||
exit(-1);
|
exit(-1);
|
||||||
|
|||||||
152
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Normal file
152
sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||||
|
const SymbolTable &sym_table) {
|
||||||
|
OfflineRecognitionResult r;
|
||||||
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
|
||||||
|
for (auto i : src.tokens) {
|
||||||
|
if (!sym_table.contains(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &s = sym_table[i];
|
||||||
|
r.text += s;
|
||||||
|
r.tokens.push_back(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||||
|
public:
|
||||||
|
explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
symbol_table_(config_.model_config.tokens),
|
||||||
|
model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
|
||||||
|
// tokens.txt from whisper is base64 encoded, so we need to decode it
|
||||||
|
symbol_table_.ApplyBase64Decode();
|
||||||
|
|
||||||
|
if (config.decoding_method == "greedy_search") {
|
||||||
|
decoder_ =
|
||||||
|
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Only greedy_search is supported at present for whisper. Given %s",
|
||||||
|
config.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
|
return std::make_unique<OfflineStream>(WhisperTag{});
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||||
|
// batch decoding is not implemented yet
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
DecodeStream(ss[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void DecodeStream(OfflineStream *s) const {
|
||||||
|
int32_t max_num_frames = 3000;
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
int32_t feat_dim = s->FeatureDim();
|
||||||
|
std::vector<float> f = s->GetFrames();
|
||||||
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
|
||||||
|
if (num_frames > max_num_frames) {
|
||||||
|
SHERPA_ONNX_LOGE("Only waves less than 30 seconds are supported.");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> shape{1, max_num_frames, feat_dim};
|
||||||
|
|
||||||
|
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
||||||
|
model_->Allocator(), shape.data(), shape.size());
|
||||||
|
float *p_mel = mel.GetTensorMutableData<float>();
|
||||||
|
std::copy(f.begin(), f.end(), p_mel);
|
||||||
|
|
||||||
|
memset(p_mel + f.size(), 0,
|
||||||
|
(max_num_frames - num_frames) * feat_dim * sizeof(float));
|
||||||
|
mel = Transpose12(model_->Allocator(), &mel);
|
||||||
|
|
||||||
|
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||||
|
auto results =
|
||||||
|
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
|
||||||
|
|
||||||
|
auto r = Convert(results[0], symbol_table_);
|
||||||
|
s->SetResult(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static void NormalizeFeatures(float *features, int32_t num_frames,
|
||||||
|
int32_t feat_dim) {
|
||||||
|
// log_spec = torch.clamp(features, min=1e-10).log10()
|
||||||
|
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
// mel = (log_spec + 4.0) / 4.0
|
||||||
|
|
||||||
|
int32_t n = num_frames * feat_dim;
|
||||||
|
float max_v = -1e20;
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
float f = features[i];
|
||||||
|
|
||||||
|
f = std::max<float>(f, 1e-10);
|
||||||
|
f = std::log10(f);
|
||||||
|
|
||||||
|
max_v = std::max(f, max_v);
|
||||||
|
|
||||||
|
features[i] = f;
|
||||||
|
}
|
||||||
|
|
||||||
|
max_v -= 8;
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
float f = features[i];
|
||||||
|
f = std::max(f, max_v);
|
||||||
|
|
||||||
|
f = (f + 4) / 4;
|
||||||
|
|
||||||
|
features[i] = f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineRecognizerConfig config_;
|
||||||
|
SymbolTable symbol_table_;
|
||||||
|
std::unique_ptr<OfflineWhisperModel> model_;
|
||||||
|
std::unique_ptr<OfflineWhisperDecoder> decoder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_
|
||||||
@@ -86,6 +86,15 @@ class OfflineStream::Impl {
|
|||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
|
||||||
|
: context_graph_(context_graph) {
|
||||||
|
config_.normalize_samples = true;
|
||||||
|
opts_.frame_opts.samp_freq = 16000;
|
||||||
|
opts_.mel_opts.num_bins = 80;
|
||||||
|
whisper_fbank_ =
|
||||||
|
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
|
||||||
|
}
|
||||||
|
|
||||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||||
if (config_.normalize_samples) {
|
if (config_.normalize_samples) {
|
||||||
AcceptWaveformImpl(sampling_rate, waveform, n);
|
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||||
@@ -117,20 +126,35 @@ class OfflineStream::Impl {
|
|||||||
lowpass_filter_width);
|
lowpass_filter_width);
|
||||||
std::vector<float> samples;
|
std::vector<float> samples;
|
||||||
resampler->Resample(waveform, n, true, &samples);
|
resampler->Resample(waveform, n, true, &samples);
|
||||||
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
|
||||||
samples.size());
|
|
||||||
fbank_->InputFinished();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
if (fbank_) {
|
||||||
fbank_->InputFinished();
|
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||||
|
samples.size());
|
||||||
|
fbank_->InputFinished();
|
||||||
|
} else {
|
||||||
|
whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq,
|
||||||
|
samples.data(), samples.size());
|
||||||
|
whisper_fbank_->InputFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
} // if (sampling_rate != opts_.frame_opts.samp_freq)
|
||||||
|
|
||||||
|
if (fbank_) {
|
||||||
|
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
|
fbank_->InputFinished();
|
||||||
|
} else {
|
||||||
|
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
|
whisper_fbank_->InputFinished();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||||
|
|
||||||
std::vector<float> GetFrames() const {
|
std::vector<float> GetFrames() const {
|
||||||
int32_t n = fbank_->NumFramesReady();
|
int32_t n =
|
||||||
|
fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady();
|
||||||
|
|
||||||
assert(n > 0 && "Please first call AcceptWaveform()");
|
assert(n > 0 && "Please first call AcceptWaveform()");
|
||||||
|
|
||||||
int32_t feature_dim = FeatureDim();
|
int32_t feature_dim = FeatureDim();
|
||||||
@@ -140,7 +164,8 @@ class OfflineStream::Impl {
|
|||||||
float *p = features.data();
|
float *p = features.data();
|
||||||
|
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
const float *f = fbank_->GetFrame(i);
|
const float *f =
|
||||||
|
fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i);
|
||||||
std::copy(f, f + feature_dim, p);
|
std::copy(f, f + feature_dim, p);
|
||||||
p += feature_dim;
|
p += feature_dim;
|
||||||
}
|
}
|
||||||
@@ -191,6 +216,7 @@ class OfflineStream::Impl {
|
|||||||
private:
|
private:
|
||||||
OfflineFeatureExtractorConfig config_;
|
OfflineFeatureExtractorConfig config_;
|
||||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||||
|
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||||
knf::FbankOptions opts_;
|
knf::FbankOptions opts_;
|
||||||
OfflineRecognitionResult r_;
|
OfflineRecognitionResult r_;
|
||||||
ContextGraphPtr context_graph_;
|
ContextGraphPtr context_graph_;
|
||||||
@@ -201,6 +227,10 @@ OfflineStream::OfflineStream(
|
|||||||
ContextGraphPtr context_graph /*= nullptr*/)
|
ContextGraphPtr context_graph /*= nullptr*/)
|
||||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||||
|
|
||||||
|
OfflineStream::OfflineStream(WhisperTag tag,
|
||||||
|
ContextGraphPtr context_graph /*= nullptr*/)
|
||||||
|
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
|
||||||
|
|
||||||
OfflineStream::~OfflineStream() = default;
|
OfflineStream::~OfflineStream() = default;
|
||||||
|
|
||||||
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||||
|
|||||||
@@ -65,10 +65,15 @@ struct OfflineFeatureExtractorConfig {
|
|||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct WhisperTag {};
|
||||||
|
|
||||||
class OfflineStream {
|
class OfflineStream {
|
||||||
public:
|
public:
|
||||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||||
ContextGraphPtr context_graph = nullptr);
|
ContextGraphPtr context_graph = nullptr);
|
||||||
|
|
||||||
|
explicit OfflineStream(WhisperTag tag,
|
||||||
|
ContextGraphPtr context_graph = nullptr);
|
||||||
~OfflineStream();
|
~OfflineStream();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -18,17 +18,20 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
|
|||||||
|
|
||||||
bool OfflineTransducerModelConfig::Validate() const {
|
bool OfflineTransducerModelConfig::Validate() const {
|
||||||
if (!FileExists(encoder_filename)) {
|
if (!FileExists(encoder_filename)) {
|
||||||
SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
|
SHERPA_ONNX_LOGE("transducer encoder: %s does not exist",
|
||||||
|
encoder_filename.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!FileExists(decoder_filename)) {
|
if (!FileExists(decoder_filename)) {
|
||||||
SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
|
SHERPA_ONNX_LOGE("transducer decoder: %s does not exist",
|
||||||
|
decoder_filename.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!FileExists(joiner_filename)) {
|
if (!FileExists(joiner_filename)) {
|
||||||
SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
|
SHERPA_ONNX_LOGE("transducer joiner: %s does not exist",
|
||||||
|
joiner_filename.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
38
sherpa-onnx/csrc/offline-whisper-decoder.h
Normal file
38
sherpa-onnx/csrc/offline-whisper-decoder.h
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-whisper-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineWhisperDecoderResult {
|
||||||
|
/// The decoded token IDs
|
||||||
|
std::vector<int32_t> tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OfflineWhisperDecoder {
|
||||||
|
public:
|
||||||
|
virtual ~OfflineWhisperDecoder() = default;
|
||||||
|
|
||||||
|
/** Run beam search given the output from the whisper encoder model.
|
||||||
|
*
|
||||||
|
* @param n_layer_cross_k A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||||
|
* @param n_layer_cross_v A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||||
|
*
|
||||||
|
* @return Return a vector of size `N` containing the decoded results.
|
||||||
|
*/
|
||||||
|
virtual std::vector<OfflineWhisperDecoderResult> Decode(
|
||||||
|
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
|
||||||
93
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
Normal file
93
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::vector<OfflineWhisperDecoderResult>
|
||||||
|
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||||
|
Ort::Value cross_v) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||||
|
|
||||||
|
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
|
||||||
|
int32_t batch_size = 1;
|
||||||
|
std::array<int64_t, 2> token_shape{
|
||||||
|
batch_size, static_cast<int64_t>(initial_tokens.size())};
|
||||||
|
|
||||||
|
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||||
|
memory_info, initial_tokens.data(), initial_tokens.size(),
|
||||||
|
token_shape.data(), token_shape.size());
|
||||||
|
|
||||||
|
std::array<int64_t, 1> offset_shape{1};
|
||||||
|
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||||
|
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||||
|
|
||||||
|
auto decoder_out = model_->ForwardDecoder(
|
||||||
|
std::move(tokens), std::move(self_kv_cache.first),
|
||||||
|
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||||
|
std::move(offset));
|
||||||
|
|
||||||
|
const auto &logits = std::get<0>(decoder_out);
|
||||||
|
const float *p_logits = logits.GetTensorData<float>();
|
||||||
|
|
||||||
|
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
int32_t vocab_size = logits_shape[2];
|
||||||
|
|
||||||
|
int32_t max_token_id = static_cast<int32_t>(std::distance(
|
||||||
|
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
|
||||||
|
|
||||||
|
int32_t n_text_ctx = model_->TextCtx();
|
||||||
|
|
||||||
|
std::vector<int32_t> predicted_tokens;
|
||||||
|
for (int32_t i = 0; i < n_text_ctx; ++i) {
|
||||||
|
if (max_token_id == model_->EOT()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
predicted_tokens.push_back(max_token_id);
|
||||||
|
|
||||||
|
std::array<int64_t, 2> token_shape{1, 1};
|
||||||
|
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
model_->Allocator(), token_shape.data(), token_shape.size());
|
||||||
|
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
|
||||||
|
p_tokens[0] = max_token_id;
|
||||||
|
|
||||||
|
int64_t *p_offset =
|
||||||
|
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
|
||||||
|
|
||||||
|
if (i == 0) {
|
||||||
|
*p_offset = initial_tokens.size();
|
||||||
|
} else {
|
||||||
|
*p_offset += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder_out = model_->ForwardDecoder(std::move(tokens),
|
||||||
|
std::move(std::get<1>(decoder_out)),
|
||||||
|
std::move(std::get<2>(decoder_out)),
|
||||||
|
std::move(std::get<3>(decoder_out)),
|
||||||
|
std::move(std::get<4>(decoder_out)),
|
||||||
|
std::move(std::get<5>(decoder_out)));
|
||||||
|
|
||||||
|
const auto &logits = std::get<0>(decoder_out);
|
||||||
|
const float *p_logits = logits.GetTensorData<float>();
|
||||||
|
|
||||||
|
max_token_id = static_cast<int64_t>(std::distance(
|
||||||
|
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<OfflineWhisperDecoderResult> ans(1);
|
||||||
|
ans[0].tokens = std::move(predicted_tokens);
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
29
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
Normal file
29
sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
||||||
|
public:
|
||||||
|
explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)
|
||||||
|
: model_(model) {}
|
||||||
|
|
||||||
|
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
||||||
|
Ort::Value cross_v) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineWhisperModel *model_; // not owned
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_
|
||||||
46
sherpa-onnx/csrc/offline-whisper-model-config.cc
Normal file
46
sherpa-onnx/csrc/offline-whisper-model-config.cc
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-whisper-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("whisper-encoder", &encoder,
|
||||||
|
"Path to onnx encoder of whisper, e.g., tiny-encoder.onnx, "
|
||||||
|
"medium.en-encoder.onnx.");
|
||||||
|
|
||||||
|
po->Register("whisper-decoder", &decoder,
|
||||||
|
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
|
||||||
|
"medium.en-decoder.onnx.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OfflineWhisperModelConfig::Validate() const {
|
||||||
|
if (!FileExists(encoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(decoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineWhisperModelConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OfflineWhisperModelConfig(";
|
||||||
|
os << "encoder=\"" << encoder << "\", ";
|
||||||
|
os << "decoder=\"" << decoder << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
30
sherpa-onnx/csrc/offline-whisper-model-config.h
Normal file
30
sherpa-onnx/csrc/offline-whisper-model-config.h
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-whisper-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineWhisperModelConfig {
|
||||||
|
std::string encoder;
|
||||||
|
std::string decoder;
|
||||||
|
|
||||||
|
OfflineWhisperModelConfig() = default;
|
||||||
|
OfflineWhisperModelConfig(const std::string &encoder,
|
||||||
|
const std::string &decoder)
|
||||||
|
: encoder(encoder), decoder(decoder) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||||
213
sherpa-onnx/csrc/offline-whisper-model.cc
Normal file
213
sherpa-onnx/csrc/offline-whisper-model.cc
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-whisper-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineWhisperModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OfflineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.whisper.encoder);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.whisper.decoder);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) {
|
||||||
|
auto encoder_out = encoder_sess_->Run(
|
||||||
|
{}, encoder_input_names_ptr_.data(), &features, 1,
|
||||||
|
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return {std::move(encoder_out[0]), std::move(encoder_out[1])};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||||
|
Ort::Value>
|
||||||
|
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||||
|
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||||
|
Ort::Value n_layer_cross_v, Ort::Value offset) {
|
||||||
|
std::array<Ort::Value, 6> decoder_input = {std::move(tokens),
|
||||||
|
std::move(n_layer_self_k_cache),
|
||||||
|
std::move(n_layer_self_v_cache),
|
||||||
|
std::move(n_layer_cross_k),
|
||||||
|
std::move(n_layer_cross_v),
|
||||||
|
std::move(offset)};
|
||||||
|
|
||||||
|
auto decoder_out = decoder_sess_->Run(
|
||||||
|
{}, decoder_input_names_ptr_.data(), decoder_input.data(),
|
||||||
|
decoder_input.size(), decoder_output_names_ptr_.data(),
|
||||||
|
decoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return {std::move(decoder_out[0]), std::move(decoder_out[1]),
|
||||||
|
std::move(decoder_out[2]), std::move(decoder_input[3]),
|
||||||
|
std::move(decoder_input[4]), std::move(decoder_input[5])};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
|
||||||
|
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
|
||||||
|
|
||||||
|
Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor<float>(
|
||||||
|
Allocator(), shape.data(), shape.size());
|
||||||
|
|
||||||
|
Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor<float>(
|
||||||
|
Allocator(), shape.data(), shape.size());
|
||||||
|
|
||||||
|
auto n = shape[0] * shape[1] * shape[2] * shape[3];
|
||||||
|
|
||||||
|
float *p_k = n_layer_self_k_cache.GetTensorMutableData<float>();
|
||||||
|
float *p_v = n_layer_self_v_cache.GetTensorMutableData<float>();
|
||||||
|
|
||||||
|
memset(p_k, 0, sizeof(float) * n);
|
||||||
|
memset(p_v, 0, sizeof(float) * n);
|
||||||
|
|
||||||
|
return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)};
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
|
||||||
|
|
||||||
|
int32_t EOT() const { return eot_; }
|
||||||
|
|
||||||
|
int32_t TextCtx() const { return n_text_ctx_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||||
|
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||||
|
&encoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||||
|
&encoder_output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "---encoder---\n";
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(sot_, "sot");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(eot_, "eot");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(translate_, "translate");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||||
|
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||||
|
&decoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||||
|
&decoder_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_input_names_;
|
||||||
|
std::vector<const char *> encoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_output_names_;
|
||||||
|
std::vector<const char *> encoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_input_names_;
|
||||||
|
std::vector<const char *> decoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_output_names_;
|
||||||
|
std::vector<const char *> decoder_output_names_ptr_;
|
||||||
|
|
||||||
|
// model meta data
|
||||||
|
int32_t n_text_layer_;
|
||||||
|
int32_t n_text_ctx_;
|
||||||
|
int32_t n_text_state_;
|
||||||
|
int32_t sot_;
|
||||||
|
int32_t eot_;
|
||||||
|
int32_t blank_;
|
||||||
|
int32_t translate_;
|
||||||
|
int32_t no_timestamps_;
|
||||||
|
int32_t no_speech_;
|
||||||
|
std::vector<int64_t> sot_sequence_;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
OfflineWhisperModel::~OfflineWhisperModel() = default;
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
|
||||||
|
Ort::Value features) {
|
||||||
|
return impl_->ForwardEncoder(std::move(features));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||||
|
Ort::Value>
|
||||||
|
OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
|
||||||
|
Ort::Value n_layer_self_k_cache,
|
||||||
|
Ort::Value n_layer_self_v_cache,
|
||||||
|
Ort::Value n_layer_cross_k,
|
||||||
|
Ort::Value n_layer_cross_v,
|
||||||
|
Ort::Value offset) {
|
||||||
|
return impl_->ForwardDecoder(
|
||||||
|
std::move(tokens), std::move(n_layer_self_k_cache),
|
||||||
|
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
|
||||||
|
std::move(n_layer_cross_v), std::move(offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() {
|
||||||
|
return impl_->GetInitialSelfKVCache();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *OfflineWhisperModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
|
||||||
|
return impl_->GetInitialTokens();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
|
||||||
|
|
||||||
|
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
85
sherpa-onnx/csrc/offline-whisper-model.h
Normal file
85
sherpa-onnx/csrc/offline-whisper-model.h
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-whisper-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineWhisperModel {
|
||||||
|
public:
|
||||||
|
explicit OfflineWhisperModel(const OfflineModelConfig &config);
|
||||||
|
~OfflineWhisperModel();
|
||||||
|
|
||||||
|
/** Run the encoder model.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, C, T). It is changed in-place.
|
||||||
|
* C is 80 and T is 3000.
|
||||||
|
*
|
||||||
|
* @return Return a pair containing:
|
||||||
|
* - n_layer_cross_k: A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_audio_ctx, n_text_state)
|
||||||
|
* - n_layer_cross_v: A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_audio_ctx, n_text_state)
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
|
||||||
|
|
||||||
|
/** Run the decoder model.
|
||||||
|
*
|
||||||
|
* @param tokens A int64 tensor of shape (N, num_words)
|
||||||
|
* @param n_layer_self_k_cache A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_text_ctx, n_text_state).
|
||||||
|
* @param n_layer_self_v_cache A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_text_ctx, n_text_state).
|
||||||
|
* @param n_layer_cross_k A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||||
|
* @param n_layer_cross_v A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||||
|
* @param offset A int64 tensor of shape (N,)
|
||||||
|
*
|
||||||
|
* @return Return a tuple containing 6 tensors:
|
||||||
|
*
|
||||||
|
* - logits A 3-D tensor of shape (N, num_words, vocab_size)
|
||||||
|
* - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache
|
||||||
|
* - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache
|
||||||
|
* - out_n_layer_cross_k Same as n_layer_cross_k
|
||||||
|
* - out_n_layer_cross_v Same as n_layer_cross_v
|
||||||
|
* - out_offset Same as offset
|
||||||
|
*/
|
||||||
|
std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value, Ort::Value,
|
||||||
|
Ort::Value>
|
||||||
|
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||||
|
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||||
|
Ort::Value n_layer_cross_v, Ort::Value offset);
|
||||||
|
|
||||||
|
/** Return the initial self kv cache in a pair
|
||||||
|
* - n_layer_self_k_cache A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||||
|
* - n_layer_self_v_cache A 4-D tensor of shape
|
||||||
|
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
|
||||||
|
const std::vector<int64_t> &GetInitialTokens() const;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const;
|
||||||
|
int32_t EOT() const;
|
||||||
|
int32_t TextCtx() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||||
@@ -98,11 +98,15 @@ Usage:
|
|||||||
./bin/sherpa-onnx-microphone-offline \
|
./bin/sherpa-onnx-microphone-offline \
|
||||||
--tokens=/path/to/tokens.txt \
|
--tokens=/path/to/tokens.txt \
|
||||||
--paraformer=/path/to/model.onnx \
|
--paraformer=/path/to/model.onnx \
|
||||||
--num-threads=2 \
|
--num-threads=1
|
||||||
--decoding-method=greedy_search
|
|
||||||
|
|
||||||
Default value for num_threads is 2.
|
(3) Whisper models
|
||||||
Valid values for decoding_method: greedy_search.
|
|
||||||
|
./bin/sherpa-onnx-microphone-offline \
|
||||||
|
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||||
|
--num-threads=1
|
||||||
|
|
||||||
Please refer to
|
Please refer to
|
||||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ Usage:
|
|||||||
--encoder=/path/to/encoder.onnx \
|
--encoder=/path/to/encoder.onnx \
|
||||||
--decoder=/path/to/decoder.onnx \
|
--decoder=/path/to/decoder.onnx \
|
||||||
--joiner=/path/to/joiner.onnx \
|
--joiner=/path/to/joiner.onnx \
|
||||||
--num-threads=2 \
|
--num-threads=1 \
|
||||||
--decoding-method=greedy_search \
|
--decoding-method=greedy_search \
|
||||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||||
|
|
||||||
@@ -33,14 +33,22 @@ Usage:
|
|||||||
./bin/sherpa-onnx-offline \
|
./bin/sherpa-onnx-offline \
|
||||||
--tokens=/path/to/tokens.txt \
|
--tokens=/path/to/tokens.txt \
|
||||||
--paraformer=/path/to/model.onnx \
|
--paraformer=/path/to/model.onnx \
|
||||||
--num-threads=2 \
|
--num-threads=1 \
|
||||||
--decoding-method=greedy_search \
|
--decoding-method=greedy_search \
|
||||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||||
|
|
||||||
|
(3) Whisper models
|
||||||
|
|
||||||
|
./bin/sherpa-onnx-offline \
|
||||||
|
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||||
|
--num-threads=1 \
|
||||||
|
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||||
|
|
||||||
|
|
||||||
Note: It supports decoding multiple files in batches
|
Note: It supports decoding multiple files in batches
|
||||||
|
|
||||||
Default value for num_threads is 2.
|
|
||||||
Valid values for decoding_method: greedy_search.
|
|
||||||
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||||
sampling rate can be arbitrary and does not need to be 16kHz.
|
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||||
|
|
||||||
@@ -55,6 +63,7 @@ for a list of pre-trained models to download.
|
|||||||
|
|
||||||
po.Read(argc, argv);
|
po.Read(argc, argv);
|
||||||
if (po.NumArgs() < 1) {
|
if (po.NumArgs() < 1) {
|
||||||
|
fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n");
|
||||||
po.PrintUsage();
|
po.PrintUsage();
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <strstream>
|
#include <strstream>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/base64-decode.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
@@ -82,4 +83,12 @@ std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
|
|||||||
return os << symbol_table.ToString();
|
return os << symbol_table.ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SymbolTable::ApplyBase64Decode() {
|
||||||
|
sym2id_.clear();
|
||||||
|
for (auto &p : id2sym_) {
|
||||||
|
p.second = Base64Decode(p.second);
|
||||||
|
sym2id_[p.second] = p.first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ class SymbolTable {
|
|||||||
/// Return true if there is a given symbol in the symbol table.
|
/// Return true if there is a given symbol in the symbol table.
|
||||||
bool contains(const std::string &sym) const;
|
bool contains(const std::string &sym) const;
|
||||||
|
|
||||||
|
// for tokens.txt from Whisper
|
||||||
|
void ApplyBase64Decode();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Init(std::istream &is);
|
void Init(std::istream &is);
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx
|
|||||||
offline-recognizer.cc
|
offline-recognizer.cc
|
||||||
offline-stream.cc
|
offline-stream.cc
|
||||||
offline-transducer-model-config.cc
|
offline-transducer-model-config.cc
|
||||||
|
offline-whisper-model-config.cc
|
||||||
online-lm-config.cc
|
online-lm-config.cc
|
||||||
online-recognizer.cc
|
online-recognizer.cc
|
||||||
online-stream.cc
|
online-stream.cc
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -18,22 +19,25 @@ void PybindOfflineModelConfig(py::module *m) {
|
|||||||
PybindOfflineTransducerModelConfig(m);
|
PybindOfflineTransducerModelConfig(m);
|
||||||
PybindOfflineParaformerModelConfig(m);
|
PybindOfflineParaformerModelConfig(m);
|
||||||
PybindOfflineNemoEncDecCtcModelConfig(m);
|
PybindOfflineNemoEncDecCtcModelConfig(m);
|
||||||
|
PybindOfflineWhisperModelConfig(m);
|
||||||
|
|
||||||
using PyClass = OfflineModelConfig;
|
using PyClass = OfflineModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||||
.def(
|
.def(py::init<const OfflineTransducerModelConfig &,
|
||||||
py::init<const OfflineTransducerModelConfig &,
|
const OfflineParaformerModelConfig &,
|
||||||
const OfflineParaformerModelConfig &,
|
const OfflineNemoEncDecCtcModelConfig &,
|
||||||
const OfflineNemoEncDecCtcModelConfig &, const std::string &,
|
const OfflineWhisperModelConfig &, const std::string &,
|
||||||
int32_t, bool, const std::string &, const std::string &>(),
|
int32_t, bool, const std::string &, const std::string &>(),
|
||||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"),
|
||||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
py::arg("num_threads"), py::arg("debug") = false,
|
||||||
|
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||||
.def_readwrite("transducer", &PyClass::transducer)
|
.def_readwrite("transducer", &PyClass::transducer)
|
||||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||||
|
.def_readwrite("whisper", &PyClass::whisper)
|
||||||
.def_readwrite("tokens", &PyClass::tokens)
|
.def_readwrite("tokens", &PyClass::tokens)
|
||||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
.def_readwrite("debug", &PyClass::debug)
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
|||||||
24
sherpa-onnx/python/csrc/offline-whisper-model-config.cc
Normal file
24
sherpa-onnx/python/csrc/offline-whisper-model-config.cc
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-whisper-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineWhisperModelConfig(py::module *m) {
|
||||||
|
using PyClass = OfflineWhisperModelConfig;
|
||||||
|
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
|
||||||
|
.def(py::init<const std::string &, const std::string &>(),
|
||||||
|
py::arg("encoder"), py::arg("decoder"))
|
||||||
|
.def_readwrite("encoder", &PyClass::encoder)
|
||||||
|
.def_readwrite("decoder", &PyClass::decoder)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/offline-whisper-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-whisper-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-whisper-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineWhisperModelConfig(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2023 by manyeyes
|
# Copyright (c) 2023 by manyeyes
|
||||||
|
# Copyright (c) 2023 Xiaomi Corporation
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
@@ -7,6 +8,7 @@ from _sherpa_onnx import (
|
|||||||
OfflineModelConfig,
|
OfflineModelConfig,
|
||||||
OfflineNemoEncDecCtcModelConfig,
|
OfflineNemoEncDecCtcModelConfig,
|
||||||
OfflineParaformerModelConfig,
|
OfflineParaformerModelConfig,
|
||||||
|
OfflineWhisperModelConfig,
|
||||||
)
|
)
|
||||||
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
||||||
from _sherpa_onnx import (
|
from _sherpa_onnx import (
|
||||||
@@ -69,7 +71,7 @@ class OfflineRecognizer(object):
|
|||||||
feature_dim:
|
feature_dim:
|
||||||
Dimension of the feature used to train the model.
|
Dimension of the feature used to train the model.
|
||||||
decoding_method:
|
decoding_method:
|
||||||
Support only greedy_search for now.
|
Valid values: greedy_search, modified_beam_search.
|
||||||
debug:
|
debug:
|
||||||
True to show debug messages.
|
True to show debug messages.
|
||||||
provider:
|
provider:
|
||||||
@@ -137,7 +139,7 @@ class OfflineRecognizer(object):
|
|||||||
feature_dim:
|
feature_dim:
|
||||||
Dimension of the feature used to train the model.
|
Dimension of the feature used to train the model.
|
||||||
decoding_method:
|
decoding_method:
|
||||||
Valid values are greedy_search, modified_beam_search.
|
Valid values are greedy_search.
|
||||||
debug:
|
debug:
|
||||||
True to show debug messages.
|
True to show debug messages.
|
||||||
provider:
|
provider:
|
||||||
@@ -185,14 +187,14 @@ class OfflineRecognizer(object):
|
|||||||
English, etc.
|
English, etc.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
model:
|
||||||
|
Path to ``model.onnx``.
|
||||||
tokens:
|
tokens:
|
||||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||||
columns::
|
columns::
|
||||||
|
|
||||||
symbol integer_id
|
symbol integer_id
|
||||||
|
|
||||||
model:
|
|
||||||
Path to ``model.onnx``.
|
|
||||||
num_threads:
|
num_threads:
|
||||||
Number of threads for neural network computation.
|
Number of threads for neural network computation.
|
||||||
sample_rate:
|
sample_rate:
|
||||||
@@ -200,7 +202,7 @@ class OfflineRecognizer(object):
|
|||||||
feature_dim:
|
feature_dim:
|
||||||
Dimension of the feature used to train the model.
|
Dimension of the feature used to train the model.
|
||||||
decoding_method:
|
decoding_method:
|
||||||
Valid values are greedy_search, modified_beam_search.
|
Valid values are greedy_search.
|
||||||
debug:
|
debug:
|
||||||
True to show debug messages.
|
True to show debug messages.
|
||||||
provider:
|
provider:
|
||||||
@@ -229,6 +231,68 @@ class OfflineRecognizer(object):
|
|||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_whisper(
|
||||||
|
cls,
|
||||||
|
encoder: str,
|
||||||
|
decoder: str,
|
||||||
|
tokens: str,
|
||||||
|
num_threads: int,
|
||||||
|
decoding_method: str = "greedy_search",
|
||||||
|
debug: bool = False,
|
||||||
|
provider: str = "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Please refer to
|
||||||
|
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
||||||
|
to download pre-trained models for different kinds of whisper models,
|
||||||
|
e.g., tiny, tiny.en, base, base.en, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
Path to the encoder model, e.g., tiny-encoder.onnx,
|
||||||
|
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
|
||||||
|
decoder_model:
|
||||||
|
Path to the encoder model, e.g., tiny-encoder.onnx,
|
||||||
|
tiny-encoder.int8.onnx, tiny-encoder.ort, etc.
|
||||||
|
tokens:
|
||||||
|
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||||
|
columns::
|
||||||
|
|
||||||
|
symbol integer_id
|
||||||
|
|
||||||
|
num_threads:
|
||||||
|
Number of threads for neural network computation.
|
||||||
|
decoding_method:
|
||||||
|
Valid values: greedy_search.
|
||||||
|
debug:
|
||||||
|
True to show debug messages.
|
||||||
|
provider:
|
||||||
|
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||||
|
"""
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
model_config = OfflineModelConfig(
|
||||||
|
whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder),
|
||||||
|
tokens=tokens,
|
||||||
|
num_threads=num_threads,
|
||||||
|
debug=debug,
|
||||||
|
provider=provider,
|
||||||
|
model_type="whisper",
|
||||||
|
)
|
||||||
|
|
||||||
|
feat_config = OfflineFeatureExtractorConfig(
|
||||||
|
sampling_rate=16000,
|
||||||
|
feature_dim=80,
|
||||||
|
)
|
||||||
|
|
||||||
|
recognizer_config = OfflineRecognizerConfig(
|
||||||
|
feat_config=feat_config,
|
||||||
|
model_config=model_config,
|
||||||
|
decoding_method=decoding_method,
|
||||||
|
)
|
||||||
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
return self
|
||||||
|
|
||||||
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||||
if contexts_list is None:
|
if contexts_list is None:
|
||||||
return self.recognizer.create_stream()
|
return self.recognizer.create_stream()
|
||||||
|
|||||||
Reference in New Issue
Block a user