567 lines
20 KiB
Python
567 lines
20 KiB
Python
import dataclasses
|
|
import pprint
|
|
from functools import partial
|
|
import re
|
|
import os
|
|
from threading import Lock
|
|
import urllib
|
|
import time
|
|
from typing import List, Optional, Union
|
|
|
|
from pydantic import BaseModel
|
|
import absl.logging
|
|
from tqdm import tqdm, trange
|
|
import numpy as np
|
|
import mlxu
|
|
from ml_collections import ConfigDict
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
import gradio as gr
|
|
import requests
|
|
from requests.exceptions import Timeout, ConnectionError
|
|
|
|
|
|
class InferenceRequest(BaseModel):
|
|
prefix_text: Optional[List[str]] = None
|
|
text: Optional[List[str]] = None
|
|
until: Optional[Union[List[str], List[List[str]]]] = None
|
|
temperature: Optional[float] = None
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
prompt: str
|
|
context: str = ''
|
|
temperature: Optional[float] = None
|
|
|
|
|
|
class LMServer(object):
|
|
""" HTTP server for serving langauge models. """
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.host = '0.0.0.0'
|
|
config.port = 5007
|
|
config.batch_size = 1
|
|
config.logging = False
|
|
config.pre_compile = 'loglikelihood'
|
|
config.default_temperature = 1.0
|
|
config.greedy_until_max_length = 5000
|
|
config.prepend_to_prefix = ''
|
|
config.append_to_prefix = ''
|
|
config.prepend_to_text = ''
|
|
config.append_to_text = ''
|
|
config.chat_prepend_text = ''
|
|
config.chat_user_prefix = ''
|
|
config.chat_user_suffix = ''
|
|
config.chat_lm_prefix = ''
|
|
config.chat_lm_suffix = ''
|
|
config.notes = ''
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
def __init__(self, config):
|
|
self.config = self.get_default_config(config)
|
|
self.lock = Lock()
|
|
self.app = FastAPI()
|
|
self.app.post('/loglikelihood')(self.serve_loglikelihood)
|
|
self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling)
|
|
self.app.post('/generate')(self.serve_generate)
|
|
self.app.post('/greedy-until')(self.serve_greedy_until)
|
|
self.app.post('/chat')(self.serve_chat)
|
|
self.app.get('/ready')(self.serve_ready)
|
|
self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/')
|
|
|
|
@staticmethod
|
|
def loglikelihood(prefix_text, text):
|
|
raise NotImplementedError()
|
|
|
|
@staticmethod
|
|
def loglikelihood_rolling(text):
|
|
raise NotImplementedError()
|
|
|
|
@staticmethod
|
|
def generate(text, temperature):
|
|
raise NotImplementedError()
|
|
|
|
@staticmethod
|
|
def greedy_until(prefix_text, until, max_length):
|
|
raise NotImplementedError()
|
|
|
|
@staticmethod
|
|
def to_list(x):
|
|
if isinstance(x, np.ndarray):
|
|
return x.tolist()
|
|
return x
|
|
|
|
def serve_ready(self):
|
|
return 'Ready!\n'
|
|
|
|
def serve_loglikelihood(self, data: InferenceRequest):
|
|
with self.lock:
|
|
if self.config.logging:
|
|
absl.logging.info(
|
|
'\n========= Serving Log Likelihood Request ========= \n'
|
|
+ pprint.pformat(data) + '\n'
|
|
)
|
|
|
|
if data.prefix_text is None:
|
|
data.prefix_text = ['' for _ in data.text]
|
|
|
|
prefix_text = [
|
|
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
|
for p in data.prefix_text
|
|
]
|
|
text = [
|
|
self.config.prepend_to_text + t + self.config.append_to_text
|
|
for t in data.text
|
|
]
|
|
|
|
log_likelihood = []
|
|
is_greedy = []
|
|
for i in trange(0, len(text), self.config.batch_size, ncols=0):
|
|
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
|
batch_text = text[i:i + self.config.batch_size]
|
|
batch_size = len(batch_text)
|
|
|
|
if batch_size < self.config.batch_size:
|
|
extra = self.config.batch_size - batch_size
|
|
batch_prefix_text.extend(['a' for _ in range(extra)])
|
|
batch_text.extend(['a' for _ in range(extra)])
|
|
|
|
batch_log_likelihood, batch_is_greedy = self.loglikelihood(
|
|
batch_prefix_text, batch_text
|
|
)
|
|
batch_log_likelihood = self.to_list(batch_log_likelihood)
|
|
batch_is_greedy = self.to_list(batch_is_greedy)
|
|
log_likelihood.extend(batch_log_likelihood[:batch_size])
|
|
is_greedy.extend(batch_is_greedy[:batch_size])
|
|
|
|
output = {
|
|
'prefix_text': data.prefix_text,
|
|
'text': data.text,
|
|
'log_likelihood': log_likelihood,
|
|
'is_greedy': is_greedy,
|
|
}
|
|
if self.config.logging:
|
|
absl.logging.info(
|
|
'\n========= Output ========= \n'
|
|
+ pprint.pformat(output) + '\n'
|
|
)
|
|
|
|
return output
|
|
|
|
def serve_loglikelihood_rolling(self, data: InferenceRequest):
|
|
with self.lock:
|
|
if self.config.logging:
|
|
absl.logging.info(
|
|
'\n========= Serving Log Likelihood Request ========= \n'
|
|
+ pprint.pformat(data) + '\n'
|
|
)
|
|
|
|
text = [
|
|
self.config.prepend_to_text + t + self.config.append_to_text
|
|
for t in data.text
|
|
]
|
|
log_likelihood = []
|
|
is_greedy = []
|
|
for i in trange(0, len(text), self.config.batch_size, ncols=0):
|
|
batch_text = text[i:i + self.config.batch_size]
|
|
batch_size = len(batch_text)
|
|
|
|
if batch_size < self.config.batch_size:
|
|
extra = self.config.batch_size - batch_size
|
|
batch_text.extend(['a' for _ in range(extra)])
|
|
|
|
batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
|
|
batch_text
|
|
)
|
|
batch_log_likelihood = self.to_list(batch_log_likelihood)
|
|
batch_is_greedy = self.to_list(batch_is_greedy)
|
|
log_likelihood.extend(batch_log_likelihood[:batch_size])
|
|
is_greedy.extend(batch_is_greedy[:batch_size])
|
|
|
|
output = {
|
|
'text': data.text,
|
|
'log_likelihood': log_likelihood,
|
|
'is_greedy': is_greedy,
|
|
}
|
|
if self.config.logging:
|
|
absl.logging.info(
|
|
'\n========= Output ========= \n'
|
|
+ pprint.pformat(output) + '\n'
|
|
)
|
|
|
|
return output
|
|
|
|
def serve_generate(self, data: InferenceRequest):
|
|
with self.lock:
|
|
if self.config.logging:
|
|
absl.logging.info(
|
|
'\n========= Serving Generate Request ========= \n'
|
|
+ pprint.pformat(data) + '\n'
|
|
)
|
|
prefix_text = [
|
|
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
|
for p in data.prefix_text
|
|
]
|
|
|
|
if data.temperature is None:
|
|
data.temperature = self.config.default_temperature
|
|
|
|
output_text = []
|
|
for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
|
|
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
|
batch_size = len(batch_prefix_text)
|
|
|
|
if batch_size < self.config.batch_size:
|
|
extra = self.config.batch_size - batch_size
|
|
batch_prefix_text.extend(['a' for _ in range(extra)])
|
|
|
|
batch_output_text = self.generate(
|
|
batch_prefix_text,
|
|
temperature=data.temperature,
|
|
)
|
|
output_text.extend(self.to_list(batch_output_text)[:batch_size])
|
|
|
|
output = {
|
|
'prefix_text': data.prefix_text,
|
|
'output_text': output_text,
|
|
'temperature': data.temperature,
|
|
}
|
|
if self.config.logging:
|
|
absl.logging.info(
|
|
'\n========= Output ========= \n'
|
|
+ pprint.pformat(output) + '\n'
|
|
)
|
|
return output
|
|
|
|
def serve_greedy_until(self, data: InferenceRequest):
|
|
with self.lock:
|
|
if self.config.logging:
|
|
absl.logging.info(
|
|
'\n========= Serving Greedy Until Request ========= \n'
|
|
+ pprint.pformat(data) + '\n'
|
|
)
|
|
prefix_text = [
|
|
self.config.prepend_to_prefix + p + self.config.append_to_prefix
|
|
for p in data.prefix_text
|
|
]
|
|
until = data.until
|
|
max_length = self.config.greedy_until_max_length
|
|
|
|
output_text = []
|
|
for i in range(0, len(prefix_text), self.config.batch_size):
|
|
batch_prefix_text = prefix_text[i:i + self.config.batch_size]
|
|
batch_until = until[i:i + self.config.batch_size]
|
|
batch_size = len(batch_prefix_text)
|
|
|
|
batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length)
|
|
output_text.extend(self.to_list(batch_output_text)[:batch_size])
|
|
|
|
output = {
|
|
'prefix_text': data.prefix_text,
|
|
'until': data.until,
|
|
'max_length': max_length,
|
|
'output_text': output_text,
|
|
}
|
|
if self.config.logging:
|
|
absl.logging.info(
|
|
'\n========= Output ========= \n'
|
|
+ pprint.pformat(output) + '\n'
|
|
)
|
|
return output
|
|
|
|
def process_chat(self, prompt, context, temperature):
|
|
context = (
|
|
context + self.config.chat_user_prefix
|
|
+ prompt + self.config.chat_user_suffix
|
|
+ self.config.chat_lm_prefix
|
|
)
|
|
response = self.generate(
|
|
[self.config.chat_prepend_text + context],
|
|
temperature=float(temperature),
|
|
)[0]
|
|
context = context + response + self.config.chat_lm_suffix
|
|
return response, context
|
|
|
|
def serve_chat(self, data: ChatRequest):
|
|
if data.temperature is None:
|
|
data.temperature = self.config.default_temperature
|
|
response, context = self.process_chat(
|
|
data.prompt, data.context,
|
|
temperature=data.temperature,
|
|
)
|
|
return {
|
|
'response': response,
|
|
'context': context,
|
|
'temperature': data.temperature,
|
|
}
|
|
|
|
def create_chat_app(self):
|
|
with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot:
|
|
gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)')
|
|
gr.Markdown(self.config.notes)
|
|
chatbot = gr.Chatbot(label='Chat history')
|
|
msg = gr.Textbox(
|
|
placeholder='Type your message here...',
|
|
show_label=False
|
|
)
|
|
with gr.Row():
|
|
send = gr.Button('Send')
|
|
regenerate = gr.Button('Regenerate', interactive=False)
|
|
clear = gr.Button('Reset')
|
|
|
|
temp_slider = gr.Slider(
|
|
label='Temperature', minimum=0, maximum=2.0,
|
|
value=self.config.default_temperature
|
|
)
|
|
|
|
context_state = gr.State(['', ''])
|
|
|
|
def user_fn(user_message, history, context):
|
|
return {
|
|
msg: gr.update(value='', interactive=False),
|
|
clear: gr.update(interactive=False),
|
|
send: gr.update(interactive=False),
|
|
regenerate: gr.update(interactive=False),
|
|
chatbot: history + [[user_message, None]],
|
|
context_state: [context[1], context[1]],
|
|
}
|
|
|
|
def model_fn(history, context, temperature):
|
|
history[-1][1], new_context = self.process_chat(
|
|
history[-1][0], context[0], temperature
|
|
)
|
|
return {
|
|
msg: gr.update(value='', interactive=True),
|
|
clear: gr.update(interactive=True),
|
|
send: gr.update(interactive=True),
|
|
chatbot: history,
|
|
context_state: [context[0], new_context],
|
|
regenerate: gr.update(interactive=True),
|
|
}
|
|
|
|
def regenerate_fn():
|
|
return {
|
|
msg: gr.update(value='', interactive=False),
|
|
clear: gr.update(interactive=False),
|
|
send: gr.update(interactive=False),
|
|
regenerate: gr.update(interactive=False),
|
|
}
|
|
|
|
def clear_fn():
|
|
return {
|
|
chatbot: None,
|
|
msg: '',
|
|
context_state: ['', ''],
|
|
regenerate: gr.update(interactive=False),
|
|
}
|
|
|
|
msg.submit(
|
|
user_fn,
|
|
inputs=[msg, chatbot, context_state],
|
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
|
queue=False
|
|
).then(
|
|
model_fn,
|
|
inputs=[chatbot, context_state, temp_slider],
|
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
|
queue=True
|
|
)
|
|
send.click(
|
|
user_fn,
|
|
inputs=[msg, chatbot, context_state],
|
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
|
queue=False
|
|
).then(
|
|
model_fn,
|
|
inputs=[chatbot, context_state, temp_slider],
|
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
|
queue=True
|
|
)
|
|
regenerate.click(
|
|
regenerate_fn,
|
|
inputs=None,
|
|
outputs=[msg, clear, send, regenerate],
|
|
queue=False
|
|
).then(
|
|
model_fn,
|
|
inputs=[chatbot, context_state, temp_slider],
|
|
outputs=[msg, clear, send, chatbot, context_state, regenerate],
|
|
queue=True
|
|
)
|
|
clear.click(
|
|
clear_fn,
|
|
inputs=None,
|
|
outputs=[chatbot, msg, context_state, regenerate],
|
|
queue=False
|
|
)
|
|
|
|
gradio_chatbot.queue(concurrency_count=1)
|
|
return gradio_chatbot
|
|
|
|
def run(self):
|
|
if self.config.pre_compile != '':
|
|
if self.config.pre_compile == 'all':
|
|
pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat']
|
|
else:
|
|
pre_compile = self.config.pre_compile.split(',')
|
|
|
|
pre_compile_data = ['a' for _ in range(self.config.batch_size)]
|
|
for task in pre_compile:
|
|
if task == 'loglikelihood':
|
|
self.loglikelihood(pre_compile_data, pre_compile_data)
|
|
self.loglikelihood_rolling(pre_compile_data)
|
|
elif task == 'generate':
|
|
self.generate(pre_compile_data, 1.0)
|
|
elif task == 'greedy_until':
|
|
self.greedy_until(
|
|
pre_compile_data, pre_compile_data,
|
|
self.config.greedy_until_max_length
|
|
)
|
|
elif task == 'chat':
|
|
self.process_chat('a', 'a', 1.0)
|
|
else:
|
|
raise ValueError(f'Invalid precompile task: {task}!')
|
|
|
|
uvicorn.run(self.app, host=self.config.host, port=self.config.port)
|
|
|
|
|
|
class LMClient(object):
|
|
""" A simple client for the LM server. """
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.url = 'http://localhost:5007'
|
|
config.batch_size = 1
|
|
config.wait_for_ready = True
|
|
config.dummy = False
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
def __init__(self, config=None):
|
|
self.config = self.get_default_config(config)
|
|
if self.config.wait_for_ready:
|
|
self.wait_for_ready()
|
|
|
|
def wait_for_ready(self):
|
|
if self.config.dummy:
|
|
return
|
|
while True:
|
|
try:
|
|
requests.get(urllib.parse.urljoin(self.config.url, 'ready'))
|
|
return
|
|
except (Timeout, ConnectionError) as e:
|
|
time.sleep(10)
|
|
|
|
@staticmethod
|
|
def batched(iterator, batch_size):
|
|
batch = []
|
|
for example in iterator:
|
|
batch.append(example)
|
|
if len(batch) == batch_size:
|
|
yield batch
|
|
batch = []
|
|
if len(batch) > 0:
|
|
yield batch
|
|
|
|
def loglikelihood(self, prefix, text):
|
|
prefix, text = list(prefix), list(text)
|
|
if self.config.dummy:
|
|
return [-1.0 for _ in text], [False for _ in text]
|
|
|
|
log_likelihood = []
|
|
is_greedy = []
|
|
|
|
batched_iterator = list(zip(
|
|
self.batched(prefix, self.config.batch_size),
|
|
self.batched(text, self.config.batch_size)
|
|
))
|
|
for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
|
|
response = requests.post(
|
|
urllib.parse.urljoin(self.config.url, 'loglikelihood'),
|
|
json={'prefix_text': batch_prefix, 'text': batch_text}
|
|
).json()
|
|
log_likelihood.extend(response['log_likelihood'])
|
|
is_greedy.extend(response['is_greedy'])
|
|
|
|
return log_likelihood, is_greedy
|
|
|
|
def loglikelihood_rolling(self, text):
|
|
text = list(text)
|
|
if self.config.dummy:
|
|
return [-1.0 for _ in text], [False for _ in text]
|
|
|
|
log_likelihood = []
|
|
is_greedy = []
|
|
batched_iterator = list(self.batched(text, self.config.batch_size))
|
|
for batch_text in tqdm(batched_iterator, ncols=0):
|
|
response = requests.post(
|
|
urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'),
|
|
json={'text': batch_text}
|
|
).json()
|
|
log_likelihood.extend(response['log_likelihood'])
|
|
is_greedy.extend(response['is_greedy'])
|
|
return log_likelihood, is_greedy
|
|
|
|
def greedy_until(self, prefix, until):
|
|
prefix, until = list(prefix), list(until)
|
|
if self.config.dummy:
|
|
results = []
|
|
for u in until:
|
|
if isinstance(u, str):
|
|
results.append('dummy text ' + u)
|
|
else:
|
|
results.append('dummy text ' + u[0])
|
|
return results
|
|
|
|
batched_iterator = list(zip(
|
|
self.batched(prefix, self.config.batch_size),
|
|
self.batched(until, self.config.batch_size),
|
|
))
|
|
output_text = []
|
|
for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
|
|
response = requests.post(
|
|
urllib.parse.urljoin(self.config.url, 'greedy-until'),
|
|
json={'prefix_text': batch_prefix, 'until': batch_until}
|
|
).json()
|
|
output_text.extend(response['output_text'])
|
|
return output_text
|
|
|
|
def generate(self, prefix, temperature=None):
|
|
prefix = list(prefix)
|
|
if self.config.dummy:
|
|
return ['' for _ in prefix]
|
|
|
|
output_text = []
|
|
batched_iterator = list(self.batched(prefix, self.config.batch_size))
|
|
for batch_prefix in tqdm(batched_iterator, ncols=0):
|
|
response = requests.post(
|
|
urllib.parse.urljoin(self.config.url, 'generate'),
|
|
json={
|
|
'prefix_text': batch_prefix,
|
|
'temperature': temperature,
|
|
}
|
|
).json()
|
|
output_text.extend(response['output_text'])
|
|
return output_text
|
|
|
|
def chat(self, prompt, context, temperature=None):
|
|
if self.config.dummy:
|
|
return ''
|
|
response = requests.post(
|
|
urllib.parse.urljoin(self.config.url, 'chat'),
|
|
json={
|
|
'prompt': prompt,
|
|
'context': context,
|
|
'temperature': temperature,
|
|
}
|
|
).json()
|
|
return response['response'], response['context']
|