From aee4f523cfd92f844208118e42dcc6bfeb271d08 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 12 May 2024 04:54:07 -0700 Subject: [PATCH] Fix logit processor bugs (#427) --- README.md | 5 -- python/sglang/api.py | 2 + python/sglang/backend/anthropic.py | 9 +- python/sglang/backend/openai.py | 2 +- python/sglang/backend/runtime_endpoint.py | 1 - python/sglang/srt/layers/logits_processor.py | 31 ++++--- python/sglang/srt/managers/io_struct.py | 1 - .../sglang/srt/managers/router/infer_batch.py | 34 ++++---- .../sglang/srt/managers/router/model_rpc.py | 82 ++++++++---------- .../srt/managers/router/model_runner.py | 28 +++--- .../sglang/srt/managers/tokenizer_manager.py | 49 +++++++++-- python/sglang/srt/openai_protocol.py | 1 + python/sglang/srt/server.py | 75 ++-------------- python/sglang/srt/server_args.py | 2 +- python/sglang/srt/utils.py | 6 +- python/sglang/test/test_programs.py | 2 +- test/{killall_python.sh => killall_sglang.sh} | 0 test/lang/example_image.png | Bin 0 -> 57365 bytes test/lang/test_openai_backend.py | 13 +-- test/lang/test_openai_spec.py | 68 --------------- test/srt/example_image.png | 1 + test/srt/test_httpserver_decode.py | 2 +- test/srt/test_httpserver_decode_stream.py | 1 + test/srt/test_httpserver_llava.py | 4 +- test/srt/test_httpserver_reuse.py | 2 +- test/srt/test_openai_server.py | 2 +- 26 files changed, 166 insertions(+), 257 deletions(-) rename test/{killall_python.sh => killall_sglang.sh} (100%) create mode 100644 test/lang/example_image.png delete mode 100644 test/lang/test_openai_spec.py create mode 120000 test/srt/example_image.png diff --git a/README.md b/README.md index ecbdbbfde..b1e372ae8 100644 --- a/README.md +++ b/README.md @@ -297,7 +297,6 @@ curl http://localhost:30000/generate \ Learn more about the argument format [here](docs/sampling_params.md). ### OpenAI Compatible API - In addition, the server supports an experimental OpenAI-compatible API. ```python @@ -386,7 +385,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md). ## Benchmark And Performance - - Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1 ![llama_7b](assets/llama_7b.jpg) @@ -410,7 +408,4 @@ https://github.com/sgl-project/sglang/issues/157 } ``` -[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/2312.07104) - - We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql). diff --git a/python/sglang/api.py b/python/sglang/api.py index 087bac031..f2b92a960 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -1,5 +1,6 @@ """Some Public API Definitions""" +import os import re from typing import Callable, List, Optional, Union @@ -31,6 +32,7 @@ def function( def Runtime(*args, **kwargs): # Avoid importing unnecessary dependency + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" from sglang.srt.server import Runtime return Runtime(*args, **kwargs) diff --git a/python/sglang/backend/anthropic.py b/python/sglang/backend/anthropic.py index 851bc176a..330b2a412 100644 --- a/python/sglang/backend/anthropic.py +++ b/python/sglang/backend/anthropic.py @@ -14,7 +14,7 @@ except ImportError as e: class Anthropic(BaseBackend): - def __init__(self, model_name): + def __init__(self, model_name, *args, **kwargs): super().__init__() if isinstance(anthropic, Exception): @@ -22,6 +22,7 @@ class Anthropic(BaseBackend): self.model_name = model_name self.chat_template = get_chat_template("claude") + self.client = anthropic.Anthropic(*args, **kwargs) def get_chat_template(self): return self.chat_template @@ -41,7 +42,7 @@ class Anthropic(BaseBackend): else: system = "" - ret = anthropic.Anthropic().messages.create( + ret = self.client.messages.create( model=self.model_name, system=system, messages=messages, @@ -66,11 +67,11 @@ class Anthropic(BaseBackend): else: system = "" - with anthropic.Anthropic().messages.stream( + with self.client.messages.stream( model=self.model_name, system=system, messages=messages, **sampling_params.to_anthropic_kwargs(), ) as stream: for text in stream.text_stream: - yield text, {} + yield text, {} \ No newline at end of file diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index 3c0210975..5ac1d9447 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -228,7 +228,7 @@ class OpenAI(BaseBackend): prompt_tokens.append(ret_token) decision = choices[np.argmax(scores)] - return decision, scores, scores + return decision, scores, None, None def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index c4adf6987..efa0cd5a4 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -220,7 +220,6 @@ class RuntimeEndpoint(BaseBackend): "sampling_params": {"max_new_tokens": 0}, "return_logprob": True, "logprob_start_len": max(prompt_len - 2, 0), - "return_text_in_logprobs": True, } self._add_images(s, data) res = http_request( diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 668cd3390..53d6620e9 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module): for i in range(all_logprobs.shape[0]): k = input_metadata.top_logprobs_nums[i] t = all_logprobs[i].topk(k) - v_cpu = t.values.cpu().tolist() - p_cpu = t.indices.cpu().tolist() + v_cpu = t.values.tolist() + p_cpu = t.indices.tolist() decode_top_logprobs.append(list(zip(v_cpu, p_cpu))) return None, decode_top_logprobs else: prefill_top_logprobs, decode_top_logprobs = [], [] pt = 0 # NOTE: the GPU-CPU overhead can be reduced - extend_seq_lens_cpu = input_metadata.extend_seq_lens - for i in range(len(input_metadata.extend_seq_lens)): + extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy() + for i in range(len(extend_seq_lens_cpu)): if extend_seq_lens_cpu[i] == 0: + prefill_top_logprobs.append([]) + decode_top_logprobs.append([]) continue k = input_metadata.top_logprobs_nums[i] t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k) - vs_cpu = t.values.cpu().tolist() - ps_cpu = t.indices.cpu().tolist() + vs_cpu = t.values.tolist() + ps_cpu = t.indices.tolist() prefill_top_logprobs.append( [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] ) decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) + pt += extend_seq_lens_cpu[i] return prefill_top_logprobs, decode_top_logprobs def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): @@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module): all_logits = all_logits[:, : self.config.vocab_size] all_logprobs = all_logits.float() - all_logits = None + del all_logits all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) - prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( - all_logprobs, input_metadata - ) + return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) + if return_top_logprob: + prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( + all_logprobs, input_metadata + ) + else: + prefill_top_logprobs = decode_top_logprobs = None if input_metadata.forward_mode == ForwardMode.DECODE: last_logprobs = all_logprobs return last_logits, ( None, None, - decode_top_logprobs, None, + decode_top_logprobs, last_logprobs, ) else: @@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module): ) return last_logits, ( prefill_token_logprobs, + normalized_prompt_logprobs, prefill_top_logprobs, decode_top_logprobs, - normalized_prompt_logprobs, last_logprobs, ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 6e64380c9..a99498949 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -25,7 +25,6 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output stream: bool = False - # TODO: make all parameters a Union[List[T], T] to allow for batched requests def post_init(self): is_single = isinstance(self.text, str) diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 3920fe039..1f655513a 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from enum import Enum, auto +from enum import IntEnum, auto from typing import List import numpy as np @@ -9,15 +9,15 @@ from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool -class ForwardMode(Enum): +class ForwardMode(IntEnum): PREFILL = auto() EXTEND = auto() DECODE = auto() -class FinishReason(Enum): - LENGTH = auto() +class FinishReason(IntEnum): EOS_TOKEN = auto() + LENGTH = auto() STOP_STR = auto() @@ -31,6 +31,7 @@ class Req: # Since jump forward may retokenize the prompt with partial outputs, # we maintain the original prompt length to report the correct usage. self.prompt_tokens = len(input_ids) + # The number of decoded tokens for token usage report. Note that # this does not include the jump forward tokens. self.completion_tokens_wo_jump_forward = 0 @@ -41,12 +42,11 @@ class Req: self.image_offset = 0 self.pad_value = None + # Sampling parameters self.sampling_params = None - self.return_logprob = False - self.logprob_start_len = 0 - self.top_logprobs_num = 0 self.stream = False + # Check finish self.tokenizer = None self.finished = False self.finish_reason = None @@ -56,13 +56,17 @@ class Req: self.prefix_indices = [] self.last_node = None + # Logprobs + self.return_logprob = False + self.logprob_start_len = 0 + self.top_logprobs_num = 0 + self.normalized_prompt_logprob = None self.prefill_token_logprobs = None self.decode_token_logprobs = None - self.normalized_prompt_logprob = None self.prefill_top_logprobs = None self.decode_top_logprobs = None - # For constrained decoding + # Constrained decoding self.regex_fsm = None self.regex_fsm_state = 0 self.jump_forward_map = None @@ -165,8 +169,8 @@ class Batch: out_cache_cont_end: torch.Tensor = None # for processing logprobs - top_logprobs_nums: List[int] = None return_logprob: bool = False + top_logprobs_nums: List[int] = None # for multimodal pixel_values: List[torch.Tensor] = None @@ -321,8 +325,8 @@ class Batch: ) retracted_reqs = [] - seq_lens_np = self.seq_lens.cpu().numpy() - req_pool_indices_np = self.req_pool_indices.cpu().numpy() + seq_lens_cpu = self.seq_lens.cpu().numpy() + req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() while self.token_to_kv_pool.available_size() < len(self.reqs): idx = sorted_indices.pop() req = self.reqs[idx] @@ -338,8 +342,8 @@ class Batch: # TODO: apply more fine-grained retraction token_indices = self.req_to_token_pool.req_to_token[ - req_pool_indices_np[idx] - ][: seq_lens_np[idx]] + req_pool_indices_cpu[idx] + ][: seq_lens_cpu[idx]] self.token_to_kv_pool.dec_refs(token_indices) self.filter_batch(sorted_indices) @@ -363,7 +367,7 @@ class Batch: # insert the old request into tree_cache token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1] if req_pool_indices_cpu is None: - req_pool_indices_cpu = self.req_pool_indices.cpu().tolist() + req_pool_indices_cpu = self.req_pool_indices.tolist() req_pool_idx = req_pool_indices_cpu[i] indices = self.req_to_token_pool.req_to_token[ req_pool_idx, : len(token_ids_in_memory) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index f837d9029..44b9d0210 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -36,7 +36,9 @@ from sglang.srt.utils import ( set_random_seed, ) + logger = logging.getLogger("model_rpc") +vllm_default_logger.setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN) @@ -54,9 +56,6 @@ class ModelRpcServer: self.tp_size = server_args.tp_size self.schedule_heuristic = server_args.schedule_heuristic self.disable_regex_jump_forward = server_args.disable_regex_jump_forward - vllm_default_logger.setLevel( - level=getattr(logging, server_args.log_level.upper()) - ) # Init model and tokenizer self.model_config = ModelConfig( @@ -65,7 +64,7 @@ class ModelRpcServer: context_length=server_args.context_length, ) - # for model end global settings + # For model end global settings server_args_dict = { "enable_flashinfer": server_args.enable_flashinfer, "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, @@ -164,7 +163,7 @@ class ModelRpcServer: logger.info("Cache flushed successfully!") else: warnings.warn( - "Cache not flushed because there are pending requests. " + f"Cache not flushed because there are pending requests. " f"#queue-req: {len(self.forward_queue)}, " f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" ) @@ -386,12 +385,12 @@ class ModelRpcServer: f"#running_req: {running_req}. " f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%." ) - logger.debug( - f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " - f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " - f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " - f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " - ) + #logger.debug( + # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " + # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " + # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " + # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " + #) new_batch = Batch.init_new( can_run_list, @@ -408,47 +407,41 @@ class ModelRpcServer: self.model_config.vocab_size, self.int_token_logit_bias ) - prefill_token_logprobs = None if batch.extend_num_tokens != 0: # Forward logits, ( prefill_token_logprobs, + normalized_prompt_logprobs, prefill_top_logprobs, decode_top_logprobs, - normalized_prompt_logprobs, last_logprobs, ) = self.model_runner.forward(batch, ForwardMode.EXTEND) if prefill_token_logprobs is not None: - prefill_token_logprobs = prefill_token_logprobs.cpu().tolist() - normalized_prompt_logprobs = normalized_prompt_logprobs.cpu().tolist() + prefill_token_logprobs = prefill_token_logprobs.tolist() + normalized_prompt_logprobs = normalized_prompt_logprobs.tolist() next_token_ids, _ = batch.sample(logits) - next_token_ids = next_token_ids.cpu().tolist() + + # Only transfer the selected logprobs of the next token to CPU to reduce overhead. + if last_logprobs is not None: + last_token_logprobs = ( + last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist() + ) + + next_token_ids = next_token_ids.tolist() else: next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) - ( - logits, - prefill_token_logprobs, - normalized_prompt_logprobs, - last_logprobs, - ) = (None,) * 4 - - # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. - reqs = batch.reqs - last_token_logprobs = None - if last_logprobs is not None: - last_token_logprobs = ( - last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist() - ) # Check finish condition pt = 0 - for i, req in enumerate(reqs): + for i, req in enumerate(batch.reqs): req.completion_tokens_wo_jump_forward += 1 req.output_ids = [next_token_ids[i]] req.check_finished() - if prefill_token_logprobs is not None: + if req.return_logprob: + req.normalized_prompt_logprob = normalized_prompt_logprobs[i] + # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. req.prefill_token_logprobs = list( zip( @@ -463,12 +456,14 @@ class ModelRpcServer: req.decode_token_logprobs = [ (last_token_logprobs[i], next_token_ids[i]) ] + + if req.top_logprobs_num > 0: req.prefill_top_logprobs = prefill_top_logprobs[i] if req.logprob_start_len == 0: req.prefill_top_logprobs = [None] + req.prefill_top_logprobs req.decode_top_logprobs = [decode_top_logprobs[i]] - req.normalized_prompt_logprob = normalized_prompt_logprobs[i] - pt += req.extend_input_len + + pt += req.extend_input_len self.handle_finished_requests(batch) @@ -520,29 +515,29 @@ class ModelRpcServer: logits, ( _, _, - decode_top_logprobs, _, + decode_top_logprobs, last_logprobs, ) = self.model_runner.forward(batch, ForwardMode.DECODE) next_token_ids, _ = batch.sample(logits) - next_token_ids = next_token_ids.cpu().tolist() + next_token_ids = next_token_ids.tolist() # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. - reqs = batch.reqs - new_token_logprobs = None if last_logprobs is not None: new_token_logprobs = last_logprobs[ - torch.arange(len(reqs)), next_token_ids + torch.arange(len(batch.reqs)), next_token_ids ].tolist() # Check finish condition - for i, (req, next_token_id) in enumerate(zip(reqs, next_token_ids)): + for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() - if new_token_logprobs is not None: + if req.return_logprob: req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id)) + + if req.top_logprobs_num > 0: req.decode_top_logprobs.append(decode_top_logprobs[i]) self.handle_finished_requests(batch) @@ -590,8 +585,7 @@ class ModelRpcServer: + len(req.output_ids) - req.prompt_tokens, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "finish_reason": str(req.finish_reason), - "hit_stop_str": req.hit_stop_str, + "finish_reason": str(req.finish_reason), # FIXME: convert to the correct string } if req.return_logprob: ( @@ -628,7 +622,7 @@ class ModelRpcServer: # Remove finished reqs if finished_indices: # Update radix cache - req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist() + req_pool_indices_cpu = batch.req_pool_indices.tolist() for i in finished_indices: req = batch.reqs[i] req_pool_idx = req_pool_indices_cpu[i] diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 8d6851caf..d24db31c7 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -29,7 +29,7 @@ QUANTIZATION_CONFIG_MAPPING = { logger = logging.getLogger("model_runner") # for server args in model endpoints -global_server_args_dict: dict = None +global_server_args_dict = {} @lru_cache() @@ -86,8 +86,8 @@ class InputMetadata: out_cache_cont_end: torch.Tensor = None other_kv_index: torch.Tensor = None - top_logprobs_nums: List[int] = None return_logprob: bool = False + top_logprobs_nums: List[int] = None # for flashinfer qo_indptr: torch.Tensor = None @@ -107,18 +107,20 @@ class InputMetadata: (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0) + self.kv_last_page_len = torch.ones( + (self.batch_size,), dtype=torch.int32, device="cuda" + ) + req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() + seq_lens_cpu = self.seq_lens.cpu().numpy() self.kv_indices = torch.cat( [ self.req_to_token_pool.req_to_token[ - self.req_pool_indices[i].item(), : self.seq_lens[i].item() + req_pool_indices_cpu[i]: seq_lens_cpu[i] ] for i in range(self.batch_size) ], dim=0, ).contiguous() - self.kv_last_page_len = torch.ones( - (self.batch_size,), dtype=torch.int32, device="cuda" - ) workspace_buffer = torch.empty( 32 * 1024 * 1024, dtype=torch.int8, device="cuda" @@ -195,15 +197,15 @@ class InputMetadata: req_pool_indices[0], seq_lens[0] - 1 ].item() else: - seq_lens_np = seq_lens.cpu().numpy() - prefix_lens_np = prefix_lens.cpu().numpy() - position_ids_offsets_np = position_ids_offsets.cpu().numpy() + seq_lens_cpu = seq_lens.cpu().numpy() + prefix_lens_cpu = prefix_lens.cpu().numpy() + position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() positions = torch.tensor( np.concatenate( [ np.arange( - prefix_lens_np[i] + position_ids_offsets_np[i], - seq_lens_np[i] + position_ids_offsets_np[i], + prefix_lens_cpu[i] + position_ids_offsets_cpu[i], + seq_lens_cpu[i] + position_ids_offsets_cpu[i], ) for i in range(batch_size) ], @@ -229,9 +231,9 @@ class InputMetadata: out_cache_loc=out_cache_loc, out_cache_cont_start=out_cache_cont_start, out_cache_cont_end=out_cache_cont_end, - top_logprobs_nums=top_logprobs_nums, - return_logprob=return_logprob, other_kv_index=other_kv_index, + return_logprob=return_logprob, + top_logprobs_nums=top_logprobs_nums, ) if forward_mode == ForwardMode.EXTEND: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 78241b74b..a95606de1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -185,7 +185,10 @@ class TokenizerManager: while True: await event.wait() - yield state.out_list[-1] + yield self.convert_logprob_style(state.out_list[-1], + obj.return_logprob, + obj.top_logprobs_num, + obj.return_text_in_logprobs) state.out_list = [] if state.finished: del self.rid_to_state[rid] @@ -231,16 +234,16 @@ class TokenizerManager: rid = obj.rid[i] state = self.rid_to_state[rid] await state.event.wait() - output_list.append(state.out_list[-1]) + output_list.append( + self.convert_logprob_style(state.out_list[-1], + obj.return_logprob[i], + obj.top_logprobs_num[i], + obj.return_text_in_logprobs)) assert state.finished del self.rid_to_state[rid] yield output_list - async def detokenize(self, obj: DetokenizeReqInput): - token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids) - return [t.decode() if isinstance(t, bytes) else t for t in token_texts] - async def flush_cache(self): flush_cache_req = FlushCacheReq() self.send_to_router.send_pyobj(flush_cache_req) @@ -267,3 +270,37 @@ class TokenizerManager: state.event.set() else: raise ValueError(f"Invalid object: {recv_obj}") + + def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs): + if return_logprob: + ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens( + ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs + ) + ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens( + ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs + ) + if top_logprobs_num > 0: + ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens( + ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs + ) + ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens( + ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs + ) + return ret + + def detokenize_logprob_tokens(self, token_logprobs, decode_to_text): + if not decode_to_text: + return [(logprob, token_id, None) for logprob, token_id in token_logprobs] + + token_ids = [tid for _, tid in token_logprobs] + token_texts = self.tokenizer.batch_decode(token_ids) + return [ + (logprob, token_id, token_text) + for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) + ] + + def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text): + for i, t in enumerate(top_logprobs): + if t: + top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text) + return top_logprobs diff --git a/python/sglang/srt/openai_protocol.py b/python/sglang/srt/openai_protocol.py index 1cf1fed73..95b92948d 100644 --- a/python/sglang/srt/openai_protocol.py +++ b/python/sglang/srt/openai_protocol.py @@ -1,3 +1,4 @@ +"""pydantic models for OpenAI API protocol""" import time from typing import Dict, List, Optional, Union diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index c1b7780cc..7a22632f4 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -10,7 +10,7 @@ import threading import time from typing import List, Optional, Union -# Fix a Python bug +# Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import aiohttp @@ -53,10 +53,10 @@ from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( - enable_show_time_cost, allocate_init_ports, - jsonify_pydantic_model, assert_pkg_version, + enable_show_time_cost, + jsonify_pydantic_model, get_exception_traceback, API_KEY_HEADER_NAME, APIKeyValidatorMiddleware @@ -99,12 +99,6 @@ async def flush_cache(): ) -async def stream_generator(obj: GenerateReqInput): - async for out in tokenizer_manager.generate_request(obj): - await handle_token_logprobs_results(obj, out) - yield out - - @app.post("/generate") async def generate_request(obj: GenerateReqInput): obj.post_init() @@ -112,69 +106,16 @@ async def generate_request(obj: GenerateReqInput): if obj.stream: async def stream_results(): - async for out in stream_generator(obj): + async for out in tokenizer_manager.generate_request(obj): yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(stream_results(), media_type="text/event-stream") ret = await tokenizer_manager.generate_request(obj).__anext__() - await handle_token_logprobs_results(obj, ret) - return ret -async def detokenize_logprob_tokens(token_logprobs, decode_to_text): - if not decode_to_text: - return [(logprob, token_id, None) for logprob, token_id in token_logprobs] - - token_ids = [tid for _, tid in token_logprobs] - token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids)) - return [ - (logprob, token_id, token_text) - for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) - ] - - -async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text): - for i, t in enumerate(top_logprobs): - if top_logprobs[i] is not None: - top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text) - return top_logprobs - - -async def handle_token_logprobs_results(obj: GenerateReqInput, ret): - """Handle the token logprobs results, convert token ids to text if needed. - - Args: - obj (GenerateReqInput): The request object. - ret (Union[Dict, List[Dict]]): The response object. - """ - # NOTE: This is because the multiple requests in one http request. - - async def convert_style(r, return_text): - r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens( - r["meta_info"]["prefill_token_logprobs"], return_text - ) - r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens( - r["meta_info"]["decode_token_logprobs"], return_text - ) - r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens( - r["meta_info"]["prefill_top_logprobs"], return_text - ) - r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens( - r["meta_info"]["decode_top_logprobs"], return_text - ) - - if isinstance(obj.text, str): - if obj.return_logprob: - await convert_style(ret, obj.return_text_in_logprobs) - else: - for i, r in enumerate(ret): - if obj.return_logprob[i]: - await convert_style(r, obj.return_text_in_logprobs) - - @app.post("/v1/completions") async def v1_completions(raw_request: Request): request_json = await raw_request.json() @@ -203,10 +144,10 @@ async def v1_completions(raw_request: Request): if adapted_request.stream: - async def gnerate_stream_resp(): + async def generate_stream_resp(): stream_buffer = "" n_prev_token = 0 - async for content in stream_generator(adapted_request): + async for content in tokenizer_manager.generate_request(adapted_request): text = content["text"] prompt_tokens = content["meta_info"]["prompt_tokens"] completion_tokens = content["meta_info"]["completion_tokens"] @@ -266,7 +207,7 @@ async def v1_completions(raw_request: Request): yield f"data: {jsonify_pydantic_model(chunk)}\n\n" yield "data: [DONE]\n\n" - return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream") + return StreamingResponse(generate_stream_resp(), media_type="text/event-stream") # Non-streaming response. ret = await generate_request(adapted_request) @@ -384,7 +325,7 @@ async def v1_chat_completions(raw_request: Request): is_first = True stream_buffer = "" - async for content in stream_generator(adapted_request): + async for content in tokenizer_manager.generate_request(adapted_request): if is_first: # First chunk with role is_first = False diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 78b8537c4..0bbbb3e3c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -241,7 +241,7 @@ class ServerArgs: def print_mode_args(self): return ( f"enable_flashinfer={self.enable_flashinfer}, " - f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}" + f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, " f"disable_radix_cache={self.disable_radix_cache}, " f"disable_regex_jump_forward={self.disable_regex_jump_forward}, " f"disable_disk_cache={self.disable_disk_cache}, " diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 774bdf2c9..56f408db0 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1,3 +1,5 @@ +"""Common utilities.""" + import base64 import os import random @@ -13,6 +15,7 @@ import numpy as np import pydantic import requests import torch +from fastapi.responses import JSONResponse from packaging import version as pkg_version from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware @@ -303,6 +306,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): response = await call_next(request) return response + # FIXME: Remove this once we drop support for pydantic 1.x IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 @@ -310,4 +314,4 @@ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 def jsonify_pydantic_model(obj: BaseModel): if IS_PYDANTIC_1: return obj.json(ensure_ascii=False) - return obj.model_dump_json() + return obj.model_dump_json() \ No newline at end of file diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 32c319166..853ed9846 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True): def test_image_qa(): @sgl.function def image_qa(s, question): - s += sgl.user(sgl.image("test_image.png") + question) + s += sgl.user(sgl.image("example_image.png") + question) s += sgl.assistant(sgl.gen("answer")) state = image_qa.run( diff --git a/test/killall_python.sh b/test/killall_sglang.sh similarity index 100% rename from test/killall_python.sh rename to test/killall_sglang.sh diff --git a/test/lang/example_image.png b/test/lang/example_image.png new file mode 100644 index 0000000000000000000000000000000000000000..851d085605c01fe89d887e9ca9fe6a33bcc8c93b GIT binary patch literal 57365 zcmb4qWmFsQ6D}>Zg`&Zo60}Gt4#nEwRtOMWe#MFfcdbBim*5n4C=fJAaS2dd0t5)s z;>C;h^8au@+>dwmJ!j9JcXnsa?!3FR&ph*Q;olk#nVOQS63(MXk8u8b*l_+W;mG6Q zK7Rba@gR5);|czgCwO>Io88sPzf|81giinhk zmYR~5oRW(2e;0X#`*0546M`pC2q=jOi7EfT<=+n+0RE%E$FFf80dXD!9^nEW{p-VF z!ohj;&PAOHE&XL!%>aUVZ8g7bv@=`#QY zP>2)90jmc47I%{DQjfo=E~~U5}|hG#ny&R>|v}UsAB53f7*5 z4aam`V$jgu=@VLU#Df9GM>vmh{~!Ag$9PYkK6-}#Fe?GTd9a83828E3rw?lXbMaw^ zNB#sTL;-#Zpk!qe6xN}73wdwpM$Ino{weWuzCu!6_dhfa(S!D504@MW4(ErI@YXE* zx`*W=bTd8e1jTHg?GU?|uvNB-$xFW5G5F%=I)hvhCEFEXp^i}S|C*w6;=7S(qx#8a zOg$oAdqVKd0aLiyoWUcosiN}Osm@n@%(fi1lF(S0tI?~k4wY^n-)*%3mTpWG9pH25 zbxhjCixH;>Ia^e!hc&Yn(9JEIqUd*}-7oyyiqCQV&H5O`F!_LD9XisN?lWh0Tcd_> zeh*ICtyC^U+vff^yoEqrLIK-WOI&w+cMp26GU?QzaCnh@l~ZIl{j9!|hQ6s`YMrs0 z+oenzioU`7zn{2T!?G@e@D+N(rpBJ;PWihG2xcstpZ28GAQryzJXNLlj39}|`oi3Rrb#^txG zK=taivD3tTgfHp$fhTJ*BOPB^&p9mCoH+%`!>PW$A#0J9J~s`}@+YQHw~q zwnpgg)B5S}SN?A0daBRTj{o7*{jK$?FA|a1a7;ceKVucxbpDko4{Y2%L-g&MeObn+drH}4+wMmcFQ}y5YDbf9^SXg*gbAZnyrWm&gb;pf(|Gm z`Nex4@TSy$au)c2Gm@fVY*CAzdIixtq@o-?dtP2WRuKOo$uJ(fqnF#R2AVIgeeJz2 zCZfy5RxdRMk^YARE~YXRD}Ah)F*IJUncPvZ5z;{&UJdE1Ze9$FPXRJJd;l&iI-Qwv z_C@&n`Gsc|@miX^s)5+r>(fQQ%m5`?f;l2dXA33D>JU9mWvZzdo?zRbS99i5X&EOk zP{|ubM2!<%ON!UHj?E_hJuB9kWsq^hkuhDr%`fq}h3`|05aCTg zQORtlFa|QU;~RbyAgZ`ur@yuS$XQUq>~!cv@tl?4gup#cTO+c^uEi#j2BO_L++^c7 zon_vz@*&15?8S=MvdtfvrH`qS5TDt-ZET5L9LLBUPp}oNl7>RRMWrLmP43=MR{=!n z@el+SSxUFOx&)**XP^J!tjBWx!?|J1RzLBxF$@~)PkG$j%Uk^8%0H^QL()k_Y)2__ zYLaDJ(U@#<&dD5PTtD1JH4UecNn^aJROy$+eqT(#f2<%oX6WCu-GqA5xCsDa4-Yi& zqi`30Y&hPCxK&Lshgrrb##5rQ7EeQ7(;Nz15IF6g-~A_**@hR<^U-Cc1!^F$^>)cM z;f6k}i&|AOngH3~>w7g9CnAKW?3~X+0T3h_oL*=IRMfuuhf_VTWIF__A#!8I(m)AM zA7>Nfqzi3oJ|P?s+5nJaFY|sv(`w?}MHPpi;vT;28Zjy$Ht{!F%P3)`<$00mqV*FR zdIxzH>Fu=STjIM;7_5xR_j{uLc{LO!NlQg(O0Aqqv+}Lw`jhXQmpI>n00SWnyQytO z<^h`>{7u!x9l$@FV%_Ti$Y|;xeuNRLXAvo2kdeh9Y?(v-gC0$c{f_2EM3%yH3n}>J ztX#3T^K1hI?aXr%RhNTdNUn33LgQq^SiM9}IDBD|M?buC3v}hdhv+2|TKQa>VtPrC{viP(L3sV}u&9Hi5`5JSvWdDdvyM4E_D=xx>129iM zkA_wp)i>J=q{wo(+rPy(k@;PIGxceIZb{PN)0!~zst+p{!Ph}RfWQaLrN2Y>pF*_! zyiJCTZ$NrxZ1YD>wZ(_Hkt?SRF|O!Hh5@%iI+7_Y1>Hl5ZQ26=8`E83rPM2hqWMi+ zp;gS(d|w^ZBgu}={2h3>BKv3N{vIFmWo;yEjhP-`&&d$In|;q9^l4MP0?cQevoy^i zrH$mD#Khe#_iB9k`xLbQYcKhSBc9p5Pv0IbZlOVq*#8_dESeB6O1fj|qU)gU_r_I? zPu5=RUbQ{EsAE@dZf=0fh?oN_h%=GHCFS!xK#t|~TZfEiyzCu05&k`w%Lky%Cl{6~ zMGVi07@BI+IM$4<$2Aa|#~i49d@=~%ZN626w4t=kHZgNR(_&XDdPQqfZA!D@uJ8I8 z)!kohVtj~NdRNwy?sI0w|AI6+Ra(It1QHuITXdk`8wWbP#85OVLfx>Tgk4nx%KFaW zS$)|hg3<9dz1&CG4=|&^I4eH8F7#K|vuX*pHLz_uWP0qA#b`+LItxpx4-^IDBdnT` z!W^tynfs|@%L`Q0v*L$u%$6Q*cuBE-zsfGi*!yYrw9b1ID?BBLIQ7e4#)h$IqWzOB#lEV%#2#EV*Bl=BDe9s# zgl(1Z+Pm1}C<(`Ud>NQG>3v+KfmNE2tD45V+Ov&_l+Z zleUWJwgKCGk9k(UTzk5x`eRq;#0+;}=PfZ{RApu`&f~>TD@~_d|#J6XDFn3v);SwwBefdUY2toq7n@}z+ zrBdhnwR%`Vc4g;S_s#T%y=&H^EkI$kZ~PPY-3$Ai6Oa+L3_2kZZcjz))>tc`ohn>6 zKvRmJ<@Y@5+)EsfUz1z586U|bAnnXVQs=!G~OgKcPV!B9O+*tL7&skB6-3?GVB-WBh!7WM>!($w-~!Gas40iQ*mRZDsXjYhNd z�*{=qGkPn*hQ6dbTesr;OK~_SMnA?> zT2cssE-cjq?O~)yMY)WlmiNKusGKO|rrC%2!ook>$x>oyb__tO;lm^Lgnv=-?^}WZ zs2)sLU-i2KZN!qPt(p+e9edYEwV!-S0SH14I@e{eb>!VJ3bAC9myABm@u(=i>MD)h^T%1^_qI3-bNbvnjG)T8O5aI`+Ot0eDE= zAX(n!xAMMDO8YBq!*r3Pyb(?S2HancS>FF}PWKATj67iJ&g>pGJ&Jw9-M?0TS1mU1 zfAHX!PGTpQu+3(FWj>>D23Za$szeDpYbmpm_qRuO-!u;Q=c}N@D?gxv!zP?wyDCbF zErWFuR`IgMP^D#dQr!``sf{OYrah6ypjT-Lr05zRybB8+UI7otNDK+w ziwZELNB`0{K9POX4h|qmN~SD&{^vLbU%^yCBOJ|QM(#j80Q!og@w062DnqPlG}Tut z;6)U5yVI1iKwqOt1kW`8C>OU_vnktaN>Zi8L7j(Nq)GUbX2dO?@d2~6OJ38s(J<;V>xS;A}tgBN+5kqf}cq!}p4wLe&qjVQv zf5nlHY0&?pU$1}dhuTxcN>I@n@k|i+qNMwLYNBQxbNtdI&eN~Ni;cy_pnKNqD=X+XX%3EidI66|VL{|C| zhQ)&<%~`G(kdMVr4$%GG`Tl$Jsim1De9yZdbvrS?R*XlHd^GHi(FB$j2Xp|6)G{w| z5}7!AvQ267k&hx<73}Ic)pLgS!g?Zof-B}qRL{k}+hIam6AE6sQhe`_-q0K94=vg& zDN1>ZXNUi$*?-dEekNXI^l8gSb~KG$($#N={zpjIdc>_DHre`@IFzuDcj(0|*?_gS zT{iVGUu!hrlwZ$;s$tn`cNjG8bj+!#(9DHMwK;h1644jd)p>X{f%vKRdI*ziQ>EU& zwB1*qZ(z65W8(gWbE`Nr8qmmotB&{g(hS<*vprCb7tTBnEHkO?I zR6!zts!o0_!|Y#k0d4S_ya6LWb;f17u^lR`bh&YE-BH>g2tyq;eB~t~)66)wSAPG) z;nlf#8ME9Nni0GCuEWJym0EU=O^bdr-&~nNp2pgQ_Nf6Q)9@4BaU?tl|F?JUU%&7? zJ_QqyF&Db{7%5Y~?EIx{o=Ty`s%ui-SoU^ocUNVav>{r}93rk=iJ)ty(X5#gX03bS z7gPa#LGjKS3TuxKft<~SAiN4tR0Nr(#6;ePTvAV<*ALD-Tw8aLW5 zC#zZ4x)I<*>c%Tj#iYU}cnK&!=q6>(Lo{kU8xZK6(Hb#WgMYw#8) zLL~>tS3wkcj27`2rZ)1N@yw2Zlkb*Weq0puBbI8iKTdpEa2V+Ujfir8skz79DNsq( z9>OU)!@50E>|IDjhi3SZ1AK(^8B_Z}=2#vdm94ib7_R&5?%h9}u}U^;zKx=ji(L20 ztoZw|#P|k%u@O$3of6vRR%@GUls-OXJwI-;)rgr$&aYBXjj@jds(&6CE7C@mY|S=1 z{g%TwaXjLdVS1%VWI3=HKU8BifS&LDsE*(P|IZ^!6woeQ`%(()0N%Ui}%Itnmq}crvMMAPa&To z4|Y0Kx&SP1{7lW{)>NY-({6(e3F~L^vkRFzIQoYZIQS3ei-s{2YOYRtJJ9#z3pv%h{<1sOe12V#$>&OIt)EFWYIc*7;oUWH zdZ!S0^un>=V7gMiQ>2mF$YT8m;Etax$15{b1Xv)JU`LZp?D_I@4IsOQ5B0%$z!H8) z#ck)glHK}Nrkq$b%g7ZMG%=(sdkzPOmiuQgQ&8ahI(Y;uc3M=*qu1T1 zD4ytgn)#o&>hp z(~6WOwyukePmch%ESWX-iYf7SO69N$d}L(~@V?Kj7W(zScYiqs?ew~Ut1D_&t}N#o zE_S&<9(a3WOsymJ;Ndh~tPM5GWtavJ_1B;5yxwnLVOsKp9k(rF<$e~#U{%&l*)!Mqk&RNt6eY`4SX_TYolnl@qj3Fl2N`NB40|9ZM+4YXM1g$%1w(Eo}+y0o;no;AkF+8UJ#T*b>DV^adNAq?qh6{{1y{Gw@3@`Oai-&zcn`@m6HXK5J=$ zCWRfNF!k-|YWVHatEITq4wbm+X(jZ<{Y*+rRPFe^r(0w3dv71X0dv?N;0FqrPr=pO zQ;0c-sM(JvUoxc>kZCp&`YG zuL%F5sgawfr;-JX&ehBVK87)_7Cp~(8Z&bB>o^o^xw!Bz(~5q6tJhxmh&m^>J((@; zWrAO_?<~JfiwbWagZ*+we`^jKWov6kMGGJ)`h9q!=b%gJb}_Vxn676gu_RDz(TX?n zM_>K!PJ!&Y1p?8S?UC^xoRMig`BOZ_>|3;jhT4~)(7j7IM~u1l2X~U=K`+1oZJ?F< zxC@_e>_)0y7S9y<{`H}B203IjQ&vyZY_%Zl`3Z<0MlFM7_nl_v^A_HABfveYR2dTMiTXJUhcdRZeyNciV!|XVw{cDay*^g&0B<#|i zS4N|6$~Fsi6QIrNq1;U<^>|udD=rTC{ne31s_$8dmq~{_XweQFzb`-0mA4fJ&Y3Se zjGjxg1uo%X+B>~UmPjT6XA!12X5v)cc;;M%4@v1SB`GCmg}Fr+?E-Xx6gysurk=`| z;N9>Ak!=qCvESS8o`SiN&HlSq4*Hxf)eSfk*-pCBih+CT``1g(Q;CtER3;e|-c7!< zv53oXA0+qXw{1zLH30JVFdt*Zc#V(5`c_u@6t_?PI<+=ayb7wOfQxGwMt?~`p(Q}m z5>0zcLY!j4rJbW4F0Sqy{LgOeVlTT;z?!X9Woq7(p2>@{npV<5EHe{>y#)BCN#uOf4YN7prPB1uez2C@|C z7&4DUF9LSlyU=xDg=T1Of0Vrw9(L&m;yu!&F}7S>L9{GzK;3reN- z^^QaxzWB2sj-<^Ho9m2b;l*8~Stq`_%S(gj)nu#IEAUfbme(JXKYy_q-BxH|I5r$I zzBk$2p|30@IxNiuB3q z4!Q9K%t+eUH`~}8LAIjg)|sZr%L_Lli7^L_{VR?3KV<(SW0_TpiIH_kMx*5K(`x=+ zHOH^0#l=XY%uG%7*0-->cs>8&Oq{&-s~B6fhTfMcdWJKeE_Zg?R7{;nEMl%3U7T)Q z?TBUB<&s42yvm1qO(bce;DS^lU8K(m``9I^B%1)ftqE~#8ZOym=^HEGY_V^LXMFHo zuftZox=Yb6#$uXnI^WLVx;7fWda?BdriGCGe%h+!z+$+#Caa3DmEJ;8Gskn4dX>lp ztaGX6+i=$pf9#2+*>k;r@|1%hC-Gd1K>DyoksBI2*q2a@7OB>_nvzampr&RaUm+>8 zabuMxS-l=zarL2VUs(Mhvrb{p%uI;y~pmUqh=Um?pUI5&3KTkRz=1RLqO=pMm$6(N^L(&D+Ps^ z$5&@8zrbLl$wj_F?%e`ky?G~jY#Aqkiga43}qG4S&;qbFCwF-shkH}KK0@TV4 zVjxi4|0hT3CuXPa0GeqnP~Jh@$z^%t57h^FgDC`NwaQfb(q@hzE? z1ZRPKssMm&s9KX(>dDe;j7FAILW}gQ$qcXiYl;2aLS1f$Ap25=))yLN@vQehIH!GF z3WIdp_D$fG+c@p>00d(e@L)#`zOHI{>47na|8QW0Ki^sIZr5lj#+^OZ@yRyI)$H4T@?ZP`(kn%LmGm2KgcwT1Xn|=Lrw~hnqc>oo@ zJkTc*Aa055(&V|N)dvMtbZYDrjDyE%N{lL3K9i=;~_DbifYOr z!+dY={V=W|l`8A5Z97z=np|Gm$R@llx|nz~tWpn>C}(pg6l-<$XP};W8asAaW9Z+k z!yuI2Q0v3>q}1rYH-Z~n1BgMGZdgLItph=Z2iH(aZZA90U@a+O27$!FbR~?YC8$xo z8+|0!POr@HR)_s3(_PoE(|Hq~)|?&Ci8d7i8}YHA;9BlldwXR$^mWwK%pb3T$ztf* zZP!`3*+<0mh5DC7w6=B81zT=#N+H5|J*e1$Yu{j@WZLQVDt;>2O7HExL(;?;vK+CQ ztVG7FOrYF9AM(pW(;%08vN#npl#prkoX$Fe;YNz5`;XNO%&^II(|(lFkAol6uepp& zb`M5>HOYWYkFyWQE%i7bV}uK~ELd}PhN<-@Hn|DDc~IZ$6~d?_YUv3zhrP7pe`9#-Bg92eK}D_J!Pi`iXX3CMvB~4?$I1*%-0<5Ti-mGy^7fdl zl$GP6^AAy+HAfMm2tvN4#s!mjb~-0o#=)?z5Qoe1@vqZIhF}A_d!d#u%PZTfw>v=) zo+TFTPF+*h6E@5Md^)L6-M&8@PE7R-Xm)PEUy5Q#w0mTUk!~-Pf1UJRIZSefdbz)V zmFAi6Rfnh?YQR$^x+Glm#YYo%G8dtp{vgqVXOxm&syJ9uC3w|qRKFP5Vq(fI*q@f)@WRCT zRivdfcv~RHAapPLtL1AaX$`UxbIpG6?b%@;VADFt%{A3!#(iD=@iXEn&Z{fEz@@3E zYJPaP>uqI7Bq>t)x<|px?lKOD+75BQv4M)-{KJvY2Do=@S5pD|RSMaxQoTyLueFJ3E0um4cCG~8kJOK0=FHe@I)!Q%qlRBPay?`Li z8K*3}zik_EQOWQMJ_B|ITkHhD$CNJwCLPQ@6gN?S>U?sqQ8}t{y@|V05UtTBNh3%s z)hzF_9&PceN%qyMAXj_J?j0=PE1me&J7;5kB7^70Fh+wk`4B(uzC8Yt!tWxQ;t?(E zBFc36)an>kdgl6uF$JeMzWB>8x~1fiX|PZskin%)X)%0HhtAS$_abYXVJY`?^>*YJ z5Vca4G8-LH*0NT~K1l|r5+|$8%-J0?a%)twzGB`!ppjJri-B5xB*bsvwK<}vcwK}h zkZCf@4o{yaYl}to`rDbo$>0G-0#AGWbJ^sYN|tgfk8$jB;UYfM+w zmd?0>FvBL|b~p5M#>?rpma}s94`eeCfVS?WxlHl$Id`0iEx6XPk1sSrTPQ*+mnsc=?V8@I4s{6b z0Tc0K%@WOlxw4$tAy=k*%BtVzH^G(QB_j=nSdg7zUyM;tzfYGC%2$lsO)Ead;Btm! ziUUasHjhm``8+T1VIb02$}yKHvst;-uYHiZ_B*?((etw>7#Q%--1_vIW~_tfYdb-Q zomUE#n-&1@2%|ol8?Wh^b_p46RT?l+m^=~IY%m->79W3u_tAyYG&#f3*a@uX{>Qpw z^gq-8tThs6OnW0Xpj3ml&Kl9u7T<@M&GX~rF_qIFUfYbxt)bJ_Z05!WSC<7yI7Taf zUpu}vGW?WypOi7-Okn>3k4U8*7KrMaNI9~{P0NXs8fr*l`Ci%LALu!(`FR4%CPjX= zFe5wX5@&zI{wxdDJRRQ>Ha@ds9O=hgOEKIVU-@DS<3i6kZPP#Z1@14O*RNu7{?@eg zZTgvy1a_edwSP8`@GLvrjrfVX3~|k79Vm8YSF?{5WKE&Em6&qN$M1&ze&M&P$mSpg`!P8&j^oG%kR$ts$hX*Nm6s>3xWG56!%wt8+(K_FV!ZYRfWfKKN}*>yud46=)Ya#*M0?v#huh_b-v={xJ@ky z|X?HD7VkQ?wrwmJqxT&qMI@hUKX-e(yaA6usTCp!R z1oS5>9AFx)oPvq6-sOmKY%eW)uwlJxGvrbYmm)#xiQ{cwHs)yxAAgI$w))#^;-d#M zJi=aRr>R7j<<_jM{@TnR^K(S3igNTwYNYh{kJBDA1OQwBM*+pUnXtJ74%*BF9kVEQ z%sXoKHTXLV<4m-3u|owi<>;1y(_pmnFU9o~UaZi{_s?>1ujG?sB#hiQrgh|HWomv} zajAzf^qDYrSfu33O5=5}H74L?Uy z+hQDT5R8pxpro4P(CUr93MIo>tcMsF>W^XP)#f}qso&ld$ZUj*c-JG<&`)DNYs5d^ z1aDR;7qzxIYC3?|7Jo^wjq=goS4__*B+`|YaeyatMG%HbLfd>%8HiI&p74 ze`p@s%Rz`OAx<)cSf7E?=GPLE&Cq?VGOzTf{t{r5!ZF1gC173A>if2;uFv1Ge>0ph zDv$7;rM26#fPVsz(IBO@1O2oVh{%xpQ+gGCoU)jK1wWJMw9ox42h$R;*lp-o z{h?L461r&frNTSz4OA+%ajugx$fZGEpEDwBC}3q*uxYTrvm#P2i$9;KK$rc&8|Fnd zS+-mw{H-_R#cK`<$pK_G7n&Z zP*i%~){%ic3k5tmq2n}jI=)@h&Ccx~PJL7N&H%5Mx&_le9R21anm@Di#a+*)E`Itm zaOjlNhd~I3zh!t%+?@y~R}0hQ;uwqt7KTN{k91aK+GVZ51#M8;&Cw~2HdQ6#u^Vo* ztQ;_k;Cpn!F&xCv6j7P~$r+dtuYJ;$zDNB)*TQDDScab++wGo1?k5}EQ? z{XQQyV65s=u$>pPeb@Njuy$TGvtW?wfn5-2nm&~1nI@_IHEI&YW2n`_aQISym%F_z z>rT|&Q1(%3`2jKFOvq>8*;oEkvJLe}vn-~PPxX=bab|}O*}wr=^K(Vx!sZ!{RJ9Ri zN+D3YJCAcueTIOVRNt3}&Y*^)^mVcBC`ffGoJ|0>xv7R0UJR(`W)*e8jbG9*6g0C043ph_{p1d_U zoAJ_@d9P&I${l0ESa;#5;47h|e|C#yH@o#@cF9al=U-B9Qe`zb_&MmoT=GEKGtTO+ zS*JfJD(a#m{i*|%E^szaDTM-}qc93`4rOL7_-V5TmK_U^F6U~~^ZJb*kO@+oCVct* z9>=2`VSw@wZ40$7lfd6lC|%%31MWbIqJq zaE`)gXBQQ>*bVcO*0C0WhD+J7(w!iwz>H?WSbprf3Hk*W6Ib1qfNlJR zj7yfVejI+gyzFX3kX}v6WR-nuoFF;p-`~g^gWzj~7~mV{@{;(Jnn{+qInY^iSjWs! z?fW^!#wERXQ-|2_!o(8mjB;tz`hRIfjQ%-<<-8oW95>WQEbDhAYMhbdw z#gHD*8oWJ|P~C557eL)19E)R3FC%!a@ODQ_Trew!nc ziAu>5$kE+_n+lt3ouUG|u&Y9`l?SAf(UN?{+kp48u%nH^^Cnr+q3 zJt_TR6^uT}0b{mx_F zw9q86-*bu&|F)dxOpD}^|B?(i2GOR;&Lb@x5Vkq`idLXYLqtV8jCubJ@f%6=8eGg@ z)sZiAn&QCmA49Nag=~-xeRkRr{e}dxUvz75uUcdVGPTy{9r?Qc%&$^=o#vyk)1>}l z0SMQ+Tb!tlb0o`HEbz7MYAHL5AjLR-BB>939ilk2NO~_&uDjiWJ{c_nh!{wVkx5TR zT|e$Pv8)iodZ>joy^@%bof0uI{-IdbtnKbd`nRX)m}L8to&s8PxNEGXBVeio@$D#O z)y3mdTN?SJvD!wu+l`<(o1B?zcjH~%IpHkYq*i1(fX6pKm}HvZW74HUNUE+sgNAHo ze`)a$ZS|+hY;g*{yHL$l@Ew)&a)~gDR_B_Bh+s;XqEyQSbK0!7jXA$YV?5KZI^bX> zoUhHHlXqFhkPpo#S!+LR4;UvL`O?dNTCkbIR8Pos9K9KNTHJE>#;%vjhZV(oH*2>l za(Aj@#NOeEP|clzDzkV?sevb+yFfY4y#0S`X*fvOf;YJ$3d#N`51-TcR!Fk%C~rgk z1YxEAG3Kdd48Ebp$NE?9jP(M4X2z3{LT9Hjn~$hG8dbkj4h8M9%)lrsli;yp{1V_* z?sKwu?rAyNzaZ`8jTc2*(Q)FKi%_N3og$~E40jdOiyH@H2mNm?iiCR8pSr#9Q9dq{ z%1zu>L}uZv_2h>we>Zwj9kI&B#md|pFS5DTo51V~_o8k3-p%V8sn|n`0{Qr;2R_?H znx6T}O*3RW&Cb?nK*~_;_e%md8pOPh z#nriSGmg*1~AVn9d72i;wXM9#LI5jytsKQLSJogXhb^M@Cjbkgb(T@!&_6WtH zu@<9{hs8MarOjMk!TmEtig+DM@nl4EYd_VA-N%6NP^|7e1dlICZNJ#hv2kd9*jllf z$-4b}#HNu?*{Y!#|M3SCeh;=5MXca$gH4?5yULg1iSeH8q%joaN=iTl!;O`ZxQh&# zh1W?G`1q%`y($5>!$UdcH}^7UNdpUA+P~V(7(n@)7KyVROHKw^bUMu5Aj2l`C1xN7 z>-f}|*5%DfwU}X=^kM)4;1KW9*d)%z6b%$XfbXHiShDA$Egne><`Gpq_VQs*7F$?dFzIgouwhfPd9`AaKh z2a6)Up(q{kD7L4$AK)5M2RMS2T8(SO?k5!-3-7-}PpSW2Y<^iW02hA9Mr~i|G(i{hV5<**cei7} zY^he|C+Ad*Z3?ZGxnzH?17=#{U!|Fe_7;HASm@fBaDVsm2hC-P&0}x>2Mr46cyU1k zim<5`jjdhR&hyt{lyxkern4VcsbAy4iY~Ie>fe%GAz*fK3L}vx6JDQ^;^d-_+$XEc zOo!cF1gasHstL-LQKnSP9E^r7%kN!~2ijhrrFupZ@i)?B^n6MS+b2Dv+k4|h`6*-5 z``b3M?2WS{>Tg9?I}vlQPKgDo)cK2c3m{q-^TF3sfCsA2?+vq3k=rKgRc;4*p>^cK=>N;4jnbtjuC z^aT+swv2)h@2)>7>yYQjj{JFxkSPPf)eo@H9=t&TvG5v&4cLHhgjFQ9e8Q!ld)YNl zH3G?WGgQrxwrlTQx?r=kQInul+h+InT(s^Ww4-PYSqVq@DXwgz8(TGk&HmxE zHJm3qG~c}*#n|;?2R%ceYWWLpkFj(Rc6=sDe$2bsd4R z^zbmlrSlnAX64>?G`!?Pd*2RV1*VObqw&;B)C2ej-sj2g#unKK3=#>+CpBT8vR4L! zHHZwZkU#ZNe;|zB&L9fQiB`d$EuO30p?-kQ@XrR5-fl6v{zW`oZLc~mGr>_QgB;z< z;HcCLy)hdr`98Zk0WM6{M&sq1CiLe{7wY(~(;Ry#ig%{@$)nQJoLA~uWo?y02S(eu zOHTE$KG(NR&%QSwo;WHQy(G7;Zu#QbT0dA;PB0%1Q==VFB~I{2$~{;uNQLl)r@NOl zs~5-d8RW0zruWu|`bs}Np~3SHUmS@o+HZwvOX)aP0x^{Aq8xBMslxUUn0@UN1E4v7 zuV}I|>nq!xmhx5romGdxNJC;CF|yNVu*$ z8Rwb!Fb!Hyk;TBJuTdsYD8pJTS7B7w9bP+Y`VXgw+t>%v=)&DITcbh*+fy>C0QjTcm%zMkZn@qXy!6Q9Ds-x^({OT^=)4+%V?`)45CFvEs2C zSo81Jj>uVjj+~;&fOkC0i>k0h`c>wqaL&KhN3o=&i~$|HwV<>`V!Oq(Ga$BU8#07k zI3=+#NQ+oO>^W|3TT6L1EwKrYnDb#wix#P+F~gQ}b``?vzq7_VDnDb+>CMMx$P`vu z3@KO9Q5_u>T{Oe`d{!2z?Acw{uT@6KoGRSSEfoXbKuql$yy_0W#c}6V>YFD-Mt)@F z1P15HX#4yBIn09l*Kz(mSs6B2L9(B52ZveYJ1kBc_ak9tHSR;%Z}dp$U4BUiF5NMu z#`_S;v+=^Xa!kfDz^Fa315G=)Wjmsr;1ud0AM86`z{6VltLnh@yXv%m+_YUbS{x)q z&(fapP%@0ZZ!A6#Cl(c&Szc`DA(zN0`@HUTqV49mMs z4kFy3*CE&Qe(Tj{)tkRnmJHV#O{lzfdD1o9xR@U5u>Juo2t8$oW^h0(aLqbKOy)D; zsCB3H9YWBVFN~16lu2NoM|=uW2u#UfhR)3oiSL^IZoQ*AZ@lB;>Fi&ov=5M!Z<3bE z-rUmT*wVV57?}5SM%$zwe(AsD>24JHQ~#%B5ho&!FV!HxLhDA|TJv_Oo8}k)YOfC6 zm4J5oLY0q-v*h9f&w#;kF&d?4wKMQZkKnaRNp#Y}~e=K*k#sehUMRcxc1S1(sI zS4$jzytu&z!)0h&M?Q~fh6%@OLJ56k==)P$QQT&J6&+@`I{|9UQo>Ny7T*JulYc<8 zuG+b@io^QfG87C-(E*38pmU}Sloko7d_DlLCaKJm4Y zD6Q5A0UVKDa32+mRB@~A9#?*Jaa;Oo)BMnIE-*=aNi!fuR7&hgbWAWnj*p@vAtJeE}){NNA0-j@RP&}T1~!kXqFt3z&1zSe9bityc;{)%b$JR43;I3b*K zD8ija*2TXNI+UJ0J;nB;jbap@l(>Q_g(DRlU*)(dm?qpPjwK&@o(70exNOKfsopro z__k~+{KF9tKq~O9elv%p!FK){ktj&D4;xe)QQPc_O^5U%YYXnza<4udGdwaSI2*=N z4bCQp%tiXro@~qH0-l#T-jrd3EAR_bvQy)Uv3T%wZ&vho2YBrauL1;9n_9^#=O+N z(cB;Z0Ol<@ivfZ(0-n=uFA#7v=~U_b^%Zt?OI2S@Ki=Hk_LjH@WIHQ~!Dp4syVbW$ zu`o8|I@H1}#A&5$NRi&H{%*yq(@r!cVYl`!Tmbv2R5qe?on?|y<|W_m)_~K*BWs6` zFOeU+Q=o&%7S>%RvTvu}%Aj_U4~L;A0P^z(pXVhNGCWaHBFmfEq3x+hH}sbX2?b5j z-d?WSVWi?_JMWSq9C=Hsn#6U+oD$*?`&gG%Na|?rN14Mh$b^V;#=WF1L4DMg+hdP# z8kYF7PHx|V67JtjQ|%8lCXi!Wy})B~aK|$AdTPmwDXo{c^86 zB#}YOc>E~iR|X)sO-gK|6mJ+2A`#^h9J}2`^{UO#^F9sP{-U;(w4HUhrN+5<(BsTK z(f%4rM$d;&*kIiTE;Hy==mUuoXAe?s*=ESicU!7DngJc1!_E)DtCYZP&vP4i3RLJ)(8G0(zx9Qu1 z)_B&FZr0nfo8iY8E^?5MFApy0kF`}^Jd}g2iELQ5pQGZ~=DZq#$Zodtjz@`;nK|s_ zT{#!s#;E(Fqa$f|XEEUsI5{`O$g<$~ilpsYHYsXr4W3hRL|$#j+j+ShMhl0L7rou4 zFUKO=HQ;h@{iwM|r}1OcPesLX1#Y8cbjX(vw0_kEZ8FNQBR3V6)N};ILB-vY_OE4I zYaf+IAGK8`+OBR#zQDI|n|d+CS5TEaq89Q_#r1YqtiDt)E44AAgPJp9mz}EXz_(j1 zfgxX#oK@oP#5)yN^H*Ur2BTahC+g>uw7ugX0o=s>BwVZ~^P6XZm%jChw+ zrS$7=t+h_$N7Lk0>M#qm%c4HhsT9qXY_bA*cdkiCnh#cK&1qpkx<1?UV=4i|aB@H` z-AVb@s7s*CvtJr+Q;$q^ToL#@OClo6lt_oC+9CI9GiNiNhkBasYK9WsCk1T5j#uB}={ZOUBKbr|hrSPf9!mJCFA;y{fRUFt@f5#Q?bRD@U~7BI41F7_(?y_Fee znAGS@yBzp~=wDS&qeMl-$Z_vi4RO?pMvhKJH5ZJm$P(do2<1>gFHN#u()22<+kowK(~gWA z^m58%auJY2-Y@*rW3zkoE358Imovp6Fi7K(RQGX9x|pX`K+lmJWE4IR3Jc8e$h)nT zSbGoIkXl?Dox!u-DV#>Z1kn8Hp6~PLuD@8McaTSPzw-UdJx+{lj8DzePD*wA#H$~| z^Bx}}S0?dcvF?bsmni(=p;tDe9@8n9H@koJYF1dYaJJf8As@Au+8QO=cF^ZqJZpf< zHwFinfZ<|I`=mqOtnqo*QZ=>O9hSvm8!VX)W01%qA@++d=#Nnsvsm_4?O8n3$7wTP zYJmGpxM@u%ujfLKsNn>2Um9$lQUu)ZGJ5@MQ#un-SsXM~lFj_jOh(1mKU+R>-=&fF-e%M;h{tzUY3!nZ_BPly(;|Ta7map)=KL8ecq-=h9&cghSelroS2gO zIiYNdUs73MS$1@;G`uNMkX-SEE$XG&kni;RQ9g>TG=`*us;-P9Z@T5tpe0EBE!)%A z`#(KyD`tU}sV%-D_Jgz8ZH^IS#OUp|U-HC@?^EWs8KQKJz31FwHLf@TE2{yaj*2PT zszjZ?rj~6=(vjrNyYeW>r}ZBuV;0{nJ~f1UjK=K`!*ht)IUrm4xn)t|s4W^=CBjUr zs}o52);cSSrMu80zw{+*9iM2OgKZY#bF4O-ZNhYi9O21+LykW852%|bGU@kdy`KZ^ zmZJ`^#gW)dx}=BV?y8ciuKnsts=ubG!>dtb>C3f8Ygm>)cjjQNn`*00zr8`Q+Tz@< zcL@WGf+<4*jC#17!cJrLYR%f_%e4N^XIWzEi-dS_t_0g{+%))=h!q$>zJ1e|Ho3$_kg+YYZy$7sxkjTL*tV^Qb(wJ}Q7+NlE$ZF5 z^mIq)Rm=2Nz1pXPf$EK$PqbLA@vPdIm_cIa9RWi$KmTd+^ z2+c2Gd~H`qGv%`1E73%obclPr+@v)|ivu>{TE;fqgF6O=yiAZ&0t<~b#3npa;M1pwzAe8W5p${W>2j>n)A}H?_%`J-#_veCbZ3mv zkN1N@x|2d)?eWTd6K3ga5X|I2bsWG`wN#VNDZA3CM@mlKCJkX~HsDK3PalM%Lu6Zb z@PcocdA}LnqHV4>6`gf;L3!3(&y$9no6$VWL)sHayyS~4#-%rjIL{HX5*|`;ORL(V zjix5)np1DXXgHUBQKZFiK+v7#W2Ip~o zwk_^r$e$2f&v2y(*gQ_@Ie~avn_IxV+}am$oT_^LMLd4r?za@!zF zj9GBVd7I0-O{@`IkT1i28^k1BP9U?=RPN`qIhv{rXCcV(XImoU43^87YV-@|&3w+Q zK$<&iq>@G2G1D5DuiDDq+pY-{(g!YDxnw*%cYL7rFYQ{V^#02)x^Glk?j~_0L}}47 z;i}}Hohs7tl3&RFwRI6aXuisM9ao68SnelQhvDR4?+4OBwp|Fj&xht$YKEbz0|K4F5%Fm*^W-QeBDL-RVxXnF(&w&rBP9`BFo;X zi)4p4xtB23I~z1@p#`!#f+#$+V`yo?WJsJ5DigNK9>EgGsG=l{giDx)qheJIZH~Mm z$CT>kA;ycXZoE%oWF?5uJ*rV;*<*;0DP=~M%tga^NO7oW+mvSFVZ|YK{{VL(>>nC; zciUQd7R5^Y>0m* zH9+c2of}}DIS|$PA*Z6-MV3jPTyiH19_ZAn%*&DHIM>MP716n&vX%4j? zPaqK|JHRNr_ipSeNs_q;kdF@@=?j$GR}+qDsNppbRG;6y&eZcQ3c%b-V^v4R&@^$k z2e*DEl?KaCH2322$D1J`qK#()BZSJb5sO;*U(kw%LN|aT^;J3pz8f2sN~x$_vc=YfB+=m0TT8m{{TYzchD}5vZc!R zVutznyAM7slZd?trMrDzn)^#=>qXiywDq1FqS(c zZSN2lvM;ez*0HwBS6eO?LL<3IQ-12WHqcUB)jD;y+$3m$+?z%7muBzZRc@)gO|{Bc zjJoS=x*Xz*=lJRSNSl!RW?IbjZlSc?ZVgM90wOq|!zbQRR87deo^`DQr?#mEOjVEoSiNvP4f@4)mMg*)DYxsMAaXC7J-tkgfH;TtPu=;>&q2$OXG{_M+^?ayw&?^z_!B@O`u` zgGNocjsqfd`74eN^}ss?6L-d^`RfB*ylY!VWpBvbvWKZA;$LW&qq5qO*z)ASWK*z_ zgmIPli9H`JW?fZwxLzA^q1N$AGo36p@d*(1bmRP0ZtZhZ&1PzNz03&^GDaS45M57F z@2snQ)XYwv{F{UbP!Z84jANk+VxptzCG;(PhfLF?y`Rkb53E}YW~d>nW85veiqy*l z^Gp{6TB{^BvZf%+M%fa*;9C zQMmFPi_4Zn(k1Uu1>wC*aGor<**xKvA}ohj$wgXa$F&TUw-7vt>2J-E$U%J>d{ynK zHFepRo?GsnzPdNKP78p`AehMw`;<3)g}>_h`)ex4Nx8UftNu-yFC({K)D&9uJ33oV z!MR@TuxDILndaoZ%^%$yc}p6DylW`>e^0DgSo|K`f$=$QkrS-t%sG3!xYa-Z0QnPl z*R~;(b+Z+!Th0`4l^k-6;i4krxmO`Q*)OF#J9kqUmqzb0hQoA_EJL_9Ofp30OWeFN zA8VSoZ1j7TnXM)Ed6CC;9n^hHh-;<(WVbF2Hzn1K`^suLs)?5VuG$52ZCYPS72*ts zE)yk6j!_(X*5hNt66z6fx=Y%wF6~`49P6E?Mbg`%Vn=wGc&<-=6Cj6_v21y<15+*>>BsyQEcnPHBBFrEUKJ?<`xzkstfX85h~1IXd<%83sCM zPTK5z+^)9COOx}sj&%J|8j6zj&h;g?t`gmS*He0v&Z?z#$rLKPOqn%{n(~pDKTGAx z=`+SPp446uBIaV!+m9F397JkPM$ohtmA17q$)|7Hvfxf~{{UE~aN7R!hTc5B=@nKk z>Gz0~hsudIUD6@@MO?~gqru^j^AmicBqf*GuNxnvkL`<`cCbqJ+BF=>EWX2WEYVQZ zh7GbaY_h7Qmf}QCZ>N1}7#N7&q3`C9sySQM9yX4%7SuWVn{2>j#oJ5=vA?ohY zR98n4<%5-v;FzREA@ss-EQk2<#dLWH}Lb}FjxW-|Iz_kPbm>{mKo%_7TlOS}uL z*f(G`My_nA@s)RWpK2<0`j^y|Ry`KZHpx(~*K>VFf<389*%sJmGp6Cgavqhkw`Jl< zZ;6zbF>*-{y^bT5UVAw3sMMOnQ$w@dWm;_oFmaCa(8!x{?(2PAq+Qy@;;P-cViWGj zl<}LdJn^X;t{HD5C0^B4RaM@;x;@dmIn~dG~Iu53V&)>0K2*;+E+*N5O>xe;#91na z3WeB-+3Ni@V?yxpI8B<*vrB}4yOg5OLQk@&pSHec*`$0Sq35^1dE4eT@^*QJ!Nq?_ zrCOG|lznGohA)u|;@fS#?>Ic7Ut~tJ8%4M6%l(#pOs4v_{{Tp)i;()nmwA&3@GMQa z@i(GF2n7eh3jXgD>-s@?<%B5D2 zz7l9iVZ23<<&U&J@cTpeYMW?{$~kH}-Yz?o`uld!`zO{GRJg91I7;hCZ9N6@0bT3o z$A0h8RPD>arUk=M{5z!fi*?dhLB|%QAiW&@X}DAs6XjK?uL9kCx|3%UH00+^XW5k$ z*>QKT2-Q_LePmzGl=H%eX}Y&&ngS)?Pt8b9MxWyITqcFVP*7YGn}3~AE0%`h@(Zs` z_xYM#7@8iYB+cK!LtVs|5`DZam(xk1B3qWpD2S*kil&HYi`c2=-M_QVRBq99PNcN! zi<9^LJw6mecQKIj3JAOXOS^K=6JOE+w9O}T19>f%NG6J4k4ND+h^niCC@Lo8Ue7nv zR+fzZmFM>)O-ZI4DC^j2%jo!gxETJ$uEPt=ed7NBqb_>Kbi16c?$#D#%xT7?5fK($ z(#Pj}G+CszoQDm;d8B$Eil-3Qc_mz_wqFLlvD@2`bsy?H=>#~*hxj=*;d`aZUpk;Q zhQpUi+azioFL!5E>0}+tD%?n>+N!TuqOR{UcJEZ}Pg*5bT(QNJ#hT5dT!2v5QD?}w z!#FLBmzyCO+)PE!o$S;c4;!&Gj7~r%-Q@>|&q-c1Oqq-!!x3!-e7HEPycf=tH5*Y& zh~to+Y*mmg;DzY=&K`L|3M-CqeFg1|4ij`M3 zE5r>uZUQmm6qHWXUYbA9o+)y<-EL-WGTfHMwU58dqu%|$o~4Fn%M#_hjl?Zmkj;+` zw_7F%h;kqN)moOfo09$!ZS(uz9T`e}m<|sLIJ8_mg~4+n>Go3@Id{{W?1Zjnk97kY=Ps(prs`nBTe5_NxKRkl(bA<55YADv#4 zc8x~c8S*@Y3qM%-s!rlX9?vjGltr?xVh=YAh;bU?{?PmrHL)zbiBL*~2oCd9lt)NK zcheZrM-}D7vF}u!y4LK$bA-vbUV#zczOk8=TXCj~w4Sv{9C9DEMfRy4OulVISoUeC zu14C%vrTYGJFJra>a%p_z_o;xL@Ohpi~XXY{@Tx$?%qC5Tv(;A8WGm#TOr7Zii)gd zlWyh8(20H8D6)7pYBD|V11zrZZC7{bqZl3yml0zJrlV=OQB|;2th(x!(YG?krjeK6 ze93G{*RDc6UNBwnc&R|Q$jdB8o8GP6Z~l%uWOQ(kR9pzm_o$3yAKtYs7y3Rl^iuM- zTbrLWfj@$*N#a`w`-)NNQ?ZgQmOE|Hn_)I?A0l;!AH5y*IL z;v@B|Ra-m>^xQ!lB;4{3HD>`Nd|Me#ET1e-qjDIWeX2n-(`?&aogZWop8nBCiwFUm+e-|i{NvlRfTAp zy10tF`)+R{eq=4w?0gCIw1Pz&LA|+m`DqRsed^v{5Zkij8n^l|X<;g|<%x}LQl7St z1qH~%jwB>P{cRIO0!?Lsfoo(IB64`VXtC8~3+X1n3~o zCDU0oZJ#_Nw;~sg?}*k)(C3aNy?=dfJD2oU$(c1Iv8yG*BwWB%q%lSO=}BHcqf1{% zxg7R(l-t;QS|AahQlEPA_37 zk>WIKY_-VKI&MdeINPqwZOFxl^A8Qd%C-Lh{{Uu=)ogGiz_wo?%eTH@qsQspR1noK-K-G@LF69}f1W~e<*_LX1%~N`(V%4R&iIjqn$gEVGGV)q_dT#2L zsrrFNEQpD7do-<+W17!FSRQWqrLK187FwS7qnZxv5q$f;I^;x!&Tgt@Wo{+PD=#9N z+UU0K&~(s~&ItIc-6K`5(BiVy(jHPA=5N`nKWhqut~BZJ!+42#Li#y>&#tSTp*HzE z%cB_xdspJGN^DJRS#&FNrpeasx-gLhVpawdd-COG?5n>1F5ope_v!L&PjKtxwYIietTr)vIXtqWjk9*?r!t`}_=Dhv4d z%vLgNMmXim^QvrikYl(nyB5y19ERDBQDn8o9I|AGiS>Vmp>#E}Ew=Y?2;qqFUBm=& z5FEKv^HhD0Zu)ajUf5tF1YmLY+YxZlRY;3+{_334HrW`=nJJeCh~wNg3zUn>7dN^! z9B}3EC$)}8Y-k)oj`i9?GSVuWg@a9}QIh$QD_FMZ>FqIZg2J@fw7{|~tKz~ti20K* zlh0Ic!0hp_7ah;E($5DoXp}NG8WG*Szf!7}s&BEJ8yicIn+rtcNxW{{D)r%tfOkB!Sk~i;2=K8t@h`w-DMBIzp z*+Wsa6|{yo9@61)Iz)(nTA{Uc8;$E(g}L~LX-GnC3cL%LDyo<5Ri+xtOxzi2*2IVW zU`(^(BJ~&ZuAfWO_wY%H(D1}NuftmAzxC`8c`xytkNBP#ip7JjJNh6`s;C3 zm3)lE$IS0hZ8eRTBR4V?dPYfvbS=tbzqXDnnvxdO-)2pZit0GKjmT<*jVGl^Js%}w zO`z)sb&cvo)j~+@7#;py|t5nqQ`j4H*v9NsWraPD?ncj>&3#=V?`Z5Y-<;1yq&MQyOnpB z#lOw_w@^;dwSr$sSZ&jBge*|^a+>=?@l+3l7S$tJ=eeYr0Ml@-$e`>7m{!d_Yq=A0 zs`>Z&cklGneu1~@-K<^Wjf%xKTTO@ny&EC+9#m*if=?W+-4ENjzKSJo|QlNGe9 z%v}RvJA=^WTpfDhAYsN5@(q`GvFQ@Wz24@kBUfv?-lElzu5#xjM4d zs>#-FMhjm_N?^QA6SwU?*1FP`y`ouFSd@c=l-;>kw-F-v_wOH%m(^7Ls~Pd~VmOa< zxxW|feV(ecJQ;O>+^v@_FM3se6wp&`9PRDhxVt_?75b}H-o3CrqIGPPl0xwVst?C) z5j#D-$>vJaopN5Hmao-hjyYrd`3`NO+}E0pQXG=$a+7g~9KGzSefbVRkWHr!Go#Eb zusJ*W-8Bnpxqtx1miFXxky{wYnYPC#%i8(BR;idQ z7s$9w;jrVP+kC=Jl#=2*^+uy5y5xpk;*oHOv23Kdh?mW5d5WOzTKdl^<%bowt>WW& zjFN1gb;17tKRTst+_E`Dqdy>VAIna7es@UIo*hCW@wj=9DKQSlyRjjcRfvUk8HBQetg$^0y`nSN7lb38+bMNFWmOZ`}(Hk zH7z<|9Vl!Nq7;b=3%T4A&^z~zOX)>T%euF|&$el&HgjlLT6WmfCim|*7jnAl_9pW$ zO8#U|LMNK0Rmk<2qP9WccNz;Lao0`^`3uaJc29dF+utur{MFs@t!ao5S+~V%ELK~O z8Ii%ZSrit~4^T{To}7x*j;{rJdx+FA;#b#ccV8lm*M%3ooHuq$e zEYX-3oUG2K{magXpsK6q=Kw_bm1|e+%SVQ>(s#=vq6)L)Dzu7c-F!^_O)p9L45>TlR zJ7v6^Utg-W&eJwz-h`d3E>_%zOnZmJP}?4!85Tr8WQwTWZ=&t@D7V+F(gtI(NN&xL z^>E(m68X5-cWF0L6@RG4likS532DcI^=_AL^MQ`>$$K~d06#b8sV$Xl%USFzQuxlD zx<0}p;QhKms>PBU(Nzo)cPcLA_|wp+v>T@;d>?si(78*IKA$(~-Ba4y)8mBP*2A#V zn7%(%SY^p~9-9Kz4sGJv&ZN-tz7MuILN^OeUS8VwR|hWheMH%NJhT&dhB|Lh+$~Q& zb5b$y7YK3Qiktf?h}cHhr|er^N4v5IBHZevd!4xpa!F%;;@$5eDyXlo;#EbzwyYAb zTEOpig7ND`!Wg|Cb7Dj;3z=eHXY7?`O>U}tSk=^HGi!lmZ6YH@9n}0=#2rcMQT_Ck z9k(|I?hK59jz$@igq*#;imWDSyOB6PF(NKlL%ziq%6?~{wqSl2wv1bXp zVRrRAU$AQ3@u4R15~{Cbu=3NwWxJK*{ewkUZ9v>MOhkln?^UwugfefDZK5RZT)ojM z5Y!u2h%7k}BNNI=yR~^}7WsaWba4EkaK3d9BzWV@Z+^9Vh^;0ed`+Z{kaCYz#)vvp{kX^Oww|Sh^caZ{H;{~c zy`CLJ*0$^YH6E7}H*hptVb$b6{vo}WyjKUmopQ$6HJBHY)VHzuvz`8RY+8w#_>(bN z$Ggfr)aH1-4m+YS*&-q3F z%bo5VDEijo`I(|9K3rbUtcz7OqE+%cghgv*@ypq*BjB;3+|?S0u2QX|8-iGlBT!yg z+AK`%8j{-2Jd-NqZ(WSUyWXvhJ9w5-2~(*>2Y=?Pt`?WJwjy3GBzX6#zR~1XTUdx< zA|sDYO0Kk+JMnOF9_}^9+sUR`{KVXJlgB*VSg>K+Vv3N2cE=xfn_r?!7%q5}zZ84S zg@P{X9@BcrrA&F3@g7{Vr`O$kXtUcSKQ>x2+ew=CWxBhqHd|Ym^MvNF5@ z(zQsaLfbXZPN$(^+ZZz4@|_G9f+nn(OYoT^kW zBQc|{!Ep_yYunKX%*2D|Whv1+)^tJe6BK1tZ{h7Ni6^(SimHEYDVHzA%TT(w^7PYC^==r`PA<$q}9 zHM=PBQC)F2`p|dot8|JvJ_TyZ(6_aZpWYCD1A2sYJX_t4a)U9qAjNzX(SB*C@BH;# z+dKkx!L_qxB3YK(ZZci-=e{j7itg>lseXj5kF_5UYZGuLMz(1QcSxoR*}We?lrai^sk9E99633^#QuJ{VZWcx@F9+Y~*Vdae zx82&l-Sp()o=vHvZaS(W-dLw#r?ymUt-J{DuUKymy?oQL5STXFazZ|U(|gkaR6$+a zlhwYdxqcbDe%SV-2KANF(RtHSvMBSsNM0uSxkQL@`K3oz4-U6=T!(Htw?Mn7ZQ$Ue z3u=8^28Yb2!ZcMr-%e@rQ!1@PryDlle`a6N_cnF@<*C((lVln1jO>9XqK@v(;ETCS zoD4yl!RQWl5bqBN!V<$?nYgY=9)G1cr!8*q924qV=lf^lqa) zN4ASDyJcibyj$i$k03}&8`%pn%aoiiyH#D&Tg9=*gB{Xs*I~FV@dUYvvUACnUe7Sq zD_cnJE|)IjQcH(|+i*XuHCCV2GNee4kIdEj^d|oR zI($jGkDSsvCrpEHxeg$xc8sY?uCLOht5P;iuW5@EnAaBXdT`YXdqFAiK;P?8LKXd$ zG?p-`?NT+Fvg8eHQ~EERyk6FR&}XYUb(KYp0vYa+Gxe=^gXhlOp7OD9~g zkvGefnV9{W<(kHPwyCGoTzL|A&0A%;JnW~f?cSekv2Dlf{A|6E636dUWJ|WQzda@> zX@7xYv&zqmyezc$%@+Rvih^$58kU=DoEYut*GRfXCD5F8&#Ks$>rmFD@Qd0`tjZ3L zx9R&+uQ&eE=Nac)A9#62r;5gJq=}|p<}Q0io7GI+SpKlw?1`sbVbUi=w0 zv9t@TmrYr=7dWI@6LFh!7XGbWJ2cn>`eFlYKdO%RC0<{?V6N9jVMlTvfcZJd(SFX}WEWf)R z_p?i!!M41E!+VP(o5hLG!enFAPdWKE>ekg>4ceDX>PhisNVYg}3|oOE9YlnptJtVU zqq%*1_ip{M`dh)S@jm2ZHu+~49>j#C%HkBF{FwPNBH{engShksx3u3Fhan=AII{jY zli2s09_wTGt6uES!RWSev$yrWp|_30Qp}Dm*zKi@jWZ%gD+FHAb7k(;BdN9}XVW6V zf2ExxbDZX0G{i!9NU{8$SmG^@vs%BwS6xNgKAo`G-?=w;EhG3zJ|aRgGRR{VCC!NB zU#m|VipGJ^8bUUJ(KcN@83dcQMHY%kSs!$7; z=GB!IM}Y4&dyCCYRIwGtlj6{NNGWps+RvLriY2CY@eli9mi|B9LLcze^e*z%5w3~n zW^I+0$B0F5WQ20+$`haAqp&D^bcZ3xb6JqtV&fzxCGn*l>cRomcUPO`Y3IrG{{XyP z)ASVqIm+fVae8>3lL*}A@t6#d7se>9{{Z`SpV~=IMx*dgxx%>V>6Y0M8hPlBtHP_) zf=Hje7Of3MW`^~Wa7Ks2xQ}QnZT6n;QYW2jEciR>mW`PrFjpjRijTw#);{sUCKpz(d}gz>7P5|A#^T0o&=tj04f*d>>|c>;)7eiN&8)RG zwf1-gB0PDb-?TzpGAiW!MQ_xzX8KE-OES$*)CO&3g>9%3R1td|f@!%G7jmMH$h((+ z(^+cJ)7oz8Qzc5{*zIpQX_U_dau!_aaJ{Zt+B;U#=Y6$Yp~g1F;U=dvWINRqQBhGo zR8*{iuC1~(=AhNK`Q+X+EdDp-3nrW<`EnBbHSdS-R%CVmd8RJwzk7>Tl zYBS>^Fgm2miRlID{S;7+jpEKDkqJECjER559;!=SX)v_utv5ZCMB}nR>Y{ZoP?zQ} z(NX$ZVGI)HETX(F2>U`^LtQabA%?8evh1#j$!S?}#>K(-#eB>AYWVP1+Ju?SrqAWN zZnpGXx@mm1Dfen+f8M{8vTh9s=AzLa0K&of5og`iqDw8>dr8xI@%e5wz*PJ+r018@UL- z^_wB;s5Ox*sz!;GCzqIoxQ-QBdNN#Z2rW4qSfi{rHcMg?#6=L2K&C27yVi)CzAe=t z_Q;8}%^ysB>AQVXTUgmfXjji7veOoOd<$N%l@8ZB)ces(BH;45?=~aURa_5Bf35Ww zoBsgIT>k(KX?U6`=vMiS1F4*xQ44zV?pka8K)C+^zd8Q^jI9KPe7%HJPa_#<%gE{d zCasxsaR{-Y>*lI*V}xAT<$SqWNioh_xZ-fBiixR*mcP|6O2+94KI6`jb9bs#f`WBh=_;Kqvu7n1O-Ptt^>6P74W2io^rTc( zByq&Zvh&qXrL>j4MCfeAZRbabJtN9BsbyEbWQ(h#^&7#HwAK7Yl~wC5=)0 zXut67Q=58Ppb&PZd%G&il{|z?o4Z+~!WMg<1;t~7Rn4ZPUPc8HGaEOq-9pC%nFXF7S1jxweZNmAu)Z{sq_J@)j!7pf}Q)CRL zmH2x}L-Ncy4Wb`pd8+r=w$`Ct*lniYLR9Q8ioSP9yEsMU-XS0FDDPpK`RfL(7ES*E zBPW-ce8pVWw`=OHTC2%TtsY*K`1Y~%PhSwf0Mwm#A3APTsu|0Mly*@Iu|!|*7dBOF zTkj)%(<&t3lcGM=X04_?-HQCBxW%`9DZ9^DJ{i>JukzD&nmXp4AI*%gW0segFGm(~ z$E0#1{wfO9+al9XNC|N!&9uo=$Q+KoWuSI-1H&VU7GsQM(pb@Elxq2ToY5|zk&m+5 zBh&S(=vicND@hfyP_o5vi82vbH8l1wXi38ol+q6neMCNO^wU&r2IsOXOeU?2H2sn@ zh>9wynkHQ*>kpH1{C;Zo(l%`bY~D8fyNrooo7z<2@A`2B?5YjjTW7m7D&KN$>BlZn z9%nKyn@Gt{MznHu7Zy`YTD85yT~2!GGT~$KVTMA?NRKni)ygCFy{xLJ(AL(8hoetm zNkj*5xE~o%u9w&bu#)$T$FW`h3thyH6AFsILM2#QdqBIVUwy)ASh(UMd1O@Gx~sI& z@m4P-yqvKm+ld^b__{JJG-Hkpat=PO%2s>=hhdXA9C(Sb7Q}eOScteJJw@))i5zhf z9%n?TGT~>8-20U911E==RSdqjWnI&mdg#JkD>Q6dcCASKi+qT{=M6;x)!e>u^KPO7 z)y}zYO~j479ece#stVK1VB4Zxu9IRURWAvl7m1bpU*yZ)_G|0ajZP_v9iU$koeyx6 zAQ4bw8JMT56|A$T^#;u?u^Di!dZ0*KA1=^{{cNh)@SfKRVcEocG-E}ddp;5(ZdFFp zc7H;mzP7-ud2|&YRv#(_Yf#Ym9j@RL|C1I+Y_L z9a!(OE=c9CM%Nk$NJ>75xLnKHojWYPsEue3fwXZC4pewCpUb>CG+$fUq^?l(mYHvi zcJ@Or#gR9Nm)529@-J|87wIbt7S**hR|IsXfP-qwrS{?ZU$j(b9_6+B%yW@^zYd&& zMU%w&3qN<5sT;1+r%28&oq&y^b|j(^TDKDODmv++nbIc(lqMHdn&vW>KB<^&Q0McwON z>+#i-a*Jl!W{{P2AVh8x3Sug~W#$q3RZ%{g#dD`Ua=A;m+#FvUkl{Tg#PbGAzXbT# z^j2o8aIwdUX2w3F8bbB`ARA^&Yk>=2U1BObxg2>SI*N(sMBLIZXJ@8>uU;2BTGnrI zZ@P`^+A?Zj4-(*rzR3BL=&Gfd4qTo`9WDa{$~cZZG4@1%ik#g9TDHh#k;j)R<)O4@ zR*>mY))rr?p>bZd~<8iz`WSy)NVAz@_{{S}hAJgaf==)~dYSpE0n$#UjbGN-I zt`wG{&2Rz6{X{6cs=m-0Ik@dtRarFExWo!UZ+*f}R)G_Gn*5)qu{|}VA%923(%CfR z*|wI$E*N3>MX6V_?bkV|YdU?Wlru?-6Sb`KJ2I$`gZ{X&Z1-r6-gQj27`C|(+jZ%J zrigpdQ-1G}Xp?LH5gR;#X}I0Aefka0THG8aA(H|^HvG9s?w+i9TlR9whOR%+?X)|@ zMPT-$W_gq9ZP<>6{{XTp`)Fcs^lR-YzxG{Er)B)EI-j&F8}02qxK5+e)@9_S1txnX zy~7sYX~7oU?&Krg`LuuF=eJjOU9>ota(vZ?9L7r}zj3syxl?zgQ9D!PN-%f&J2uLV zIMZpTUPVpW1=R=dqEtfv07tIdt?sLVVz)?A=<$223aR>PXf-=-U1k3OV^-b1lj<$H z@fPa_B-nBWOgRmU%{&XrU`f3f^D0xH81kJ}igpR4{3W(1<5t<)26m&;L0Fo<$J>lQlm7;C4BbCs&Op8v)ZzkJpoZ%1mW9@rYm!djNW<=)-CBXnZSu%V@jh8Ts^$BuDR^XQ zDR&Firc4`CuM6O&vJn#!RbMyJyY$xs;oFX#&rj+$FJsC)iu-HYeVaDTE1hM2$8al~&Ivb&jwB*p)?V!c{{W&5 zSldt(ERsy*BBHmOe7(E5xw}4^`lD8dIg1DTCeoLUcV{`43tQ8q+uRo1af^opxw(W% zym^Na>Z;3YHYqj@aXuuwVcDb_iL5t|HE9($^?GQ}W|pm{XbY2}Y_?tE0-~d>vI?T4 zK5VynzAsf622UC#Xm;xgpAMB~#NG2Xu^Cv84Uh>ER|b@&Twxlr}#X>KYM zOR6sLBSaJ-MJ@fm7GIvKy`buyL44HLol7F~BPZs%jT%iR$UzBEM0ajfOm`;g1-M&f zSS1^3<_`~;n{gNRS1;Zr+@LcM@EGiG-HL9gyp?w;W68@SXFe4=?`QXx9$SgtZNb=f z>%x?#$vs&T<|)6okdm#S@X~h_%TCA3cW>b-=ls>0ty?P9SNO*5t4`hSjG`eEk)Sx> z{9MSJs>5%=durS<{{XRf4>ETL`Pa0SY8^A^6b}b263Qxum$j;|omGS{(zQKr*fxzF zB=(|@h*aIDnSPPf>y#xOM4j;;u9ffJ8wsN=hN@6#Ho|zr$8gYix*4qXA z>Ru;<9f_`+2>qI|oZF?5uQ0q{R-@eTgSIp)x--S|oU5sKXAk~U5HxoZe@cT3Uc=(r z@$qPA%8#;Ylv@X9s3Iu%v|>J^DiC|ZHa_*?TmI9P9RC1>EU(GIY(8-+{{Sjx)->Ph zOIAHgJG^|AaV(d%TWOW&B5b|VsK>T71eCCt=%>B8XaM{muB(;|MiKc7`O_zM^Qa4? zqmTXFXGy5&4ay?I+Ug739L?A}FOMzAhvzpzx8DQX2Brc_nVj)n>N-s#{gq0LcWzpY zywbpb?#KC47k1sKy*Cvfngjf++dXSYv6HE96-R3u>6aE_EB^qqlV+&lnP^LNk|p1h zm)S%55fYrk{*N7*M(rJ***>-AVz2F}i|2%e=FZ8h)8v-#VF&qDU|t<iB(`a>%=5 zcpPlTIN~}RL?_O(BCqYIar_nbcW-bUYf0#qOgcy`AcUVzb(t4gY9ejkCBoob-|tk! zHBO=Njpr^=x9>?+{q+-g!C80=z8+<6M#g7l8g|~wh!ut^f9&uZ6#Xi-7b9kxdx?HS(N zc0R0OmK(nhh{wCF%WHpE5U5TfZgu1=zl^sSSyoj;-)oNc8%-0Oj~Zlh9A3mn z-5&KuTs)J`y@@m;=O~eM^8Wyay=Wq2jl$(3BFVquuh%r&5%H$@mA<6W5AC0BmTP1L z7m73@L5@X>c>tS*^!I**qc&g+rrw6dzu^f`NSk`wC7v9%Tpppvz4dALgUlC~!b0zP z;5zhJ)`Q-l(yF1E0w?It^wqDotqp9|;yRAUEgPD%9qjs2$!v_9TN5};)fT%9qQgzS8fHAhEyC&cM=E1)zZ9uR@Ikmk zi``^P5$5+uN{@6#(P3>lhgwKTl2|MqWPFj67EgBAvX#A*eio||JBka(= zt-m*CvQ50)H^jEY4d{<2lzhDj@pjdgWmUT7Me*Kk=ch$$*>zioGkiL4Sh(Kf8SS$oe&c%N%&hSTLvLT~#p}XP=sZ;7<|ismQZAhPX2G zC|Me3H&sY8BQC6YXjU}LilpO8_p0o2vXz-ah?0bGEODt=*J2S7%7i0KF3xIrGP!!I zp|!tw)TYto5!^2h{2wS=Ar}dvLCIIEbyu-r-G$~ea7m%&a}s<10OO)J9`UH;-w=n1 z-WrXKryFU^ym{jJ(c}LB{B`tw3&e3j8!W|Vv4!_BIpc#IvdEkH!Xx^{OzqHZms{2P zOv`K;>&9)Qa#Ub>Q9-oxWIdfc z2lTWLhZE$#?O%@AXeKNmp$7=X^BK!O*LgCF6mpE{{RgG z-ZZ_68vz_N^_*gfWGwiDSpIBxRSmAZJ|*az!?6m%YSfVv@pG*=6Bf*VLTZE0`a^6t zB1D&@XT)9oyeR^IwFbUw0Aly1Sl8Yuf6UXmax06Dn6VCB;kGmW8e{N% zo0|(2oxhuIXZ$r4r8bdh+8u$by`Az_m#~1-ae;c4lqp0aMUhrxPj?e^-J2yBr5|dF zoNH`y<+fYf`Hmy}p_bc&k5KtfXPUI;#I71~?}rNdBPp23M$<%O+A1#N zQpia>+v}jZw)8Qi(YGSL**3C@y=B>*wj+lEAl# zGA;m!hmi=54mbh)XN+Xjrj zLIit2Y2;))#et!IbW~cIiEYtXwQY}jdQ0TqEVsnSoaOno?K!o5DW-HJTUWsmC)|nC z11Suln(kfRl=W2|Mh!~O$z9uVHN22bwNwcI07rW{d56HX9gP)Jc#trny`LggZKELC zWNORZwcRsmU|`4+5n|dTMBjItE?Tv(x{^K0aX#Idh;A2JM^!=k5mtbT#>`z$@W4TP zC!55w;P8hM`l; zY;B!t6pq*W&M!)SQSVpHt*C1qn)KR3K3_Q-&WZ~={Ui42jqu{>IYQ^OQd9X;IKCK@ zukRNFPm6gM=@#(;IAEoc3MpinIlnR?3+abCgB(s}JZcKoG7RG$Jlak>Z7(hm{xd2`Zk^sCP}1DFdPbWS+a6r2iCC&*;uEdN zhrRmLn%sGE%j(eQ!iDWpMRs$#RVu<5gue<1^KmQ1cK-m0LZx(Zd3cnAmtxBtvijBU zr(nb__+OT~3zkAXLqk^bkv+^nsKynZ1+||Ty2}>ZA3jBmd044kr;X)B8oM(raU!Vs z@}BOAR7R=rvDqHZa$)D3xkwK)P{Sqs>ZqTlfrt7@c6v^5S47-+KiZOq?B87_pNDk| zh3YI-#K@OYwuP}~3+2AdhY)A7POx3kEJ~XOnc34)22p6)QzVb&AGi?x(XZQYgr?5+ znH^RGD;dq)&AGT$DkkkQym;2Pn{lsXiyF|1JIy`1bzDj5BsXi2Ln7-5ox=v+7S_}9 zV&Na@Dr{H6UiAc((eBc9;V+v=Ez{4<^@g+Lx5=_C$HbnyZ30E47YNO`gY7EQYevys zxhg3LHc58SqXj;AI4WjF~?cr zaRf)+Dg$crSq>%5-Nv6gD&aA;#tf(Z;cddDTNUKOvtkqcI9;ubhq-&AMUURE*w<#p zQ$(*hmrpRzqc+>jg5`wpeDvY{?bPADKX(!naglEyY0>_5FGnq;GGmEivwUsX>%x#E zL+r=Oqi)w}MI4BnEK9sXBSNTGX4~PTkVJi_E(4G1jWvDal-pWP{*d? z*&OU}y2+5b-3N%g!(J=ki2g8swM05vb)U!IoG#NzlQ9Hi6wHIFP zmIK~2b{Ouovignm>A~4*|F?S@mwxqCTjvH{zmXJv;ZiT5C~hJu|6YVf)>y z#-%W*p&^&tWYH6UDyR}~PTw-u<@>#E+S^u^8P~~c?a2EgThqyM_NntXY4_bvqxQR~ zu7D2dyXwz52Ge_$8dn!JPf-f*`4@d??UZfG*{!2nB2IiGUggI{WtUF@e3P)iNqZGh zcXoOClB&NAeG=A{+H+SpY^yfgykQ_gRa8d-5J-2g#ftqkR%w`a`4)1t$SDrrC1mt? znEoZ<9--#tA{=>p+_m*BD^~XlEIe4@FRfLiWvZV+;56KbWI@Dlm~rBFOQ#ylei^#4 z==SSQoQPz+U&0?;lXmXZ6#c)NlGz7^0mps2X0AQa-9bQ}+A8poSFv6AzC`TxdQ!CT znwIp6NpIGA!agMKfxD3V zQKd3N_p38(qiHg9?d6ACH3TilyT3(l%>%SorfFMAUlFU7#G?wyl&uKyOV_wj6Z0Zh zz_sf;dj#q%O5c~X`VYfC|5gcZG770WLf54ZSlaKjTx5QljQxbpSxc@jiH|`<1Ic)N3*oMY+=e& zTeLC9A}8%`e-2dUHrVuvq7mLi5WLyMkRP?Wnz?Xe#UhJvaDpiK%UYyZdP(*WXmBsCxTO%EN)EZ4c$s!?-d__o)GC zmq*ISj^PaEJeaK~yIhzk zg?W!5GGpS88jL*pfH#7pLo<4)O;u`d&i zDW>Db7lcTBT2zKeeXdb+mb+i!Tg1m5#V)-XJh3#KNdzQ0iDFp%>}%!j)2|GyLgMFgIj`j&UDOOI_=sV%kaf*mmkUwqLes@jzi0pe9ude-KHfjbmPROYM**-zh>YKJRTSEhmW<*AuhM@B?<1=-B` z5hln>Cdhfd$rW9;o_mGzZOWx)Xw`~qXWGv8&Gyf4IeVatks@WLgmT~Q5++Oe>gAuO zAhbl_-KIuhvf>dqC6~3#UEg=4MyWW~^|xyV87`*f2>p_)8(eCL)@xUdPIVh+Bn2f+ z`}g*4udjG}6@{5jS5oC@QuFp6z z-hH8Ss7uKR;y(2iLAXCA3?L^~JAl#(lW2#(%+zO7SeOCSd;nNnZ$V%!nB|kF)8me{R`jr^F+8{P^KJggkNeIlGxZ zKKjV}6tFRG*Qp|D?}=+>yhVJO(4NYpWaeUxDT0D2`_`UAnj9Eoc($TXd57X1V<8YSG}$qGB#=z?Ry1vSw-j06OQti>&TA#Z#n;-ul^g%2VwB z07spHmQ~Bf*R7Ee5mh{1GMavJF`yl%_&028W>)tJHj{MJw?V{wO5Y8UKb?^kYbQJ& zc48yG;*TP3+^Y4}MHBU^B2LD9AMN|q`$JRLI`)f@#%SBw7KkSiiH6?AbxB4(8 zqfhCUu3TuX+&Q{<^P3X=R?4nTE-;z^;VCD?PUZWOv}0^+J50ExY|xg@DD;s*aK80^ z%iE@%v$msautkO=)8f|lNBFCk)ioGzHpZr$GhVjoTQA(HW5t;xc{dJ2zumM&`x2cs zr0?^Ucs5C;{{ZR9KXnmZ{3hSrLc*AQTyziJK=*$Lo14kX`ht9NO;_`#>f=_;n~K!a zS)}?&_V#(JBQne+PHa^UPbMB2GnnueHdy_SkQW5I$ zXs;LE8l@iBCHY&E-}r$VioELv0xOH-D)#j3RCQ9W$EL{JSdFe4i^#ZjX4o`_fQ0HZ zbD1gOPQ`8zr`|NId$o@VJNCDb^hi?PY5{ymD!-fg(B9JZ6O-OH!Ol|ms1X_#g?Ve~ znjV2|Y`tk0CO8dGR+dKT#x!xKo-B(cBId}ttMhK(U8lMJcrBLlJOQ_XBwwtIx-`-nHNoV5ZUP#a{1Ns zPqX7*On1JjM(r&{+K!!%zk_9uko!bM_llXW+M~7p&uNAw?F@FrzU=@tZp)$>f6J)| z+J5Bec_=0JYFb;szYyw3!Bd-bykfi|AHw=C1^wAF3ME& zFvTS)cZhL^y_3UFj6Z;_)g$F?3s6jLa(T{gdZV4cUnLJbq7nY`roIo;Fo(ge(E2&Yw1m@q`SO?4@xVg{OWbKxRZ$+T)?3o(6nXN-((_r#+&8K5rX>D*LDCu9mP$^pDjsi4Q4lDR+%$+eQLYu4G}FM-ys?w4Kp3<>#o;EvyvKaIPFsD zW__l@zco57mm|bu6C&T@`nq~)EQvXZW%j86VlC2D zz8{VpKLSBe!v4LrU``DS}-t2@v;2fUH-C z?i!9v*=NL%`$nd!zMbjW`YIb}7TYG9JDcs@?q_y&!Fbtug+V;iF2nUAZuQsXb+wB= z_}Moela^dWJ;lH2{#Exq7fhv%qQO=;d&`croj~_q+2UI-P&FOBq+3noJa8MuQ5!za zD}Pm0Q|U^ave}04^R{{ImmG0qxSd7IlMa1m`F?YL^>?5d4SBQ6hL|-U8;rg&2#GpP zJAC_>&Wc0oP3h{Qb|U`(TI{-WR9lOOK0706x<=_H$a+iX2@XT^X)oE5l~2d-C$#de z@4%{cj?Cxjk|$b4kr1~o;ZD%`Xe6TZFHbdW+0j|l8fAyE5t8$R5#b)|nR9afu>9($ zcD=J^ty8V%ZDh(+*GaIB$R|D{zm!s`X!_=FOIg=(J?(}o){LnkG);44RGhw9SmY!` zIdxT=v+bTuy<0B#QeplLxJ(5FF;jR)dV=rsW7)rdn}6nzzG*2F;@z2bvL>Rr5-^i( z^9d7uUs=K$>~9+->8XbcV%ZhVia{{o`ZDFq)mC(m2>o56Br^gmD3f8f(kKxLW#;s* zRad`jm14QQ8c5>vWTTeu!rkg$N>!g_(bj!Yv)jGu;|;koZL!?N!poPGh<{qSIc(sa zvAA56-dUtr@{BXmo)@>N(Gl~q!6jAh5eMzctd=|?R%8!jIqyh2ADz606 zLnl%;#fX|bmvA9@*)$k%-Ttm<+a2JX5hs|qNXAY>{hN*n4nFQ*R-!GN6Ni3_jps-I z099RavdUX?8^}wmm60xz>FKLAnHZK_zh;*S83>6MLtjUO#uCSE({j*gP{4YNbP=_apj-{;t*1M3yDidb$;Q+oV949`vnWB-+HKZ_ z0KO4p&FS)kS~}%0MpUeF%Zx+ST>iNdsqBWWM3b=Fh$HQja_#QP#PyiTy-0T!+B7m!e4tH)%^R@=&nL5 zA6<5kFCOBo({y3UhTBZrge{adcFQYaHu1Z2^BhH0Mfj0@A19u(#Lo#mOLDj*yMv1= zGA`u;X*VaxySpln+NRYXOQeJXN3JiP^+R3Y(0V&PEAh}Sp4tWEZain_zbF^6)=arg z7vit6L8n0s9T?Zxam_T0-9ul-#{G8F+JZZrNY@DjIQN8yo0^O7)(DsPAvrBc0xP^u z8ZRTAo0Lp@!{%Mf=|rI2krYWfQaVON;rPnP0bcI(PG9ruqlfrya~>m)v{i1#)oqoP ze-USyq{kzf960e^1{-X7OA&1`9ICR?x^bch-JKX-1$9t+RXp8``ghfbDtvKe(H93H zZz_SaMj^CC$y84cZ1$~> zq=gJYdpGj0r_P>g8Wx-v7cv%y-5=toI-^=m)Qlu)HyIMurh%rwja8O}M-kjUWZXTK zlhlqS^QmQHY={>w-vR93T3rmWDcHgi?{NPB80YHp)*`q*vuDYX#XlkF78wxIVqM(tMnozP(K0)_Iwguot;&;V;{PqO5wMkUI$f09-C}} z{{XwBCTn?IJY(487rnz83rmq=xVT?xqi$D6F+V6@YQHVRlJP#x(^2*R09k5j0$eM5 zjSsx@V^z)K?WyDula9w=OQ@WjBll^0wbQJKhJY_OfyqQ`qqL2rVA~pw?h>wlmAhL< z=Sp$Pf{^p2*L;VM{5WnAezil$nBfxQKdns2o_DC55%-F^Xkf)kZ6PpA7xK4x&i?>$ z4YB*hNH4)+T~{ambfItVjW9^;mt5vXmy9-ch%hin*xwZ6)g}}8$4qo!A{{Skhb@tcp_W^aZCYb{6wh3Cc7V|vR zt0|RvaT3DYii+$+?aAzin0l#-i8@9-IQ3HUBowI*yk(IWQB%7VSn{q7HzCd&S#@P! zP>XD5&E(xCN->RT7e|T1c#GX^ij$XbIHlzob-(Wib!z<&rF3Ps%FxZYWcdzUHr(Q4 zRnFM0_C6qtA$fuk%PD-MUsOh&TOnCWJx9sW*3KKZW`!}sZ8GH+$9)JSOul!sQ@VY| z;^UjGwu9&1TdIxH^&(|0*|st(8ZsWB(eNU)BX5;4+EZHvnfPRIGrxs_9? ztSwO+Oy-Gh_C%NWRD1hOg2@*Rhh9&!T!QHA)!cR8X)3QG=YML~Qbkr6g%i}Rw5$78 zqonN){{SR+9nH{BWxik8RHOS(e7Eo*G@kJ))NGBS-J}m>hS>L88oeHfzkyNpWp3i; zk$IUgvVL3cufFLo?WpJHM~+IayxU*i8Yl0nN?xX&H8FZ??kBRQxYHGv+P&hd(dg^l zQj==TU2p44L~ZJ`K>jBAmZK!?F5K9!2p@as-Blq>A|9IOj^;huyt-^X#Rk+SV#$@Y zLL&Sc5)Z~lzL+Mp)Hg!LBxesV>{ufHbQHzK=7~rP8hF;JximgbTdn-k-By;?6)QE= zt=tpJwd+X@XwC)C> z0VFoBk`DC&%b8Q?s}W;HQa1LfxHJW}9-5g9T7xT&ER&B(anCB@dtT_(-xB5H8W!x~ z31mV;kry^2jZm5wX_`x1-{jsP&temlTU@2TYDRqm&-hq%ZaaaHk@l`dRePAqI@mD z_(*ZXVbCWR1VbfMCFbl+-j2lHovT?NZ@NY{nb&h&hU{mf-5qjr1T)?&5g93l??-Q= z<4T^7v@LlZzay<8&TG5lFxxM?%|lw+R-(RXxYDmg7iJ%8)lqwX6i>eMuRzcu9hL7r(y5g||M>@YBs*6kzLpGkKE4$KqSYL5GR4 z#Uyxx6v;R5#~i%L*;J0qH8bqrZd!8g>ZH_g$W_E$>*?(JH?Ba^=Q?;cZ@f!O3%CM*jdl zNQ+LL*TpbYWn19P66)eJR+n12R^1T=bm)6n8SsMARtr9uiFdbMSf29}5vRx`w(6j` z83GO-aZC&CRb`~Ms1d9hj>5|h8y7f=(h_d$HBm%W__4`9H5cun+dcbCGQ?X%Yh14x z55gm&2QKCAr_oj3h>c*kT^LuIn)HAyhC-eq(TeXC6dym${;gKpBW*S-#3*OPeN2%XQzp8J6VJIk z+t!QiTJp~8JWm?Q^6DIG;aZKY#8OE~KBKa>kz`fGo@w!FL%91@F&0CPk06m_%viQX zKm&YkL&oBCk#f<-kJN!_w@ZD==>SF6!5uc##Z_|bUEd{C(zL&{ZYcw8J-V5ok0M?? zq8?M_`NL8S+y@mRDtT(uvZOXlhNQ8~eZd|cP2=;6?&DKQDsfDd*$s|MLOC~)2$N-w zJj4CvcJg=?>Bd79b!GZUR%KsCtXon?kmp-4I_QpPRi~&V@*UUw`;5H zunpyk2A+EPWYbj#neZJIV^rsCUNHhWhkK5OdZ@TPcL z@%13AzflKQEfTrfRKw1}h|9IL}XY z%FSCq*U4~@@0V2(^CB9q`fGm|O(SBlmTO#>YT{y@F|nOzGgDgU%(j#!W6N#TSDh48 zK_o9W{{V`uS&u2sl0sah;eF~ccIf2GlLWk?CCW50IvEqsI3wQ3YURb`x2(DC@+$EQ zwtUIBE*oj(9b5rEI*;(c*3Id*VQHohCT$V7yN`I+&-!TVV$HPgQG z3d+2N+{9+EvaX-fQ>3dNW|fp)_cs3k=Ry|iP@YRns&2C40$&ph*_@JB3B)I ziY4!HKU8a}cGz$tlaptDFWKfH=Jsi_>bVz3I4u!AJw8BjXnwU1T`q$wP(iY`BtlOM zk2qe)^B2{nhf0b*TbLg4q~Zt}4m1VZDJaM+ea$gL2OCsfNqX}OkRrnU++PnV% z*=l2CmgY8H_wuiC;4Js0qzf%Vqx9&ObVvIjf>0eA2S{qgqL)vi6i) zP+zD&Tj6C%E%4#%X&Je0E?qUQJ`ap2(zP8u zp-ykkv=^(dLv)%;naGfSbtB%i7^NyY2#I5lyHB9Qkm+%a&YqH`RUaWG9uVdBsv<;= zV()WN4f0bw7P)bM@{}a+F;fXcuNq%fa!-*heqz4LSLQsl4tZQ6CGDp5Z0EU)k z3+|R6$rsh2@I-wz>{f^^x(@ktEMwc=FPt=IWAJy?cVWq6iv!-~q!g<=wbljf){e1w zBua6?DYCwB<5otT@K@8Z1WEQq^2@zaDuIichay3FCfp+VwL56p<^J%yquZTv7t_0P zR`s!Lb52@WQIBh9AughBi{WX@Y%viC;@8a@Y(T>nF9uy(bi<0+#vgjDYX^awgd$ff zkz@Y=3~O376}G&Xx0lYP$}3`75YdPf9gD#aNXN&@ zcynLVRwpkLE*X7V=ouWb2Fnumsj&hoCYiNb;<(IM5u-4uqCnFKD7(|*sw(euji4+p z68UG9R2XFow9D1LxE&#!DBQ$_ow; zD7WonNh56n!`T+iwI!NI^{J_lU(|Dl@ejGfb-c3W5^!6V67qxdXva!lq0ECBuzDZD4YDPSmjqoGxMUl&-sC6_} z<>~!EPfl=xHNsIulNT3=u@dR#?{jl<*Egr$Tr2anvyM38;ZP={(boYHdrxVZ4K>8_ zo-HV4x0WN)L(Pv>0A8-QsScz|yVKikqU?(HNP@o_FPgU2Ot;V#v0PYA`^B9v#UeJU z@t>f#)s3=!mg=29dyREvfKcCqt*P!M1^Os^V(`}Wv`Fkas7;7CNGb6O)~JaHW40n5 z%}Tqu5m2u+wws$K`wI7GjEgS=S5Z3)Bu-KvkNG7q+CG^*Nv5OEcGyXAhy%oqL%7jB zDckDbR$2(PD0prTBwVCC*tk5@t<(w8YEcn$a=1n8_kPV)#}axO5!*8@67g*Em?8fF z+1&mbMU#sZxh|5U5cH22hwoO?gE1~JnOSk!c!!j@jR|voKDuc#!)Ba_kJ%wAB6i@~ z2kfcnt0L|r$(4+MaCylLqi)|GoY0nBHyrO##+A^23e1F$z6jfhLM(%ujZ}Ro9h=kR zPc>W@jgV;zyIY3D#8(x1nKJ$AwI86ZZg{&<#m5DY7{?K0cn=pamy|4V$GllmvaEbb zLi7|+Yy%k00;(RkKy`a{TscLZHt3N7LN)2HEsh{lXbTni@8DTajGxN zyLwcU9tPV=B_p?LEfseRkp)ubfQ9ld?%M3U54POM`5h^2yAEIaoyDd7^_~1XWUbVX z&0F-6{i!b&+)LSfCNx~ffxy)EfiXl-NfDK`?D)u1jjcOu(^mw2S-k1V2PIJ*RFc>! z_U}#m%d@ts8;64y*BJ!~+aaFfns%rwZv1|QRiQcl4%&boFNAhAs9PcIB}-DT?(@{G zqf1T6GcWy#>m%+;8otr0Ew9_g99VGOjl>tn#U&O(L(|QXdsIUoRg71K!W)wUBs{)6 zLeXRH``@itMpnjQ+8oAClhoE?66PU?A-sK(<@>bRn_E_=1f%NhQ)3?XF39EkU0=gh zUrp2&mBN4bLDJ&iE(g&VS7DgTvIKGQhl$*MoT@VEaw5#Kp$!o7>Af3iwq0&#D&)fm zkG(1Mam6L=R8mhq?fbX!mli@7U7mn)%uRECyk>}e^%VNN~(*r zq9?kjt5)hbW1e#vAN$I!XKIUfYxqbwzRP{9T^^TVCO*YHmQ__GqjJ23?^iuTl_g}4 zaMX{s)w0ztS~ysXgcvu!=PzgJsrgziRT? zoQ%SQ0MQ7FZbBsU(@Ru5Z>>vV>Tie3_D84d zP*mM*3eR|y_10@`u*)b7gKNxWy`-t~d8i*y?e)gcnA5NlAP?o9PH)Xf1@3o7V8mBv z2*q!9=G!Gi-|-?PPeLXeCB}h&vE)evQS4i;^^AS>5ud2Glsw2b1L|%I5CH2=> z7qec#y_x{^aY}akHTlhrI3BNMjeD!vuVO^B0RA=M=&kVE2FK?OaU$g+So4?Jroh{T zJfoE;5Z)ljTO!JUI%(1p<*EB@rpU23I7RJLWKBDCnxnyl`?-59lF=QvWE#%wgt6=r z)dhDhV~_2iQhFth*FCXvRmrjUNU0mM)|*^@8xjTb_iA3(;K{D0IM_biOKKYaS@b0z zv%Zh)UI(pY+THJF_pP|0xjxeW04e~@ljS)%;>2>x*{61`6L-8SBxp`Ji{9PKTVG9h zKWU2)61MS0C`=eUE6}6b-kVpGNPC_pTlaz7CmW)Iwix-0piWrGJ2H|{bxhU$i347Co3x8 zwmikuLqG^K7xH~nxHZd`IeWg%bU2Z7cdNt1;~0^353@`MJS=^h>90v3jv`##)#NU2 zY)FQ?66VVuUiLHr=0u3H5ge&x66F%ho2x2pxe0Rg<6ca7WrugXQ~ zUyW08F6H$ih3&I!I>T8%3RThty2%NrNg`QXca9MOI~4>G5i;G_rMXppL_|OU4GpJS z)jELALwQ7x&Ta18i@(gDPUZa7YF~8B4M@kdC^vjQFWVS~O71lB$TuBVg-RPUq7kVNpFVpGt zr6Vf#r~bah@phuhz*wWo8F) zCq1fn2IX{zibt6tfllZ+L;ETWR1H6~?UY{h)~km;Y@1Wj8bis}K~d$%cVp)BcKZEv zL$@huZ`q~gU`tezirOg$X5m)_7kALLmg4OzRNHqDT-`Aq>zv=VpU&GY;_qgYZJsT> z!_Dx^mY4>5BUMY&H@FvGJE|N+M@3D0Lw{{vTVT6ghRq890H?0+w%bjKON`O3Dh0b0 zQBRk0{U04>fVm58n`qxIJihH{p9hJ9v|BscR%30t?X{N+5*}wF7vluIta_?CsOl2R z)c1Ra%wcLhJvq3NodR)N+=irpN+c?Zp0r<~DzLQpPU@D`5= z*=D@GXse`$1BW1jD~@+axrDxJfecUrjZzm;=8y}qbM zZ-tWLROOvOSoKq=hKFauGg&r+o!lND_|ckQ;jW(Wi@nHg7F;I1KMj$N>BLE6@OyYbE>xD&E>KebS{-1USe(;34Q9W zW?i!CDjd2@<;N9+6UMkEue$N6ZkLy!Q zr18kPhM)>YMPAVnc&X-W#H@`|d6B%Fxk^%1Oaxve{q-JB~FildC^Cch@89CeRW6d`rd=9T#V|XBw{2z>+0;SF+WmitNy*!k+ke1 zXHKysQqa$ygt(W^dSn{B2$VW*<!H zO#lc3GVV_i5fiQ#ox)x8W2UxqCgkF(HcE-t=oI>DUPS#nq_GllyR4t=VV9C0oL%Z4 zy=@}ncxc!>V#3{s#gEDepCtejOTB4*5Gcz+H|vvslWeh!CGQt{f^{apylTP_n;{Y~ z7VwJNV9I?P=xn-|9IKiJENzqPNjf&~CQM9K$6DOEuWs~R=}zU?Z#PL#O>w~^kYw>L z1KOmjtNUmM(8Kd4vea}kaaVUDZ{1qL-K1g+2*_4>nIl4r+McRv3S}RH9QG@fwU<^k zp;Kf{==6i`5rcW+#m%}*ih zA`3!7MBg!fGKNqxklNnGOA^MFYR4(pF$P>BFLt&hO^jKhz2#YN4`!zs*|m-+i;Y7P zAudrGF<>!_Eem#-NFx=&$g$=cA+~Hv{{TM{@lpXI<;u2nJvRZjMWuu(zqfLxRUz{?lL|xj@tiWg4B6p-poyTaJZ}ELx z*<;N{AXB!l2a5z&Ra}g37D!#Ix{PhxD$~$d6=u-FDKF&*FWW|q4-~wwFE`E_P#O^- z$a2ftro=HLU}>0enZP49%K5QmP-*bR5hk*1esObCmFeXY$hv>?S4cu4@p(){d1=7p zn(~~Y*%0)&BTA!9mLzTwWL(w479u5%SyHUEj(kDK=F$P=X|WM5{{XUG>Qp#*yO4=%%U*Px7X-2DDoAl6As$?Zn7x|stAsga z^{4?vjwI$Gmp8Oo6)J=zE%RlMRUiT`k>gz}rNdFj%9SWEs1`(ckJ_bBxJ{99JgD@K zc90B9BFh|xy1+@iaxAju=BC62c}Rq@#EQI_k5H%=puH8l8QAl_NDBf_)YOo5aw>$~Lr)Bz?-h=~}Nx+6;@ z&l7Kh;^O_m%t%MsA9%Z?`o-PGnc3l; z=}NPErIsU*^HVsMc#!5ulGI^tL#T*u52k<}Cl4Hp-7H3w0UjnsSdj9LHOV?TVk6zh z9D1w7?pU{vcbc39&|4S2X!B)J_q#2xQQDlhU1CPMM0cr~0E9dKnq4s#@g7m?A{zG) zM6%|`=FkNduY3^n#rKUbPdy^VAX-#Nt)d$)5+`S?9sa%RShh+z5>w1cXxe?UR?JJz zB#XII^Vie|Wn#`lm$Q{)Zwf7v9iLIj)pmvm?K4j)AlNUEchXDo7h>=As`}G>ETCui zgylKTRkc^>?c2@I2PZq5LE^e{gwGB0Y3w1hh596^GY?;_vyUwMo>C zuHfya4o3-!dXW1ve!_`E4I@(k1 zNJ!HjE->_w7vni``(E`8Yt16%mRU}Y>yquyuJm}(HT18?>8V#C1;=amw%k%O z5-*c>@1mt*?fUr0^ETNHRrvE)u%;Bb;dn>RYw@visaW0;?c?QPl~A3^_e}t4IGz&r zvM(zg!77)D;UlNWM~LHt^=m0|(V&@E=;&l-h%`ARxXIBgF zRrRxLTQ?0OWg_AU3uNIq*>`fshC|-SX}~(t))=vFmm8GHP7IkBViFSNXk&==5q$2P zd!tmgh<8oN+7OG-dbqib~ zc_KpJ*%mtDi}SsIO$ra9yQMw&ADrJEe&R(+Fu^=nh?mWyJuS0|SCS!tBqJv~(-r~{fS zXHaY{i8TybFKC91yg+^0m4j^}ZU=;Mj{%m89GlCl5v9^U2z-+x zk#zTIR!QuQ!NshH9$aZr5h2LlA;+g0jA15+$>L69?u~Yl#E3^8C+hOoTrVleIFC=g zT_58zCkvL84w1x&5t)o(yCl7-8|hTNMteG z%PwK|O1cM2F6nU#y-EV)MALaTO#Glm`b^fIXEI%k2YCSsv_wYBH?0L zb9aWA4=9dDlbK>%xerLvs#!P9^>P|XIC*gZOWGe5oqIf!|NFuaxnT(^A#pveBnoe0P zhQZQ4jNFH!f+I`im&A$iEvxb3ICe7`V;L_2lu19cW2j_YnuSzG%_0mM1|)A46%#g3 zC%8G&%ow&#VW@fp+C}D19>A>-C;)M$w<;N2G-)M9hkFEF1cxha`kHb*H>IzR6nN%S ztRNIpHT_o6Ji_^p-x^T*41I2TyK+37vMJs609Be%?V#}7ElDx`5^2B^f%I^RtyrK# zSfJ�kPk$MiD{?t%d63tAG7%$VlHb;Z>Jhk_e8d7Lj4hX)HDT8=P5{c1Fm8YuwCf z>jZJY|7ehX_Vx<{TTc??E->TFA~iZi)l^2b3CPPOM)3gxF-F-8p;a=LM^$Oo>$0Xs z+fIP9nBKv`<X|k$js5$r(P?uB7YM#$62)_S{@uH_?~k`|X5L769Dh1r@73Ns!$3z8X5cgDZ*OCn7B0e* zoqm+!_QUs6El;g{Z<Rm+~655Asa9hfb(3 zc%*%-A{M?@Gh9_=<{J6teaK{kRG(F|UtuIwdE#z)n^tUQ064P4Pjcf5vDxI@tp&|4 z2`Pp|18Czcl3#P|!0BQHlTkvdJcs#yVEO*>3c=kk&CCbVJ$E)MM%yee1ZiW9i%~f) zQs)6?Rbf=$aW?C#Jg^29?Kq1oq%V}x!a{P%SPJz+RSMHltdG>h>$e<*ygs#8;bEH2 z&sTXXY<##B;C5u?_Omt~!QC1_#546``QSPm2ieOGccAD+mN$}`L9i7&QDs1B zy2Mmff;4cr2lkmh>MkxT+5$z(jDL0-ig%&^JsQ0U?h^N?1iZm!H~yFST4CTxQQb)d z+S2NZc~|)EqoZdpy+F(bCrNWYvk$8dNhvn;669??!H9r-Ah(U~V;sLL zc~llk1D|o$Jg>YOiAdJ)SuZ;Ag8uwPTq0#DhY_17XA6-oz1gfav`A1vFrbHxm}Z23l){ zki`B)zQ%Bw`Ed&@9;C^ER&U0rVf5KCN_Y*2>*PLdvF?#N>gME^Gg713btiRak^l?84SqLAo~|YR>EOiUmY08fVEh205xyuHggt0 z+ESv-5?0y>Vju+9qAYpSS4p`;L>+;yj{2BU*3xyOAMWRX-qeOdqAl9oPB&6>kUG87 z#U~c{@NHr>OAULu+}KGT64P3;?i+>-h)SM@y{DVbZt!!oD{}B37*`uEA2dc7G^Oh{ zxG9S=oiXDMs_7@yNPG@VPy>@h6KomKO@N+hS5IPVGK7e%u*ZrT&J~(*{Z&m~rWmR; zb`<;l;cSpB+$BUf%W^l68Ii$bL6lQK_o z;6NaB38BSY00BE{w|%wYVJxHSEvFw$82%S)`QzzwuwAl-?)?7%eLAxj?Bq4#|w#&FCwAc5C&4y;_PKZY2nVO>_ z%}PIrcAj+!+ZcTW2SWYsOU%J~XfU5>%T@=8QPiFLDi?{18bZ6hoEzflSs~=7%L5qE zG{1xdN=qk3-pV6y&IP4%py!9rd}T9ke&$e7bmOnTshcEUjlE6%u9xJiw`?D8e1=Xk znKpfw7=C}UOSp%J{%|Qx9*RBA6sl*&o#UVd!&Lv?8y-JSD)x_kZu7oubI5bKFo^f9 zGIPBl@Iip77c=mQ)&1{PMpx$3_kVzacWwM?8W_om%bxgir{Zb8zmm|&-TQ7i$U^0k{9#7)@;<&Tt&Vli0Ecb5InO3SM1Z8XG+hVs zb^q`6r33JH30)gS=4P~>pormGi;}mn?U>NpMHGBlji{FbX^}c zfR#>bOjqA>&xw=sTps;8E_&?Nd5hsMlyDC`Sd5(gyF{`Vp8fCS)s~?~$#GCk2o0H8 zEGECHDjIp(_U3Ge3}?0`)L~a6qV0`ZZNYZWSB;!Do^X_K_G=vhTOBd04Lud__%Q3x zxpyXviD!kfCbIA@TxThkuyOz<56jdjjr6eF+cbFS8(x;9g6G4LORjnPCO|FZQ3S#e zZ|Q#X(}~+d1-r8cp&I|}=n$`M%;-?XsicimD~ilW>P736Gn(HfbED6^@AB1Efna3w8U5=N#q_`BAxo)UaTM_&2jO4%=`2NKsfh?ZRFO43N6!mCRBiUB;7^_#z$zO{ zE_U%Ge8GwU&ZPKme(njCGnBajv_U%|6cTk!0AGBcDF!BT%A(n|)l>e(=`iSm6k7A+ zI*FmdGPA6APM1MbC`731;&}{1s&JMOMp~>U2_u{Q#yt!P4iES5F$ozW=QnL^`$l+4 zq7x}`6gnk6nn}bA@F201&HmYNt{!WvOl$%R4`2$`5_{l3Z3Y6|6xvZ|j80u^9}|RqNt9f%p7AxP`-1cS0{YI#qU%~bYTVP zViidG>r`{qABitQG3GP*Y5y0x1zGm9dC0$?wQ7RDJw1J9HBv2s6~ky{~mPo>CMO9hGx?O8=HAiWM1IUo=lB=w~5ka z{j1mAqrT<5g`YOL`XCa_9NZYX{05(+v3JKyZbrpvF&syjh8#1G2hYMWpMPc2!1g~v zJ?{T(y{bN=_j0W(=~m#{ug+fQ-~5_jU7x_tjL^)Dh*1(oQ7Payx{!a1tkf}k@z)VW zs)wTaOpBCsy@r%z6&ow>F*&KcKL*;Ad~9@XlE(hm+CI8HFuL#R-Z}Cs9$6|f3*dZy z%Yua4_m-~-oAF!;#|2aoLby@QQMHQ&4p3YvLS}UzI=NUh5cU)&6wZyUimk}JO7Ly0 zL)I+mVa?gf`QDH1fof4ca`(dvp)6q6>pJ^2-V~1kCAj8g_g@^p0aqHr{221g*!8gh zg49@xqW4OTOn%9Y!MF27@Nf=tI;eW34g-SGsUDc&9=*s4y|MWfIWzc9yYZf7GeHMA zL1{=k-j%0D>_{%QXv35rF4jtHhy9vbG+|ohf7t6`Z3hU=VMSIxOlp$6=Ghd*Zu{J< zjpNr_bxRs3Q))$R3f4TIhL1}0jdeQ!lWuTT?HUhh&wu8BQ>7NKxx75}^{G}0;;gda zvR^>P8aDxJ(X#FXQ0g@@wV2w98G4V2w@^dB`ZulWSc*`$UDgWLK}n*Y**D`&)la`- zs4eEM=S?iaKhQ#mB3$XX>98o#_?s=i_rge2zHwlyRW|uvj?PDCso|-fgtn%lpn^lg z0YTc%Ofr8-D&#_DWHy@ZH|o(nZlLrIZx;rhOaf#)7`aicO8b+x*7`!CXkpP62$&EpZWyR8LjPJXiK!64XLQw~DASDT|n zx!|TPhEL5`kTI8`WJe{mh4;Pblgu|d;icn+9{AbTANBKm48{hVCc-(n-0edvQ-Y;R%kdKQrBdo+4PA9>vwqXT#E0v4`>&luAMIp6P!ZEV zKh?5<8?L*VR(b7-wCd}zsjG(#p{Wd+qm{ia{b+hsFD+%plM%yIE9*Ek(97CtluR0s zmHf5%mgS4PKO%6Q5~>uQNqu;j=E?LxDE6jDm$--kfMohFn@o_P94B7sUGEz3Ghp4j z8?B=wf2J9|${0Btb;w9`8sXarz1Z`|9)b9a$RL6%;3jl>Y>tu!cfV0QfVB0vvum_O z;z8b{5eUT}>yAx7B#ZB=&yMKHu%}D+(!*baJ4nDSZ21?-p!O@m1!Rq1wh*4C2Y9J| z+fF1#D+x2S1Q(!0BoW9jFl)RPSep2Pui$g1PbbV2jhSDAL?x^FX9L0eNQ=nAl8>Gg zp{Bx3=CimHlsU@7mwMq~vhTFn-_@_X-F!H$T0Gre!b{Pb199wP!A6%Ka23XD?dgb_TKt*<$N&2rySGb@PFW>#7DsbuX!ocW@NBR~G(C)l zP_a;Jjv|!+`;{@RXlyf>ZiMFI4CLQ<{MXN-* zuFA~L-5w2#E)(M(4Gf!g-t^Kg0}{0~)nW@*9OURZN<7Nu@vy`mcwnvO@Axz-xm`T>rfs;h*OqLKOTKep zk~wTttMcTVtCn1E#v^Y0EBvc)Y=d{xq)KIo$EB$AZ=oT_>;L#oUU~HR$7gSQOV^G@ zO!R5he6|3|Rk*k^{VMorIPMy5HLag@`9Phm* z7rM(2*HufQY&=&3pVz_#3nHbl9bGX}?`J4h=X_h!caZ351V`K84T@VWTG39AO9Lae zN#u-E#$d;N+K|InM^czAw=W*dS!Y4I9TPPgoV46FW#p57%`gymaWSz?2zW9gQODQp zPPwk^pU)V(=zG4pe~fpMHX8Hx!f^l{lcmBWE4T^@@r~-`B5fbz^rmILhX@1;e!spL zXJA9;K!{wTfcK0o_8kriFM5u@*W)q^s)7erpR?qhBZJF7Mfefb%zi`Tea_{W=%dDYsUP#T>C!% z9vchwP>Y#z9w-~J?`jvw@<4)xf@3LR*XB-o=^ZgB$w3W9&KLq}bq~_gvtWW>Y}^Zp z6JBnLQ_#E{wLC?j8c6`T)29wFN8C!VxUa!*%oiDMiXrhdR1T7s*kd+TzUI@4KS)9C zIByKPgQ5NALai0g^bZB%7d#p}>_v7HY!o@xX_6B7v%T1vKMmsp%O9kCLhC0dr`x#o z7fVzzb!pI`YFEHLp)Lw}#JXh~Vv(1tFP(tAF_%d7?$)-f-Te=3oX7$#c4uZu-$lym z*KX!4{STl!YP_eqT{#)v)el<;u3$Uwr+@7+{8x9XsMM3)E-TVdyzHn~!{?**Ldbqd zFGC~lN5GovY6ueq<3_bD5_F2G9|DqL3Y8ANAPXN#hL_Tm;Os4R`da~USr4os)#Lg; z4(3`H!Xu@KuWbBQ4X@84IJ2!h$wDtl@?*THK|Oic7-;3wfxe6v z>pzb%z{6dV5AWoCcx4Qp>8Yh+ZW3}gM`HHpq7p#k>kElF4v&Znes9t8! ze<+ymZXI%=f7e}#g{-2m{ha1VX5DiGPMp@AUMJrdT0*@JXh7SO8*y^k2QaK)2&L_D zM$Pe($0u@3zr8dZ;Fs@@o4J4lqAthg3A7_mPV0GL2Iqg3iKV|iJwTW$?H;U+bqAl zY?Ygw6ws;{$%G`(rbSJ-i46g@+^A^tk`D2lK<{=g8f`IgHXJH428R#>3z3t%U~w*CY+Za^B1UJPOE?H1L89?| z-r^`FD!ciShzGAQTum(kiCW!@pbm@OqPyn3TF#IO%S0P_f}i|tsdWFL8&>n2zNK3Y zASGDoK^WA&g9zPJowMFi|Jwj?x<3maF=J?G=sJa=k?FG>aqB2*P-jfL|ET!O1!N_80M`g7y>QS9L$^*F)bJubmrCq`?Wqb0Lfg4D`DJ#E% zSxC4AQ;E&sJ-fek^S?&?%H@4y@oL8y``@z0an(%8@ht4HZfv5v!<39@R8}6BYG9MYLOf#GPBalSuPORLTm~;2Lw557} zssGeIiX`P_V9J))ljBb7NC#!BkSXFNFuRRto1seD^KGcbA?O3ww9z&^VP8P5_GC3+ zZX)GwOQO9GX=Y~sa|ybK`WD_CbBq-jnfKe^Z+cP1_>nBUjr;e5J zq}+&nmcU?coq$i}NMU(=r((GDvB$)IUXnr4*9+QdXTKdmTXhricA;3?iU_0`p(3uO z2;L<|lObX0NReMilG89Z%#j)X4>U8UN%33VtQz|NoMl z(e%xRE$95}%51gNlAl^~NXyHERjh-bEICS2s*1YaswK2N4ss37^6m~IF=FIx%yO|y zz~$en=5sHP(0nA~!F~=eJV==a?y$AAS;Lt>ethokyDhXSS|n<}9ig{M6|G(TaNu!$ z*_w{{^W|l*`ssF@52>qG4c(kj$&mLhB(|&Qv73EEcil&-um1%ZIWqI&POxcJxnkkq zy59c*a$4s}vevaQ*p@}!qc-lMe7M8qF67Za_XZaW4%`>!1>yy;MNVlB9ho1`kpm%v zpa|UJWrkejF|~GR){>>ooRWu|TVT$ofP<&FipE7q@zo9rn!=8{=vXhIzvT@urNU*7Jnb!LC4+eNnWBEiti+ zM5PRHAE_n_N|*r9u)+FSk8R;1DSf@#HJQFKj~S*64F*cf?3^o=883NDMLeRarChr!J}p2eYkQY;@)9ct5)A)>Lc3Ye&8b{!9J&23>{FnYeSG*Jqy z$vek==4{Bzk*Uz+9Nf>D3&C6EVZRr?m<6~_Gq-o^@7l##X-X!Uwez>i)87_<%s_om z%5gIo+o?A`lG%`fv=WYL?p8hj+BZCGqGPfGhKEwr=<^2S(vef81X#DhG|oau-o)*$ zpBZC%6*&)j`=DJWA59R&U&OhHQp)KOtyf06>Ow&vXwY>*xu%b~ff-s5{_0K|BaJJQ zw6Xl)G>{hP#pKeMQzf1lm>M|>JW*P8_wdV;P`j|GenXw%ws)dmwYljScOU!rAG3#| zA+lH8aKFj|g{kiolA)#g?Y2$364qZn`6ToK4?^E+(1r=IbIKDsgaFHjM?5WrQBq(T zN;IWqXxZSls%*g8OFdabqa{R@E$HvokhXNmKl%ZPO{V+DDfhke4TCf>-|0iBw^WLXxe%N{jxm;ORzSC*61kr?Rdsd z=^i{A0#&q;smu*&%TF#7U9)bZwW0UX2!R{iJ(_8#G?I2kGU1+ObNA8Acxe0RQ%nxBu$h3d>R#&arbrvzV5baS#M$_A839s~Py=`q}LR Z|89{~%P+df1@nuW_=^g?#tQ!%|389wKb8Oh literal 0 HcmV?d00001 diff --git a/test/lang/test_openai_backend.py b/test/lang/test_openai_backend.py index bb8b9f77e..90d097956 100644 --- a/test/lang/test_openai_backend.py +++ b/test/lang/test_openai_backend.py @@ -28,7 +28,7 @@ class TestOpenAIBackend(unittest.TestCase): if cls.backend is None: cls.backend = OpenAI("gpt-3.5-turbo-instruct") cls.chat_backend = OpenAI("gpt-3.5-turbo") - cls.chat_vision_backend = OpenAI("gpt-4-vision-preview") + cls.chat_vision_backend = OpenAI("gpt-4-turbo") def test_few_shot_qa(self): set_default_backend(self.backend) @@ -88,14 +88,3 @@ if __name__ == "__main__": # t = TestOpenAIBackend() # t.setUp() # t.test_few_shot_qa() - # t.test_mt_bench() - # t.test_select() - # t.test_decode_int() - # t.test_decode_json() - # t.test_expert_answer() - # t.test_tool_use() - # t.test_react() - # t.test_parallel_decoding() - # t.test_parallel_encoding() - # t.test_image_qa() - # t.test_stream() diff --git a/test/lang/test_openai_spec.py b/test/lang/test_openai_spec.py deleted file mode 100644 index fdda19127..000000000 --- a/test/lang/test_openai_spec.py +++ /dev/null @@ -1,68 +0,0 @@ -from sglang import OpenAI, function, gen, set_default_backend - - -@function() -def gen_character_default(s): - s += "Construct a character within the following format:\n" - s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n" - s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n") - s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n" - - -@function(api_num_spec_tokens=512) -def gen_character_spec(s): - s += "Construct a character within the following format:\n" - s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n" - s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n") - s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n" - - -@function(api_num_spec_tokens=512) -def gen_character_no_stop(s): - s += "Construct a character within the following format:\n" - s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n" - s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday") - s += "\nJob:" + gen("job") + "\nWelcome.\n" - - -@function(api_num_spec_tokens=512) -def gen_character_multi_stop(s): - s += "Construct a character within the following format:\n" - s += ( - "Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n" - ) - s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + gen("name", stop=["\n", "###"]) - s += "###Birthday:" + gen("birthday", stop=["\n", "###"]) - s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n" - - -set_default_backend(OpenAI("gpt-3.5-turbo-instruct")) - -state = gen_character_default.run() -print(state.text()) - -print("=" * 60) - -state = gen_character_no_stop.run() - -print("name###", state["name"]) -print("birthday###:", state["birthday"]) -print("job###", state["job"]) - -print("=" * 60) - -state = gen_character_multi_stop.run() -print(state.text()) - -print("=" * 60) - -state = gen_character_spec.run() -print(state.text()) - -print("name###", state["name"]) -print("birthday###", state["birthday"]) -print("job###", state["job"]) diff --git a/test/srt/example_image.png b/test/srt/example_image.png new file mode 120000 index 000000000..c8a970edd --- /dev/null +++ b/test/srt/example_image.png @@ -0,0 +1 @@ +../lang/example_image.png \ No newline at end of file diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py index 04897b398..7e169f3e4 100644 --- a/test/srt/test_httpserver_decode.py +++ b/test/srt/test_httpserver_decode.py @@ -3,7 +3,6 @@ Usage: python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 python3 test_httpserver_decode.py - Output: The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo """ @@ -23,6 +22,7 @@ def test_decode(url, return_logprob, top_logprobs_num, return_text): "temperature": 0, "max_new_tokens": 32, }, + "stream": False, "return_logprob": return_logprob, "top_logprobs_num": top_logprobs_num, "return_text_in_logprobs": return_text, diff --git a/test/srt/test_httpserver_decode_stream.py b/test/srt/test_httpserver_decode_stream.py index 7c2b5da1e..38f090b7d 100644 --- a/test/srt/test_httpserver_decode_stream.py +++ b/test/srt/test_httpserver_decode_stream.py @@ -26,6 +26,7 @@ def test_decode_stream(url, return_logprob, top_logprobs_num): "return_logprob": return_logprob, "top_logprobs_num": top_logprobs_num, "return_text_in_logprobs": True, + "logprob_start_len": 0, }, stream=True, ) diff --git a/test/srt/test_httpserver_llava.py b/test/srt/test_httpserver_llava.py index 0f6571b45..e3cf1b799 100644 --- a/test/srt/test_httpserver_llava.py +++ b/test/srt/test_httpserver_llava.py @@ -34,7 +34,7 @@ async def test_concurrent(args): url + "/generate", { "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", - "image_data": "test_image.png", + "image_data": "example_image.png", "sampling_params": { "temperature": 0, "max_new_tokens": 16, @@ -55,7 +55,7 @@ def test_streaming(args): url + "/generate", json={ "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", - "image_data": "test_image.png", + "image_data": "example_image.png", "sampling_params": { "temperature": 0, "max_new_tokens": 128, diff --git a/test/srt/test_httpserver_reuse.py b/test/srt/test_httpserver_reuse.py index c3f786589..36804e4b7 100644 --- a/test/srt/test_httpserver_reuse.py +++ b/test/srt/test_httpserver_reuse.py @@ -6,10 +6,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington, """ import argparse -import time import requests + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="http://127.0.0.1") diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 4cc50af85..569eaccb7 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -163,7 +163,7 @@ def test_regex(args): regex = ( r"""\{\n""" + r""" "name": "[\w]+",\n""" - + r""" "population": "[\w\d\s]+"\n""" + + r""" "population": [\w\d\s]+\n""" + r"""\}""" )