Re-implement LM rescore for online transducer (#1231)
Co-authored-by: Martins Kronis <martins.kuznecovs@tilde.lv>
This commit is contained in:
@@ -22,13 +22,17 @@ class OnlineRnnLM : public OnlineLM {
|
||||
|
||||
explicit OnlineRnnLM(const OnlineLMConfig &config);
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
|
||||
// init scores for classic rescore
|
||||
std::vector<Ort::Value> GetInitStates() override;
|
||||
|
||||
/** ScoreToken a batch of sentences.
|
||||
// init scores for shallow fusion
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() override;
|
||||
|
||||
/** ScoreToken a batch of sentences (shallow fusion).
|
||||
*
|
||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||
* @param states It contains the states for the LM model
|
||||
* @return Return a pair containingo
|
||||
* @return Return a pair containing
|
||||
* - log_prob of NN LM
|
||||
* - updated states
|
||||
*
|
||||
@@ -36,13 +40,23 @@ class OnlineRnnLM : public OnlineLM {
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
|
||||
Ort::Value x, std::vector<Ort::Value> states) override;
|
||||
|
||||
/** This function updates lm_lob_prob and nn_lm_scores of hyp
|
||||
/** This function updates hyp.lm_lob_prob of hyps (classic rescore).
|
||||
*
|
||||
* @param scale LM score
|
||||
* @param context_size Context size of the transducer decoder model
|
||||
* @param hyps It is changed in-place.
|
||||
*
|
||||
*/
|
||||
void ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps) override;
|
||||
|
||||
/** This function updates lm_lob_prob and nn_lm_scores of hyp (shallow fusion).
|
||||
*
|
||||
* @param scale LM score
|
||||
* @param hyps It is changed in-place.
|
||||
*
|
||||
*/
|
||||
void ComputeLMScore(float scale, Hypothesis *hyp) override;
|
||||
void ComputeLMScoreSF(float scale, Hypothesis *hyp) override;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
|
||||
Reference in New Issue
Block a user