Add address sanitizer and undefined behavior sanitizer (#951)
This commit is contained in:
@@ -17,7 +17,7 @@ namespace sherpa_onnx {
|
||||
static void UseCachedDecoderOut(
|
||||
const std::vector<int32_t> &hyps_row_splits,
|
||||
const std::vector<OnlineTransducerDecoderResult> &results,
|
||||
int32_t context_size, Ort::Value *decoder_out) {
|
||||
Ort::Value *decoder_out) {
|
||||
std::vector<int64_t> shape =
|
||||
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -80,7 +80,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (encoder_out_shape[0] != result->size()) {
|
||||
if (static_cast<int32_t>(encoder_out_shape[0]) !=
|
||||
static_cast<int32_t>(result->size())) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]),
|
||||
@@ -117,8 +118,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
if (t == 0) {
|
||||
UseCachedDecoderOut(hyps_row_splits, *result, model_->ContextSize(),
|
||||
&decoder_out);
|
||||
UseCachedDecoderOut(hyps_row_splits, *result, &decoder_out);
|
||||
}
|
||||
|
||||
Ort::Value cur_encoder_out =
|
||||
@@ -136,10 +136,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
int32_t p_logit_items = vocab_size * num_hyps;
|
||||
std::vector<float> logit_with_temperature(p_logit_items);
|
||||
{
|
||||
std::copy(p_logit,
|
||||
p_logit + p_logit_items,
|
||||
std::copy(p_logit, p_logit + p_logit_items,
|
||||
logit_with_temperature.begin());
|
||||
for (float& elem : logit_with_temperature) {
|
||||
for (float &elem : logit_with_temperature) {
|
||||
elem /= temperature_scale_;
|
||||
}
|
||||
LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps);
|
||||
@@ -226,7 +225,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
cur.push_back(std::move(hyps));
|
||||
p_logprob += (end - start) * vocab_size;
|
||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||
} // for (int32_t t = 0; t != num_frames; ++t)
|
||||
} // for (int32_t t = 0; t != num_frames; ++t)
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &hyps = cur[b];
|
||||
@@ -242,7 +241,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
|
||||
void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
|
||||
OnlineTransducerDecoderResult *result) {
|
||||
if (result->tokens.size() == model_->ContextSize()) {
|
||||
if (static_cast<int32_t>(result->tokens.size()) == model_->ContextSize()) {
|
||||
result->decoder_out = Ort::Value{nullptr};
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user