code refactoring and add CI (#11)

This commit is contained in:
Fangjun Kuang
2022-10-12 11:27:05 +08:00
committed by GitHub
parent d9b84d5526
commit 77ccd625b8
9 changed files with 267 additions and 121 deletions

View File

@@ -61,7 +61,6 @@ std::vector<std::vector<int32_t>> GreedySearch(
auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,
std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},
memory_info);
Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];
int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];
int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
@@ -78,12 +77,12 @@ std::vector<std::vector<int32_t>> GreedySearch(
auto logits = model->joiner_forward(cur_encoder_out,
projected_decoder_out_vector,
std::vector<int64_t> {1, 1, 1, projected_encoder_out_dim2},
std::vector<int64_t> {1, 1, 1, projected_decoder_out_dim},
std::vector<int64_t> {1, projected_encoder_out_dim2},
std::vector<int64_t> {1, projected_decoder_out_dim},
memory_info);
Ort::Value &logits_tensor = logits[0];
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[3];
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);
int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));