[FIX] Update EOS from config (#2475)

This commit is contained in:
Yang Zheng
2024-12-28 02:59:56 +08:00
committed by GitHub
parent d9e6ee382b
commit 7a7ac6bea1
3 changed files with 30 additions and 13 deletions

View File

@@ -15,7 +15,8 @@
import json
import logging
from enum import IntEnum, auto
from typing import List, Optional, Union
from functools import lru_cache
from typing import List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
@@ -271,6 +272,14 @@ class ModelConfig:
self.quantization,
)
@lru_cache()
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
return eos_ids
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.