Fix a bug for multilingual ASR (#281)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.7.8")
|
set(SHERPA_ONNX_VERSION "1.7.9")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -136,8 +136,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
|||||||
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
|
auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
int32_t vocab_size = logits_shape[2];
|
int32_t vocab_size = logits_shape[2];
|
||||||
|
|
||||||
int32_t max_token_id = static_cast<int32_t>(std::distance(
|
const float *p_start = p_logits + (logits_shape[1] - 1) * vocab_size;
|
||||||
p_logits, std::max_element(p_logits, p_logits + vocab_size)));
|
|
||||||
|
int32_t max_token_id = static_cast<int32_t>(
|
||||||
|
std::distance(p_start, std::max_element(p_start, p_start + vocab_size)));
|
||||||
|
|
||||||
int32_t n_text_ctx = model_->TextCtx();
|
int32_t n_text_ctx = model_->TextCtx();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user