Support multilingual whisper models (#274)

This commit is contained in:
Fangjun Kuang
2023-08-16 00:28:52 +08:00
committed by GitHub
parent 496c5dd7f5
commit f709c95c5f
24 changed files with 692 additions and 73 deletions

View File

@@ -36,6 +36,9 @@ jobs:
CIBW_ARCHS: "universal2" CIBW_ARCHS: "universal2"
CIBW_BUILD_VERBOSITY: 3 CIBW_BUILD_VERBOSITY: 3
# Don't repair macOS wheels
CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""
- name: Display wheels - name: Display wheels
shell: bash shell: bash
run: | run: |

View File

@@ -16,7 +16,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [macos-latest] os: [macos-latest]
model: ["tiny.en", "base.en", "small.en", "medium.en"] model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR) cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx) project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.7.6") set(SHERPA_ONNX_VERSION "1.7.7")
# Disable warning about # Disable warning about
# #

View File

@@ -3,7 +3,7 @@ module non-streaming-decode-files
go 1.12 go 1.12
require ( require (
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/youpy/go-wav v0.3.2 github.com/youpy/go-wav v0.3.2
) )

View File

@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco= github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw= github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk= github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ= github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0= github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View File

@@ -4,6 +4,6 @@ go 1.12
require ( require (
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
) )

View File

@@ -1,12 +1,12 @@
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc= github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es= github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco= github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw= github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk= github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ= github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0= github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=

View File

@@ -3,7 +3,7 @@ module streaming-decode-files
go 1.12 go 1.12
require ( require (
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/youpy/go-wav v0.3.2 github.com/youpy/go-wav v0.3.2
) )

View File

@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco= github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw= github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk= github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ= github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0= github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View File

@@ -11,10 +11,12 @@ fun main() {
// please refer to // please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models // to dowload pre-trained models
var modelConfig = OnlineTransducerModelConfig( var modelConfig = OnlineModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", transducer = OnlineTransducerModelConfig(
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1, numThreads = 1,
debug = false, debug = false,
@@ -41,19 +43,19 @@ fun main() {
var objArray = WaveReader.readWaveFromFile( var objArray = WaveReader.readWaveFromFile(
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
) )
var samples : FloatArray = objArray[0] as FloatArray var samples: FloatArray = objArray[0] as FloatArray
var sampleRate : Int = objArray[1] as Int var sampleRate: Int = objArray[1] as Int
model.acceptWaveform(samples, sampleRate=sampleRate) model.acceptWaveform(samples, sampleRate = sampleRate)
while (model.isReady()) { while (model.isReady()) {
model.decode() model.decode()
} }
var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
model.acceptWaveform(tailPaddings, sampleRate=sampleRate) model.acceptWaveform(tailPaddings, sampleRate = sampleRate)
model.inputFinished() model.inputFinished()
while (model.isReady()) { while (model.isReady()) {
model.decode() model.decode()
} }
println("results: ${model.text}") println("results: ${model.text}")

View File

@@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser):
help="Path to whisper decoder model", help="Path to whisper decoder model",
) )
parser.add_argument(
"--whisper-language",
default="",
type=str,
help="""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
""",
)
parser.add_argument(
"--whisper-task",
default="transcribe",
choices=["transcribe", "translate"],
type=str,
help="""For multilingual models, if you specify translate, the output
will be in English.
""",
)
def add_model_args(parser: argparse.ArgumentParser): def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args(parser) add_transducer_model_args(parser)
@@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
tokens=args.tokens, tokens=args.tokens,
num_threads=args.num_threads, num_threads=args.num_threads,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
language=args.whisper_language,
task=args.whisper_task,
) )
elif args.tdnn_model: elif args.tdnn_model:
assert_file_exists(args.tdnn_model) assert_file_exists(args.tdnn_model)

View File

@@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--whisper-task=transcribe \
--num-threads=1 \ --num-threads=1 \
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
@@ -200,6 +201,28 @@ def get_args():
help="Path to whisper decoder model", help="Path to whisper decoder model",
) )
parser.add_argument(
"--whisper-language",
default="",
type=str,
help="""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
""",
)
parser.add_argument(
"--whisper-task",
default="transcribe",
choices=["transcribe", "translate"],
type=str,
help="""For multilingual models, if you specify translate, the output
will be in English.
""",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@@ -371,10 +394,10 @@ def main():
decoder=args.whisper_decoder, decoder=args.whisper_decoder,
tokens=args.tokens, tokens=args.tokens,
num_threads=args.num_threads, num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
debug=args.debug, debug=args.debug,
language=args.whisper_language,
task=args.whisper_task,
) )
elif args.tdnn_model: elif args.tdnn_model:
assert_file_exists(args.tdnn_model) assert_file_exists(args.tdnn_model)

View File

@@ -11,6 +11,7 @@ for making the onnx export script public.
""" """
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@@ -250,6 +251,7 @@ def main():
# write tokens # write tokens
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
model.eval() model.eval()
print(model.dims) print(model.dims)
audio = torch.rand(16000 * 2) audio = torch.rand(16000 * 2)
@@ -306,8 +308,12 @@ def main():
"n_text_head": model.dims.n_text_head, "n_text_head": model.dims.n_text_head,
"n_text_layer": model.dims.n_text_layer, "n_text_layer": model.dims.n_text_layer,
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))), "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))), "all_language_tokens": ",".join(
"all_language_codes": ",".join(tokenizer.all_language_codes), list(map(str, tokenizer.all_language_tokens))
), # a list of ids
"all_language_codes": ",".join(
tokenizer.all_language_codes
), # e.g., en, de, zh, fr
"sot": tokenizer.sot, "sot": tokenizer.sot,
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot), "sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
"eot": tokenizer.eot, "eot": tokenizer.eot,
@@ -413,6 +419,9 @@ def main():
}, },
) )
if 'large' in args.model:
# it causes errors for large models, so skip it.
return
# Generate int8 quantization models # Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

View File

@@ -38,6 +38,24 @@ def get_args():
help="Path to the tokens", help="Path to the tokens",
) )
parser.add_argument(
"--language",
type=str,
help="""The actual spoken language in the audio.
Example values, en, de, zh, jp, fr.
If None, we will detect the language using the first 30s of the
input audio
""",
)
parser.add_argument(
"--task",
choices=["transcribe", "translate"],
type=str,
default="transcribe",
help="Valid values are: transcribe, translate",
)
parser.add_argument( parser.add_argument(
"sound_file", "sound_file",
type=str, type=str,
@@ -74,12 +92,22 @@ class OnnxModel:
self.sot = int(meta["sot"]) self.sot = int(meta["sot"])
self.eot = int(meta["eot"]) self.eot = int(meta["eot"])
self.translate = int(meta["translate"]) self.translate = int(meta["translate"])
self.transcribe = int(meta["transcribe"])
self.no_timestamps = int(meta["no_timestamps"]) self.no_timestamps = int(meta["no_timestamps"])
self.no_speech = int(meta["no_speech"]) self.no_speech = int(meta["no_speech"])
self.blank = int(meta["blank_id"]) self.blank = int(meta["blank_id"])
self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
self.sot_sequence.append(self.no_timestamps)
self.all_language_tokens = list(
map(int, meta["all_language_tokens"].split(","))
)
self.all_language_codes = meta["all_language_codes"].split(",")
self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens))
self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes))
self.is_multilingual = int(meta["is_multilingual"]) == 1 self.is_multilingual = int(meta["is_multilingual"]) == 1
def init_decoder(self, decoder: str): def init_decoder(self, decoder: str):
@@ -164,6 +192,29 @@ class OnnxModel:
# logits is changed in-place # logits is changed in-place
logits[self.translate] = float("-inf") logits[self.translate] = float("-inf")
def detect_language(
self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor
) -> int:
tokens = torch.tensor([[self.sot]], dtype=torch.int64)
offset = torch.zeros(1, dtype=torch.int64)
n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache()
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder(
tokens=tokens,
n_layer_self_k_cache=n_layer_self_k_cache,
n_layer_self_v_cache=n_layer_self_v_cache,
n_layer_cross_k=n_layer_cross_k,
n_layer_cross_v=n_layer_cross_v,
offset=offset,
)
logits = logits.reshape(-1)
mask = torch.ones(logits.shape[0], dtype=torch.int64)
mask[self.all_language_tokens] = 0
logits[mask] = float("-inf")
lang_id = logits.argmax().item()
print("detected language: ", self.id2lang[lang_id])
return lang_id
def load_tokens(filename): def load_tokens(filename):
tokens = dict() tokens = dict()
@@ -200,7 +251,35 @@ def main():
mel = mel.t().unsqueeze(0) mel = mel.t().unsqueeze(0)
model = OnnxModel(encoder, decoder) model = OnnxModel(encoder, decoder)
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
if args.language is not None:
if model.is_multilingual is False and args.language != "en":
print(f"This model supports only English. Given: {args.language}")
return
if args.language not in model.lang2id:
print(f"Invalid language: {args.language}")
print(f"Valid values are: {list(model.lang2id.keys())}")
return
# [sot, lang, task, notimestamps]
model.sot_sequence[1] = model.lang2id[args.language]
elif model.is_multilingual is True:
print("detecting language")
lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
model.sot_sequence[1] = lang
if args.task is not None:
if model.is_multilingual is False and args.task != "transcribe":
print("This model supports only English. Please use --task=transcribe")
return
assert args.task in ["transcribe", "translate"], args.task
if args.task == "translate":
model.sot_sequence[2] = model.translate
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
@@ -213,6 +292,7 @@ def main():
n_layer_cross_v=n_layer_cross_v, n_layer_cross_v=n_layer_cross_v,
offset=offset, offset=offset,
) )
offset += len(model.sot_sequence)
# logits.shape (batch_size, tokens.shape[1], vocab_size) # logits.shape (batch_size, tokens.shape[1], vocab_size)
logits = logits[0, -1] logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=True) model.suppress_tokens(logits, is_initial=True)
@@ -225,7 +305,6 @@ def main():
break break
results.append(max_token_id.item()) results.append(max_token_id.item())
tokens = torch.tensor([[results[-1]]]) tokens = torch.tensor([[results[-1]]])
offset += 1
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
tokens=tokens, tokens=tokens,
@@ -235,6 +314,7 @@ def main():
n_layer_cross_v=n_layer_cross_v, n_layer_cross_v=n_layer_cross_v,
offset=offset, offset=offset,
) )
offset += 1
logits = logits[0, -1] logits = logits[0, -1]
model.suppress_tokens(logits, is_initial=False) model.suppress_tokens(logits, is_initial=False)
max_token_id = logits.argmax(dim=-1) max_token_id = logits.argmax(dim=-1)

View File

@@ -37,7 +37,7 @@
} \ } \
\ \
dst = atoi(value.get()); \ dst = atoi(value.get()); \
if (dst <= 0) { \ if (dst < 0) { \
SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \ SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
exit(-1); \ exit(-1); \
} \ } \
@@ -77,6 +77,24 @@
} \ } \
} while (0) } while (0)
// read a vector of strings
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
do { \
auto value = \
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
if (!value) { \
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
exit(-1); \
} \
SplitStringToVector(value.get(), ",", false, &dst); \
\
if (dst.empty()) { \
SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
src_key); \
exit(-1); \
} \
} while (0)
// Read a string // Read a string
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
do { \ do { \

View File

@@ -23,21 +23,227 @@
namespace sherpa_onnx { namespace sherpa_onnx {
static std::string FixInvalidUtf8(const std::string &s) {
int32_t s_size = s.size();
std::string ans;
ans.reserve(s_size);
for (int32_t i = 0; i < s_size;) {
uint8_t c = s[i];
if (c < 0x80) {
// valid
ans.append(1, c);
++i;
continue;
} else if ((c >= 0xc0) && (c < 0xe0)) {
// beginning of two bytes
if ((i + 1) > (s_size - 1)) {
// no subsequent byte. invalid!
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
// valid 2-byte utf-8
ans.append(1, c);
ans.append(1, next);
i += 2;
continue;
} else if ((c >= 0xe0) && (c < 0xf0)) {
// beginning of 3 bytes
if ((i + 2) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
i += 3;
continue;
} else if ((c >= 0xf0) && (c < 0xf8)) {
// 4 bytes
if ((i + 3) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next3 = s[i + 3];
if (!(next3 >= 0x80 && next3 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
ans.append(1, next3);
i += 4;
continue;
} else if ((c >= 0xf8) && (c < 0xfc)) {
// 5 bytes
if ((i + 4) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next3 = s[i + 3];
if (!(next3 >= 0x80 && next3 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next4 = s[i + 4];
if (!(next4 >= 0x80 && next4 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
ans.append(1, next3);
ans.append(1, next4);
i += 5;
continue;
} else if ((c >= 0xfc) && (c < 0xfe)) {
// 6 bytes
if ((i + 5) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next3 = s[i + 3];
if (!(next3 >= 0x80 && next3 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next4 = s[i + 4];
if (!(next4 >= 0x80 && next4 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next5 = s[i + 5];
if (!(next5 >= 0x80 && next5 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
ans.append(1, next3);
ans.append(1, next4);
ans.append(1, next5);
i += 6;
continue;
} else {
i += 1;
}
}
return ans;
}
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
const SymbolTable &sym_table) { const SymbolTable &sym_table) {
OfflineRecognitionResult r; OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size()); r.tokens.reserve(src.tokens.size());
std::string text;
for (auto i : src.tokens) { for (auto i : src.tokens) {
if (!sym_table.contains(i)) { if (!sym_table.contains(i)) {
continue; continue;
} }
const auto &s = sym_table[i]; const auto &s = sym_table[i];
r.text += s; text += s;
r.tokens.push_back(s); r.tokens.push_back(s);
} }
// TODO(fangjun): Fix the following error in offline-stream.cc
//
// j["text"] = text;
// libc++abi: terminating with uncaught exception of type
// nlohmann::json_abi_v3_11_2::detail::type_error:
// [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
#if 0
r.text = FixInvalidUtf8(text);
#else
r.text = text;
#endif
return r; return r;
} }
@@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
symbol_table_.ApplyBase64Decode(); symbol_table_.ApplyBase64Decode();
if (config.decoding_method == "greedy_search") { if (config.decoding_method == "greedy_search") {
decoder_ = decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get()); config_.model_config.whisper, model_.get());
} else { } else {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"Only greedy_search is supported at present for whisper. Given %s", "Only greedy_search is supported at present for whisper. Given %s",
@@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
mel = Transpose12(model_->Allocator(), &mel); mel = Transpose12(model_->Allocator(), &mel);
auto cross_kv = model_->ForwardEncoder(std::move(mel)); auto cross_kv = model_->ForwardEncoder(std::move(mel));
auto results = auto results =
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));

View File

@@ -7,17 +7,106 @@
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx { namespace sherpa_onnx {
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
int64_t token_val = model_->SOT();
std::array<int64_t, 2> token_shape{1, 1};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
auto self_kv_cache = model_->GetInitialSelfKVCache();
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto decoder_out = model_->ForwardDecoder(
std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
std::move(offset));
cross_k = std::move(std::get<3>(decoder_out));
cross_v = std::move(std::get<4>(decoder_out));
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
int32_t vocab_size = model_->VocabSize();
const auto &all_language_ids = model_->GetAllLanguageIDs();
int32_t lang_id = all_language_ids[0];
float this_logit = p_logits[lang_id];
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
int32_t id = all_language_ids[i];
float p = p_logits[id];
if (p > this_logit) {
this_logit = p;
lang_id = id;
}
}
#if 1
SHERPA_ONNX_LOGE("Detected language: %s",
model_->GetID2Lang().at(lang_id).c_str());
#endif
return lang_id;
}
std::vector<OfflineWhisperDecoderResult> std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) { Ort::Value cross_v) {
auto memory_info = auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
auto self_kv_cache = model_->GetInitialSelfKVCache(); // For multilingual models, initial_tokens contains [sot, language, task]
// - language is English by default
// - task is transcribe by default
//
// For non-multilingual models, initial_tokens contains [sot]
std::vector<int64_t> initial_tokens = model_->GetInitialTokens(); std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
if (model_->IsMultiLingual()) {
if (!config_.language.empty()) {
const auto &lang2id = model_->GetLang2ID();
if (!lang2id.count(config_.language)) {
SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str());
exit(-1);
}
int32_t lang_id = lang2id.at(config_.language);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;
} else {
int32_t lang_id = DetectLanguage(cross_k, cross_v);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;
}
if (config_.task == "translate") {
initial_tokens[2] = model_->Translate();
} else if (config_.task != "transcribe") {
// initial_tokens[2] is transcribe by default
SHERPA_ONNX_LOGE(
"Unsupported task: %s. Valid values are: transcribe, translate.",
config_.task.c_str());
}
}
initial_tokens.push_back(model_->NoTimeStampsToken());
int32_t batch_size = 1; int32_t batch_size = 1;
std::array<int64_t, 2> token_shape{ std::array<int64_t, 2> token_shape{
batch_size, static_cast<int64_t>(initial_tokens.size())}; batch_size, static_cast<int64_t>(initial_tokens.size())};
@@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
model_->Allocator(), offset_shape.data(), offset_shape.size()); model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0; *(offset.GetTensorMutableData<int64_t>()) = 0;
auto self_kv_cache = model_->GetInitialSelfKVCache();
auto decoder_out = model_->ForwardDecoder( auto decoder_out = model_->ForwardDecoder(
std::move(tokens), std::move(self_kv_cache.first), std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
std::move(offset)); std::move(offset));
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) =
initial_tokens.size();
const auto &logits = std::get<0>(decoder_out); const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>(); const float *p_logits = logits.GetTensorData<float>();
@@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
std::array<int64_t, 2> token_shape{1, 1}; std::array<int64_t, 2> token_shape{1, 1};
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>( Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), token_shape.data(), token_shape.size()); model_->Allocator(), token_shape.data(), token_shape.size());
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>(); int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
p_tokens[0] = max_token_id; p_tokens[0] = max_token_id;
int64_t *p_offset =
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
if (i == 0) {
*p_offset = initial_tokens.size();
} else {
*p_offset += 1;
}
decoder_out = model_->ForwardDecoder(std::move(tokens), decoder_out = model_->ForwardDecoder(std::move(tokens),
std::move(std::get<1>(decoder_out)), std::move(std::get<1>(decoder_out)),
std::move(std::get<2>(decoder_out)), std::move(std::get<2>(decoder_out)),
@@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
std::move(std::get<4>(decoder_out)), std::move(std::get<4>(decoder_out)),
std::move(std::get<5>(decoder_out))); std::move(std::get<5>(decoder_out)));
int64_t *p_offset =
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
*p_offset += 1;
const auto &logits = std::get<0>(decoder_out); const auto &logits = std::get<0>(decoder_out);
const float *p_logits = logits.GetTensorData<float>(); const float *p_logits = logits.GetTensorData<float>();
@@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
} }
std::vector<OfflineWhisperDecoderResult> ans(1); std::vector<OfflineWhisperDecoderResult> ans(1);
ans[0].tokens = std::move(predicted_tokens); ans[0].tokens = std::move(predicted_tokens);
return ans; return ans;

View File

@@ -8,19 +8,25 @@
#include <vector> #include <vector>
#include "sherpa-onnx/csrc/offline-whisper-decoder.h" #include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h" #include "sherpa-onnx/csrc/offline-whisper-model.h"
namespace sherpa_onnx { namespace sherpa_onnx {
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
public: public:
explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model) OfflineWhisperGreedySearchDecoder(const OfflineWhisperModelConfig &config,
: model_(model) {} OfflineWhisperModel *model)
: config_(config), model_(model) {}
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override; Ort::Value cross_v) override;
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) const; // NOLINT
private: private:
OfflineWhisperModelConfig config_;
OfflineWhisperModel *model_; // not owned OfflineWhisperModel *model_; // not owned
}; };

View File

@@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
po->Register("whisper-decoder", &decoder, po->Register("whisper-decoder", &decoder,
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, " "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
"medium.en-decoder.onnx."); "medium.en-decoder.onnx.");
po->Register(
"whisper-language", &language,
"The spoke language in the input audio file. Example values: "
"en, de, fr, zh, jp. If it is not given for a multilingual model, we will"
" infer the language from the input audio file. "
"Please refer to "
"https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10"
" for valid values. Note that for non-multilingual models, it supports "
"only 'en'");
po->Register("whisper-task", &task,
"Valid values: transcribe, translate. "
"Note that for non-multilingual models, it supports "
"only 'transcribe'");
} }
bool OfflineWhisperModelConfig::Validate() const { bool OfflineWhisperModelConfig::Validate() const {
@@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const {
return false; return false;
} }
if (task != "translate" && task != "transcribe") {
SHERPA_ONNX_LOGE(
"--whisper-task supports only translate and transcribe. Given: %s",
task.c_str());
return false;
}
return true; return true;
} }
@@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const {
os << "OfflineWhisperModelConfig("; os << "OfflineWhisperModelConfig(";
os << "encoder=\"" << encoder << "\", "; os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\")"; os << "decoder=\"" << decoder << "\", ";
os << "language=\"" << language << "\", ";
os << "task=\"" << task << "\")";
return os.str(); return os.str();
} }

View File

@@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig {
std::string encoder; std::string encoder;
std::string decoder; std::string decoder;
// Available languages can be found at
// https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
//
// Note: For non-multilingual models, it supports only "en"
//
// If empty, we will infer it from the input audio file when
// the model is multilingual.
std::string language;
// Valid values are transcribe and translate
//
// Note: For non-multilingual models, it supports only "transcribe"
std::string task = "transcribe";
OfflineWhisperModelConfig() = default; OfflineWhisperModelConfig() = default;
OfflineWhisperModelConfig(const std::string &encoder, OfflineWhisperModelConfig(const std::string &encoder,
const std::string &decoder) const std::string &decoder,
: encoder(encoder), decoder(decoder) {} const std::string &language,
const std::string &task)
: encoder(encoder), decoder(decoder), language(language), task(task) {}
void Register(ParseOptions *po); void Register(ParseOptions *po);
bool Validate() const; bool Validate() const;

View File

@@ -7,6 +7,7 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map>
#include <utility> #include <utility>
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
@@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl {
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; } const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
const std::vector<int32_t> &GetAllLanguageIDs() const {
return all_language_tokens_;
}
const std::unordered_map<std::string, int32_t> &GetLang2ID() const {
return lang2id_;
}
const std::unordered_map<int32_t, std::string> &GetID2Lang() const {
return id2lang_;
}
int32_t NoTimeStampsToken() const { return no_timestamps_; }
int32_t EOT() const { return eot_; } int32_t EOT() const { return eot_; }
int32_t SOT() const { return sot_; }
int32_t TextCtx() const { return n_text_ctx_; } int32_t TextCtx() const { return n_text_ctx_; }
int32_t VocabSize() const { return n_vocab_; }
int32_t Translate() const { return translate_; }
bool IsMultiLingual() const { return is_multilingual_; }
private: private:
void InitEncoder(void *model_data, size_t model_data_length) { void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>( encoder_sess_ = std::make_unique<Ort::Session>(
@@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl {
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
SHERPA_ONNX_READ_META_DATA(n_vocab_, "n_vocab");
SHERPA_ONNX_READ_META_DATA(sot_, "sot"); SHERPA_ONNX_READ_META_DATA(sot_, "sot");
SHERPA_ONNX_READ_META_DATA(eot_, "eot"); SHERPA_ONNX_READ_META_DATA(eot_, "eot");
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id"); SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
SHERPA_ONNX_READ_META_DATA(translate_, "translate"); SHERPA_ONNX_READ_META_DATA(translate_, "translate");
SHERPA_ONNX_READ_META_DATA(transcribe_, "transcribe");
SHERPA_ONNX_READ_META_DATA(is_multilingual_, "is_multilingual");
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps"); SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech"); SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence"); SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
if (is_multilingual_) {
SHERPA_ONNX_READ_META_DATA_VEC(all_language_tokens_,
"all_language_tokens");
SHERPA_ONNX_READ_META_DATA_VEC_STRING(all_language_codes_,
"all_language_codes");
if (all_language_tokens_.size() != all_language_codes_.size()) {
SHERPA_ONNX_LOGE("# lang_id: %d != # lang_code: %d",
static_cast<int32_t>(all_language_tokens_.size()),
static_cast<int32_t>(all_language_codes_.size()));
exit(-1);
}
for (int32_t i = 0;
i != static_cast<int32_t>(all_language_tokens_.size()); ++i) {
lang2id_[all_language_codes_[i]] = all_language_tokens_[i];
id2lang_[all_language_tokens_[i]] = all_language_codes_[i];
}
}
} }
void InitDecoder(void *model_data, size_t model_data_length) { void InitDecoder(void *model_data, size_t model_data_length) {
@@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl {
std::vector<std::string> decoder_output_names_; std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_; std::vector<const char *> decoder_output_names_ptr_;
std::vector<int32_t> all_language_tokens_;
std::vector<std::string> all_language_codes_;
std::unordered_map<std::string, int32_t> lang2id_;
std::unordered_map<int32_t, std::string> id2lang_;
// model meta data // model meta data
int32_t n_text_layer_; int32_t n_text_layer_;
int32_t n_text_ctx_; int32_t n_text_ctx_;
int32_t n_text_state_; int32_t n_text_state_;
int32_t n_vocab_;
int32_t sot_; int32_t sot_;
int32_t eot_; int32_t eot_;
int32_t blank_; int32_t blank_;
int32_t translate_; int32_t translate_;
int32_t transcribe_;
int32_t no_timestamps_; int32_t no_timestamps_;
int32_t no_speech_; int32_t no_speech_;
int32_t is_multilingual_;
std::vector<int64_t> sot_sequence_; std::vector<int64_t> sot_sequence_;
}; };
@@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
OfflineWhisperModel::~OfflineWhisperModel() = default; OfflineWhisperModel::~OfflineWhisperModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder( std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
Ort::Value features) { Ort::Value features) const {
return impl_->ForwardEncoder(std::move(features)); return impl_->ForwardEncoder(std::move(features));
} }
@@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_self_v_cache,
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value n_layer_cross_v,
Ort::Value offset) { Ort::Value offset) const {
return impl_->ForwardDecoder( return impl_->ForwardDecoder(
std::move(tokens), std::move(n_layer_self_k_cache), std::move(tokens), std::move(n_layer_self_k_cache),
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
std::move(n_layer_cross_v), std::move(offset)); std::move(n_layer_cross_v), std::move(offset));
} }
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() { std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
const {
return impl_->GetInitialSelfKVCache(); return impl_->GetInitialSelfKVCache();
} }
@@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
return impl_->GetInitialTokens(); return impl_->GetInitialTokens();
} }
const std::vector<int32_t> &OfflineWhisperModel::GetAllLanguageIDs() const {
return impl_->GetAllLanguageIDs();
}
const std::unordered_map<std::string, int32_t>
&OfflineWhisperModel::GetLang2ID() const {
return impl_->GetLang2ID();
}
const std::unordered_map<int32_t, std::string>
&OfflineWhisperModel::GetID2Lang() const {
return impl_->GetID2Lang();
}
int32_t OfflineWhisperModel::NoTimeStampsToken() const {
return impl_->NoTimeStampsToken();
}
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); } int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
int32_t OfflineWhisperModel::SOT() const { return impl_->SOT(); }
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
bool OfflineWhisperModel::IsMultiLingual() const {
return impl_->IsMultiLingual();
}
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -5,7 +5,9 @@
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
#include <memory> #include <memory>
#include <string>
#include <tuple> #include <tuple>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
@@ -30,7 +32,7 @@ class OfflineWhisperModel {
* - n_layer_cross_v: A 4-D tensor of shape * - n_layer_cross_v: A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state) * (n_text_layer, N, n_audio_ctx, n_text_state)
*/ */
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features); std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) const;
/** Run the decoder model. /** Run the decoder model.
* *
@@ -58,7 +60,9 @@ class OfflineWhisperModel {
Ort::Value> Ort::Value>
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset); Ort::Value n_layer_cross_v, Ort::Value offset) const;
int32_t DetectLanguage() const;
/** Return the initial self kv cache in a pair /** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 4-D tensor of shape * - n_layer_self_k_cache A 4-D tensor of shape
@@ -66,14 +70,23 @@ class OfflineWhisperModel {
* - n_layer_self_v_cache A 4-D tensor of shape * - n_layer_self_v_cache A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state). * (n_text_layer, N, n_audio_ctx, n_text_state).
*/ */
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache(); std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
const std::vector<int64_t> &GetInitialTokens() const; const std::vector<int64_t> &GetInitialTokens() const;
const std::vector<int32_t> &GetAllLanguageIDs() const;
const std::unordered_map<std::string, int32_t> &GetLang2ID() const;
const std::unordered_map<int32_t, std::string> &GetID2Lang() const;
/** Return an allocator for allocating memory /** Return an allocator for allocating memory
*/ */
OrtAllocator *Allocator() const; OrtAllocator *Allocator() const;
int32_t NoTimeStampsToken() const;
int32_t EOT() const; int32_t EOT() const;
int32_t SOT() const;
int32_t TextCtx() const; int32_t TextCtx() const;
int32_t VocabSize() const;
int32_t Translate() const;
bool IsMultiLingual() const;
private: private:
class Impl; class Impl;

View File

@@ -14,10 +14,14 @@ namespace sherpa_onnx {
void PybindOfflineWhisperModelConfig(py::module *m) { void PybindOfflineWhisperModelConfig(py::module *m) {
using PyClass = OfflineWhisperModelConfig; using PyClass = OfflineWhisperModelConfig;
py::class_<PyClass>(*m, "OfflineWhisperModelConfig") py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
.def(py::init<const std::string &, const std::string &>(), .def(py::init<const std::string &, const std::string &,
py::arg("encoder"), py::arg("decoder")) const std::string &, const std::string &>(),
py::arg("encoder"), py::arg("decoder"), py::arg("language"),
py::arg("task"))
.def_readwrite("encoder", &PyClass::encoder) .def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder) .def_readwrite("decoder", &PyClass::decoder)
.def_readwrite("language", &PyClass::language)
.def_readwrite("task", &PyClass::task)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }

View File

@@ -244,6 +244,8 @@ class OfflineRecognizer(object):
encoder: str, encoder: str,
decoder: str, decoder: str,
tokens: str, tokens: str,
language: str = "en",
task: str = "transcribe",
num_threads: int = 1, num_threads: int = 1,
decoding_method: str = "greedy_search", decoding_method: str = "greedy_search",
debug: bool = False, debug: bool = False,
@@ -268,6 +270,14 @@ class OfflineRecognizer(object):
symbol integer_id symbol integer_id
language:
The spoken language in the audio file. Example values: en, de, zh,
jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
for all possible values. Note that for non-multilingual models, the
only valid value is 'en'.
task:
Valid values are: transcribe, translate. Note that for
non-multilingual models, the only valid value is 'transcribe'.
num_threads: num_threads:
Number of threads for neural network computation. Number of threads for neural network computation.
decoding_method: decoding_method:
@@ -279,7 +289,12 @@ class OfflineRecognizer(object):
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
model_config = OfflineModelConfig( model_config = OfflineModelConfig(
whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder), whisper=OfflineWhisperModelConfig(
encoder=encoder,
decoder=decoder,
language=language,
task=task,
),
tokens=tokens, tokens=tokens,
num_threads=num_threads, num_threads=num_threads,
debug=debug, debug=debug,