Support multilingual whisper models (#274)
This commit is contained in:
3
.github/workflows/build-wheels-macos.yaml
vendored
3
.github/workflows/build-wheels-macos.yaml
vendored
@@ -36,6 +36,9 @@ jobs:
|
|||||||
CIBW_ARCHS: "universal2"
|
CIBW_ARCHS: "universal2"
|
||||||
CIBW_BUILD_VERBOSITY: 3
|
CIBW_BUILD_VERBOSITY: 3
|
||||||
|
|
||||||
|
# Don't repair macOS wheels
|
||||||
|
CIBW_REPAIR_WHEEL_COMMAND_MACOS: ""
|
||||||
|
|
||||||
- name: Display wheels
|
- name: Display wheels
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos-latest]
|
os: [macos-latest]
|
||||||
model: ["tiny.en", "base.en", "small.en", "medium.en"]
|
model: ["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.7.6")
|
set(SHERPA_ONNX_VERSION "1.7.7")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ module non-streaming-decode-files
|
|||||||
go 1.12
|
go 1.12
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1
|
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/youpy/go-wav v0.3.2
|
github.com/youpy/go-wav v0.3.2
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
|
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
|
||||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
|
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
|
||||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
|
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
|
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
|
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
|
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
|||||||
@@ -4,6 +4,6 @@ go 1.12
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
|
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
|
||||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1
|
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
|
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
|
||||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
|
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
|
||||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
|
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
|
||||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
|
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
|
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
|
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
|
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ module streaming-decode-files
|
|||||||
go 1.12
|
go 1.12
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1
|
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/youpy/go-wav v0.3.2
|
github.com/youpy/go-wav v0.3.2
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
|
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
|
||||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1 h1:kVAAowsJCJxZzRD++0xzUsJwDAx1FZMgiDjI4NSAWco=
|
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ=
|
||||||
github.com/k2-fsa/sherpa-onnx-go v1.5.5-alpha.1/go.mod h1:egcXRfYdJvNbw1vMYcvE3dHUPXXP+s4TRm1VRFECZNw=
|
github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5 h1:A7N2uio/qsrtwMO3D2KloLEBlzLsYMRgcKx9jVeq1xk=
|
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-linux v1.5.5/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5 h1:S8o7rJMXuzf6Fzi7MXKlBPTnv2ic5a5KMn3d9KJ45gQ=
|
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-macos v1.5.5/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5 h1:7+RyRugpibpA4TvRrvU885qiSkEzntxMo7Aq+xzV3F0=
|
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4=
|
||||||
github.com/k2-fsa/sherpa-onnx-go-windows v1.5.5/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
|||||||
@@ -11,10 +11,12 @@ fun main() {
|
|||||||
// please refer to
|
// please refer to
|
||||||
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||||
// to dowload pre-trained models
|
// to dowload pre-trained models
|
||||||
var modelConfig = OnlineTransducerModelConfig(
|
var modelConfig = OnlineModelConfig(
|
||||||
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
|
transducer = OnlineTransducerModelConfig(
|
||||||
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
|
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
|
||||||
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
|
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
|
||||||
|
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
|
||||||
|
),
|
||||||
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
|
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
|
||||||
numThreads = 1,
|
numThreads = 1,
|
||||||
debug = false,
|
debug = false,
|
||||||
@@ -41,19 +43,19 @@ fun main() {
|
|||||||
var objArray = WaveReader.readWaveFromFile(
|
var objArray = WaveReader.readWaveFromFile(
|
||||||
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
|
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
|
||||||
)
|
)
|
||||||
var samples : FloatArray = objArray[0] as FloatArray
|
var samples: FloatArray = objArray[0] as FloatArray
|
||||||
var sampleRate : Int = objArray[1] as Int
|
var sampleRate: Int = objArray[1] as Int
|
||||||
|
|
||||||
model.acceptWaveform(samples, sampleRate=sampleRate)
|
model.acceptWaveform(samples, sampleRate = sampleRate)
|
||||||
while (model.isReady()) {
|
while (model.isReady()) {
|
||||||
model.decode()
|
model.decode()
|
||||||
}
|
}
|
||||||
|
|
||||||
var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
|
var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
|
||||||
model.acceptWaveform(tailPaddings, sampleRate=sampleRate)
|
model.acceptWaveform(tailPaddings, sampleRate = sampleRate)
|
||||||
model.inputFinished()
|
model.inputFinished()
|
||||||
while (model.isReady()) {
|
while (model.isReady()) {
|
||||||
model.decode()
|
model.decode()
|
||||||
}
|
}
|
||||||
|
|
||||||
println("results: ${model.text}")
|
println("results: ${model.text}")
|
||||||
|
|||||||
@@ -234,6 +234,28 @@ def add_whisper_model_args(parser: argparse.ArgumentParser):
|
|||||||
help="Path to whisper decoder model",
|
help="Path to whisper decoder model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-language",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="""It specifies the spoken language in the input audio file.
|
||||||
|
Example values: en, fr, de, zh, jp.
|
||||||
|
Available languages for multilingual models can be found at
|
||||||
|
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||||
|
If not specified, we infer the language from the input audio file.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-task",
|
||||||
|
default="transcribe",
|
||||||
|
choices=["transcribe", "translate"],
|
||||||
|
type=str,
|
||||||
|
help="""For multilingual models, if you specify translate, the output
|
||||||
|
will be in English.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_model_args(parser: argparse.ArgumentParser):
|
def add_model_args(parser: argparse.ArgumentParser):
|
||||||
add_transducer_model_args(parser)
|
add_transducer_model_args(parser)
|
||||||
@@ -813,6 +835,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
tokens=args.tokens,
|
tokens=args.tokens,
|
||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
|
language=args.whisper_language,
|
||||||
|
task=args.whisper_task,
|
||||||
)
|
)
|
||||||
elif args.tdnn_model:
|
elif args.tdnn_model:
|
||||||
assert_file_exists(args.tdnn_model)
|
assert_file_exists(args.tdnn_model)
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ python3 ./python-api-examples/offline-decode-files.py \
|
|||||||
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||||
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||||
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||||
|
--whisper-task=transcribe \
|
||||||
--num-threads=1 \
|
--num-threads=1 \
|
||||||
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||||
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||||
@@ -200,6 +201,28 @@ def get_args():
|
|||||||
help="Path to whisper decoder model",
|
help="Path to whisper decoder model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-language",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="""It specifies the spoken language in the input audio file.
|
||||||
|
Example values: en, fr, de, zh, jp.
|
||||||
|
Available languages for multilingual models can be found at
|
||||||
|
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||||
|
If not specified, we infer the language from the input audio file.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-task",
|
||||||
|
default="transcribe",
|
||||||
|
choices=["transcribe", "translate"],
|
||||||
|
type=str,
|
||||||
|
help="""For multilingual models, if you specify translate, the output
|
||||||
|
will be in English.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -371,10 +394,10 @@ def main():
|
|||||||
decoder=args.whisper_decoder,
|
decoder=args.whisper_decoder,
|
||||||
tokens=args.tokens,
|
tokens=args.tokens,
|
||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
sample_rate=args.sample_rate,
|
|
||||||
feature_dim=args.feature_dim,
|
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
|
language=args.whisper_language,
|
||||||
|
task=args.whisper_task,
|
||||||
)
|
)
|
||||||
elif args.tdnn_model:
|
elif args.tdnn_model:
|
||||||
assert_file_exists(args.tdnn_model)
|
assert_file_exists(args.tdnn_model)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ for making the onnx export script public.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
@@ -250,6 +251,7 @@ def main():
|
|||||||
# write tokens
|
# write tokens
|
||||||
|
|
||||||
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
print(model.dims)
|
print(model.dims)
|
||||||
audio = torch.rand(16000 * 2)
|
audio = torch.rand(16000 * 2)
|
||||||
@@ -306,8 +308,12 @@ def main():
|
|||||||
"n_text_head": model.dims.n_text_head,
|
"n_text_head": model.dims.n_text_head,
|
||||||
"n_text_layer": model.dims.n_text_layer,
|
"n_text_layer": model.dims.n_text_layer,
|
||||||
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
|
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
|
||||||
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
|
"all_language_tokens": ",".join(
|
||||||
"all_language_codes": ",".join(tokenizer.all_language_codes),
|
list(map(str, tokenizer.all_language_tokens))
|
||||||
|
), # a list of ids
|
||||||
|
"all_language_codes": ",".join(
|
||||||
|
tokenizer.all_language_codes
|
||||||
|
), # e.g., en, de, zh, fr
|
||||||
"sot": tokenizer.sot,
|
"sot": tokenizer.sot,
|
||||||
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
|
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
|
||||||
"eot": tokenizer.eot,
|
"eot": tokenizer.eot,
|
||||||
@@ -413,6 +419,9 @@ def main():
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if 'large' in args.model:
|
||||||
|
# it causes errors for large models, so skip it.
|
||||||
|
return
|
||||||
# Generate int8 quantization models
|
# Generate int8 quantization models
|
||||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,24 @@ def get_args():
|
|||||||
help="Path to the tokens",
|
help="Path to the tokens",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--language",
|
||||||
|
type=str,
|
||||||
|
help="""The actual spoken language in the audio.
|
||||||
|
Example values, en, de, zh, jp, fr.
|
||||||
|
If None, we will detect the language using the first 30s of the
|
||||||
|
input audio
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
choices=["transcribe", "translate"],
|
||||||
|
type=str,
|
||||||
|
default="transcribe",
|
||||||
|
help="Valid values are: transcribe, translate",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"sound_file",
|
"sound_file",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -74,12 +92,22 @@ class OnnxModel:
|
|||||||
self.sot = int(meta["sot"])
|
self.sot = int(meta["sot"])
|
||||||
self.eot = int(meta["eot"])
|
self.eot = int(meta["eot"])
|
||||||
self.translate = int(meta["translate"])
|
self.translate = int(meta["translate"])
|
||||||
|
self.transcribe = int(meta["transcribe"])
|
||||||
self.no_timestamps = int(meta["no_timestamps"])
|
self.no_timestamps = int(meta["no_timestamps"])
|
||||||
self.no_speech = int(meta["no_speech"])
|
self.no_speech = int(meta["no_speech"])
|
||||||
self.blank = int(meta["blank_id"])
|
self.blank = int(meta["blank_id"])
|
||||||
|
|
||||||
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
||||||
|
|
||||||
|
self.sot_sequence.append(self.no_timestamps)
|
||||||
|
|
||||||
|
self.all_language_tokens = list(
|
||||||
|
map(int, meta["all_language_tokens"].split(","))
|
||||||
|
)
|
||||||
|
self.all_language_codes = meta["all_language_codes"].split(",")
|
||||||
|
self.lang2id = dict(zip(self.all_language_codes, self.all_language_tokens))
|
||||||
|
self.id2lang = dict(zip(self.all_language_tokens, self.all_language_codes))
|
||||||
|
|
||||||
self.is_multilingual = int(meta["is_multilingual"]) == 1
|
self.is_multilingual = int(meta["is_multilingual"]) == 1
|
||||||
|
|
||||||
def init_decoder(self, decoder: str):
|
def init_decoder(self, decoder: str):
|
||||||
@@ -164,6 +192,29 @@ class OnnxModel:
|
|||||||
# logits is changed in-place
|
# logits is changed in-place
|
||||||
logits[self.translate] = float("-inf")
|
logits[self.translate] = float("-inf")
|
||||||
|
|
||||||
|
def detect_language(
|
||||||
|
self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor
|
||||||
|
) -> int:
|
||||||
|
tokens = torch.tensor([[self.sot]], dtype=torch.int64)
|
||||||
|
offset = torch.zeros(1, dtype=torch.int64)
|
||||||
|
n_layer_self_k_cache, n_layer_self_v_cache = self.get_self_cache()
|
||||||
|
|
||||||
|
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.run_decoder(
|
||||||
|
tokens=tokens,
|
||||||
|
n_layer_self_k_cache=n_layer_self_k_cache,
|
||||||
|
n_layer_self_v_cache=n_layer_self_v_cache,
|
||||||
|
n_layer_cross_k=n_layer_cross_k,
|
||||||
|
n_layer_cross_v=n_layer_cross_v,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
logits = logits.reshape(-1)
|
||||||
|
mask = torch.ones(logits.shape[0], dtype=torch.int64)
|
||||||
|
mask[self.all_language_tokens] = 0
|
||||||
|
logits[mask] = float("-inf")
|
||||||
|
lang_id = logits.argmax().item()
|
||||||
|
print("detected language: ", self.id2lang[lang_id])
|
||||||
|
return lang_id
|
||||||
|
|
||||||
|
|
||||||
def load_tokens(filename):
|
def load_tokens(filename):
|
||||||
tokens = dict()
|
tokens = dict()
|
||||||
@@ -200,7 +251,35 @@ def main():
|
|||||||
mel = mel.t().unsqueeze(0)
|
mel = mel.t().unsqueeze(0)
|
||||||
|
|
||||||
model = OnnxModel(encoder, decoder)
|
model = OnnxModel(encoder, decoder)
|
||||||
|
|
||||||
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
||||||
|
|
||||||
|
if args.language is not None:
|
||||||
|
if model.is_multilingual is False and args.language != "en":
|
||||||
|
print(f"This model supports only English. Given: {args.language}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.language not in model.lang2id:
|
||||||
|
print(f"Invalid language: {args.language}")
|
||||||
|
print(f"Valid values are: {list(model.lang2id.keys())}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# [sot, lang, task, notimestamps]
|
||||||
|
model.sot_sequence[1] = model.lang2id[args.language]
|
||||||
|
elif model.is_multilingual is True:
|
||||||
|
print("detecting language")
|
||||||
|
lang = model.detect_language(n_layer_cross_k, n_layer_cross_v)
|
||||||
|
model.sot_sequence[1] = lang
|
||||||
|
|
||||||
|
if args.task is not None:
|
||||||
|
if model.is_multilingual is False and args.task != "transcribe":
|
||||||
|
print("This model supports only English. Please use --task=transcribe")
|
||||||
|
return
|
||||||
|
assert args.task in ["transcribe", "translate"], args.task
|
||||||
|
|
||||||
|
if args.task == "translate":
|
||||||
|
model.sot_sequence[2] = model.translate
|
||||||
|
|
||||||
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
||||||
|
|
||||||
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
||||||
@@ -213,6 +292,7 @@ def main():
|
|||||||
n_layer_cross_v=n_layer_cross_v,
|
n_layer_cross_v=n_layer_cross_v,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
)
|
)
|
||||||
|
offset += len(model.sot_sequence)
|
||||||
# logits.shape (batch_size, tokens.shape[1], vocab_size)
|
# logits.shape (batch_size, tokens.shape[1], vocab_size)
|
||||||
logits = logits[0, -1]
|
logits = logits[0, -1]
|
||||||
model.suppress_tokens(logits, is_initial=True)
|
model.suppress_tokens(logits, is_initial=True)
|
||||||
@@ -225,7 +305,6 @@ def main():
|
|||||||
break
|
break
|
||||||
results.append(max_token_id.item())
|
results.append(max_token_id.item())
|
||||||
tokens = torch.tensor([[results[-1]]])
|
tokens = torch.tensor([[results[-1]]])
|
||||||
offset += 1
|
|
||||||
|
|
||||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
@@ -235,6 +314,7 @@ def main():
|
|||||||
n_layer_cross_v=n_layer_cross_v,
|
n_layer_cross_v=n_layer_cross_v,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
)
|
)
|
||||||
|
offset += 1
|
||||||
logits = logits[0, -1]
|
logits = logits[0, -1]
|
||||||
model.suppress_tokens(logits, is_initial=False)
|
model.suppress_tokens(logits, is_initial=False)
|
||||||
max_token_id = logits.argmax(dim=-1)
|
max_token_id = logits.argmax(dim=-1)
|
||||||
|
|||||||
@@ -37,7 +37,7 @@
|
|||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
dst = atoi(value.get()); \
|
dst = atoi(value.get()); \
|
||||||
if (dst <= 0) { \
|
if (dst < 0) { \
|
||||||
SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
|
SHERPA_ONNX_LOGE("Invalid value %d for %s", dst, src_key); \
|
||||||
exit(-1); \
|
exit(-1); \
|
||||||
} \
|
} \
|
||||||
@@ -77,6 +77,24 @@
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
// read a vector of strings
|
||||||
|
#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \
|
||||||
|
do { \
|
||||||
|
auto value = \
|
||||||
|
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||||
|
if (!value) { \
|
||||||
|
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
|
||||||
|
exit(-1); \
|
||||||
|
} \
|
||||||
|
SplitStringToVector(value.get(), ",", false, &dst); \
|
||||||
|
\
|
||||||
|
if (dst.empty()) { \
|
||||||
|
SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
|
||||||
|
src_key); \
|
||||||
|
exit(-1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
// Read a string
|
// Read a string
|
||||||
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
|
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
|
||||||
do { \
|
do { \
|
||||||
|
|||||||
@@ -23,21 +23,227 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static std::string FixInvalidUtf8(const std::string &s) {
|
||||||
|
int32_t s_size = s.size();
|
||||||
|
|
||||||
|
std::string ans;
|
||||||
|
ans.reserve(s_size);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < s_size;) {
|
||||||
|
uint8_t c = s[i];
|
||||||
|
if (c < 0x80) {
|
||||||
|
// valid
|
||||||
|
ans.append(1, c);
|
||||||
|
++i;
|
||||||
|
continue;
|
||||||
|
} else if ((c >= 0xc0) && (c < 0xe0)) {
|
||||||
|
// beginning of two bytes
|
||||||
|
if ((i + 1) > (s_size - 1)) {
|
||||||
|
// no subsequent byte. invalid!
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
uint8_t next = s[i + 1];
|
||||||
|
if (!(next >= 0x80 && next < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// valid 2-byte utf-8
|
||||||
|
ans.append(1, c);
|
||||||
|
ans.append(1, next);
|
||||||
|
i += 2;
|
||||||
|
continue;
|
||||||
|
} else if ((c >= 0xe0) && (c < 0xf0)) {
|
||||||
|
// beginning of 3 bytes
|
||||||
|
if ((i + 2) > (s_size - 1)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next = s[i + 1];
|
||||||
|
if (!(next >= 0x80 && next < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next2 = s[i + 2];
|
||||||
|
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
ans.append(1, c);
|
||||||
|
ans.append(1, next);
|
||||||
|
ans.append(1, next2);
|
||||||
|
i += 3;
|
||||||
|
continue;
|
||||||
|
} else if ((c >= 0xf0) && (c < 0xf8)) {
|
||||||
|
// 4 bytes
|
||||||
|
if ((i + 3) > (s_size - 1)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next = s[i + 1];
|
||||||
|
if (!(next >= 0x80 && next < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next2 = s[i + 2];
|
||||||
|
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next3 = s[i + 3];
|
||||||
|
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ans.append(1, c);
|
||||||
|
ans.append(1, next);
|
||||||
|
ans.append(1, next2);
|
||||||
|
ans.append(1, next3);
|
||||||
|
i += 4;
|
||||||
|
continue;
|
||||||
|
} else if ((c >= 0xf8) && (c < 0xfc)) {
|
||||||
|
// 5 bytes
|
||||||
|
if ((i + 4) > (s_size - 1)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next = s[i + 1];
|
||||||
|
if (!(next >= 0x80 && next < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next2 = s[i + 2];
|
||||||
|
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next3 = s[i + 3];
|
||||||
|
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next4 = s[i + 4];
|
||||||
|
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ans.append(1, c);
|
||||||
|
ans.append(1, next);
|
||||||
|
ans.append(1, next2);
|
||||||
|
ans.append(1, next3);
|
||||||
|
ans.append(1, next4);
|
||||||
|
i += 5;
|
||||||
|
continue;
|
||||||
|
} else if ((c >= 0xfc) && (c < 0xfe)) {
|
||||||
|
// 6 bytes
|
||||||
|
if ((i + 5) > (s_size - 1)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next = s[i + 1];
|
||||||
|
if (!(next >= 0x80 && next < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next2 = s[i + 2];
|
||||||
|
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next3 = s[i + 3];
|
||||||
|
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next4 = s[i + 4];
|
||||||
|
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t next5 = s[i + 5];
|
||||||
|
if (!(next5 >= 0x80 && next5 < 0xc0)) {
|
||||||
|
// invalid
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ans.append(1, c);
|
||||||
|
ans.append(1, next);
|
||||||
|
ans.append(1, next2);
|
||||||
|
ans.append(1, next3);
|
||||||
|
ans.append(1, next4);
|
||||||
|
ans.append(1, next5);
|
||||||
|
i += 6;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||||
const SymbolTable &sym_table) {
|
const SymbolTable &sym_table) {
|
||||||
OfflineRecognitionResult r;
|
OfflineRecognitionResult r;
|
||||||
r.tokens.reserve(src.tokens.size());
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
|
||||||
|
std::string text;
|
||||||
for (auto i : src.tokens) {
|
for (auto i : src.tokens) {
|
||||||
if (!sym_table.contains(i)) {
|
if (!sym_table.contains(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto &s = sym_table[i];
|
const auto &s = sym_table[i];
|
||||||
r.text += s;
|
text += s;
|
||||||
r.tokens.push_back(s);
|
r.tokens.push_back(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(fangjun): Fix the following error in offline-stream.cc
|
||||||
|
//
|
||||||
|
// j["text"] = text;
|
||||||
|
|
||||||
|
// libc++abi: terminating with uncaught exception of type
|
||||||
|
// nlohmann::json_abi_v3_11_2::detail::type_error:
|
||||||
|
// [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
r.text = FixInvalidUtf8(text);
|
||||||
|
#else
|
||||||
|
r.text = text;
|
||||||
|
#endif
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,8 +257,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
symbol_table_.ApplyBase64Decode();
|
symbol_table_.ApplyBase64Decode();
|
||||||
|
|
||||||
if (config.decoding_method == "greedy_search") {
|
if (config.decoding_method == "greedy_search") {
|
||||||
decoder_ =
|
decoder_ = std::make_unique<OfflineWhisperGreedySearchDecoder>(
|
||||||
std::make_unique<OfflineWhisperGreedySearchDecoder>(model_.get());
|
config_.model_config.whisper, model_.get());
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"Only greedy_search is supported at present for whisper. Given %s",
|
"Only greedy_search is supported at present for whisper. Given %s",
|
||||||
@@ -101,6 +307,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
mel = Transpose12(model_->Allocator(), &mel);
|
mel = Transpose12(model_->Allocator(), &mel);
|
||||||
|
|
||||||
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||||
|
|
||||||
auto results =
|
auto results =
|
||||||
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
|
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
|
||||||
|
|
||||||
|
|||||||
@@ -7,17 +7,106 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
|
||||||
|
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
|
||||||
|
int64_t token_val = model_->SOT();
|
||||||
|
std::array<int64_t, 2> token_shape{1, 1};
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
||||||
|
|
||||||
|
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||||
|
|
||||||
|
std::array<int64_t, 1> offset_shape{1};
|
||||||
|
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||||
|
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||||
|
|
||||||
|
auto decoder_out = model_->ForwardDecoder(
|
||||||
|
std::move(tokens), std::move(self_kv_cache.first),
|
||||||
|
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||||
|
std::move(offset));
|
||||||
|
|
||||||
|
cross_k = std::move(std::get<3>(decoder_out));
|
||||||
|
cross_v = std::move(std::get<4>(decoder_out));
|
||||||
|
|
||||||
|
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
||||||
|
int32_t vocab_size = model_->VocabSize();
|
||||||
|
const auto &all_language_ids = model_->GetAllLanguageIDs();
|
||||||
|
|
||||||
|
int32_t lang_id = all_language_ids[0];
|
||||||
|
float this_logit = p_logits[lang_id];
|
||||||
|
|
||||||
|
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
||||||
|
int32_t id = all_language_ids[i];
|
||||||
|
float p = p_logits[id];
|
||||||
|
|
||||||
|
if (p > this_logit) {
|
||||||
|
this_logit = p;
|
||||||
|
lang_id = id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#if 1
|
||||||
|
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||||
|
model_->GetID2Lang().at(lang_id).c_str());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return lang_id;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<OfflineWhisperDecoderResult>
|
std::vector<OfflineWhisperDecoderResult>
|
||||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||||
Ort::Value cross_v) {
|
Ort::Value cross_v) {
|
||||||
auto memory_info =
|
auto memory_info =
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
// For multilingual models, initial_tokens contains [sot, language, task]
|
||||||
|
// - language is English by default
|
||||||
|
// - task is transcribe by default
|
||||||
|
//
|
||||||
|
// For non-multilingual models, initial_tokens contains [sot]
|
||||||
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
|
std::vector<int64_t> initial_tokens = model_->GetInitialTokens();
|
||||||
|
|
||||||
|
if (model_->IsMultiLingual()) {
|
||||||
|
if (!config_.language.empty()) {
|
||||||
|
const auto &lang2id = model_->GetLang2ID();
|
||||||
|
|
||||||
|
if (!lang2id.count(config_.language)) {
|
||||||
|
SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t lang_id = lang2id.at(config_.language);
|
||||||
|
|
||||||
|
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||||
|
initial_tokens[1] = lang_id;
|
||||||
|
} else {
|
||||||
|
int32_t lang_id = DetectLanguage(cross_k, cross_v);
|
||||||
|
|
||||||
|
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||||
|
initial_tokens[1] = lang_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config_.task == "translate") {
|
||||||
|
initial_tokens[2] = model_->Translate();
|
||||||
|
} else if (config_.task != "transcribe") {
|
||||||
|
// initial_tokens[2] is transcribe by default
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Unsupported task: %s. Valid values are: transcribe, translate.",
|
||||||
|
config_.task.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
initial_tokens.push_back(model_->NoTimeStampsToken());
|
||||||
|
|
||||||
int32_t batch_size = 1;
|
int32_t batch_size = 1;
|
||||||
std::array<int64_t, 2> token_shape{
|
std::array<int64_t, 2> token_shape{
|
||||||
batch_size, static_cast<int64_t>(initial_tokens.size())};
|
batch_size, static_cast<int64_t>(initial_tokens.size())};
|
||||||
@@ -31,11 +120,16 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
|||||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||||
|
|
||||||
|
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||||
|
|
||||||
auto decoder_out = model_->ForwardDecoder(
|
auto decoder_out = model_->ForwardDecoder(
|
||||||
std::move(tokens), std::move(self_kv_cache.first),
|
std::move(tokens), std::move(self_kv_cache.first),
|
||||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||||
std::move(offset));
|
std::move(offset));
|
||||||
|
|
||||||
|
*(std::get<5>(decoder_out).GetTensorMutableData<int64_t>()) =
|
||||||
|
initial_tokens.size();
|
||||||
|
|
||||||
const auto &logits = std::get<0>(decoder_out);
|
const auto &logits = std::get<0>(decoder_out);
|
||||||
const float *p_logits = logits.GetTensorData<float>();
|
const float *p_logits = logits.GetTensorData<float>();
|
||||||
|
|
||||||
@@ -58,18 +152,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
|||||||
std::array<int64_t, 2> token_shape{1, 1};
|
std::array<int64_t, 2> token_shape{1, 1};
|
||||||
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
|
Ort::Value tokens = Ort::Value::CreateTensor<int64_t>(
|
||||||
model_->Allocator(), token_shape.data(), token_shape.size());
|
model_->Allocator(), token_shape.data(), token_shape.size());
|
||||||
|
|
||||||
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
|
int64_t *p_tokens = tokens.GetTensorMutableData<int64_t>();
|
||||||
p_tokens[0] = max_token_id;
|
p_tokens[0] = max_token_id;
|
||||||
|
|
||||||
int64_t *p_offset =
|
|
||||||
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
|
|
||||||
|
|
||||||
if (i == 0) {
|
|
||||||
*p_offset = initial_tokens.size();
|
|
||||||
} else {
|
|
||||||
*p_offset += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
decoder_out = model_->ForwardDecoder(std::move(tokens),
|
decoder_out = model_->ForwardDecoder(std::move(tokens),
|
||||||
std::move(std::get<1>(decoder_out)),
|
std::move(std::get<1>(decoder_out)),
|
||||||
std::move(std::get<2>(decoder_out)),
|
std::move(std::get<2>(decoder_out)),
|
||||||
@@ -77,6 +163,11 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
|||||||
std::move(std::get<4>(decoder_out)),
|
std::move(std::get<4>(decoder_out)),
|
||||||
std::move(std::get<5>(decoder_out)));
|
std::move(std::get<5>(decoder_out)));
|
||||||
|
|
||||||
|
int64_t *p_offset =
|
||||||
|
std::get<5>(decoder_out).GetTensorMutableData<int64_t>();
|
||||||
|
|
||||||
|
*p_offset += 1;
|
||||||
|
|
||||||
const auto &logits = std::get<0>(decoder_out);
|
const auto &logits = std::get<0>(decoder_out);
|
||||||
const float *p_logits = logits.GetTensorData<float>();
|
const float *p_logits = logits.GetTensorData<float>();
|
||||||
|
|
||||||
@@ -85,6 +176,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<OfflineWhisperDecoderResult> ans(1);
|
std::vector<OfflineWhisperDecoderResult> ans(1);
|
||||||
|
|
||||||
ans[0].tokens = std::move(predicted_tokens);
|
ans[0].tokens = std::move(predicted_tokens);
|
||||||
|
|
||||||
return ans;
|
return ans;
|
||||||
|
|||||||
@@ -8,19 +8,25 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
||||||
public:
|
public:
|
||||||
explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model)
|
OfflineWhisperGreedySearchDecoder(const OfflineWhisperModelConfig &config,
|
||||||
: model_(model) {}
|
OfflineWhisperModel *model)
|
||||||
|
: config_(config), model_(model) {}
|
||||||
|
|
||||||
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
||||||
Ort::Value cross_v) override;
|
Ort::Value cross_v) override;
|
||||||
|
|
||||||
|
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||||
|
Ort::Value &cross_v) const; // NOLINT
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
OfflineWhisperModelConfig config_;
|
||||||
OfflineWhisperModel *model_; // not owned
|
OfflineWhisperModel *model_; // not owned
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,21 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
|||||||
po->Register("whisper-decoder", &decoder,
|
po->Register("whisper-decoder", &decoder,
|
||||||
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
|
"Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, "
|
||||||
"medium.en-decoder.onnx.");
|
"medium.en-decoder.onnx.");
|
||||||
|
|
||||||
|
po->Register(
|
||||||
|
"whisper-language", &language,
|
||||||
|
"The spoke language in the input audio file. Example values: "
|
||||||
|
"en, de, fr, zh, jp. If it is not given for a multilingual model, we will"
|
||||||
|
" infer the language from the input audio file. "
|
||||||
|
"Please refer to "
|
||||||
|
"https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10"
|
||||||
|
" for valid values. Note that for non-multilingual models, it supports "
|
||||||
|
"only 'en'");
|
||||||
|
|
||||||
|
po->Register("whisper-task", &task,
|
||||||
|
"Valid values: transcribe, translate. "
|
||||||
|
"Note that for non-multilingual models, it supports "
|
||||||
|
"only 'transcribe'");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OfflineWhisperModelConfig::Validate() const {
|
bool OfflineWhisperModelConfig::Validate() const {
|
||||||
@@ -30,6 +45,14 @@ bool OfflineWhisperModelConfig::Validate() const {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (task != "translate" && task != "transcribe") {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"--whisper-task supports only translate and transcribe. Given: %s",
|
||||||
|
task.c_str());
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,7 +61,9 @@ std::string OfflineWhisperModelConfig::ToString() const {
|
|||||||
|
|
||||||
os << "OfflineWhisperModelConfig(";
|
os << "OfflineWhisperModelConfig(";
|
||||||
os << "encoder=\"" << encoder << "\", ";
|
os << "encoder=\"" << encoder << "\", ";
|
||||||
os << "decoder=\"" << decoder << "\")";
|
os << "decoder=\"" << decoder << "\", ";
|
||||||
|
os << "language=\"" << language << "\", ";
|
||||||
|
os << "task=\"" << task << "\")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,10 +14,26 @@ struct OfflineWhisperModelConfig {
|
|||||||
std::string encoder;
|
std::string encoder;
|
||||||
std::string decoder;
|
std::string decoder;
|
||||||
|
|
||||||
|
// Available languages can be found at
|
||||||
|
// https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||||
|
//
|
||||||
|
// Note: For non-multilingual models, it supports only "en"
|
||||||
|
//
|
||||||
|
// If empty, we will infer it from the input audio file when
|
||||||
|
// the model is multilingual.
|
||||||
|
std::string language;
|
||||||
|
|
||||||
|
// Valid values are transcribe and translate
|
||||||
|
//
|
||||||
|
// Note: For non-multilingual models, it supports only "transcribe"
|
||||||
|
std::string task = "transcribe";
|
||||||
|
|
||||||
OfflineWhisperModelConfig() = default;
|
OfflineWhisperModelConfig() = default;
|
||||||
OfflineWhisperModelConfig(const std::string &encoder,
|
OfflineWhisperModelConfig(const std::string &encoder,
|
||||||
const std::string &decoder)
|
const std::string &decoder,
|
||||||
: encoder(encoder), decoder(decoder) {}
|
const std::string &language,
|
||||||
|
const std::string &task)
|
||||||
|
: encoder(encoder), decoder(decoder), language(language), task(task) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
@@ -88,10 +89,32 @@ class OfflineWhisperModel::Impl {
|
|||||||
|
|
||||||
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
|
const std::vector<int64_t> &GetInitialTokens() const { return sot_sequence_; }
|
||||||
|
|
||||||
|
const std::vector<int32_t> &GetAllLanguageIDs() const {
|
||||||
|
return all_language_tokens_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<std::string, int32_t> &GetLang2ID() const {
|
||||||
|
return lang2id_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<int32_t, std::string> &GetID2Lang() const {
|
||||||
|
return id2lang_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t NoTimeStampsToken() const { return no_timestamps_; }
|
||||||
|
|
||||||
int32_t EOT() const { return eot_; }
|
int32_t EOT() const { return eot_; }
|
||||||
|
|
||||||
|
int32_t SOT() const { return sot_; }
|
||||||
|
|
||||||
int32_t TextCtx() const { return n_text_ctx_; }
|
int32_t TextCtx() const { return n_text_ctx_; }
|
||||||
|
|
||||||
|
int32_t VocabSize() const { return n_vocab_; }
|
||||||
|
|
||||||
|
int32_t Translate() const { return translate_; }
|
||||||
|
|
||||||
|
bool IsMultiLingual() const { return is_multilingual_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
@@ -116,13 +139,35 @@ class OfflineWhisperModel::Impl {
|
|||||||
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
|
SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer");
|
||||||
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
|
SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx");
|
||||||
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
|
SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(n_vocab_, "n_vocab");
|
||||||
SHERPA_ONNX_READ_META_DATA(sot_, "sot");
|
SHERPA_ONNX_READ_META_DATA(sot_, "sot");
|
||||||
SHERPA_ONNX_READ_META_DATA(eot_, "eot");
|
SHERPA_ONNX_READ_META_DATA(eot_, "eot");
|
||||||
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
|
SHERPA_ONNX_READ_META_DATA(blank_, "blank_id");
|
||||||
SHERPA_ONNX_READ_META_DATA(translate_, "translate");
|
SHERPA_ONNX_READ_META_DATA(translate_, "translate");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(transcribe_, "transcribe");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(is_multilingual_, "is_multilingual");
|
||||||
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
|
SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps");
|
||||||
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
|
SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech");
|
||||||
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
|
SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence");
|
||||||
|
|
||||||
|
if (is_multilingual_) {
|
||||||
|
SHERPA_ONNX_READ_META_DATA_VEC(all_language_tokens_,
|
||||||
|
"all_language_tokens");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_VEC_STRING(all_language_codes_,
|
||||||
|
"all_language_codes");
|
||||||
|
if (all_language_tokens_.size() != all_language_codes_.size()) {
|
||||||
|
SHERPA_ONNX_LOGE("# lang_id: %d != # lang_code: %d",
|
||||||
|
static_cast<int32_t>(all_language_tokens_.size()),
|
||||||
|
static_cast<int32_t>(all_language_codes_.size()));
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t i = 0;
|
||||||
|
i != static_cast<int32_t>(all_language_tokens_.size()); ++i) {
|
||||||
|
lang2id_[all_language_codes_[i]] = all_language_tokens_[i];
|
||||||
|
id2lang_[all_language_tokens_[i]] = all_language_codes_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitDecoder(void *model_data, size_t model_data_length) {
|
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||||
@@ -157,16 +202,24 @@ class OfflineWhisperModel::Impl {
|
|||||||
std::vector<std::string> decoder_output_names_;
|
std::vector<std::string> decoder_output_names_;
|
||||||
std::vector<const char *> decoder_output_names_ptr_;
|
std::vector<const char *> decoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<int32_t> all_language_tokens_;
|
||||||
|
std::vector<std::string> all_language_codes_;
|
||||||
|
std::unordered_map<std::string, int32_t> lang2id_;
|
||||||
|
std::unordered_map<int32_t, std::string> id2lang_;
|
||||||
|
|
||||||
// model meta data
|
// model meta data
|
||||||
int32_t n_text_layer_;
|
int32_t n_text_layer_;
|
||||||
int32_t n_text_ctx_;
|
int32_t n_text_ctx_;
|
||||||
int32_t n_text_state_;
|
int32_t n_text_state_;
|
||||||
|
int32_t n_vocab_;
|
||||||
int32_t sot_;
|
int32_t sot_;
|
||||||
int32_t eot_;
|
int32_t eot_;
|
||||||
int32_t blank_;
|
int32_t blank_;
|
||||||
int32_t translate_;
|
int32_t translate_;
|
||||||
|
int32_t transcribe_;
|
||||||
int32_t no_timestamps_;
|
int32_t no_timestamps_;
|
||||||
int32_t no_speech_;
|
int32_t no_speech_;
|
||||||
|
int32_t is_multilingual_;
|
||||||
std::vector<int64_t> sot_sequence_;
|
std::vector<int64_t> sot_sequence_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -176,7 +229,7 @@ OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
|||||||
OfflineWhisperModel::~OfflineWhisperModel() = default;
|
OfflineWhisperModel::~OfflineWhisperModel() = default;
|
||||||
|
|
||||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
|
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::ForwardEncoder(
|
||||||
Ort::Value features) {
|
Ort::Value features) const {
|
||||||
return impl_->ForwardEncoder(std::move(features));
|
return impl_->ForwardEncoder(std::move(features));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,14 +240,15 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
|
|||||||
Ort::Value n_layer_self_v_cache,
|
Ort::Value n_layer_self_v_cache,
|
||||||
Ort::Value n_layer_cross_k,
|
Ort::Value n_layer_cross_k,
|
||||||
Ort::Value n_layer_cross_v,
|
Ort::Value n_layer_cross_v,
|
||||||
Ort::Value offset) {
|
Ort::Value offset) const {
|
||||||
return impl_->ForwardDecoder(
|
return impl_->ForwardDecoder(
|
||||||
std::move(tokens), std::move(n_layer_self_k_cache),
|
std::move(tokens), std::move(n_layer_self_k_cache),
|
||||||
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
|
std::move(n_layer_self_v_cache), std::move(n_layer_cross_k),
|
||||||
std::move(n_layer_cross_v), std::move(offset));
|
std::move(n_layer_cross_v), std::move(offset));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache() {
|
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
|
||||||
|
const {
|
||||||
return impl_->GetInitialSelfKVCache();
|
return impl_->GetInitialSelfKVCache();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,8 +260,36 @@ const std::vector<int64_t> &OfflineWhisperModel::GetInitialTokens() const {
|
|||||||
return impl_->GetInitialTokens();
|
return impl_->GetInitialTokens();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::vector<int32_t> &OfflineWhisperModel::GetAllLanguageIDs() const {
|
||||||
|
return impl_->GetAllLanguageIDs();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<std::string, int32_t>
|
||||||
|
&OfflineWhisperModel::GetLang2ID() const {
|
||||||
|
return impl_->GetLang2ID();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<int32_t, std::string>
|
||||||
|
&OfflineWhisperModel::GetID2Lang() const {
|
||||||
|
return impl_->GetID2Lang();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OfflineWhisperModel::NoTimeStampsToken() const {
|
||||||
|
return impl_->NoTimeStampsToken();
|
||||||
|
}
|
||||||
|
|
||||||
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
|
int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); }
|
||||||
|
|
||||||
|
int32_t OfflineWhisperModel::SOT() const { return impl_->SOT(); }
|
||||||
|
|
||||||
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
|
int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); }
|
||||||
|
|
||||||
|
int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); }
|
||||||
|
|
||||||
|
int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); }
|
||||||
|
|
||||||
|
bool OfflineWhisperModel::IsMultiLingual() const {
|
||||||
|
return impl_->IsMultiLingual();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -5,7 +5,9 @@
|
|||||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -30,7 +32,7 @@ class OfflineWhisperModel {
|
|||||||
* - n_layer_cross_v: A 4-D tensor of shape
|
* - n_layer_cross_v: A 4-D tensor of shape
|
||||||
* (n_text_layer, N, n_audio_ctx, n_text_state)
|
* (n_text_layer, N, n_audio_ctx, n_text_state)
|
||||||
*/
|
*/
|
||||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
|
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) const;
|
||||||
|
|
||||||
/** Run the decoder model.
|
/** Run the decoder model.
|
||||||
*
|
*
|
||||||
@@ -58,7 +60,9 @@ class OfflineWhisperModel {
|
|||||||
Ort::Value>
|
Ort::Value>
|
||||||
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||||
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||||
Ort::Value n_layer_cross_v, Ort::Value offset);
|
Ort::Value n_layer_cross_v, Ort::Value offset) const;
|
||||||
|
|
||||||
|
int32_t DetectLanguage() const;
|
||||||
|
|
||||||
/** Return the initial self kv cache in a pair
|
/** Return the initial self kv cache in a pair
|
||||||
* - n_layer_self_k_cache A 4-D tensor of shape
|
* - n_layer_self_k_cache A 4-D tensor of shape
|
||||||
@@ -66,14 +70,23 @@ class OfflineWhisperModel {
|
|||||||
* - n_layer_self_v_cache A 4-D tensor of shape
|
* - n_layer_self_v_cache A 4-D tensor of shape
|
||||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||||
*/
|
*/
|
||||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
|
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
|
||||||
const std::vector<int64_t> &GetInitialTokens() const;
|
const std::vector<int64_t> &GetInitialTokens() const;
|
||||||
|
const std::vector<int32_t> &GetAllLanguageIDs() const;
|
||||||
|
const std::unordered_map<std::string, int32_t> &GetLang2ID() const;
|
||||||
|
const std::unordered_map<int32_t, std::string> &GetID2Lang() const;
|
||||||
|
|
||||||
/** Return an allocator for allocating memory
|
/** Return an allocator for allocating memory
|
||||||
*/
|
*/
|
||||||
OrtAllocator *Allocator() const;
|
OrtAllocator *Allocator() const;
|
||||||
|
|
||||||
|
int32_t NoTimeStampsToken() const;
|
||||||
int32_t EOT() const;
|
int32_t EOT() const;
|
||||||
|
int32_t SOT() const;
|
||||||
int32_t TextCtx() const;
|
int32_t TextCtx() const;
|
||||||
|
int32_t VocabSize() const;
|
||||||
|
int32_t Translate() const;
|
||||||
|
bool IsMultiLingual() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
|
|||||||
@@ -14,10 +14,14 @@ namespace sherpa_onnx {
|
|||||||
void PybindOfflineWhisperModelConfig(py::module *m) {
|
void PybindOfflineWhisperModelConfig(py::module *m) {
|
||||||
using PyClass = OfflineWhisperModelConfig;
|
using PyClass = OfflineWhisperModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
|
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
|
||||||
.def(py::init<const std::string &, const std::string &>(),
|
.def(py::init<const std::string &, const std::string &,
|
||||||
py::arg("encoder"), py::arg("decoder"))
|
const std::string &, const std::string &>(),
|
||||||
|
py::arg("encoder"), py::arg("decoder"), py::arg("language"),
|
||||||
|
py::arg("task"))
|
||||||
.def_readwrite("encoder", &PyClass::encoder)
|
.def_readwrite("encoder", &PyClass::encoder)
|
||||||
.def_readwrite("decoder", &PyClass::decoder)
|
.def_readwrite("decoder", &PyClass::decoder)
|
||||||
|
.def_readwrite("language", &PyClass::language)
|
||||||
|
.def_readwrite("task", &PyClass::task)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -244,6 +244,8 @@ class OfflineRecognizer(object):
|
|||||||
encoder: str,
|
encoder: str,
|
||||||
decoder: str,
|
decoder: str,
|
||||||
tokens: str,
|
tokens: str,
|
||||||
|
language: str = "en",
|
||||||
|
task: str = "transcribe",
|
||||||
num_threads: int = 1,
|
num_threads: int = 1,
|
||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
@@ -268,6 +270,14 @@ class OfflineRecognizer(object):
|
|||||||
|
|
||||||
symbol integer_id
|
symbol integer_id
|
||||||
|
|
||||||
|
language:
|
||||||
|
The spoken language in the audio file. Example values: en, de, zh,
|
||||||
|
jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||||
|
for all possible values. Note that for non-multilingual models, the
|
||||||
|
only valid value is 'en'.
|
||||||
|
task:
|
||||||
|
Valid values are: transcribe, translate. Note that for
|
||||||
|
non-multilingual models, the only valid value is 'transcribe'.
|
||||||
num_threads:
|
num_threads:
|
||||||
Number of threads for neural network computation.
|
Number of threads for neural network computation.
|
||||||
decoding_method:
|
decoding_method:
|
||||||
@@ -279,7 +289,12 @@ class OfflineRecognizer(object):
|
|||||||
"""
|
"""
|
||||||
self = cls.__new__(cls)
|
self = cls.__new__(cls)
|
||||||
model_config = OfflineModelConfig(
|
model_config = OfflineModelConfig(
|
||||||
whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder),
|
whisper=OfflineWhisperModelConfig(
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
language=language,
|
||||||
|
task=task,
|
||||||
|
),
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=num_threads,
|
num_threads=num_threads,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
|
|||||||
Reference in New Issue
Block a user