From de167cf5faff8977e9decf573c1a313b146cc48a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 10 May 2025 21:54:46 -0700 Subject: [PATCH] Fix request abortion (#6184) --- .github/workflows/pr-test.yml | 4 +- python/sglang/srt/entrypoints/http_server.py | 11 +++ python/sglang/srt/managers/schedule_batch.py | 45 +++++++++---- python/sglang/srt/managers/scheduler.py | 67 ++++++++----------- .../scheduler_output_processor_mixin.py | 35 ++++++---- .../sglang/srt/managers/tokenizer_manager.py | 47 ++++++++++--- python/sglang/srt/sampling/sampling_params.py | 2 + python/sglang/test/send_one.py | 16 ++++- test/srt/test_bench_serving.py | 2 +- test/srt/test_flashmla.py | 3 +- 10 files changed, 148 insertions(+), 84 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 45759c0be..f910d3a65 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -56,7 +56,7 @@ jobs: strategy: fail-fast: false matrix: - part: [0, 1, 2, 3, 4, 5, 6, 7] + part: [0, 1, 2, 3, 4, 5, 6, 7, 8] steps: - name: Checkout code uses: actions/checkout@v4 @@ -69,7 +69,7 @@ jobs: timeout-minutes: 30 run: | cd test/srt - python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 8 + python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 9 unit-test-backend-2-gpu: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 1ac29e8d5..6e4a1f3f5 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import ( from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( + AbortReq, CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, @@ -539,6 +540,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request): return Response(status_code=200) +@app.post("/abort_request") +async def abort_request(obj: AbortReq, request: Request): + """Abort a request.""" + try: + _global_state.tokenizer_manager.abort_request(rid=obj.rid) + return Response(status_code=200) + except Exception as e: + return _create_error_response(e) + + @app.post("/parse_function_call") async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request): """ diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ee9f40719..1a5865830 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,8 +1,5 @@ from __future__ import annotations -import hashlib -from enum import Enum, auto - # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch It will be transformed from CPU scheduler to GPU model runner. - ForwardBatch is managed by `model_runner.py::ModelRunner`. It contains low-level tensor data. Most of the data consists of GPU tensors. + +TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future. """ import copy import dataclasses +import hashlib import logging import threading +from enum import Enum, auto from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union import numpy as np @@ -134,9 +135,9 @@ class FINISH_LENGTH(BaseFinishReason): class FINISH_ABORT(BaseFinishReason): - def __init__(self, message="Unknown error", status_code=None, err_type=None): + def __init__(self, message=None, status_code=None, err_type=None): super().__init__(is_error=True) - self.message = message + self.message = message or "Aborted" self.status_code = status_code self.err_type = err_type @@ -441,11 +442,13 @@ class Req: # Check finish self.tokenizer = None self.finished_reason = None + # Whether this request has finished output + self.finished_output = None # If we want to abort the request in the middle of the event loop, set this to true # Note: We should never set finished_reason in the middle, the req will get filtered and never respond self.to_abort = False # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop - self.to_abort_message: str = "Unknown error" + self.to_abort_message: str = None self.stream = stream self.eos_token_ids = eos_token_ids @@ -546,8 +549,6 @@ class Req: self.bootstrap_room: Optional[int] = bootstrap_room self.disagg_kv_sender: Optional[BaseKVSender] = None - # used for warmup because we don't have a pair yet when init - self.skip_kv_transfer: bool = False # the start index of the sent kv cache # We want to send it chunk by chunk for chunked prefill. # After every chunk forward, we do the following: @@ -555,15 +556,15 @@ class Req: # start_send_idx = len(req.fill_ids) self.start_send_idx: int = 0 - self.metadata_buffer_index: int = -1 - # The first output_id transferred from prefill instance. - self.transferred_output_id: Optional[int] = None - # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap # This is because kv is not ready in `process_prefill_chunk`. # We use `tmp_end_idx` to store the end index of the kv cache to send. self.tmp_end_idx: int = -1 + self.metadata_buffer_index: int = -1 + # The first output_id transferred from prefill instance. + self.transferred_output_id: Optional[int] = None + @property def seqlen(self): return len(self.origin_input_ids) + len(self.output_ids) @@ -697,13 +698,29 @@ class Req: self.req_pool_idx = None self.already_computed = 0 + def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator): + token_indices = req_to_token_pool.req_to_token[ + self.req_pool_idx, : self.seqlen - 1 + ] + self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices) + + def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator): + token_indices = req_to_token_pool.req_to_token[ + self.req_pool_idx, : self.seqlen - 1 + ] + token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices) + del self.kv_cache_cpu + def __repr__(self): return ( f"Req(rid={self.rid}, " - f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})" + f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, " + f"{self.grammar=}, " + f"{self.sampling_params=})" ) +# Batch id bid = 0 @@ -1447,7 +1464,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and not self.reqs[i] in chunked_req_to_exclude + and self.reqs[i] not in chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 82f8c9ad8..601f3450a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -20,7 +20,6 @@ import signal import sys import threading import time -import warnings from collections import defaultdict, deque from concurrent import futures from dataclasses import dataclass @@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats -from sglang.srt.model_executor.forward_batch_info import ( - ForwardBatch, - ForwardMode, - PPProxyTensors, -) +from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -135,6 +130,7 @@ from sglang.srt.utils import ( broadcast_pyobj, configure_logger, crash_on_warnings, + disable_request_logging, get_bool_env_var, get_zmq_socket, kill_itself_when_parent_died, @@ -907,19 +903,6 @@ class Scheduler( fake_input_ids = [1] * seq_length recv_req.input_ids = fake_input_ids - # Handle custom logit processor passed to the request - custom_logit_processor = recv_req.custom_logit_processor - if ( - not self.server_args.enable_custom_logit_processor - and custom_logit_processor is not None - ): - logger.warning( - "The SGLang server is not configured to enable custom logit processor." - "The custom logit processor passed in will be ignored." - "Please set --enable-custom-logits-processor to enable this feature." - ) - custom_logit_processor = None - if recv_req.bootstrap_port is None: # Use default bootstrap port recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port @@ -935,7 +918,7 @@ class Scheduler( stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, - custom_logit_processor=custom_logit_processor, + custom_logit_processor=recv_req.custom_logit_processor, return_hidden_states=recv_req.return_hidden_states, eos_token_ids=self.model_config.hf_eos_token_id, bootstrap_host=recv_req.bootstrap_host, @@ -1246,9 +1229,7 @@ class Scheduler( f"{self.token_to_kv_pool_allocator.available_size()=}\n" f"{self.tree_cache.evictable_size()=}\n" ) - warnings.warn(msg) - if crash_on_warnings(): - raise ValueError(msg) + raise ValueError(msg) if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: msg = ( @@ -1256,9 +1237,7 @@ class Scheduler( f"available_size={len(self.req_to_token_pool.free_slots)}, " f"total_size={self.req_to_token_pool.size}\n" ) - warnings.warn(msg) - if crash_on_warnings(): - raise ValueError(msg) + raise ValueError(msg) if ( self.enable_metrics @@ -1774,24 +1753,27 @@ class Scheduler( if self.cur_batch is not None: if self.watchdog_last_forward_ct == self.forward_ct: if current > self.watchdog_last_time + self.watchdog_timeout: - logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") break else: self.watchdog_last_forward_ct = self.forward_ct self.watchdog_last_time = current time.sleep(self.watchdog_timeout // 2) - # Print batch size and memory pool info to check whether there are de-sync issues. - logger.error( - f"{self.cur_batch.batch_size()=}, " - f"{self.cur_batch.reqs=}, " - f"{self.token_to_kv_pool_allocator.available_size()=}, " - f"{self.tree_cache.evictable_size()=}, " - ) - # Wait for some time so that the parent process can print the error. + if not disable_request_logging(): + # Print batch size and memory pool info to check whether there are de-sync issues. + logger.error( + f"{self.cur_batch.batch_size()=}, " + f"{self.cur_batch.reqs=}, " + f"{self.token_to_kv_pool_allocator.available_size()=}, " + f"{self.tree_cache.evictable_size()=}, " + ) + pyspy_dump_schedulers() + logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") print(file=sys.stderr, flush=True) print(file=sys.stdout, flush=True) + + # Wait for some time so that the parent process can print the error. time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) @@ -1923,25 +1905,30 @@ class Scheduler( ) def abort_request(self, recv_req: AbortReq): + # TODO(lmzheng): abort the requests in the grammar queue. + # Delete requests in the waiting queue to_del = [] for i, req in enumerate(self.waiting_queue): if req.rid.startswith(recv_req.rid): to_del.append(i) - break # Sort in reverse order to avoid index issues when deleting - for i in sorted(to_del, reverse=True): + for i in reversed(to_del): req = self.waiting_queue.pop(i) + self.send_to_tokenizer.send_pyobj(AbortReq(req.rid)) logger.debug(f"Abort queued request. {req.rid=}") - return # Delete requests in the running batch - for req in self.running_batch.reqs: + if self.cur_batch is self.running_batch or self.cur_batch is None: + reqs = self.running_batch.reqs + else: + reqs = self.running_batch.reqs + self.cur_batch.reqs + + for req in reqs: if req.rid.startswith(recv_req.rid) and not req.finished(): logger.debug(f"Abort running request. {req.rid=}") req.to_abort = True - return def _pause_engine(self) -> Tuple[List[Req], int]: raise NotImplementedError() diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index c87c1b264..859a5520b 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -15,6 +15,8 @@ if TYPE_CHECKING: Scheduler, ) +DEFAULT_FORCE_STREAM_INTERVAL = 50 + class SchedulerOutputProcessorMixin: """ @@ -512,19 +514,26 @@ class SchedulerOutputProcessorMixin: if self.model_config.is_multimodal_gen and req.to_abort: continue - if ( - req.finished() - # If stream, follow the given stream_interval - or (req.stream and len(req.output_ids) % self.stream_interval == 0) - # If not stream, we still want to output some tokens to get the benefit of incremental decoding. - # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not - # always increase one-by-one. - or ( - not req.stream - and len(req.output_ids) % 50 == 0 - and not self.model_config.is_multimodal_gen - ) - ): + if req.finished(): + if req.finished_output: + # With the overlap schedule, a request will try to output twice and hit this line twice + # because of the one additional delayed token. This "continue" prevented the dummy output. + continue + req.finished_output = True + should_output = True + else: + if req.stream: + stream_interval = ( + req.sampling_params.stream_interval or self.stream_interval + ) + should_output = len(req.output_ids) % stream_interval == 0 + else: + should_output = ( + len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0 + and not self.model_config.is_multimodal_gen + ) + + if should_output: rids.append(req.rid) finished_reasons.append( req.finished_reason.to_json() if req.finished_reason else None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f5f0f4187..dea49e9be 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -288,6 +288,7 @@ class TokenizerManager: ), self._handle_batch_output, ), + (AbortReq, self._handle_abort_req), (OpenSessionReqOutput, self._handle_open_session_req_output), ( UpdateWeightFromDiskReqOutput, @@ -341,13 +342,14 @@ class TokenizerManager: ] ) + # For pd disaggregtion self.disaggregation_mode = DisaggregationMode( self.server_args.disaggregation_mode ) self.transfer_backend = TransferBackend( self.server_args.disaggregation_transfer_backend ) - # for disaggregtion, start kv boostrap server on prefill + # Start kv boostrap server on prefill if self.disaggregation_mode == DisaggregationMode.PREFILL: # only start bootstrap server on prefill tm kv_bootstrap_server_class = get_kv_class( @@ -482,6 +484,14 @@ class TokenizerManager: session_params = ( SessionParams(**obj.session_params) if obj.session_params else None ) + if ( + obj.custom_logit_processor + and not self.server_args.enable_custom_logit_processor + ): + raise ValueError( + "The server is not configured to enable custom logit processor. " + "Please set `--enable-custom-logits-processor` to enable this feature." + ) sampling_params = SamplingParams(**obj.sampling_params) sampling_params.normalize(self.tokenizer) @@ -570,9 +580,9 @@ class TokenizerManager: tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], created_time: Optional[float] = None, ): + self.send_to_scheduler.send_pyobj(tokenized_obj) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) self.rid_to_state[obj.rid] = state - self.send_to_scheduler.send_pyobj(tokenized_obj) async def _wait_one_response( self, @@ -587,10 +597,11 @@ class TokenizerManager: await asyncio.wait_for(state.event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): + # Abort the request for disconnected requests (non-streaming, waiting queue) self.abort_request(obj.rid) + # Use exception to kill the whole call stack and asyncio task raise ValueError( - "Request is disconnected from the client side. " - f"Abort request {obj.rid}" + f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}" ) continue @@ -605,7 +616,6 @@ class TokenizerManager: else: msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}" logger.info(msg) - del self.rid_to_state[obj.rid] # Check if this was an abort/error created by scheduler if isinstance(out["meta_info"].get("finish_reason"), dict): @@ -625,10 +635,11 @@ class TokenizerManager: yield out else: if request is not None and await request.is_disconnected(): + # Abort the request for disconnected requests (non-streaming, running) self.abort_request(obj.rid) + # Use exception to kill the whole call stack and asyncio task raise ValueError( - "Request is disconnected from the client side. " - f"Abort request {obj.rid}" + f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}" ) async def _handle_batch_request( @@ -728,7 +739,6 @@ class TokenizerManager: def abort_request(self, rid: str): if rid not in self.rid_to_state: return - del self.rid_to_state[rid] req = AbortReq(rid) self.send_to_scheduler.send_pyobj(req) @@ -964,7 +974,7 @@ class TokenizerManager: def create_abort_task(self, obj: GenerateReqInput): # Abort the request if the client is disconnected. async def abort_request(): - await asyncio.sleep(1) + await asyncio.sleep(2) if obj.is_single: self.abort_request(obj.rid) else: @@ -1035,6 +1045,9 @@ class TokenizerManager: for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: + logger.error( + f"Received output for {rid=} but the state was deleted in TokenizerManager." + ) continue # Build meta_info and return value @@ -1098,6 +1111,7 @@ class TokenizerManager: meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] state.finished_time = time.time() meta_info["e2e_latency"] = state.finished_time - state.created_time + del self.rid_to_state[rid] state.out_list.append(out_dict) state.event.set() @@ -1246,6 +1260,9 @@ class TokenizerManager: # Schedule the task to run in the background without awaiting it asyncio.create_task(asyncio.to_thread(background_task)) + def _handle_abort_req(self, recv_obj): + self.rid_to_state.pop(recv_obj.rid) + def _handle_open_session_req_output(self, recv_obj): self.session_futures[recv_obj.session_id].set_result( recv_obj.session_id if recv_obj.success else None @@ -1325,3 +1342,15 @@ class _Communicator(Generic[T]): self._result_values.append(recv_obj) if len(self._result_values) == self._fan_out: self._result_event.set() + + +# Note: request abort handling logic +# We should handle all of the following cases correctly. +# +# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state | +# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- | +# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req | +# | http | yes | running | background task | fast api | del in _handle_batch_output | +# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req | +# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output | +# diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 7c77a204f..4c505fe7a 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -50,6 +50,7 @@ class SamplingParams: spaces_between_special_tokens: bool = True, no_stop_trim: bool = False, custom_params: Optional[Dict[str, Any]] = None, + stream_interval: Optional[int] = None, ) -> None: self.max_new_tokens = max_new_tokens self.stop_strs = stop @@ -75,6 +76,7 @@ class SamplingParams: self.spaces_between_special_tokens = spaces_between_special_tokens self.no_stop_trim = no_stop_trim self.custom_params = custom_params + self.stream_interval = stream_interval # Process some special cases if 0 <= self.temperature < _SAMPLING_EPS: diff --git a/python/sglang/test/send_one.py b/python/sglang/test/send_one.py index 0ca4fe3a1..f3f542a34 100644 --- a/python/sglang/test/send_one.py +++ b/python/sglang/test/send_one.py @@ -27,6 +27,7 @@ class BenchArgs: "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:" ) image: bool = False + many_images: bool = False stream: bool = False @staticmethod @@ -48,6 +49,7 @@ class BenchArgs: parser.add_argument("--return-logprob", action="store_true") parser.add_argument("--prompt", type=str, default=BenchArgs.prompt) parser.add_argument("--image", action="store_true") + parser.add_argument("--many-images", action="store_true") parser.add_argument("--stream", action="store_true") @classmethod @@ -62,6 +64,17 @@ def send_one_prompt(args): "Human: Describe this image in a very short sentence.\n\nAssistant:" ) image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" + elif args.many_images: + args.prompt = ( + "Human: I have one reference image and many images." + "Describe their relationship in a very short sentence.\n\nAssistant:" + ) + image_data = [ + "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", + "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", + "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", + "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", + ] else: image_data = None @@ -74,9 +87,6 @@ def send_one_prompt(args): "Write in a format of json.\nAssistant:" ) json_schema = "$$ANY$$" - json_schema = ( - '{"type": "object", "properties": {"population": {"type": "integer"}}}' - ) else: json_schema = None diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index d86f2d81b..7aaa7ab7c 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -190,7 +190,7 @@ class TestBenchServing(CustomTestCase): f"### test_vlm_online_latency\n" f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n' ) - self.assertLess(res["median_e2e_latency_ms"], 16000) + self.assertLess(res["median_e2e_latency_ms"], 16500) if os.getenv("SGLANG_AMD_CI") == "1": self.assertLess(res["median_ttft_ms"], 150) # TODO: not set yet, need AMD machine diff --git a/test/srt/test_flashmla.py b/test/srt/test_flashmla.py index f546322a7..b80762465 100644 --- a/test/srt/test_flashmla.py +++ b/test/srt/test_flashmla.py @@ -3,7 +3,6 @@ Usage: python3 test/srt/test_flashmla.py """ -import os import unittest from types import SimpleNamespace @@ -61,7 +60,7 @@ class TestFlashMLAAttnBackend(unittest.TestCase): metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["accuracy"], 0.60) class TestFlashMLAAttnLatency(unittest.TestCase):