diff --git a/.github/workflows/build-wheels-macos.yaml b/.github/workflows/build-wheels-macos.yaml index 40450781..100c518e 100644 --- a/.github/workflows/build-wheels-macos.yaml +++ b/.github/workflows/build-wheels-macos.yaml @@ -36,6 +36,9 @@ jobs: CIBW_ARCHS: "universal2" CIBW_BUILD_VERBOSITY: 3 + # Don't repair macOS wheels + CIBW_REPAIR_WHEEL_COMMAND_MACOS: "" + - name: Display wheels shell: bash run: | diff --git a/.github/workflows/export-whisper-to-onnx.yaml b/.github/workflows/export-whisper-to-onnx.yaml index 476bff52..4d6f9531 100644 --- a/.github/workflows/export-whisper-to-onnx.yaml +++ b/.github/workflows/export-whisper-to-onnx.yaml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [macos-latest] - model: ["tiny.en", "base.en", "small.en", "medium.en"] + model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"] steps: - uses: actions/checkout@v2 diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e209969..9ca4fbeb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.7.6") +set(SHERPA_ONNX_VERSION "1.7.7") # Disable warning about # diff --git a/go-api-examples/non-streaming-decode-files/go.mod b/go-api-examples/non-streaming-decode-files/go.mod index bd52f1ea..516e08b1 100644 --- a/go-api-examples/non-streaming-decode-files/go.mod +++ b/go-api-examples/non-streaming-decode-files/go.mod @@ -3,7 +3,7 @@ module non-streaming-decode-files go 1.12 require ( - github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 + github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 github.com/spf13/pflag v1.0.5 github.com/youpy/go-wav v0.3.2 ) diff --git a/go-api-examples/non-streaming-decode-files/go.sum b/go-api-examples/non-streaming-decode-files/go.sum index a680aa92..a1565ae5 100644 --- a/go-api-examples/non-streaming-decode-files/go.sum +++ b/go-api-examples/non-streaming-decode-files/go.sum @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco= -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw= -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk= -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ= -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0= -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ= +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/go-api-examples/real-time-speech-recognition-from-microphone/go.mod b/go-api-examples/real-time-speech-recognition-from-microphone/go.mod index 98e5b36f..b10d2e39 100644 --- a/go-api-examples/real-time-speech-recognition-from-microphone/go.mod +++ b/go-api-examples/real-time-speech-recognition-from-microphone/go.mod @@ -4,6 +4,6 @@ go 1.12 require ( github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 - github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 + github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 github.com/spf13/pflag v1.0.5 ) diff --git a/go-api-examples/real-time-speech-recognition-from-microphone/go.sum b/go-api-examples/real-time-speech-recognition-from-microphone/go.sum index 5f9677e6..a7332280 100644 --- a/go-api-examples/real-time-speech-recognition-from-microphone/go.sum +++ b/go-api-examples/real-time-speech-recognition-from-microphone/go.sum @@ -1,12 +1,12 @@ github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc= github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es= -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco= -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw= -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk= -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ= -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0= -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ= +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= diff --git a/go-api-examples/streaming-decode-files/go.mod b/go-api-examples/streaming-decode-files/go.mod index 80dce05c..278520e9 100644 --- a/go-api-examples/streaming-decode-files/go.mod +++ b/go-api-examples/streaming-decode-files/go.mod @@ -3,7 +3,7 @@ module streaming-decode-files go 1.12 require ( - github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 + github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 github.com/spf13/pflag v1.0.5 github.com/youpy/go-wav v0.3.2 ) diff --git a/go-api-examples/streaming-decode-files/go.sum b/go-api-examples/streaming-decode-files/go.sum index a680aa92..a1565ae5 100644 --- a/go-api-examples/streaming-decode-files/go.sum +++ b/go-api-examples/streaming-decode-files/go.sum @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco= -github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw= -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk= -github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ= -github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0= -github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ= +github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/kotlin-api-examples/Main.kt b/kotlin-api-examples/Main.kt index e62ad488..811b3d68 100644 --- a/kotlin-api-examples/Main.kt +++ b/kotlin-api-examples/Main.kt @@ -11,10 +11,12 @@ fun main() { // please refer to // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html // to dowload pre-trained models - var modelConfig = OnlineTransducerModelConfig( - encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", - decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", - joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", + var modelConfig = OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", + decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", + joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", + ), tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", numThreads = 1, debug = false, @@ -41,19 +43,19 @@ fun main() { var objArray = WaveReader.readWaveFromFile( filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", ) - var samples : FloatArray = objArray[0] as FloatArray - var sampleRate : Int = objArray[1] as Int + var samples: FloatArray = objArray[0] as FloatArray + var sampleRate: Int = objArray[1] as Int - model.acceptWaveform(samples, sampleRate=sampleRate) + model.acceptWaveform(samples, sampleRate = sampleRate) while (model.isReady()) { - model.decode() + model.decode() } var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds - model.acceptWaveform(tailPaddings, sampleRate=sampleRate) + model.acceptWaveform(tailPaddings, sampleRate = sampleRate) model.inputFinished() while (model.isReady()) { - model.decode() + model.decode() } println("results: ${model.text}") diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index cbfaa760..7aef58d7 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser): help="Path to whisper decoder model", ) + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + def add_model_args(parser: argparse.ArgumentParser): add_transducer_model_args(parser) @@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: tokens=args.tokens, num_threads=args.num_threads, decoding_method=args.decoding_method, + language=args.whisper_language, + task=args.whisper_task, ) elif args.tdnn_model: assert_file_exists(args.tdnn_model) diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index f3f7949c..c53e1048 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \ --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --whisper-task=transcribe \ --num-threads=1 \ ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ @@ -200,6 +201,28 @@ def get_args(): help="Path to whisper decoder model", ) + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + parser.add_argument( "--decoding-method", type=str, @@ -371,10 +394,10 @@ def main(): decoder=args.whisper_decoder, tokens=args.tokens, num_threads=args.num_threads, - sample_rate=args.sample_rate, - feature_dim=args.feature_dim, decoding_method=args.decoding_method, debug=args.debug, + language=args.whisper_language, + task=args.whisper_task, ) elif args.tdnn_model: assert_file_exists(args.tdnn_model) diff --git a/scripts/whisper/export-onnx.py b/scripts/whisper/export-onnx.py index 1cdbaaf0..fbb0b132 100755 --- a/scripts/whisper/export-onnx.py +++ b/scripts/whisper/export-onnx.py @@ -11,6 +11,7 @@ for making the onnx export script public. """ import argparse +import os from pathlib import Path from typing import Any, Dict, Optional @@ -250,6 +251,7 @@ def main(): # write tokens tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) + model.eval() print(model.dims) audio = torch.rand(16000 * 2) @@ -306,8 +308,12 @@ def main(): "n_text_head": model.dims.n_text_head, "n_text_layer": model.dims.n_text_layer, "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))), - "all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))), - "all_language_codes": ",".join(tokenizer.all_language_codes), + "all_language_tokens": ",".join( + list(map(str, tokenizer.all_language_tokens)) + ), # a list of ids + "all_language_codes": ",".join( + tokenizer.all_language_codes + ), # e.g., en, de, zh, fr "sot": tokenizer.sot, "sot_index": tokenizer.sot_sequence.index(tokenizer.sot), "eot": tokenizer.eot, @@ -413,6 +419,9 @@ def main(): }, ) + if 'large' in args.model: + # it causes errors for large models, so skip it. + return # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection diff --git a/scripts/whisper/test.py b/scripts/whisper/test.py index 38a5a8a3..03e5e32c 100755 --- a/scripts/whisper/test.py +++ b/scripts/whisper/test.py @@ -38,6 +38,24 @@ def get_args(): help="Path to the tokens", ) + parser.add_argument( + "--language", + type=str, + help="""The actual spoken language in the audio. + Example values, en, de, zh, jp, fr. + If None, we will detect the language using the first 30s of the + input audio + """, + ) + + parser.add_argument( + "--task", + choices=["transcribe", "translate"], + type=str, + default="transcribe", + help="Valid values are: transcribe, translate", + ) + parser.add_argument( "sound_file", type=str, @@ -74,12 +92,22 @@ class OnnxModel: self.sot = int(meta["sot"]) self.eot = int(meta["eot"]) self.translate = int(meta["translate"]) + self.transcribe = int(meta["transcribe"]) self.no_timestamps = int(meta["no_timestamps"]) self.no_speech = int(meta["no_speech"]) self.blank = int(meta["blank_id"]) self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) + self.sot_sequence.append(self.no_timestamps) + + self.all_language_tokens = list( + map(int, meta["all_language_tokens"].split(",")) + ) + self.all_language_codes = meta["all_language_codes"].split(",") + self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens)) + self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes)) + self.is_multilingual = int(meta["is_multilingual"]) == 1 def init_decoder(self, decoder: str): @@ -164,6 +192,29 @@ class OnnxModel: # logits is changed in-place logits[self.translate] = float("-inf") + def detect_language( + self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor + ) -> int: + tokens = torch.tensor([[self.sot]], dtype=torch.int64) + offset = torch.zeros(1, dtype=torch.int64) + n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache() + + logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder( + tokens=tokens, + n_layer_self_k_cache=n_layer_self_k_cache, + n_layer_self_v_cache=n_layer_self_v_cache, + n_layer_cross_k=n_layer_cross_k, + n_layer_cross_v=n_layer_cross_v, + offset=offset, + ) + logits = logits.reshape(-1) + mask = torch.ones(logits.shape[0], dtype=torch.int64) + mask[self.all_language_tokens] = 0 + logits[mask] = float("-inf") + lang_id = logits.argmax().item() + print("detected language: ", self.id2lang[lang_id]) + return lang_id + def load_tokens(filename): tokens = dict() @@ -200,7 +251,35 @@ def main(): mel = mel.t().unsqueeze(0) model = OnnxModel(encoder, decoder) + n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) + + if args.language is not None: + if model.is_multilingual is False and args.language != "en": + print(f"This model supports only English. Given: {args.language}") + return + + if args.language not in model.lang2id: + print(f"Invalid language: {args.language}") + print(f"Valid values are: {list(model.lang2id.keys())}") + return + + # [sot, lang, task, notimestamps] + model.sot_sequence[1] = model.lang2id[args.language] + elif model.is_multilingual is True: + print("detecting language") + lang = model.detect_language(n_layer_cross_k, n_layer_cross_v) + model.sot_sequence[1] = lang + + if args.task is not None: + if model.is_multilingual is False and args.task != "transcribe": + print("This model supports only English. Please use --task=transcribe") + return + assert args.task in ["transcribe", "translate"], args.task + + if args.task == "translate": + model.sot_sequence[2] = model.translate + n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) @@ -213,6 +292,7 @@ def main(): n_layer_cross_v=n_layer_cross_v, offset=offset, ) + offset += len(model.sot_sequence) # logits.shape (batch_size, tokens.shape[1], vocab_size) logits = logits[0, -1] model.suppress_tokens(logits, is_initial=True) @@ -225,7 +305,6 @@ def main(): break results.append(max_token_id.item()) tokens = torch.tensor([[results[-1]]]) - offset += 1 logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( tokens=tokens, @@ -235,6 +314,7 @@ def main(): n_layer_cross_v=n_layer_cross_v, offset=offset, ) + offset += 1 logits = logits[0, -1] model.suppress_tokens(logits, is_initial=False) max_token_id = logits.argmax(dim=-1) diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index 93e0accd..fae92dc6 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -37,7 +37,7 @@ } \ \ dst = atoi(value.get()); \ - if (dst <= 0) { \ + if (dst < 0) { \ SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ exit(-1); \ } \ @@ -77,6 +77,24 @@ } \ } while (0) +// read a vector of strings +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + SplitStringToVector(value.get(), ",", false, &dst); \ + \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \ + src_key); \ + exit(-1); \ + } \ + } while (0) + // Read a string #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ do { \ diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index efe9da28..08b2a5d5 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -23,21 +23,227 @@ namespace sherpa_onnx { +static std::string FixInvalidUtf8(const std::string &s) { + int32_t s_size = s.size(); + + std::string ans; + ans.reserve(s_size); + + for (int32_t i = 0; i < s_size;) { + uint8_t c = s[i]; + if (c < 0x80) { + // valid + ans.append(1, c); + ++i; + continue; + } else if ((c >= 0xc0) && (c < 0xe0)) { + // beginning of two bytes + if ((i + 1) > (s_size - 1)) { + // no subsequent byte. invalid! + i += 1; + continue; + } + uint8_t next = s[i + 1]; + if (!(next >= 0x80 && next < 0xc0)) { + // invalid + i += 1; + continue; + } + // valid 2-byte utf-8 + ans.append(1, c); + ans.append(1, next); + i += 2; + continue; + } else if ((c >= 0xe0) && (c < 0xf0)) { + // beginning of 3 bytes + if ((i + 2) > (s_size - 1)) { + // invalid + i += 1; + continue; + } + + uint8_t next = s[i + 1]; + if (!(next >= 0x80 && next < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next2 = s[i + 2]; + if (!(next2 >= 0x80 && next2 < 0xc0)) { + // invalid + i += 1; + continue; + } + + ans.append(1, c); + ans.append(1, next); + ans.append(1, next2); + i += 3; + continue; + } else if ((c >= 0xf0) && (c < 0xf8)) { + // 4 bytes + if ((i + 3) > (s_size - 1)) { + // invalid + i += 1; + continue; + } + + uint8_t next = s[i + 1]; + if (!(next >= 0x80 && next < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next2 = s[i + 2]; + if (!(next2 >= 0x80 && next2 < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next3 = s[i + 3]; + if (!(next3 >= 0x80 && next3 < 0xc0)) { + // invalid + i += 1; + continue; + } + ans.append(1, c); + ans.append(1, next); + ans.append(1, next2); + ans.append(1, next3); + i += 4; + continue; + } else if ((c >= 0xf8) && (c < 0xfc)) { + // 5 bytes + if ((i + 4) > (s_size - 1)) { + // invalid + i += 1; + continue; + } + + uint8_t next = s[i + 1]; + if (!(next >= 0x80 && next < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next2 = s[i + 2]; + if (!(next2 >= 0x80 && next2 < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next3 = s[i + 3]; + if (!(next3 >= 0x80 && next3 < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next4 = s[i + 4]; + if (!(next4 >= 0x80 && next4 < 0xc0)) { + // invalid + i += 1; + continue; + } + ans.append(1, c); + ans.append(1, next); + ans.append(1, next2); + ans.append(1, next3); + ans.append(1, next4); + i += 5; + continue; + } else if ((c >= 0xfc) && (c < 0xfe)) { + // 6 bytes + if ((i + 5) > (s_size - 1)) { + // invalid + i += 1; + continue; + } + + uint8_t next = s[i + 1]; + if (!(next >= 0x80 && next < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next2 = s[i + 2]; + if (!(next2 >= 0x80 && next2 < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next3 = s[i + 3]; + if (!(next3 >= 0x80 && next3 < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next4 = s[i + 4]; + if (!(next4 >= 0x80 && next4 < 0xc0)) { + // invalid + i += 1; + continue; + } + + uint8_t next5 = s[i + 5]; + if (!(next5 >= 0x80 && next5 < 0xc0)) { + // invalid + i += 1; + continue; + } + ans.append(1, c); + ans.append(1, next); + ans.append(1, next2); + ans.append(1, next3); + ans.append(1, next4); + ans.append(1, next5); + i += 6; + continue; + } else { + i += 1; + } + } + return ans; +} + static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, const SymbolTable &sym_table) { OfflineRecognitionResult r; r.tokens.reserve(src.tokens.size()); + std::string text; for (auto i : src.tokens) { if (!sym_table.contains(i)) { continue; } const auto &s = sym_table[i]; - r.text += s; + text += s; r.tokens.push_back(s); } + // TODO(fangjun): Fix the following error in offline-stream.cc + // + // j["text"] = text; + + // libc++abi: terminating with uncaught exception of type + // nlohmann::json_abi_v3_11_2::detail::type_error: + // [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86 + +#if 0 + r.text = FixInvalidUtf8(text); +#else + r.text = text; +#endif + return r; } @@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { symbol_table_.ApplyBase64Decode(); if (config.decoding_method == "greedy_search") { - decoder_ = - std::make_unique(model_.get()); + decoder_ = std::make_unique( + config_.model_config.whisper, model_.get()); } else { SHERPA_ONNX_LOGE( "Only greedy_search is supported at present for whisper. Given %s", @@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { mel = Transpose12(model_->Allocator(), &mel); auto cross_kv = model_->ForwardEncoder(std::move(mel)); + auto results = decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc index 1b221300..036fab5b 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc @@ -7,17 +7,106 @@ #include #include +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + namespace sherpa_onnx { +int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage( + Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT + int64_t token_val = model_->SOT(); + std::array token_shape{1, 1}; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + Ort::Value tokens = Ort::Value::CreateTensor( + memory_info, &token_val, 1, token_shape.data(), token_shape.size()); + + auto self_kv_cache = model_->GetInitialSelfKVCache(); + + std::array offset_shape{1}; + Ort::Value offset = Ort::Value::CreateTensor( + model_->Allocator(), offset_shape.data(), offset_shape.size()); + *(offset.GetTensorMutableData()) = 0; + + auto decoder_out = model_->ForwardDecoder( + std::move(tokens), std::move(self_kv_cache.first), + std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), + std::move(offset)); + + cross_k = std::move(std::get<3>(decoder_out)); + cross_v = std::move(std::get<4>(decoder_out)); + + const float *p_logits = std::get<0>(decoder_out).GetTensorData(); + int32_t vocab_size = model_->VocabSize(); + const auto &all_language_ids = model_->GetAllLanguageIDs(); + + int32_t lang_id = all_language_ids[0]; + float this_logit = p_logits[lang_id]; + + for (int32_t i = 1; i != all_language_ids.size(); ++i) { + int32_t id = all_language_ids[i]; + float p = p_logits[id]; + + if (p > this_logit) { + this_logit = p; + lang_id = id; + } + } +#if 1 + SHERPA_ONNX_LOGE("Detected language: %s", + model_->GetID2Lang().at(lang_id).c_str()); +#endif + + return lang_id; +} + std::vector OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, Ort::Value cross_v) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - auto self_kv_cache = model_->GetInitialSelfKVCache(); - + // For multilingual models, initial_tokens contains [sot, language, task] + // - language is English by default + // - task is transcribe by default + // + // For non-multilingual models, initial_tokens contains [sot] std::vector initial_tokens = model_->GetInitialTokens(); + + if (model_->IsMultiLingual()) { + if (!config_.language.empty()) { + const auto &lang2id = model_->GetLang2ID(); + + if (!lang2id.count(config_.language)) { + SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str()); + exit(-1); + } + + int32_t lang_id = lang2id.at(config_.language); + + // 0: sot, 1: lang_id, 2: task, 3: no_timestamps + initial_tokens[1] = lang_id; + } else { + int32_t lang_id = DetectLanguage(cross_k, cross_v); + + // 0: sot, 1: lang_id, 2: task, 3: no_timestamps + initial_tokens[1] = lang_id; + } + + if (config_.task == "translate") { + initial_tokens[2] = model_->Translate(); + } else if (config_.task != "transcribe") { + // initial_tokens[2] is transcribe by default + SHERPA_ONNX_LOGE( + "Unsupported task: %s. Valid values are: transcribe, translate.", + config_.task.c_str()); + } + } + + initial_tokens.push_back(model_->NoTimeStampsToken()); + int32_t batch_size = 1; std::array token_shape{ batch_size, static_cast(initial_tokens.size())}; @@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, model_->Allocator(), offset_shape.data(), offset_shape.size()); *(offset.GetTensorMutableData()) = 0; + auto self_kv_cache = model_->GetInitialSelfKVCache(); + auto decoder_out = model_->ForwardDecoder( std::move(tokens), std::move(self_kv_cache.first), std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), std::move(offset)); + *(std::get<5>(decoder_out).GetTensorMutableData()) = + initial_tokens.size(); + const auto &logits = std::get<0>(decoder_out); const float *p_logits = logits.GetTensorData(); @@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, std::array token_shape{1, 1}; Ort::Value tokens = Ort::Value::CreateTensor( model_->Allocator(), token_shape.data(), token_shape.size()); + int64_t *p_tokens = tokens.GetTensorMutableData(); p_tokens[0] = max_token_id; - int64_t *p_offset = - std::get<5>(decoder_out).GetTensorMutableData(); - - if (i == 0) { - *p_offset = initial_tokens.size(); - } else { - *p_offset += 1; - } - decoder_out = model_->ForwardDecoder(std::move(tokens), std::move(std::get<1>(decoder_out)), std::move(std::get<2>(decoder_out)), @@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, std::move(std::get<4>(decoder_out)), std::move(std::get<5>(decoder_out))); + int64_t *p_offset = + std::get<5>(decoder_out).GetTensorMutableData(); + + *p_offset += 1; + const auto &logits = std::get<0>(decoder_out); const float *p_logits = logits.GetTensorData(); @@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, } std::vector ans(1); + ans[0].tokens = std::move(predicted_tokens); return ans; diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h index 98e515b9..b74bd94a 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h @@ -8,19 +8,25 @@ #include #include "sherpa-onnx/csrc/offline-whisper-decoder.h" +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" #include "sherpa-onnx/csrc/offline-whisper-model.h" namespace sherpa_onnx { class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { public: - explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model) - : model_(model) {} + OfflineWhisperGreedySearchDecoder(const OfflineWhisperModelConfig &config, + OfflineWhisperModel *model) + : config_(config), model_(model) {} std::vector Decode(Ort::Value cross_k, Ort::Value cross_v) override; + int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT + Ort::Value &cross_v) const; // NOLINT + private: + OfflineWhisperModelConfig config_; OfflineWhisperModel *model_; // not owned }; diff --git a/sherpa-onnx/csrc/offline-whisper-model-config.cc b/sherpa-onnx/csrc/offline-whisper-model-config.cc index 1a469e67..c22c80e4 100644 --- a/sherpa-onnx/csrc/offline-whisper-model-config.cc +++ b/sherpa-onnx/csrc/offline-whisper-model-config.cc @@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) { po->Register("whisper-decoder", &decoder, "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, " "medium.en-decoder.onnx."); + + po->Register( + "whisper-language", &language, + "The spoke language in the input audio file. Example values: " + "en, de, fr, zh, jp. If it is not given for a multilingual model, we will" + " infer the language from the input audio file. " + "Please refer to " + "https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10" + " for valid values. Note that for non-multilingual models, it supports " + "only 'en'"); + + po->Register("whisper-task", &task, + "Valid values: transcribe, translate. " + "Note that for non-multilingual models, it supports " + "only 'transcribe'"); } bool OfflineWhisperModelConfig::Validate() const { @@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const { return false; } + if (task != "translate" && task != "transcribe") { + SHERPA_ONNX_LOGE( + "--whisper-task supports only translate and transcribe. Given: %s", + task.c_str()); + + return false; + } + return true; } @@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const { os << "OfflineWhisperModelConfig("; os << "encoder=\"" << encoder << "\", "; - os << "decoder=\"" << decoder << "\")"; + os << "decoder=\"" << decoder << "\", "; + os << "language=\"" << language << "\", "; + os << "task=\"" << task << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-whisper-model-config.h b/sherpa-onnx/csrc/offline-whisper-model-config.h index 03e53372..b2d4dc51 100644 --- a/sherpa-onnx/csrc/offline-whisper-model-config.h +++ b/sherpa-onnx/csrc/offline-whisper-model-config.h @@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig { std::string encoder; std::string decoder; + // Available languages can be found at + // https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + // + // Note: For non-multilingual models, it supports only "en" + // + // If empty, we will infer it from the input audio file when + // the model is multilingual. + std::string language; + + // Valid values are transcribe and translate + // + // Note: For non-multilingual models, it supports only "transcribe" + std::string task = "transcribe"; + OfflineWhisperModelConfig() = default; OfflineWhisperModelConfig(const std::string &encoder, - const std::string &decoder) - : encoder(encoder), decoder(decoder) {} + const std::string &decoder, + const std::string &language, + const std::string &task) + : encoder(encoder), decoder(decoder), language(language), task(task) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-whisper-model.cc b/sherpa-onnx/csrc/offline-whisper-model.cc index 31739384..57cc55e9 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.cc +++ b/sherpa-onnx/csrc/offline-whisper-model.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "sherpa-onnx/csrc/macros.h" @@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl { const std::vector &GetInitialTokens() const { return sot_sequence_; } + const std::vector &GetAllLanguageIDs() const { + return all_language_tokens_; + } + + const std::unordered_map &GetLang2ID() const { + return lang2id_; + } + + const std::unordered_map &GetID2Lang() const { + return id2lang_; + } + + int32_t NoTimeStampsToken() const { return no_timestamps_; } + int32_t EOT() const { return eot_; } + int32_t SOT() const { return sot_; } + int32_t TextCtx() const { return n_text_ctx_; } + int32_t VocabSize() const { return n_vocab_; } + + int32_t Translate() const { return translate_; } + + bool IsMultiLingual() const { return is_multilingual_; } + private: void InitEncoder(void *model_data, size_t model_data_length) { encoder_sess_ = std::make_unique( @@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl { SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); + SHERPA_ONNX_READ_META_DATA(n_vocab_, "n_vocab"); SHERPA_ONNX_READ_META_DATA(sot_, "sot"); SHERPA_ONNX_READ_META_DATA(eot_, "eot"); SHERPA_ONNX_READ_META_DATA(blank_, "blank_id"); SHERPA_ONNX_READ_META_DATA(translate_, "translate"); + SHERPA_ONNX_READ_META_DATA(transcribe_, "transcribe"); + SHERPA_ONNX_READ_META_DATA(is_multilingual_, "is_multilingual"); SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps"); SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech"); SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence"); + + if (is_multilingual_) { + SHERPA_ONNX_READ_META_DATA_VEC(all_language_tokens_, + "all_language_tokens"); + SHERPA_ONNX_READ_META_DATA_VEC_STRING(all_language_codes_, + "all_language_codes"); + if (all_language_tokens_.size() != all_language_codes_.size()) { + SHERPA_ONNX_LOGE("# lang_id: %d != # lang_code: %d", + static_cast(all_language_tokens_.size()), + static_cast(all_language_codes_.size())); + exit(-1); + } + + for (int32_t i = 0; + i != static_cast(all_language_tokens_.size()); ++i) { + lang2id_[all_language_codes_[i]] = all_language_tokens_[i]; + id2lang_[all_language_tokens_[i]] = all_language_codes_[i]; + } + } } void InitDecoder(void *model_data, size_t model_data_length) { @@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl { std::vector decoder_output_names_; std::vector decoder_output_names_ptr_; + std::vector all_language_tokens_; + std::vector all_language_codes_; + std::unordered_map lang2id_; + std::unordered_map id2lang_; + // model meta data int32_t n_text_layer_; int32_t n_text_ctx_; int32_t n_text_state_; + int32_t n_vocab_; int32_t sot_; int32_t eot_; int32_t blank_; int32_t translate_; + int32_t transcribe_; int32_t no_timestamps_; int32_t no_speech_; + int32_t is_multilingual_; std::vector sot_sequence_; }; @@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) OfflineWhisperModel::~OfflineWhisperModel() = default; std::pair OfflineWhisperModel::ForwardEncoder( - Ort::Value features) { + Ort::Value features) const { return impl_->ForwardEncoder(std::move(features)); } @@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v, - Ort::Value offset) { + Ort::Value offset) const { return impl_->ForwardDecoder( std::move(tokens), std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), std::move(n_layer_cross_v), std::move(offset)); } -std::pair OfflineWhisperModel::GetInitialSelfKVCache() { +std::pair OfflineWhisperModel::GetInitialSelfKVCache() + const { return impl_->GetInitialSelfKVCache(); } @@ -206,8 +260,36 @@ const std::vector &OfflineWhisperModel::GetInitialTokens() const { return impl_->GetInitialTokens(); } +const std::vector &OfflineWhisperModel::GetAllLanguageIDs() const { + return impl_->GetAllLanguageIDs(); +} + +const std::unordered_map + &OfflineWhisperModel::GetLang2ID() const { + return impl_->GetLang2ID(); +} + +const std::unordered_map + &OfflineWhisperModel::GetID2Lang() const { + return impl_->GetID2Lang(); +} + +int32_t OfflineWhisperModel::NoTimeStampsToken() const { + return impl_->NoTimeStampsToken(); +} + int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); } +int32_t OfflineWhisperModel::SOT() const { return impl_->SOT(); } + int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } +int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); } + +bool OfflineWhisperModel::IsMultiLingual() const { + return impl_->IsMultiLingual(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-whisper-model.h b/sherpa-onnx/csrc/offline-whisper-model.h index 4353e42f..86038875 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.h +++ b/sherpa-onnx/csrc/offline-whisper-model.h @@ -5,7 +5,9 @@ #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ #include +#include #include +#include #include #include @@ -30,7 +32,7 @@ class OfflineWhisperModel { * - n_layer_cross_v: A 4-D tensor of shape * (n_text_layer, N, n_audio_ctx, n_text_state) */ - std::pair ForwardEncoder(Ort::Value features); + std::pair ForwardEncoder(Ort::Value features) const; /** Run the decoder model. * @@ -58,7 +60,9 @@ class OfflineWhisperModel { Ort::Value> ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, - Ort::Value n_layer_cross_v, Ort::Value offset); + Ort::Value n_layer_cross_v, Ort::Value offset) const; + + int32_t DetectLanguage() const; /** Return the initial self kv cache in a pair * - n_layer_self_k_cache A 4-D tensor of shape @@ -66,14 +70,23 @@ class OfflineWhisperModel { * - n_layer_self_v_cache A 4-D tensor of shape * (n_text_layer, N, n_audio_ctx, n_text_state). */ - std::pair GetInitialSelfKVCache(); + std::pair GetInitialSelfKVCache() const; const std::vector &GetInitialTokens() const; + const std::vector &GetAllLanguageIDs() const; + const std::unordered_map &GetLang2ID() const; + const std::unordered_map &GetID2Lang() const; /** Return an allocator for allocating memory */ OrtAllocator *Allocator() const; + + int32_t NoTimeStampsToken() const; int32_t EOT() const; + int32_t SOT() const; int32_t TextCtx() const; + int32_t VocabSize() const; + int32_t Translate() const; + bool IsMultiLingual() const; private: class Impl; diff --git a/sherpa-onnx/python/csrc/offline-whisper-model-config.cc b/sherpa-onnx/python/csrc/offline-whisper-model-config.cc index 27470492..8872484d 100644 --- a/sherpa-onnx/python/csrc/offline-whisper-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-whisper-model-config.cc @@ -14,10 +14,14 @@ namespace sherpa_onnx { void PybindOfflineWhisperModelConfig(py::module *m) { using PyClass = OfflineWhisperModelConfig; py::class_(*m, "OfflineWhisperModelConfig") - .def(py::init(), - py::arg("encoder"), py::arg("decoder")) + .def(py::init(), + py::arg("encoder"), py::arg("decoder"), py::arg("language"), + py::arg("task")) .def_readwrite("encoder", &PyClass::encoder) .def_readwrite("decoder", &PyClass::decoder) + .def_readwrite("language", &PyClass::language) + .def_readwrite("task", &PyClass::task) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index c87f34cf..26ee9b27 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -244,6 +244,8 @@ class OfflineRecognizer(object): encoder: str, decoder: str, tokens: str, + language: str = "en", + task: str = "transcribe", num_threads: int = 1, decoding_method: str = "greedy_search", debug: bool = False, @@ -268,6 +270,14 @@ class OfflineRecognizer(object): symbol integer_id + language: + The spoken language in the audio file. Example values: en, de, zh, + jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + for all possible values. Note that for non-multilingual models, the + only valid value is 'en'. + task: + Valid values are: transcribe, translate. Note that for + non-multilingual models, the only valid value is 'transcribe'. num_threads: Number of threads for neural network computation. decoding_method: @@ -279,7 +289,12 @@ class OfflineRecognizer(object): """ self = cls.__new__(cls) model_config = OfflineModelConfig( - whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder), + whisper=OfflineWhisperModelConfig( + encoder=encoder, + decoder=decoder, + language=language, + task=task, + ), tokens=tokens, num_threads=num_threads, debug=debug,