Add timestamps for offline paraformer (#310)
This commit is contained in:
27
.github/scripts/test-offline-transducer.sh
vendored
27
.github/scripts/test-offline-transducer.sh
vendored
@@ -123,3 +123,30 @@ time $EXE \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run Paraformer (Chinese) with timestamps"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14
|
||||
log "Start testing ${repo_url}"
|
||||
repo=$(basename $repo_url)
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--paraformer=$repo/model.int8.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
@@ -353,11 +353,22 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
|
||||
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
|
||||
const_cast<char *>(r->text)[text.size()] = 0;
|
||||
|
||||
if (!result.timestamps.empty()) {
|
||||
r->timestamps = new float[result.timestamps.size()];
|
||||
std::copy(result.timestamps.begin(), result.timestamps.end(),
|
||||
r->timestamps);
|
||||
r->count = result.timestamps.size();
|
||||
} else {
|
||||
r->timestamps = nullptr;
|
||||
r->count = 0;
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
void DestroyOfflineRecognizerResult(
|
||||
const SherpaOnnxOfflineRecognizerResult *r) {
|
||||
delete[] r->text;
|
||||
delete[] r->timestamps;
|
||||
delete r;
|
||||
}
|
||||
|
||||
@@ -408,6 +408,14 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
|
||||
const char *text;
|
||||
|
||||
// Pointer to continuous memory which holds timestamps
|
||||
//
|
||||
// It is NULL if the model does not support timestamps
|
||||
float *timestamps;
|
||||
|
||||
// number of entries in timestamps
|
||||
int32_t count;
|
||||
// TODO(fangjun): Add more fields
|
||||
} SherpaOnnxOfflineRecognizerResult;
|
||||
|
||||
|
||||
@@ -14,6 +14,11 @@ namespace sherpa_onnx {
|
||||
struct OfflineParaformerDecoderResult {
|
||||
/// The decoded token IDs
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
// it contains the start time of each token in seconds
|
||||
//
|
||||
// len(timestamps) == len(tokens)
|
||||
std::vector<float> timestamps;
|
||||
};
|
||||
|
||||
class OfflineParaformerDecoder {
|
||||
@@ -28,7 +33,8 @@ class OfflineParaformerDecoder {
|
||||
* @return Return a vector of size `N` containing the decoded results.
|
||||
*/
|
||||
virtual std::vector<OfflineParaformerDecoderResult> Decode(
|
||||
Ort::Value log_probs, Ort::Value token_num) = 0;
|
||||
Ort::Value log_probs, Ort::Value token_num,
|
||||
Ort::Value us_cif_peak = Ort::Value(nullptr)) = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -5,13 +5,18 @@
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineParaformerDecoderResult>
|
||||
OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
|
||||
Ort::Value /*token_num*/) {
|
||||
OfflineParaformerGreedySearchDecoder::Decode(
|
||||
Ort::Value log_probs, Ort::Value /*token_num*/,
|
||||
Ort::Value us_cif_peak /*=Ort::Value(nullptr)*/
|
||||
) {
|
||||
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = shape[0];
|
||||
int32_t num_tokens = shape[1];
|
||||
@@ -25,12 +30,43 @@ OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
|
||||
for (int32_t k = 0; k != num_tokens; ++k) {
|
||||
auto max_idx = static_cast<int64_t>(
|
||||
std::distance(p, std::max_element(p, p + vocab_size)));
|
||||
if (max_idx == eos_id_) break;
|
||||
if (max_idx == eos_id_) {
|
||||
break;
|
||||
}
|
||||
|
||||
results[i].tokens.push_back(max_idx);
|
||||
|
||||
p += vocab_size;
|
||||
}
|
||||
|
||||
if (us_cif_peak) {
|
||||
int32_t dim = us_cif_peak.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
|
||||
const auto *peak = us_cif_peak.GetTensorData<float>() + i * dim;
|
||||
std::vector<float> timestamps;
|
||||
timestamps.reserve(results[i].tokens.size());
|
||||
|
||||
// 10.0: frameshift is 10 milliseconds
|
||||
// 6: LfrWindowSize
|
||||
// 3: us_cif_peak is upsampled by a factor of 3
|
||||
// 1000: milliseconds to seconds
|
||||
float scale = 10.0 * 6 / 3 / 1000;
|
||||
|
||||
for (int32_t k = 0; k != dim; ++k) {
|
||||
if (peak[k] > 1 - 1e-4) {
|
||||
timestamps.push_back(k * scale);
|
||||
}
|
||||
}
|
||||
timestamps.pop_back();
|
||||
|
||||
if (timestamps.size() == results[i].tokens.size()) {
|
||||
results[i].timestamps = std::move(timestamps);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
|
||||
static_cast<int32_t>(results[i].tokens.size()),
|
||||
static_cast<int32_t>(timestamps.size()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
|
||||
@@ -17,7 +17,8 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
|
||||
: eos_id_(eos_id) {}
|
||||
|
||||
std::vector<OfflineParaformerDecoderResult> Decode(
|
||||
Ort::Value log_probs, Ort::Value /*token_num*/) override;
|
||||
Ort::Value log_probs, Ort::Value token_num,
|
||||
Ort::Value us_cif_peak = Ort::Value(nullptr)) override;
|
||||
|
||||
private:
|
||||
int32_t eos_id_;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
@@ -36,16 +37,13 @@ class OfflineParaformerModel::Impl {
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return {std::move(out[0]), std::move(out[1])};
|
||||
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
@@ -119,7 +117,7 @@ OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr,
|
||||
|
||||
OfflineParaformerModel::~OfflineParaformerModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(
|
||||
std::vector<Ort::Value> OfflineParaformerModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -35,13 +34,17 @@ class OfflineParaformerModel {
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int32_t.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* @return Return a vector containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size)
|
||||
* - token_num: A 1-D tensor of shape (N, T') containing number
|
||||
* of valid tokens in each utterance. Its dtype is int64_t.
|
||||
* If it is a model supporting timestamps, then there are additional two
|
||||
* outputs:
|
||||
* - us_alphas
|
||||
* - us_cif_peak
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length);
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length);
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
|
||||
@@ -31,6 +31,7 @@ static OfflineRecognitionResult Convert(
|
||||
const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps = src.timestamps;
|
||||
|
||||
std::string text;
|
||||
|
||||
@@ -184,7 +185,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
// i.e., -23.025850929940457f
|
||||
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> t{nullptr, nullptr};
|
||||
std::vector<Ort::Value> t;
|
||||
try {
|
||||
t = model_->Forward(std::move(x), std::move(x_length));
|
||||
} catch (const Ort::Exception &ex) {
|
||||
@@ -193,7 +194,13 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
return;
|
||||
}
|
||||
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
std::vector<OfflineParaformerDecoderResult> results;
|
||||
if (t.size() == 2) {
|
||||
results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
|
||||
} else {
|
||||
results =
|
||||
decoder_->Decode(std::move(t[0]), std::move(t[1]), std::move(t[3]));
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_);
|
||||
|
||||
@@ -349,6 +349,23 @@ class SherpaOnnxOfflineRecongitionResult {
|
||||
return String(cString: result.pointee.text)
|
||||
}
|
||||
|
||||
var count: Int32 {
|
||||
return result.pointee.count
|
||||
}
|
||||
|
||||
var timestamps: [Float] {
|
||||
if let p = result.pointee.timestamps {
|
||||
var timestamps: [Float] = []
|
||||
for index in 0..<count {
|
||||
timestamps.append(p[Int(index)])
|
||||
}
|
||||
return timestamps
|
||||
} else {
|
||||
let timestamps: [Float] = []
|
||||
return timestamps
|
||||
}
|
||||
}
|
||||
|
||||
init(result: UnsafePointer<SherpaOnnxOfflineRecognizerResult>!) {
|
||||
self.result = result
|
||||
}
|
||||
|
||||
@@ -13,21 +13,45 @@ extension AVAudioPCMBuffer {
|
||||
}
|
||||
|
||||
func run() {
|
||||
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
|
||||
let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
|
||||
let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
|
||||
|
||||
let whisperConfig = sherpaOnnxOfflineWhisperModelConfig(
|
||||
encoder: encoder,
|
||||
decoder: decoder
|
||||
)
|
||||
var recognizer: SherpaOnnxOfflineRecognizer
|
||||
var modelConfig: SherpaOnnxOfflineModelConfig
|
||||
var modelType = "whisper"
|
||||
// modelType = "paraformer"
|
||||
|
||||
let modelConfig = sherpaOnnxOfflineModelConfig(
|
||||
tokens: tokens,
|
||||
whisper: whisperConfig,
|
||||
debug: 0,
|
||||
modelType: "whisper"
|
||||
)
|
||||
if modelType == "whisper" {
|
||||
let encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"
|
||||
let decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"
|
||||
let tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"
|
||||
|
||||
let whisperConfig = sherpaOnnxOfflineWhisperModelConfig(
|
||||
encoder: encoder,
|
||||
decoder: decoder
|
||||
)
|
||||
|
||||
modelConfig = sherpaOnnxOfflineModelConfig(
|
||||
tokens: tokens,
|
||||
whisper: whisperConfig,
|
||||
debug: 0,
|
||||
modelType: "whisper"
|
||||
)
|
||||
} else if modelType == "paraformer" {
|
||||
let model = "./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx"
|
||||
let tokens = "./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt"
|
||||
let paraformerConfig = sherpaOnnxOfflineParaformerModelConfig(
|
||||
model: model
|
||||
)
|
||||
|
||||
modelConfig = sherpaOnnxOfflineModelConfig(
|
||||
tokens: tokens,
|
||||
paraformer: paraformerConfig,
|
||||
debug: 0,
|
||||
modelType: "paraformer"
|
||||
)
|
||||
} else {
|
||||
print("Please specify a supported modelType \(modelType)")
|
||||
return
|
||||
}
|
||||
|
||||
let featConfig = sherpaOnnxFeatureConfig(
|
||||
sampleRate: 16000,
|
||||
@@ -38,7 +62,7 @@ func run() {
|
||||
modelConfig: modelConfig
|
||||
)
|
||||
|
||||
let recognizer = SherpaOnnxOfflineRecognizer(config: &config)
|
||||
recognizer = SherpaOnnxOfflineRecognizer(config: &config)
|
||||
|
||||
let filePath = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
|
||||
let fileURL: NSURL = NSURL(fileURLWithPath: filePath)
|
||||
@@ -55,6 +79,10 @@ func run() {
|
||||
let array: [Float]! = audioFileBuffer?.array()
|
||||
let result = recognizer.decode(samples: array, sampleRate: Int(audioFormat.sampleRate))
|
||||
print("\nresult is:\n\(result.text)")
|
||||
if result.timestamps.count != 0 {
|
||||
print("\ntimestamps is:\n\(result.timestamps)")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@main
|
||||
|
||||
Reference in New Issue
Block a user