Support multilingual whisper models (#274)

This commit is contained in:
Fangjun Kuang
2023-08-16 00:28:52 +08:00
committed by GitHub
parent 496c5dd7f5
commit f709c95c5f
24 changed files with 692 additions and 73 deletions

View File

@@ -5,7 +5,9 @@
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
@@ -30,7 +32,7 @@ class OfflineWhisperModel {
* - n_layer_cross_v: A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state)
*/
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) const;
/** Run the decoder model.
*
@@ -58,7 +60,9 @@ class OfflineWhisperModel {
Ort::Value>
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset);
Ort::Value n_layer_cross_v, Ort::Value offset) const;
int32_t DetectLanguage() const;
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 4-D tensor of shape
@@ -66,14 +70,23 @@ class OfflineWhisperModel {
* - n_layer_self_v_cache A 4-D tensor of shape
* (n_text_layer, N, n_audio_ctx, n_text_state).
*/
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
const std::vector<int64_t> &GetInitialTokens() const;
const std::vector<int32_t> &GetAllLanguageIDs() const;
const std::unordered_map<std::string, int32_t> &GetLang2ID() const;
const std::unordered_map<int32_t, std::string> &GetID2Lang() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
int32_t NoTimeStampsToken() const;
int32_t EOT() const;
int32_t SOT() const;
int32_t TextCtx() const;
int32_t VocabSize() const;
int32_t Translate() const;
bool IsMultiLingual() const;
private:
class Impl;