add modified beam search (#69)
This commit is contained in:
@@ -461,24 +461,6 @@ OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
|
||||
return {std::move(encoder_out[0]), std::move(next_states)};
|
||||
}
|
||||
|
||||
Ort::Value OnlineZipformerTransducerModel::BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) {
|
||||
int32_t batch_size = static_cast<int32_t>(results.size());
|
||||
std::array<int64_t, 2> shape{batch_size, context_size_};
|
||||
Ort::Value decoder_input =
|
||||
Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());
|
||||
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||
|
||||
for (const auto &r : results) {
|
||||
const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;
|
||||
const int64_t *end = r.tokens.data() + r.tokens.size();
|
||||
std::copy(begin, end, p);
|
||||
p += context_size_;
|
||||
}
|
||||
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
Ort::Value OnlineZipformerTransducerModel::RunDecoder(
|
||||
Ort::Value decoder_input) {
|
||||
auto decoder_out = decoder_sess_->Run(
|
||||
|
||||
Reference in New Issue
Block a user