Format code & move functions (#155)
This commit is contained in:
@@ -193,6 +193,7 @@ def match_chat_ml(model_path: str):
|
|||||||
if "qwen" in model_path and "chat" in model_path:
|
if "qwen" in model_path and "chat" in model_path:
|
||||||
return get_chat_template("chatml")
|
return get_chat_template("chatml")
|
||||||
|
|
||||||
|
|
||||||
@register_chat_template_matching_function
|
@register_chat_template_matching_function
|
||||||
def match_chat_yi(model_path: str):
|
def match_chat_yi(model_path: str):
|
||||||
model_path = model_path.lower()
|
model_path = model_path.lower()
|
||||||
|
|||||||
@@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module):
|
|||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||||
]
|
]
|
||||||
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
|
logprobs_cumsum = torch.cumsum(
|
||||||
|
prefill_logprobs, dim=0, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
start = input_metadata.extend_start_loc.clone()
|
start = input_metadata.extend_start_loc.clone()
|
||||||
end = start + input_metadata.extend_seq_lens - 2
|
end = start + input_metadata.extend_seq_lens - 2
|
||||||
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||||
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
|
sum_logp = (
|
||||||
|
logprobs_cumsum[end]
|
||||||
|
- logprobs_cumsum[start]
|
||||||
|
+ prefill_logprobs[start]
|
||||||
|
)
|
||||||
normalized_logprobs = sum_logp / (
|
normalized_logprobs = sum_logp / (
|
||||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,14 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
|
|
||||||
|
|
||||||
class RadixAttention(nn.Module):
|
class RadixAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
|
||||||
self,
|
|
||||||
num_heads,
|
|
||||||
head_dim,
|
|
||||||
scaling,
|
|
||||||
num_kv_heads,
|
|
||||||
layer_id
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_q_head_num = num_heads
|
self.tp_q_head_num = num_heads
|
||||||
self.tp_k_head_num = num_kv_heads
|
self.tp_k_head_num = num_kv_heads
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ class BatchStrOut:
|
|||||||
class FlushCacheReq:
|
class FlushCacheReq:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DetokenizeReqInput:
|
class DetokenizeReqInput:
|
||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ import rpyc
|
|||||||
import torch
|
import torch
|
||||||
from rpyc.utils.classic import obtain
|
from rpyc.utils.classic import obtain
|
||||||
from rpyc.utils.server import ThreadedServer
|
from rpyc.utils.server import ThreadedServer
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
|
||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
@@ -391,8 +391,12 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
logprobs = None
|
logprobs = None
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
# Forward
|
# Forward
|
||||||
logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = (
|
logits, (
|
||||||
self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob)
|
prefill_logprobs,
|
||||||
|
normalized_logprobs,
|
||||||
|
last_logprobs,
|
||||||
|
) = self.model_runner.forward(
|
||||||
|
batch, ForwardMode.EXTEND, batch.return_logprob
|
||||||
)
|
)
|
||||||
if prefill_logprobs is not None:
|
if prefill_logprobs is not None:
|
||||||
logprobs = prefill_logprobs.cpu().tolist()
|
logprobs = prefill_logprobs.cpu().tolist()
|
||||||
@@ -407,7 +411,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
if last_logprobs is not None:
|
if last_logprobs is not None:
|
||||||
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
last_logprobs = (
|
||||||
|
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
||||||
|
)
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
pt = 0
|
pt = 0
|
||||||
@@ -482,7 +488,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
if last_logprobs is not None:
|
if last_logprobs is not None:
|
||||||
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()
|
last_logprobs = last_logprobs[
|
||||||
|
torch.arange(len(reqs)), next_token_ids
|
||||||
|
].tolist()
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
||||||
@@ -620,15 +628,16 @@ class ModelRpcClient:
|
|||||||
self.step = async_wrap("step")
|
self.step = async_wrap("step")
|
||||||
|
|
||||||
|
|
||||||
def start_model_process(port):
|
def _init_service(port):
|
||||||
def _init_service(port):
|
t = ThreadedServer(
|
||||||
t = ThreadedServer(
|
ModelRpcServer(),
|
||||||
ModelRpcServer(),
|
port=port,
|
||||||
port=port,
|
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
||||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
)
|
||||||
)
|
t.start()
|
||||||
t.start()
|
|
||||||
|
|
||||||
|
|
||||||
|
def start_model_process(port):
|
||||||
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
||||||
proc.start()
|
proc.start()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ from vllm.model_executor.model_loader import _set_default_torch_dtype
|
|||||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||||
|
|
||||||
import sglang
|
import sglang
|
||||||
QUANTIONCONFIG_MAPPING = {'awq': AWQConfig,
|
|
||||||
'gptq': GPTQConfig}
|
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
|
||||||
|
|
||||||
logger = logging.getLogger("model_runner")
|
logger = logging.getLogger("model_runner")
|
||||||
|
|
||||||
@@ -283,9 +283,13 @@ class ModelRunner:
|
|||||||
self.model_config.hf_config, "quantization_config", None
|
self.model_config.hf_config, "quantization_config", None
|
||||||
)
|
)
|
||||||
if hf_quant_config is not None:
|
if hf_quant_config is not None:
|
||||||
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_config['quant_method'])
|
quant_config_class = QUANTIONCONFIG_MAPPING.get(
|
||||||
|
hf_quant_config["quant_method"]
|
||||||
|
)
|
||||||
if quant_config_class is None:
|
if quant_config_class is None:
|
||||||
raise ValueError(f"Unsupported quantization method: {hf_quant_config['quant_method']}")
|
raise ValueError(
|
||||||
|
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
|
||||||
|
)
|
||||||
quant_config = quant_config_class.from_config(hf_quant_config)
|
quant_config = quant_config_class.from_config(hf_quant_config)
|
||||||
logger.info(f"quant_config: {quant_config}")
|
logger.info(f"quant_config: {quant_config}")
|
||||||
linear_method = quant_config.get_linear_method()
|
linear_method = quant_config.get_linear_method()
|
||||||
|
|||||||
@@ -42,14 +42,14 @@ class QWenMLP(nn.Module):
|
|||||||
2 * [intermediate_size],
|
2 * [intermediate_size],
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
linear_method=linear_method
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
self.c_proj = RowParallelLinear(
|
self.c_proj = RowParallelLinear(
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
linear_method=linear_method
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -74,7 +74,7 @@ class QWenAttention(nn.Module):
|
|||||||
layer_id: int = 0,
|
layer_id: int = 0,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
linear_method: Optional[LinearMethodBase] = None
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -86,18 +86,18 @@ class QWenAttention(nn.Module):
|
|||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
self.c_attn = QKVParallelLinear(
|
self.c_attn = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
linear_method=linear_method
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
self.c_proj = RowParallelLinear(
|
self.c_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
linear_method=linear_method
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -143,12 +143,16 @@ class QWenBlock(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
linear_method=linear_method
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method)
|
self.mlp = QWenMLP(
|
||||||
|
config.hidden_size,
|
||||||
|
config.intermediate_size // 2,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -186,7 +190,10 @@ class QWenModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)]
|
[
|
||||||
|
QWenBlock(config, i, linear_method=linear_method)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
|||||||
@@ -4,14 +4,17 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from sglang.srt.models.llava import (
|
||||||
|
LlavaLlamaForCausalLM,
|
||||||
|
clip_vision_embed_forward,
|
||||||
|
monkey_path_clip_vision_embed_forward,
|
||||||
|
)
|
||||||
from transformers import CLIPVisionModel, LlavaConfig
|
from transformers import CLIPVisionModel, LlavaConfig
|
||||||
from vllm.model_executor.weight_utils import (
|
from vllm.model_executor.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
hf_model_weights_iterator,
|
hf_model_weights_iterator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward
|
|
||||||
|
|
||||||
|
|
||||||
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|||||||
super().__init__(self.config)
|
super().__init__(self.config)
|
||||||
|
|
||||||
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
|
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
|
||||||
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./"
|
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
|
||||||
|
"./", ""
|
||||||
|
) # Everything after "./"
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(
|
||||||
self,
|
self,
|
||||||
@@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|||||||
):
|
):
|
||||||
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
|
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
|
||||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||||
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder
|
model_name_or_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
subfolder=self.vision_tower_subfolder,
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
self.vision_tower.eval()
|
self.vision_tower.eval()
|
||||||
@@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|||||||
|
|
||||||
monkey_path_clip_vision_embed_forward()
|
monkey_path_clip_vision_embed_forward()
|
||||||
|
|
||||||
|
|
||||||
class YiVLMultiModalProjector(nn.Module):
|
class YiVLMultiModalProjector(nn.Module):
|
||||||
def __init__(self, config: LlavaConfig):
|
def __init__(self, config: LlavaConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
|
self.linear_1 = nn.Linear(
|
||||||
|
config.vision_config.hidden_size, config.text_config.hidden_size
|
||||||
|
)
|
||||||
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
|
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
|
||||||
self.act = nn.GELU()
|
self.act = nn.GELU()
|
||||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
|
self.linear_2 = nn.Linear(
|
||||||
|
config.text_config.hidden_size, config.text_config.hidden_size
|
||||||
|
)
|
||||||
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)
|
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)
|
||||||
|
|
||||||
def forward(self, image_features):
|
def forward(self, image_features):
|
||||||
@@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module):
|
|||||||
hidden_states = self.ln_2(hidden_states)
|
hidden_states = self.ln_2(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
EntryClass = YiVLForCausalLM
|
|
||||||
|
EntryClass = YiVLForCausalLM
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ chat_template_name = None
|
|||||||
# FIXME: Remove this once we drop support for pydantic 1.x
|
# FIXME: Remove this once we drop support for pydantic 1.x
|
||||||
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
||||||
|
|
||||||
|
|
||||||
def jsonify_pydantic_model(obj: BaseModel):
|
def jsonify_pydantic_model(obj: BaseModel):
|
||||||
if IS_PYDANTIC_1:
|
if IS_PYDANTIC_1:
|
||||||
return obj.json(ensure_ascii=False)
|
return obj.json(ensure_ascii=False)
|
||||||
@@ -165,7 +166,7 @@ async def v1_completions(raw_request: Request):
|
|||||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||||
|
|
||||||
if not stream_buffer: # The first chunk
|
if not stream_buffer: # The first chunk
|
||||||
if request.echo:
|
if request.echo:
|
||||||
# Prepend prompt in response text.
|
# Prepend prompt in response text.
|
||||||
text = request.prompt + text
|
text = request.prompt + text
|
||||||
@@ -219,7 +220,9 @@ async def v1_completions(raw_request: Request):
|
|||||||
token_logprob_pos = prompt_tokens
|
token_logprob_pos = prompt_tokens
|
||||||
|
|
||||||
logprobs = (
|
logprobs = (
|
||||||
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
|
await make_openai_style_logprobs(
|
||||||
|
ret["meta_info"]["token_logprob"][token_logprob_pos:]
|
||||||
|
)
|
||||||
if request.logprobs is not None
|
if request.logprobs is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class ServerArgs:
|
|||||||
"--max-prefill-num-token",
|
"--max-prefill-num-token",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.max_prefill_num_token,
|
default=ServerArgs.max_prefill_num_token,
|
||||||
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length."
|
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tp-size",
|
"--tp-size",
|
||||||
|
|||||||
@@ -259,4 +259,4 @@ def load_image(image_file):
|
|||||||
else:
|
else:
|
||||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import argparse
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def test_decode(url, return_logprob):
|
def test_decode(url, return_logprob):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
@@ -27,6 +28,7 @@ def test_decode(url, return_logprob):
|
|||||||
)
|
)
|
||||||
print(response.json())
|
print(response.json())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import json
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def test_decode_stream(url, return_logprob):
|
def test_decode_stream(url, return_logprob):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
@@ -39,7 +40,7 @@ def test_decode_stream(url, return_logprob):
|
|||||||
assert data["meta_info"]["prompt_logprob"] is not None
|
assert data["meta_info"]["prompt_logprob"] is not None
|
||||||
assert data["meta_info"]["token_logprob"] is not None
|
assert data["meta_info"]["token_logprob"] is not None
|
||||||
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
||||||
if prev == 0: # Skip prompt logprobs
|
if prev == 0: # Skip prompt logprobs
|
||||||
prev = data["meta_info"]["prompt_tokens"]
|
prev = data["meta_info"]["prompt_tokens"]
|
||||||
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
|
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
|
||||||
print(f"{token_txt}\t{logprob}", flush=True)
|
print(f"{token_txt}\t{logprob}", flush=True)
|
||||||
@@ -50,6 +51,7 @@ def test_decode_stream(url, return_logprob):
|
|||||||
prev = len(output)
|
prev = len(output)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
|
|||||||
@@ -64,9 +64,8 @@ def test_completion_stream(args, echo, logprobs):
|
|||||||
first = False
|
first = False
|
||||||
if logprobs:
|
if logprobs:
|
||||||
print(
|
print(
|
||||||
f"{r.choices[0].text:12s}\t"
|
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
|
||||||
f"{r.choices[0].logprobs.token_logprobs}",
|
flush=True,
|
||||||
flush=True
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(r.choices[0].text, end="", flush=True)
|
print(r.choices[0].text, end="", flush=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user