Replace Clone() with View() (#432)
Co-authored-by: hiedean <hiedean@tju.edu.cn>
This commit is contained in:
@@ -94,7 +94,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
|
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
|
||||||
|
|
||||||
Ort::Value logit = model_->RunJoiner(
|
Ort::Value logit = model_->RunJoiner(
|
||||||
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
|
std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
float *p_logit = logit.GetTensorMutableData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||||
|
|||||||
@@ -67,13 +67,13 @@ class OnlineRnnLM::Impl {
|
|||||||
return {std::move(out[0]), std::move(next_states)};
|
return {std::move(out[0]), std::move(next_states)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() const {
|
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() {
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
ans.reserve(init_states_.size());
|
ans.reserve(init_states_.size());
|
||||||
for (const auto &s : init_states_) {
|
for (auto &s : init_states_) {
|
||||||
ans.emplace_back(Clone(allocator_, &s));
|
ans.emplace_back(View(&s));
|
||||||
}
|
}
|
||||||
return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)};
|
return {View(&init_scores_.value), std::move(ans)};
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
}
|
}
|
||||||
if (is_batch_decoder_out_cached) {
|
if (is_batch_decoder_out_cached) {
|
||||||
auto &r = result->front();
|
auto &r = result->front();
|
||||||
std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
std::vector<int64_t> decoder_out_shape =
|
||||||
|
r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
decoder_out_shape[0] = batch_size;
|
decoder_out_shape[0] = batch_size;
|
||||||
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size());
|
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(),
|
||||||
|
decoder_out_shape.data(), decoder_out_shape.size());
|
||||||
UseCachedDecoderOut(*result, &decoder_out);
|
UseCachedDecoderOut(*result, &decoder_out);
|
||||||
} else {
|
} else {
|
||||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||||
@@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
Ort::Value cur_encoder_out =
|
Ort::Value cur_encoder_out =
|
||||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||||
Ort::Value logit = model_->RunJoiner(
|
Ort::Value logit = model_->RunJoiner(
|
||||||
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
|
std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
const float *p_logit = logit.GetTensorData<float>();
|
const float *p_logit = logit.GetTensorData<float>();
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
cur_encoder_out =
|
cur_encoder_out =
|
||||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||||
Ort::Value logit = model_->RunJoiner(
|
Ort::Value logit = model_->RunJoiner(
|
||||||
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
|
std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
float *p_logit = logit.GetTensorMutableData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||||
|
|||||||
@@ -105,11 +105,11 @@ class OnlineWenetCtcModel::Impl {
|
|||||||
// - attn_cache
|
// - attn_cache
|
||||||
// - conv_cache
|
// - conv_cache
|
||||||
// - offset
|
// - offset
|
||||||
std::vector<Ort::Value> GetInitStates() const {
|
std::vector<Ort::Value> GetInitStates() {
|
||||||
std::vector<Ort::Value> ans;
|
std::vector<Ort::Value> ans;
|
||||||
ans.reserve(3);
|
ans.reserve(3);
|
||||||
ans.push_back(Clone(Allocator(), &attn_cache_));
|
ans.push_back(View(&attn_cache_));
|
||||||
ans.push_back(Clone(Allocator(), &conv_cache_));
|
ans.push_back(View(&conv_cache_));
|
||||||
|
|
||||||
int64_t offset_shape = 1;
|
int64_t offset_shape = 1;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user