[FIX] Update EOS from config (#2475)
This commit is contained in:
@@ -15,7 +15,8 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Optional, Union
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -271,6 +272,14 @@ class ModelConfig:
|
||||
self.quantization,
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
||||
if eos_ids:
|
||||
# it can be either int or list of int
|
||||
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||
return eos_ids
|
||||
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
|
||||
Reference in New Issue
Block a user