Add timestamps for offline paraformer (#310)
This commit is contained in:
27
.github/scripts/test-offline-transducer.sh
vendored
27
.github/scripts/test-offline-transducer.sh
vendored
@@ -123,3 +123,30 @@ time $EXE \
|
|||||||
$repo/test_wavs/8k.wav
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
rm -rf $repo
|
rm -rf $repo
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run Paraformer (Chinese) with timestamps"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
ls -lh *.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--paraformer=$repo/model.int8.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|||||||
@@ -353,11 +353,22 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
|
|||||||
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
|
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
|
||||||
const_cast<char *>(r->text)[text.size()] = 0;
|
const_cast<char *>(r->text)[text.size()] = 0;
|
||||||
|
|
||||||
|
if (!result.timestamps.empty()) {
|
||||||
|
r->timestamps = new float[result.timestamps.size()];
|
||||||
|
std::copy(result.timestamps.begin(), result.timestamps.end(),
|
||||||
|
r->timestamps);
|
||||||
|
r->count = result.timestamps.size();
|
||||||
|
} else {
|
||||||
|
r->timestamps = nullptr;
|
||||||
|
r->count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DestroyOfflineRecognizerResult(
|
void DestroyOfflineRecognizerResult(
|
||||||
const SherpaOnnxOfflineRecognizerResult *r) {
|
const SherpaOnnxOfflineRecognizerResult *r) {
|
||||||
delete[] r->text;
|
delete[] r->text;
|
||||||
|
delete[] r->timestamps;
|
||||||
delete r;
|
delete r;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -408,6 +408,14 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
|
|||||||
|
|
||||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
|
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
|
||||||
const char *text;
|
const char *text;
|
||||||
|
|
||||||
|
// Pointer to continuous memory which holds timestamps
|
||||||
|
//
|
||||||
|
// It is NULL if the model does not support timestamps
|
||||||
|
float *timestamps;
|
||||||
|
|
||||||
|
// number of entries in timestamps
|
||||||
|
int32_t count;
|
||||||
// TODO(fangjun): Add more fields
|
// TODO(fangjun): Add more fields
|
||||||
} SherpaOnnxOfflineRecognizerResult;
|
} SherpaOnnxOfflineRecognizerResult;
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,11 @@ namespace sherpa_onnx {
|
|||||||
struct OfflineParaformerDecoderResult {
|
struct OfflineParaformerDecoderResult {
|
||||||
/// The decoded token IDs
|
/// The decoded token IDs
|
||||||
std::vector<int64_t> tokens;
|
std::vector<int64_t> tokens;
|
||||||
|
|
||||||
|
// it contains the start time of each token in seconds
|
||||||
|
//
|
||||||
|
// len(timestamps) == len(tokens)
|
||||||
|
std::vector<float> timestamps;
|
||||||
};
|
};
|
||||||
|
|
||||||
class OfflineParaformerDecoder {
|
class OfflineParaformerDecoder {
|
||||||
@@ -28,7 +33,8 @@ class OfflineParaformerDecoder {
|
|||||||
* @return Return a vector of size `N` containing the decoded results.
|
* @return Return a vector of size `N` containing the decoded results.
|
||||||
*/
|
*/
|
||||||
virtual std::vector<OfflineParaformerDecoderResult> Decode(
|
virtual std::vector<OfflineParaformerDecoderResult> Decode(
|
||||||
Ort::Value log_probs, Ort::Value token_num) = 0;
|
Ort::Value log_probs, Ort::Value token_num,
|
||||||
|
Ort::Value us_cif_peak = Ort::Value(nullptr)) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -5,13 +5,18 @@
|
|||||||
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
|
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
std::vector<OfflineParaformerDecoderResult>
|
std::vector<OfflineParaformerDecoderResult>
|
||||||
OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
|
OfflineParaformerGreedySearchDecoder::Decode(
|
||||||
Ort::Value /*token_num*/) {
|
Ort::Value log_probs, Ort::Value /*token_num*/,
|
||||||
|
Ort::Value us_cif_peak /*=Ort::Value(nullptr)*/
|
||||||
|
) {
|
||||||
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
int32_t batch_size = shape[0];
|
int32_t batch_size = shape[0];
|
||||||
int32_t num_tokens = shape[1];
|
int32_t num_tokens = shape[1];
|
||||||
@@ -25,12 +30,43 @@ OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
|
|||||||
for (int32_t k = 0; k != num_tokens; ++k) {
|
for (int32_t k = 0; k != num_tokens; ++k) {
|
||||||
auto max_idx = static_cast<int64_t>(
|
auto max_idx = static_cast<int64_t>(
|
||||||
std::distance(p, std::max_element(p, p + vocab_size)));
|
std::distance(p, std::max_element(p, p + vocab_size)));
|
||||||
if (max_idx == eos_id_) break;
|
if (max_idx == eos_id_) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
results[i].tokens.push_back(max_idx);
|
results[i].tokens.push_back(max_idx);
|
||||||
|
|
||||||
p += vocab_size;
|
p += vocab_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (us_cif_peak) {
|
||||||
|
int32_t dim = us_cif_peak.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
|
|
||||||
|
const auto *peak = us_cif_peak.GetTensorData<float>() + i * dim;
|
||||||
|
std::vector<float> timestamps;
|
||||||
|
timestamps.reserve(results[i].tokens.size());
|
||||||
|
|
||||||
|
// 10.0: frameshift is 10 milliseconds
|
||||||
|
// 6: LfrWindowSize
|
||||||
|
// 3: us_cif_peak is upsampled by a factor of 3
|
||||||
|
// 1000: milliseconds to seconds
|
||||||
|
float scale = 10.0 * 6 / 3 / 1000;
|
||||||
|
|
||||||
|
for (int32_t k = 0; k != dim; ++k) {
|
||||||
|
if (peak[k] > 1 - 1e-4) {
|
||||||
|
timestamps.push_back(k * scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timestamps.pop_back();
|
||||||
|
|
||||||
|
if (timestamps.size() == results[i].tokens.size()) {
|
||||||
|
results[i].timestamps = std::move(timestamps);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
|
||||||
|
static_cast<int32_t>(results[i].tokens.size()),
|
||||||
|
static_cast<int32_t>(timestamps.size()));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return results;
|
return results;
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
|
|||||||
: eos_id_(eos_id) {}
|
: eos_id_(eos_id) {}
|
||||||
|
|
||||||
std::vector<OfflineParaformerDecoderResult> Decode(
|
std::vector<OfflineParaformerDecoderResult> Decode(
|
||||||
Ort::Value log_probs, Ort::Value /*token_num*/) override;
|
Ort::Value log_probs, Ort::Value token_num,
|
||||||
|
Ort::Value us_cif_peak = Ort::Value(nullptr)) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t eos_id_;
|
int32_t eos_id_;
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
@@ -36,16 +37,13 @@ class OfflineParaformerModel::Impl {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||||
Ort::Value features_length) {
|
Ort::Value features_length) {
|
||||||
std::array<Ort::Value, 2> inputs = {std::move(features),
|
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||||
std::move(features_length)};
|
std::move(features_length)};
|
||||||
|
|
||||||
auto out =
|
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
output_names_ptr_.data(), output_names_ptr_.size());
|
|
||||||
|
|
||||||
return {std::move(out[0]), std::move(out[1])};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t VocabSize() const { return vocab_size_; }
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
@@ -119,7 +117,7 @@ OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr,
|
|||||||
|
|
||||||
OfflineParaformerModel::~OfflineParaformerModel() = default;
|
OfflineParaformerModel::~OfflineParaformerModel() = default;
|
||||||
|
|
||||||
std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(
|
std::vector<Ort::Value> OfflineParaformerModel::Forward(
|
||||||
Ort::Value features, Ort::Value features_length) {
|
Ort::Value features, Ort::Value features_length) {
|
||||||
return impl_->Forward(std::move(features), std::move(features_length));
|
return impl_->Forward(std::move(features), std::move(features_length));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
|
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
@@ -35,13 +34,17 @@ class OfflineParaformerModel {
|
|||||||
* valid frames in `features` before padding.
|
* valid frames in `features` before padding.
|
||||||
* Its dtype is int32_t.
|
* Its dtype is int32_t.
|
||||||
*
|
*
|
||||||
* @return Return a pair containing:
|
* @return Return a vector containing:
|
||||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size)
|
* - log_probs: A 3-D tensor of shape (N, T', vocab_size)
|
||||||
* - token_num: A 1-D tensor of shape (N, T') containing number
|
* - token_num: A 1-D tensor of shape (N, T') containing number
|
||||||
* of valid tokens in each utterance. Its dtype is int64_t.
|
* of valid tokens in each utterance. Its dtype is int64_t.
|
||||||
|
* If it is a model supporting timestamps, then there are additional two
|
||||||
|
* outputs:
|
||||||
|
* - us_alphas
|
||||||
|
* - us_cif_peak
|
||||||
*/
|
*/
|
||||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||||
Ort::Value features_length);
|
Ort::Value features_length);
|
||||||
|
|
||||||
/** Return the vocabulary size of the model
|
/** Return the vocabulary size of the model
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ static OfflineRecognitionResult Convert(
|
|||||||
const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) {
|
const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) {
|
||||||
OfflineRecognitionResult r;
|
OfflineRecognitionResult r;
|
||||||
r.tokens.reserve(src.tokens.size());
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
r.timestamps = src.timestamps;
|
||||||
|
|
||||||
std::string text;
|
std::string text;
|
||||||
|
|
||||||
@@ -184,7 +185,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
|||||||
// i.e., -23.025850929940457f
|
// i.e., -23.025850929940457f
|
||||||
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
|
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
|
||||||
|
|
||||||
std::pair<Ort::Value, Ort::Value> t{nullptr, nullptr};
|
std::vector<Ort::Value> t;
|
||||||
try {
|
try {
|
||||||
t = model_->Forward(std::move(x), std::move(x_length));
|
t = model_->Forward(std::move(x), std::move(x_length));
|
||||||
} catch (const Ort::Exception &ex) {
|
} catch (const Ort::Exception &ex) {
|
||||||
@@ -193,7 +194,13 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
std::vector<OfflineParaformerDecoderResult> results;
|
||||||
|
if (t.size() == 2) {
|
||||||
|
results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
|
||||||
|
} else {
|
||||||
|
results =
|
||||||
|
decoder_->Decode(std::move(t[0]), std::move(t[1]), std::move(t[3]));
|
||||||
|
}
|
||||||
|
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
auto r = Convert(results[i], symbol_table_);
|
auto r = Convert(results[i], symbol_table_);
|
||||||
|
|||||||
@@ -349,6 +349,23 @@ class SherpaOnnxOfflineRecongitionResult {
|
|||||||
return String(cString: result.pointee.text)
|
return String(cString: result.pointee.text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var count: Int32 {
|
||||||
|
return result.pointee.count
|
||||||
|
}
|
||||||
|
|
||||||
|
var timestamps: [Float] {
|
||||||
|
if let p = result.pointee.timestamps {
|
||||||
|
var timestamps: [Float] = []
|
||||||
|
for index in 0..<count {
|
||||||
|
timestamps.append(p[Int(index)])
|
||||||
|
}
|
||||||
|
return timestamps
|
||||||
|
} else {
|
||||||
|
let timestamps: [Float] = []
|
||||||
|
return timestamps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
init(result: UnsafePointer<SherpaOnnxOfflineRecognizerResult>!) {
|
init(result: UnsafePointer<SherpaOnnxOfflineRecognizerResult>!) {
|
||||||
self.result = result
|
self.result = result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,21 +13,45 @@ extension AVAudioPCMBuffer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func run() {
|
func run() {
|
||||||
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
|
|
||||||
let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
|
|
||||||
let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
|
|
||||||
|
|
||||||
let whisperConfig = sherpaOnnxOfflineWhisperModelConfig(
|
var recognizer: SherpaOnnxOfflineRecognizer
|
||||||
encoder: encoder,
|
var modelConfig: SherpaOnnxOfflineModelConfig
|
||||||
decoder: decoder
|
var modelType = "whisper"
|
||||||
)
|
// modelType = "paraformer"
|
||||||
|
|
||||||
let modelConfig = sherpaOnnxOfflineModelConfig(
|
if modelType == "whisper" {
|
||||||
tokens: tokens,
|
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
|
||||||
whisper: whisperConfig,
|
let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
|
||||||
debug: 0,
|
let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
|
||||||
modelType: "whisper"
|
|
||||||
)
|
let whisperConfig = sherpaOnnxOfflineWhisperModelConfig(
|
||||||
|
encoder: encoder,
|
||||||
|
decoder: decoder
|
||||||
|
)
|
||||||
|
|
||||||
|
modelConfig = sherpaOnnxOfflineModelConfig(
|
||||||
|
tokens: tokens,
|
||||||
|
whisper: whisperConfig,
|
||||||
|
debug: 0,
|
||||||
|
modelType: "whisper"
|
||||||
|
)
|
||||||
|
} else if modelType == "paraformer" {
|
||||||
|
let model = "./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx"
|
||||||
|
let tokens = "./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt"
|
||||||
|
let paraformerConfig = sherpaOnnxOfflineParaformerModelConfig(
|
||||||
|
model: model
|
||||||
|
)
|
||||||
|
|
||||||
|
modelConfig = sherpaOnnxOfflineModelConfig(
|
||||||
|
tokens: tokens,
|
||||||
|
paraformer: paraformerConfig,
|
||||||
|
debug: 0,
|
||||||
|
modelType: "paraformer"
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
print("Please specify a supported modelType \(modelType)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
let featConfig = sherpaOnnxFeatureConfig(
|
let featConfig = sherpaOnnxFeatureConfig(
|
||||||
sampleRate: 16000,
|
sampleRate: 16000,
|
||||||
@@ -38,7 +62,7 @@ func run() {
|
|||||||
modelConfig: modelConfig
|
modelConfig: modelConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
let recognizer = SherpaOnnxOfflineRecognizer(config: &config)
|
recognizer = SherpaOnnxOfflineRecognizer(config: &config)
|
||||||
|
|
||||||
let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
|
let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
|
||||||
let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
|
let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
|
||||||
@@ -55,6 +79,10 @@ func run() {
|
|||||||
let array: [Float]! = audioFileBuffer?.array()
|
let array: [Float]! = audioFileBuffer?.array()
|
||||||
let result = recognizer.decode(samples: array, sampleRate: Int(audioFormat.sampleRate))
|
let result = recognizer.decode(samples: array, sampleRate: Int(audioFormat.sampleRate))
|
||||||
print("\nresult is:\n\(result.text)")
|
print("\nresult is:\n\(result.text)")
|
||||||
|
if result.timestamps.count != 0 {
|
||||||
|
print("\ntimestamps is:\n\(result.timestamps)")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@main
|
@main
|
||||||
|
|||||||
Reference in New Issue
Block a user