diff --git a/tokenization_internlm2_fast.py b/tokenization_internlm2_fast.py index 1506e11..4d9d5f1 100644 --- a/tokenization_internlm2_fast.py +++ b/tokenization_internlm2_fast.py @@ -56,14 +56,14 @@ class InternLM2Converter(SpmConverter): return unk_id def decoder(self, replacement, add_prefix_space): - return decoders.Sequence( - [ - decoders.Replace("▁", " "), - decoders.ByteFallback(), - decoders.Fuse(), - decoders.Strip(content=" ", left=1), - ] - ) + decoders_sequence = [ + decoders.Replace("▁", " "), + decoders.ByteFallback(), + decoders.Fuse(), + ] + if self.proto.normalizer_spec.add_dummy_prefix: + decoders_sequence.append(decoders.Strip(content=" ", left=1)) + return decoders.Sequence(decoders_sequence) def tokenizer(self, proto): model_type = proto.trainer_spec.model_type