Support multilingual whisper models (#274)
This commit is contained in:
3
.github/workflows/build-wheels-macos.yaml
vendored
3
.github/workflows/build-wheels-macos.yaml
vendored
@@ -36,6 +36,9 @@ jobs:
|
||||
CIBW_ARCHS: "universal2"
|
||||
CIBW_BUILD_VERBOSITY: 3
|
||||
|
||||
# Don't repair macOS wheels
|
||||
CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""
|
||||
|
||||
- name: Display wheels
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest]
|
||||
model: ["tiny.en", "base.en", "small.en", "medium.en"]
|
||||
model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||
project(sherpa-onnx)
|
||||
|
||||
set(SHERPA_ONNX_VERSION "1.7.6")
|
||||
set(SHERPA_ONNX_VERSION "1.7.7")
|
||||
|
||||
# Disable warning about
|
||||
#
|
||||
|
||||
@@ -3,7 +3,7 @@ module non-streaming-decode-files
|
||||
go 1.12
|
||||
|
||||
require (
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/youpy/go-wav v0.3.2
|
||||
)
|
||||
|
||||
@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
||||
@@ -4,6 +4,6 @@ go 1.12
|
||||
|
||||
require (
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
|
||||
@@ -3,7 +3,7 @@ module streaming-decode-files
|
||||
go 1.12
|
||||
|
||||
require (
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/youpy/go-wav v0.3.2
|
||||
)
|
||||
|
||||
@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
|
||||
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
|
||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
|
||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
|
||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
||||
@@ -11,10 +11,12 @@ fun main() {
|
||||
// please refer to
|
||||
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
// to dowload pre-trained models
|
||||
var modelConfig = OnlineTransducerModelConfig(
|
||||
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
|
||||
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
|
||||
var modelConfig = OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
|
||||
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
|
||||
),
|
||||
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
|
||||
numThreads = 1,
|
||||
debug = false,
|
||||
@@ -41,19 +43,19 @@ fun main() {
|
||||
var objArray = WaveReader.readWaveFromFile(
|
||||
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
|
||||
)
|
||||
var samples : FloatArray = objArray[0] as FloatArray
|
||||
var sampleRate : Int = objArray[1] as Int
|
||||
var samples: FloatArray = objArray[0] as FloatArray
|
||||
var sampleRate: Int = objArray[1] as Int
|
||||
|
||||
model.acceptWaveform(samples, sampleRate=sampleRate)
|
||||
model.acceptWaveform(samples, sampleRate = sampleRate)
|
||||
while (model.isReady()) {
|
||||
model.decode()
|
||||
model.decode()
|
||||
}
|
||||
|
||||
var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
|
||||
model.acceptWaveform(tailPaddings, sampleRate=sampleRate)
|
||||
model.acceptWaveform(tailPaddings, sampleRate = sampleRate)
|
||||
model.inputFinished()
|
||||
while (model.isReady()) {
|
||||
model.decode()
|
||||
model.decode()
|
||||
}
|
||||
|
||||
println("results: ${model.text}")
|
||||
|
||||
@@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser):
|
||||
help="Path to whisper decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-language",
|
||||
default="",
|
||||
type=str,
|
||||
help="""It specifies the spoken language in the input audio file.
|
||||
Example values: en, fr, de, zh, jp.
|
||||
Available languages for multilingual models can be found at
|
||||
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||
If not specified, we infer the language from the input audio file.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-task",
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
type=str,
|
||||
help="""For multilingual models, if you specify translate, the output
|
||||
will be in English.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser):
|
||||
add_transducer_model_args(parser)
|
||||
@@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
decoding_method=args.decoding_method,
|
||||
language=args.whisper_language,
|
||||
task=args.whisper_task,
|
||||
)
|
||||
elif args.tdnn_model:
|
||||
assert_file_exists(args.tdnn_model)
|
||||
|
||||
@@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \
|
||||
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||
--whisper-task=transcribe \
|
||||
--num-threads=1 \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||
@@ -200,6 +201,28 @@ def get_args():
|
||||
help="Path to whisper decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-language",
|
||||
default="",
|
||||
type=str,
|
||||
help="""It specifies the spoken language in the input audio file.
|
||||
Example values: en, fr, de, zh, jp.
|
||||
Available languages for multilingual models can be found at
|
||||
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||
If not specified, we infer the language from the input audio file.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-task",
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
type=str,
|
||||
help="""For multilingual models, if you specify translate, the output
|
||||
will be in English.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@@ -371,10 +394,10 @@ def main():
|
||||
decoder=args.whisper_decoder,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
language=args.whisper_language,
|
||||
task=args.whisper_task,
|
||||
)
|
||||
elif args.tdnn_model:
|
||||
assert_file_exists(args.tdnn_model)
|
||||
|
||||
@@ -11,6 +11,7 @@ for making the onnx export script public.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -250,6 +251,7 @@ def main():
|
||||
# write tokens
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
||||
|
||||
model.eval()
|
||||
print(model.dims)
|
||||
audio = torch.rand(16000 * 2)
|
||||
@@ -306,8 +308,12 @@ def main():
|
||||
"n_text_head": model.dims.n_text_head,
|
||||
"n_text_layer": model.dims.n_text_layer,
|
||||
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
|
||||
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
|
||||
"all_language_codes": ",".join(tokenizer.all_language_codes),
|
||||
"all_language_tokens": ",".join(
|
||||
list(map(str, tokenizer.all_language_tokens))
|
||||
), # a list of ids
|
||||
"all_language_codes": ",".join(
|
||||
tokenizer.all_language_codes
|
||||
), # e.g., en, de, zh, fr
|
||||
"sot": tokenizer.sot,
|
||||
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
|
||||
"eot": tokenizer.eot,
|
||||
@@ -413,6 +419,9 @@ def main():
|
||||
},
|
||||
)
|
||||
|
||||
if 'large' in args.model:
|
||||
# it causes errors for large models, so skip it.
|
||||
return
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
|
||||
@@ -38,6 +38,24 @@ def get_args():
|
||||
help="Path to the tokens",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
help="""The actual spoken language in the audio.
|
||||
Example values, en, de, zh, jp, fr.
|
||||
If None, we will detect the language using the first 30s of the
|
||||
input audio
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
choices=["transcribe", "translate"],
|
||||
type=str,
|
||||
default="transcribe",
|
||||
help="Valid values are: transcribe, translate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
@@ -74,12 +92,22 @@ class OnnxModel:
|
||||
self.sot = int(meta["sot"])
|
||||
self.eot = int(meta["eot"])
|
||||
self.translate = int(meta["translate"])
|
||||
self.transcribe = int(meta["transcribe"])
|
||||
self.no_timestamps = int(meta["no_timestamps"])
|
||||
self.no_speech = int(meta["no_speech"])
|
||||
self.blank = int(meta["blank_id"])
|
||||
|
||||
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
||||
|
||||
self.sot_sequence.append(self.no_timestamps)
|
||||
|
||||
self.all_language_tokens = list(
|
||||
map(int, meta["all_language_tokens"].split(","))
|
||||
)
|
||||
self.all_language_codes = meta["all_language_codes"].split(",")
|
||||
self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens))
|
||||
self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes))
|
||||
|
||||
self.is_multilingual = int(meta["is_multilingual"]) == 1
|
||||
|
||||
def init_decoder(self, decoder: str):
|
||||
@@ -164,6 +192,29 @@ class OnnxModel:
|
||||
# logits is changed in-place
|
||||
logits[self.translate] = float("-inf")
|
||||
|
||||
def detect_language(
|
||||
self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor
|
||||
) -> int:
|
||||
tokens = torch.tensor([[self.sot]], dtype=torch.int64)
|
||||
offset = torch.zeros(1, dtype=torch.int64)
|
||||
n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache()
|
||||
|
||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder(
|
||||
tokens=tokens,
|
||||
n_layer_self_k_cache=n_layer_self_k_cache,
|
||||
n_layer_self_v_cache=n_layer_self_v_cache,
|
||||
n_layer_cross_k=n_layer_cross_k,
|
||||
n_layer_cross_v=n_layer_cross_v,
|
||||
offset=offset,
|
||||
)
|
||||
logits = logits.reshape(-1)
|
||||
mask = torch.ones(logits.shape[0], dtype=torch.int64)
|
||||
mask[self.all_language_tokens] = 0
|
||||
logits[mask] = float("-inf")
|
||||
lang_id = logits.argmax().item()
|
||||
print("detected language: ", self.id2lang[lang_id])
|
||||
return lang_id
|
||||
|
||||
|
||||
def load_tokens(filename):
|
||||
tokens = dict()
|
||||
@@ -200,7 +251,35 @@ def main():
|
||||
mel = mel.t().unsqueeze(0)
|
||||
|
||||
model = OnnxModel(encoder, decoder)
|
||||
|
||||
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
||||
|
||||
if args.language is not None:
|
||||
if model.is_multilingual is False and args.language != "en":
|
||||
print(f"This model supports only English. Given: {args.language}")
|
||||
return
|
||||
|
||||
if args.language not in model.lang2id:
|
||||
print(f"Invalid language: {args.language}")
|
||||
print(f"Valid values are: {list(model.lang2id.keys())}")
|
||||
return
|
||||
|
||||
# [sot, lang, task, notimestamps]
|
||||
model.sot_sequence[1] = model.lang2id[args.language]
|
||||
elif model.is_multilingual is True:
|
||||
print("detecting language")
|
||||
lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
|
||||
model.sot_sequence[1] = lang
|
||||
|
||||
if args.task is not None:
|
||||
if model.is_multilingual is False and args.task != "transcribe":
|
||||
print("This model supports only English. Please use --task=transcribe")
|
||||
return
|
||||
assert args.task in ["transcribe", "translate"], args.task
|
||||
|
||||
if args.task == "translate":
|
||||
model.sot_sequence[2] = model.translate
|
||||
|
||||
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
||||
|
||||
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
||||
@@ -213,6 +292,7 @@ def main():
|
||||
n_layer_cross_v=n_layer_cross_v,
|
||||
offset=offset,
|
||||
)
|
||||
offset += len(model.sot_sequence)
|
||||
# logits.shape (batch_size, tokens.shape[1], vocab_size)
|
||||
logits = logits[0, -1]
|
||||
model.suppress_tokens(logits, is_initial=True)
|
||||
@@ -225,7 +305,6 @@ def main():
|
||||
break
|
||||
results.append(max_token_id.item())
|
||||
tokens = torch.tensor([[results[-1]]])
|
||||
offset += 1
|
||||
|
||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||
tokens=tokens,
|
||||
@@ -235,6 +314,7 @@ def main():
|
||||
n_layer_cross_v=n_layer_cross_v,
|
||||
offset=offset,
|
||||
)
|
||||
offset += 1
|
||||
logits = logits[0, -1]
|
||||
model.suppress_tokens(logits, is_initial=False)
|
||||
max_token_id = logits.argmax(dim=-1)
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
} \
|
||||
\
|
||||
dst = atoi(value.get()); \
|
||||
if (dst <= 0) { \
|
||||
if (dst < 0) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
@@ -77,6 +77,24 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// read a vector of strings
|
||||
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
SplitStringToVector(value.get(), ",", false, &dst); \
|
||||
\
|
||||
if (dst.empty()) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
|
||||
src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Read a string
|
||||
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
|
||||
do { \
|
||||
|
||||
@@ -23,21 +23,227 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::string FixInvalidUtf8(const std::string &s) {
|
||||
int32_t s_size = s.size();
|
||||
|
||||
std::string ans;
|
||||
ans.reserve(s_size);
|
||||
|
||||
for (int32_t i = 0; i < s_size;) {
|
||||
uint8_t c = s[i];
|
||||
if (c < 0x80) {
|
||||
// valid
|
||||
ans.append(1, c);
|
||||
++i;
|
||||
continue;
|
||||
} else if ((c >= 0xc0) && (c < 0xe0)) {
|
||||
// beginning of two bytes
|
||||
if ((i + 1) > (s_size - 1)) {
|
||||
// no subsequent byte. invalid!
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
// valid 2-byte utf-8
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
i += 2;
|
||||
continue;
|
||||
} else if ((c >= 0xe0) && (c < 0xf0)) {
|
||||
// beginning of 3 bytes
|
||||
if ((i + 2) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
i += 3;
|
||||
continue;
|
||||
} else if ((c >= 0xf0) && (c < 0xf8)) {
|
||||
// 4 bytes
|
||||
if ((i + 3) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
i += 4;
|
||||
continue;
|
||||
} else if ((c >= 0xf8) && (c < 0xfc)) {
|
||||
// 5 bytes
|
||||
if ((i + 4) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next4 = s[i + 4];
|
||||
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
ans.append(1, next4);
|
||||
i += 5;
|
||||
continue;
|
||||
} else if ((c >= 0xfc) && (c < 0xfe)) {
|
||||
// 6 bytes
|
||||
if ((i + 5) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next4 = s[i + 4];
|
||||
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next5 = s[i + 5];
|
||||
if (!(next5 >= 0x80 && next5 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
ans.append(1, next4);
|
||||
ans.append(1, next5);
|
||||
i += 6;
|
||||
continue;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
if (!sym_table.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &s = sym_table[i];
|
||||
r.text += s;
|
||||
text += s;
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Fix the following error in offline-stream.cc
|
||||
//
|
||||
// j["text"] = text;
|
||||
|
||||
// libc++abi: terminating with uncaught exception of type
|
||||
// nlohmann::json_abi_v3_11_2::detail::type_error:
|
||||
// [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
|
||||
|
||||
#if 0
|
||||
r.text = FixInvalidUtf8(text);
|
||||
#else
|
||||
r.text = text;
|
||||
#endif
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
symbol_table_.ApplyBase64Decode();
|
||||
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
|
||||
decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
|
||||
config_.model_config.whisper, model_.get());
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only greedy_search is supported at present for whisper. Given %s",
|
||||
@@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
mel = Transpose12(model_->Allocator(), &mel);
|
||||
|
||||
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||
|
||||
auto results =
|
||||
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
|
||||
|
||||
|
||||
@@ -7,17 +7,106 @@
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
|
||||
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
|
||||
int64_t token_val = model_->SOT();
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
std::array<int64_t, 1> offset_shape{1};
|
||||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto decoder_out = model_->ForwardDecoder(
|
||||
std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||
std::move(offset));
|
||||
|
||||
cross_k = std::move(std::get<3>(decoder_out));
|
||||
cross_v = std::move(std::get<4>(decoder_out));
|
||||
|
||||
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
const auto &all_language_ids = model_->GetAllLanguageIDs();
|
||||
|
||||
int32_t lang_id = all_language_ids[0];
|
||||
float this_logit = p_logits[lang_id];
|
||||
|
||||
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
||||
int32_t id = all_language_ids[i];
|
||||
float p = p_logits[id];
|
||||
|
||||
if (p > this_logit) {
|
||||
this_logit = p;
|
||||
lang_id = id;
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||
model_->GetID2Lang().at(lang_id).c_str());
|
||||
#endif
|
||||
|
||||
return lang_id;
|
||||
}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult>
|
||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
// For multilingual models, initial_tokens contains [sot, language, task]
|
||||
// - language is English by default
|
||||
// - task is transcribe by default
|
||||
//
|
||||
// For non-multilingual models, initial_tokens contains [sot]
|
||||
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
|
||||
|
||||
if (model_->IsMultiLingual()) {
|
||||
if (!config_.language.empty()) {
|
||||
const auto &lang2id = model_->GetLang2ID();
|
||||
|
||||
if (!lang2id.count(config_.language)) {
|
||||
SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t lang_id = lang2id.at(config_.language);
|
||||
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
} else {
|
||||
int32_t lang_id = DetectLanguage(cross_k, cross_v);
|
||||
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
}
|
||||
|
||||
if (config_.task == "translate") {
|
||||
initial_tokens[2] = model_->Translate();
|
||||
} else if (config_.task != "transcribe") {
|
||||
// initial_tokens[2] is transcribe by default
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unsupported task: %s. Valid values are: transcribe, translate.",
|
||||
config_.task.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
initial_tokens.push_back(model_->NoTimeStampsToken());
|
||||
|
||||
int32_t batch_size = 1;
|
||||
std::array<int64_t, 2> token_shape{
|
||||
batch_size, static_cast<int64_t>(initial_tokens.size())};
|
||||
@@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
auto decoder_out = model_->ForwardDecoder(
|
||||
std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||
std::move(offset));
|
||||
|
||||
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) =
|
||||
initial_tokens.size();
|
||||
|
||||
const auto &logits = std::get<0>(decoder_out);
|
||||
const float *p_logits = logits.GetTensorData<float>();
|
||||
|
||||
@@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
|
||||
model_->Allocator(), token_shape.data(), token_shape.size());
|
||||
|
||||
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
|
||||
p_tokens[0] = max_token_id;
|
||||
|
||||
int64_t *p_offset =
|
||||
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
|
||||
|
||||
if (i == 0) {
|
||||
*p_offset = initial_tokens.size();
|
||||
} else {
|
||||
*p_offset += 1;
|
||||
}
|
||||
|
||||
decoder_out = model_->ForwardDecoder(std::move(tokens),
|
||||
std::move(std::get<1>(decoder_out)),
|
||||
std::move(std::get<2>(decoder_out)),
|
||||
@@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
std::move(std::get<4>(decoder_out)),
|
||||
std::move(std::get<5>(decoder_out)));
|
||||
|
||||
int64_t *p_offset =
|
||||
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
|
||||
|
||||
*p_offset += 1;
|
||||
|
||||
const auto &logits = std::get<0>(decoder_out);
|
||||
const float *p_logits = logits.GetTensorData<float>();
|
||||
|
||||
@@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult> ans(1);
|
||||
|
||||
ans[0].tokens = std::move(predicted_tokens);
|
||||
|
||||
return ans;
|
||||
|
||||
@@ -8,19 +8,25 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
||||
public:
|
||||
explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)
|
||||
: model_(model) {}
|
||||
OfflineWhisperGreedySearchDecoder(const OfflineWhisperModelConfig &config,
|
||||
OfflineWhisperModel *model)
|
||||
: config_(config), model_(model) {}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) override;
|
||||
|
||||
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v) const; // NOLINT
|
||||
|
||||
private:
|
||||
OfflineWhisperModelConfig config_;
|
||||
OfflineWhisperModel *model_; // not owned
|
||||
};
|
||||
|
||||
|
||||
@@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("whisper-decoder", &decoder,
|
||||
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
|
||||
"medium.en-decoder.onnx.");
|
||||
|
||||
po->Register(
|
||||
"whisper-language", &language,
|
||||
"The spoke language in the input audio file. Example values: "
|
||||
"en, de, fr, zh, jp. If it is not given for a multilingual model, we will"
|
||||
" infer the language from the input audio file. "
|
||||
"Please refer to "
|
||||
"https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10"
|
||||
" for valid values. Note that for non-multilingual models, it supports "
|
||||
"only 'en'");
|
||||
|
||||
po->Register("whisper-task", &task,
|
||||
"Valid values: transcribe, translate. "
|
||||
"Note that for non-multilingual models, it supports "
|
||||
"only 'transcribe'");
|
||||
}
|
||||
|
||||
bool OfflineWhisperModelConfig::Validate() const {
|
||||
@@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (task != "translate" && task != "transcribe") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"--whisper-task supports only translate and transcribe. Given: %s",
|
||||
task.c_str());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const {
|
||||
|
||||
os << "OfflineWhisperModelConfig(";
|
||||
os << "encoder=\"" << encoder << "\", ";
|
||||
os << "decoder=\"" << decoder << "\")";
|
||||
os << "decoder=\"" << decoder << "\", ";
|
||||
os << "language=\"" << language << "\", ";
|
||||
os << "task=\"" << task << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig {
|
||||
std::string encoder;
|
||||
std::string decoder;
|
||||
|
||||
// Available languages can be found at
|
||||
// https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||
//
|
||||
// Note: For non-multilingual models, it supports only "en"
|
||||
//
|
||||
// If empty, we will infer it from the input audio file when
|
||||
// the model is multilingual.
|
||||
std::string language;
|
||||
|
||||
// Valid values are transcribe and translate
|
||||
//
|
||||
// Note: For non-multilingual models, it supports only "transcribe"
|
||||
std::string task = "transcribe";
|
||||
|
||||
OfflineWhisperModelConfig() = default;
|
||||
OfflineWhisperModelConfig(const std::string &encoder,
|
||||
const std::string &decoder)
|
||||
: encoder(encoder), decoder(decoder) {}
|
||||
const std::string &decoder,
|
||||
const std::string &language,
|
||||
const std::string &task)
|
||||
: encoder(encoder), decoder(decoder), language(language), task(task) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl {
|
||||
|
||||
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
|
||||
|
||||
const std::vector<int32_t> &GetAllLanguageIDs() const {
|
||||
return all_language_tokens_;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, int32_t> &GetLang2ID() const {
|
||||
return lang2id_;
|
||||
}
|
||||
|
||||
const std::unordered_map<int32_t, std::string> &GetID2Lang() const {
|
||||
return id2lang_;
|
||||
}
|
||||
|
||||
int32_t NoTimeStampsToken() const { return no_timestamps_; }
|
||||
|
||||
int32_t EOT() const { return eot_; }
|
||||
|
||||
int32_t SOT() const { return sot_; }
|
||||
|
||||
int32_t TextCtx() const { return n_text_ctx_; }
|
||||
|
||||
int32_t VocabSize() const { return n_vocab_; }
|
||||
|
||||
int32_t Translate() const { return translate_; }
|
||||
|
||||
bool IsMultiLingual() const { return is_multilingual_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
@@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl {
|
||||
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
|
||||
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
|
||||
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
|
||||
SHERPA_ONNX_READ_META_DATA(n_vocab_, "n_vocab");
|
||||
SHERPA_ONNX_READ_META_DATA(sot_, "sot");
|
||||
SHERPA_ONNX_READ_META_DATA(eot_, "eot");
|
||||
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
|
||||
SHERPA_ONNX_READ_META_DATA(translate_, "translate");
|
||||
SHERPA_ONNX_READ_META_DATA(transcribe_, "transcribe");
|
||||
SHERPA_ONNX_READ_META_DATA(is_multilingual_, "is_multilingual");
|
||||
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
|
||||
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
|
||||
|
||||
if (is_multilingual_) {
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(all_language_tokens_,
|
||||
"all_language_tokens");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC_STRING(all_language_codes_,
|
||||
"all_language_codes");
|
||||
if (all_language_tokens_.size() != all_language_codes_.size()) {
|
||||
SHERPA_ONNX_LOGE("# lang_id: %d != # lang_code: %d",
|
||||
static_cast<int32_t>(all_language_tokens_.size()),
|
||||
static_cast<int32_t>(all_language_codes_.size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(all_language_tokens_.size()); ++i) {
|
||||
lang2id_[all_language_codes_[i]] = all_language_tokens_[i];
|
||||
id2lang_[all_language_tokens_[i]] = all_language_codes_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||
@@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl {
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
|
||||
std::vector<int32_t> all_language_tokens_;
|
||||
std::vector<std::string> all_language_codes_;
|
||||
std::unordered_map<std::string, int32_t> lang2id_;
|
||||
std::unordered_map<int32_t, std::string> id2lang_;
|
||||
|
||||
// model meta data
|
||||
int32_t n_text_layer_;
|
||||
int32_t n_text_ctx_;
|
||||
int32_t n_text_state_;
|
||||
int32_t n_vocab_;
|
||||
int32_t sot_;
|
||||
int32_t eot_;
|
||||
int32_t blank_;
|
||||
int32_t translate_;
|
||||
int32_t transcribe_;
|
||||
int32_t no_timestamps_;
|
||||
int32_t no_speech_;
|
||||
int32_t is_multilingual_;
|
||||
std::vector<int64_t> sot_sequence_;
|
||||
};
|
||||
|
||||
@@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
||||
OfflineWhisperModel::~OfflineWhisperModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
|
||||
Ort::Value features) {
|
||||
Ort::Value features) const {
|
||||
return impl_->ForwardEncoder(std::move(features));
|
||||
}
|
||||
|
||||
@@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
|
||||
Ort::Value n_layer_self_v_cache,
|
||||
Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v,
|
||||
Ort::Value offset) {
|
||||
Ort::Value offset) const {
|
||||
return impl_->ForwardDecoder(
|
||||
std::move(tokens), std::move(n_layer_self_k_cache),
|
||||
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
|
||||
std::move(n_layer_cross_v), std::move(offset));
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() {
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
|
||||
const {
|
||||
return impl_->GetInitialSelfKVCache();
|
||||
}
|
||||
|
||||
@@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
|
||||
return impl_->GetInitialTokens();
|
||||
}
|
||||
|
||||
const std::vector<int32_t> &OfflineWhisperModel::GetAllLanguageIDs() const {
|
||||
return impl_->GetAllLanguageIDs();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, int32_t>
|
||||
&OfflineWhisperModel::GetLang2ID() const {
|
||||
return impl_->GetLang2ID();
|
||||
}
|
||||
|
||||
const std::unordered_map<int32_t, std::string>
|
||||
&OfflineWhisperModel::GetID2Lang() const {
|
||||
return impl_->GetID2Lang();
|
||||
}
|
||||
|
||||
int32_t OfflineWhisperModel::NoTimeStampsToken() const {
|
||||
return impl_->NoTimeStampsToken();
|
||||
}
|
||||
|
||||
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
|
||||
|
||||
int32_t OfflineWhisperModel::SOT() const { return impl_->SOT(); }
|
||||
|
||||
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
|
||||
|
||||
int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
|
||||
|
||||
int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
|
||||
|
||||
bool OfflineWhisperModel::IsMultiLingual() const {
|
||||
return impl_->IsMultiLingual();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@@ -30,7 +32,7 @@ class OfflineWhisperModel {
|
||||
* - n_layer_cross_v: A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state)
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) const;
|
||||
|
||||
/** Run the decoder model.
|
||||
*
|
||||
@@ -58,7 +60,9 @@ class OfflineWhisperModel {
|
||||
Ort::Value>
|
||||
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset);
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset) const;
|
||||
|
||||
int32_t DetectLanguage() const;
|
||||
|
||||
/** Return the initial self kv cache in a pair
|
||||
* - n_layer_self_k_cache A 4-D tensor of shape
|
||||
@@ -66,14 +70,23 @@ class OfflineWhisperModel {
|
||||
* - n_layer_self_v_cache A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
|
||||
const std::vector<int64_t> &GetInitialTokens() const;
|
||||
const std::vector<int32_t> &GetAllLanguageIDs() const;
|
||||
const std::unordered_map<std::string, int32_t> &GetLang2ID() const;
|
||||
const std::unordered_map<int32_t, std::string> &GetID2Lang() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
int32_t NoTimeStampsToken() const;
|
||||
int32_t EOT() const;
|
||||
int32_t SOT() const;
|
||||
int32_t TextCtx() const;
|
||||
int32_t VocabSize() const;
|
||||
int32_t Translate() const;
|
||||
bool IsMultiLingual() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
|
||||
@@ -14,10 +14,14 @@ namespace sherpa_onnx {
|
||||
void PybindOfflineWhisperModelConfig(py::module *m) {
|
||||
using PyClass = OfflineWhisperModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &>(),
|
||||
py::arg("encoder"), py::arg("decoder"))
|
||||
.def(py::init<const std::string &, const std::string &,
|
||||
const std::string &, const std::string &>(),
|
||||
py::arg("encoder"), py::arg("decoder"), py::arg("language"),
|
||||
py::arg("task"))
|
||||
.def_readwrite("encoder", &PyClass::encoder)
|
||||
.def_readwrite("decoder", &PyClass::decoder)
|
||||
.def_readwrite("language", &PyClass::language)
|
||||
.def_readwrite("task", &PyClass::task)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -244,6 +244,8 @@ class OfflineRecognizer(object):
|
||||
encoder: str,
|
||||
decoder: str,
|
||||
tokens: str,
|
||||
language: str = "en",
|
||||
task: str = "transcribe",
|
||||
num_threads: int = 1,
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
@@ -268,6 +270,14 @@ class OfflineRecognizer(object):
|
||||
|
||||
symbol integer_id
|
||||
|
||||
language:
|
||||
The spoken language in the audio file. Example values: en, de, zh,
|
||||
jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||
for all possible values. Note that for non-multilingual models, the
|
||||
only valid value is 'en'.
|
||||
task:
|
||||
Valid values are: transcribe, translate. Note that for
|
||||
non-multilingual models, the only valid value is 'transcribe'.
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
decoding_method:
|
||||
@@ -279,7 +289,12 @@ class OfflineRecognizer(object):
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder),
|
||||
whisper=OfflineWhisperModelConfig(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
language=language,
|
||||
task=task,
|
||||
),
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
|
||||
Reference in New Issue
Block a user