code refactoring and add CI (#11)
This commit is contained in:
@@ -61,7 +61,6 @@ std::vector<std::vector<int32_t>> GreedySearch(
|
||||
auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,
|
||||
std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},
|
||||
memory_info);
|
||||
|
||||
Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];
|
||||
int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];
|
||||
int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
@@ -78,12 +77,12 @@ std::vector<std::vector<int32_t>> GreedySearch(
|
||||
|
||||
auto logits = model->joiner_forward(cur_encoder_out,
|
||||
projected_decoder_out_vector,
|
||||
std::vector<int64_t> {1, 1, 1, projected_encoder_out_dim2},
|
||||
std::vector<int64_t> {1, 1, 1, projected_decoder_out_dim},
|
||||
std::vector<int64_t> {1, projected_encoder_out_dim2},
|
||||
std::vector<int64_t> {1, projected_decoder_out_dim},
|
||||
memory_info);
|
||||
|
||||
Ort::Value &logits_tensor = logits[0];
|
||||
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[3];
|
||||
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);
|
||||
|
||||
int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));
|
||||
|
||||
Reference in New Issue
Block a user