Support multilingual whisper models (#274)
This commit is contained in:
@@ -5,7 +5,9 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@@ -30,7 +32,7 @@ class OfflineWhisperModel {
|
||||
* - n_layer_cross_v: A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state)
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features);
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) const;
|
||||
|
||||
/** Run the decoder model.
|
||||
*
|
||||
@@ -58,7 +60,9 @@ class OfflineWhisperModel {
|
||||
Ort::Value>
|
||||
ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache,
|
||||
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset);
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset) const;
|
||||
|
||||
int32_t DetectLanguage() const;
|
||||
|
||||
/** Return the initial self kv cache in a pair
|
||||
* - n_layer_self_k_cache A 4-D tensor of shape
|
||||
@@ -66,14 +70,23 @@ class OfflineWhisperModel {
|
||||
* - n_layer_self_v_cache A 4-D tensor of shape
|
||||
* (n_text_layer, N, n_audio_ctx, n_text_state).
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache();
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() const;
|
||||
const std::vector<int64_t> &GetInitialTokens() const;
|
||||
const std::vector<int32_t> &GetAllLanguageIDs() const;
|
||||
const std::unordered_map<std::string, int32_t> &GetLang2ID() const;
|
||||
const std::unordered_map<int32_t, std::string> &GetID2Lang() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
int32_t NoTimeStampsToken() const;
|
||||
int32_t EOT() const;
|
||||
int32_t SOT() const;
|
||||
int32_t TextCtx() const;
|
||||
int32_t VocabSize() const;
|
||||
int32_t Translate() const;
|
||||
bool IsMultiLingual() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
|
||||
Reference in New Issue
Block a user