[FIX] Update EOS from config (#2475)

This commit is contained in:
Yang Zheng
2024-12-28 02:59:56 +08:00
committed by GitHub
parent d9e6ee382b
commit 7a7ac6bea1
3 changed files with 30 additions and 13 deletions

View File

@@ -15,7 +15,8 @@
import json import json
import logging import logging
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List, Optional, Union from functools import lru_cache
from typing import List, Optional, Set, Union
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
@@ -271,6 +272,14 @@ class ModelConfig:
self.quantization, self.quantization,
) )
@lru_cache()
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
return eos_ids
def get_hf_text_config(config: PretrainedConfig): def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models. """Get the "sub" config relevant to llm for multi modal models.

View File

@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import dataclasses import dataclasses
import logging import logging
from typing import List, Optional, Tuple, Union from typing import List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -209,6 +209,7 @@ class Req:
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None, input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
eos_token_ids: Optional[Set[int]] = None,
): ):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
@@ -236,6 +237,7 @@ class Req:
self.finished_reason = None self.finished_reason = None
self.to_abort = False self.to_abort = False
self.stream = stream self.stream = stream
self.eos_token_ids = eos_token_ids
# For incremental decoding # For incremental decoding
# ----- | --------- read_ids -------| # ----- | --------- read_ids -------|
@@ -395,18 +397,23 @@ class Req:
last_token_id = self.output_ids[-1] last_token_id = self.output_ids[-1]
matched_eos = False if not self.sampling_params.ignore_eos:
matched_eos = False
# Check stop token ids # Check stop token ids
if self.sampling_params.stop_token_ids: if self.sampling_params.stop_token_ids:
matched_eos = last_token_id in self.sampling_params.stop_token_ids matched_eos = last_token_id in self.sampling_params.stop_token_ids
if self.tokenizer is not None: if self.eos_token_ids:
matched_eos |= last_token_id == self.tokenizer.eos_token_id matched_eos |= last_token_id in self.eos_token_ids
if self.tokenizer.additional_stop_token_ids: if self.tokenizer is not None:
matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids matched_eos |= last_token_id == self.tokenizer.eos_token_id
if matched_eos and not self.sampling_params.ignore_eos: if self.tokenizer.additional_stop_token_ids:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) matched_eos |= (
return last_token_id in self.tokenizer.additional_stop_token_ids
)
if matched_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
# Check stop strings # Check stop strings
if len(self.sampling_params.stop_strs) > 0: if len(self.sampling_params.stop_strs) > 0:

View File

@@ -517,6 +517,7 @@ class Scheduler:
stream=recv_req.stream, stream=recv_req.stream,
lora_path=recv_req.lora_path, lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds, input_embeds=recv_req.input_embeds,
eos_token_ids=self.model_config.get_hf_eos_token_id(),
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer