[FIX] Update EOS from config (#2475)
This commit is contained in:
@@ -15,7 +15,8 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Optional, Union
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -271,6 +272,14 @@ class ModelConfig:
|
||||
self.quantization,
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
||||
if eos_ids:
|
||||
# it can be either int or list of int
|
||||
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||
return eos_ids
|
||||
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
|
||||
@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -209,6 +209,7 @@ class Req:
|
||||
lora_path: Optional[str] = None,
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
eos_token_ids: Optional[Set[int]] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -236,6 +237,7 @@ class Req:
|
||||
self.finished_reason = None
|
||||
self.to_abort = False
|
||||
self.stream = stream
|
||||
self.eos_token_ids = eos_token_ids
|
||||
|
||||
# For incremental decoding
|
||||
# ----- | --------- read_ids -------|
|
||||
@@ -395,18 +397,23 @@ class Req:
|
||||
|
||||
last_token_id = self.output_ids[-1]
|
||||
|
||||
matched_eos = False
|
||||
if not self.sampling_params.ignore_eos:
|
||||
matched_eos = False
|
||||
|
||||
# Check stop token ids
|
||||
if self.sampling_params.stop_token_ids:
|
||||
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
||||
if self.tokenizer is not None:
|
||||
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
||||
if self.tokenizer.additional_stop_token_ids:
|
||||
matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
|
||||
if matched_eos and not self.sampling_params.ignore_eos:
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
||||
return
|
||||
# Check stop token ids
|
||||
if self.sampling_params.stop_token_ids:
|
||||
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
||||
if self.eos_token_ids:
|
||||
matched_eos |= last_token_id in self.eos_token_ids
|
||||
if self.tokenizer is not None:
|
||||
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
||||
if self.tokenizer.additional_stop_token_ids:
|
||||
matched_eos |= (
|
||||
last_token_id in self.tokenizer.additional_stop_token_ids
|
||||
)
|
||||
if matched_eos:
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
||||
return
|
||||
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
|
||||
@@ -517,6 +517,7 @@ class Scheduler:
|
||||
stream=recv_req.stream,
|
||||
lora_path=recv_req.lora_path,
|
||||
input_embeds=recv_req.input_embeds,
|
||||
eos_token_ids=self.model_config.get_hf_eos_token_id(),
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user