SamplingParams add "spaces_between_special_tokens" argument (#392)
This commit is contained in:
@@ -107,6 +107,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
"text": s.text_,
|
"text": s.text_,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||||
|
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||||
**sampling_params.to_srt_kwargs(),
|
**sampling_params.to_srt_kwargs(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -115,6 +116,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
"text": s.text_,
|
"text": s.text_,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||||
|
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||||
"dtype": "int",
|
"dtype": "int",
|
||||||
**sampling_params.to_srt_kwargs(),
|
**sampling_params.to_srt_kwargs(),
|
||||||
},
|
},
|
||||||
@@ -145,6 +147,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
"text": s.text_,
|
"text": s.text_,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||||
|
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||||
**sampling_params.to_srt_kwargs(),
|
**sampling_params.to_srt_kwargs(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -153,6 +156,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
"text": s.text_,
|
"text": s.text_,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||||
|
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||||
"dtype": "int",
|
"dtype": "int",
|
||||||
**sampling_params.to_srt_kwargs(),
|
**sampling_params.to_srt_kwargs(),
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class GlobalConfig:
|
|||||||
|
|
||||||
# Output configs
|
# Output configs
|
||||||
self.skip_special_tokens_in_output = True
|
self.skip_special_tokens_in_output = True
|
||||||
|
self.spaces_between_special_tokens_in_out = True
|
||||||
|
|
||||||
# Optimization configs
|
# Optimization configs
|
||||||
self.eager_fill_image = False
|
self.eager_fill_image = False
|
||||||
|
|||||||
@@ -38,10 +38,11 @@ class DetokenizerManager:
|
|||||||
if isinstance(recv_obj, BatchTokenIDOut):
|
if isinstance(recv_obj, BatchTokenIDOut):
|
||||||
output_tokens = recv_obj.output_tokens
|
output_tokens = recv_obj.output_tokens
|
||||||
|
|
||||||
# TODO(lmzheng): handle skip_special_tokens per request
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
||||||
output_strs = self.tokenizer.batch_decode(
|
output_strs = self.tokenizer.batch_decode(
|
||||||
output_tokens,
|
output_tokens,
|
||||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||||
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trim stop str
|
# Trim stop str
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ class BatchTokenIDOut:
|
|||||||
output_and_jump_forward_strs: List[str]
|
output_and_jump_forward_strs: List[str]
|
||||||
hit_stop_str: List[Optional[str]]
|
hit_stop_str: List[Optional[str]]
|
||||||
skip_special_tokens: List[bool]
|
skip_special_tokens: List[bool]
|
||||||
|
spaces_between_special_tokens: List[bool]
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
finished: List[bool]
|
finished: List[bool]
|
||||||
|
|
||||||
|
|||||||
@@ -549,6 +549,7 @@ class ModelRpcServer:
|
|||||||
output_and_jump_forward_strs = []
|
output_and_jump_forward_strs = []
|
||||||
output_hit_stop_str = []
|
output_hit_stop_str = []
|
||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
|
output_spaces_between_special_tokens = []
|
||||||
output_meta_info = []
|
output_meta_info = []
|
||||||
output_finished = []
|
output_finished = []
|
||||||
finished_indices = []
|
finished_indices = []
|
||||||
@@ -575,6 +576,9 @@ class ModelRpcServer:
|
|||||||
output_skip_special_tokens.append(
|
output_skip_special_tokens.append(
|
||||||
req.sampling_params.skip_special_tokens
|
req.sampling_params.skip_special_tokens
|
||||||
)
|
)
|
||||||
|
output_spaces_between_special_tokens.append(
|
||||||
|
req.sampling_params.spaces_between_special_tokens
|
||||||
|
)
|
||||||
|
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"prompt_tokens": req.prompt_tokens,
|
"prompt_tokens": req.prompt_tokens,
|
||||||
@@ -609,6 +613,7 @@ class ModelRpcServer:
|
|||||||
output_and_jump_forward_strs,
|
output_and_jump_forward_strs,
|
||||||
output_hit_stop_str,
|
output_hit_stop_str,
|
||||||
output_skip_special_tokens,
|
output_skip_special_tokens,
|
||||||
|
output_spaces_between_special_tokens,
|
||||||
output_meta_info,
|
output_meta_info,
|
||||||
output_finished,
|
output_finished,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class SamplingParams:
|
|||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
|
spaces_between_special_tokens: bool = True,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -29,6 +30,7 @@ class SamplingParams:
|
|||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.ignore_eos = ignore_eos
|
self.ignore_eos = ignore_eos
|
||||||
self.skip_special_tokens = skip_special_tokens
|
self.skip_special_tokens = skip_special_tokens
|
||||||
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.regex = regex
|
self.regex = regex
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user