Support LoRA in Completion API (#2243)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -486,6 +486,7 @@ def v1_generate_request(
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
lora_paths = []
|
||||
|
||||
for request in all_requests:
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
@@ -496,6 +497,7 @@ def v1_generate_request(
|
||||
)
|
||||
|
||||
prompts.append(request.prompt)
|
||||
lora_paths.append(request.lora_path)
|
||||
if request.echo and request.logprobs:
|
||||
current_logprob_start_len = 0
|
||||
else:
|
||||
@@ -534,6 +536,7 @@ def v1_generate_request(
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
lora_paths = lora_paths[0]
|
||||
else:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
@@ -549,6 +552,7 @@ def v1_generate_request(
|
||||
return_text_in_logprobs=True,
|
||||
stream=all_requests[0].stream,
|
||||
rid=request_ids,
|
||||
lora_path=lora_paths,
|
||||
)
|
||||
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
|
||||
@@ -166,6 +166,7 @@ class CompletionRequest(BaseModel):
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
user: Optional[str] = None
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
json_schema: Optional[str] = None
|
||||
|
||||
Reference in New Issue
Block a user