[Misc] Fix metrics, weight update lock, request logging (#2543)
This commit is contained in:
94
python/sglang/srt/aio_rwlock.py
Normal file
94
python/sglang/srt/aio_rwlock.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
class RWLock:
|
||||
"""
|
||||
A Read-Write Lock for asyncio:
|
||||
- Multiple readers can hold the lock in parallel if no writer holds it.
|
||||
- A writer has exclusive access.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._readers = 0 # How many readers currently hold the lock
|
||||
self._writer_active = False
|
||||
self._lock = asyncio.Lock() # Internal mutex to protect state
|
||||
# Conditions associated with _lock:
|
||||
self._readers_ok = asyncio.Condition(self._lock) # Notify blocked readers
|
||||
self._writers_ok = asyncio.Condition(self._lock) # Notify blocked writers
|
||||
|
||||
# Expose two async context-manager helpers:
|
||||
self.reader_lock = self._ReaderLock(self)
|
||||
self.writer_lock = self._WriterLock(self)
|
||||
|
||||
async def _acquire_reader(self):
|
||||
"""
|
||||
Wait until there is no active writer.
|
||||
Then increment the count of active readers.
|
||||
"""
|
||||
async with self._lock:
|
||||
# If a writer is active, wait until it's done.
|
||||
while self._writer_active:
|
||||
await self._readers_ok.wait()
|
||||
self._readers += 1
|
||||
|
||||
async def _release_reader(self):
|
||||
"""
|
||||
Decrement the count of active readers.
|
||||
If this was the last active reader, wake up a possible waiting writer.
|
||||
"""
|
||||
async with self._lock:
|
||||
self._readers -= 1
|
||||
# If no more readers, a writer could proceed.
|
||||
if self._readers == 0:
|
||||
self._writers_ok.notify()
|
||||
|
||||
async def _acquire_writer(self):
|
||||
"""
|
||||
Wait until there is no active writer and no active readers.
|
||||
Then mark a writer as active.
|
||||
"""
|
||||
async with self._lock:
|
||||
while self._writer_active or self._readers > 0:
|
||||
await self._writers_ok.wait()
|
||||
self._writer_active = True
|
||||
|
||||
async def _release_writer(self):
|
||||
"""
|
||||
Mark the writer as done and notify readers and writers.
|
||||
"""
|
||||
async with self._lock:
|
||||
self._writer_active = False
|
||||
# Allow any waiting readers to proceed:
|
||||
self._readers_ok.notify_all()
|
||||
# Allow next waiting writer to proceed:
|
||||
self._writers_ok.notify()
|
||||
|
||||
class _ReaderLock:
|
||||
"""
|
||||
A simple async context manager that acquires a reader lock
|
||||
on entering and releases it on exit.
|
||||
"""
|
||||
|
||||
def __init__(self, parent: "RWLock"):
|
||||
self._parent = parent
|
||||
|
||||
async def __aenter__(self):
|
||||
await self._parent._acquire_reader()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self._parent._release_reader()
|
||||
|
||||
class _WriterLock:
|
||||
"""
|
||||
A simple async context manager that acquires a writer lock
|
||||
on entering and releases it on exit.
|
||||
"""
|
||||
|
||||
def __init__(self, parent: "RWLock"):
|
||||
self._parent = parent
|
||||
|
||||
async def __aenter__(self):
|
||||
await self._parent._acquire_writer()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self._parent._release_writer()
|
||||
@@ -124,8 +124,12 @@ class ModelConfig:
|
||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
# Veirfy quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Multimodel attrs
|
||||
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
|
||||
@@ -18,11 +18,7 @@ import triton.language as tl
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_flashinfer_available,
|
||||
should_use_tensor_core,
|
||||
)
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -731,3 +727,51 @@ def create_flashinfer_kv_indices_triton(
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||
|
||||
|
||||
def should_use_tensor_core(
|
||||
kv_cache_dtype: torch.dtype,
|
||||
num_attention_heads: int,
|
||||
num_kv_heads: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine whether to use tensor cores for attention computation.
|
||||
|
||||
Args:
|
||||
kv_cache_dtype: Data type of the KV cache
|
||||
num_attention_heads: Number of attention heads
|
||||
num_kv_heads: Number of key/value heads
|
||||
|
||||
Returns:
|
||||
bool: Whether to use tensor cores
|
||||
"""
|
||||
# Try to use environment variable first
|
||||
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
|
||||
if env_override is not None:
|
||||
return env_override.lower() == "true"
|
||||
|
||||
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
||||
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
||||
try:
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# Calculate GQA group size
|
||||
gqa_group_size = num_attention_heads // num_kv_heads
|
||||
|
||||
# Determine based on dtype and GQA group size
|
||||
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||
return True
|
||||
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
|
||||
return gqa_group_size > 4
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -479,8 +479,22 @@ class Req:
|
||||
|
||||
return True
|
||||
|
||||
def reset_for_retract(self):
|
||||
self.prefix_indices = []
|
||||
self.last_node = None
|
||||
self.extend_input_len = 0
|
||||
self.is_retracted = True
|
||||
|
||||
# For incremental logprobs
|
||||
# TODO: Fix the `logprob_start_len`
|
||||
self.last_update_decode_tokens = 0
|
||||
self.logprob_start_len = 10**9
|
||||
|
||||
def __repr__(self):
|
||||
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
||||
return (
|
||||
f"rid(n={self.rid}, "
|
||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
||||
)
|
||||
|
||||
|
||||
bid = 0
|
||||
@@ -894,15 +908,7 @@ class ScheduleBatch:
|
||||
)
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
||||
|
||||
req.prefix_indices = []
|
||||
req.last_node = None
|
||||
req.extend_input_len = 0
|
||||
req.is_retracted = True
|
||||
|
||||
# For incremental logprobs
|
||||
req.last_update_decode_tokens = 0
|
||||
req.logprob_start_len = 10**9
|
||||
req.reset_for_retract()
|
||||
|
||||
self.filter_batch(keep_indices=sorted_indices)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import warnings
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -260,7 +260,7 @@ class Scheduler:
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
|
||||
# Session info
|
||||
self.sessions = {}
|
||||
self.sessions: Dict[str, Session] = {}
|
||||
|
||||
# Init chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
|
||||
@@ -22,7 +22,7 @@ import signal
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import uvloop
|
||||
@@ -30,6 +30,7 @@ import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from sglang.srt.aio_rwlock import RWLock
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.image_processor import (
|
||||
@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||
from sglang.srt.utils import (
|
||||
dataclass_to_string_truncated,
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
)
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -82,6 +87,9 @@ class ReqState:
|
||||
created_time: float
|
||||
first_token_time: Optional[float] = None
|
||||
|
||||
# For streaming output
|
||||
last_output_offset: int = 0
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
@@ -120,6 +128,7 @@ class TokenizerManager:
|
||||
|
||||
self.is_generation = self.model_config.is_generation
|
||||
self.context_len = self.model_config.context_len
|
||||
self.image_token_id = self.model_config.image_token_id
|
||||
|
||||
# Create image processor placeholder
|
||||
self.image_processor = get_dummy_image_processor()
|
||||
@@ -152,9 +161,12 @@ class TokenizerManager:
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
|
||||
# For update model weights
|
||||
self.model_update_lock = asyncio.Lock()
|
||||
self.model_update_result = None
|
||||
# The event to notify the weight sync is finished.
|
||||
self.model_update_lock = RWLock()
|
||||
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
||||
None
|
||||
)
|
||||
self.asyncio_tasks = set()
|
||||
|
||||
# For session info
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
@@ -181,9 +193,6 @@ class TokenizerManager:
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
while self.model_update_lock.locked():
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||
raise ValueError(
|
||||
"This model does not appear to be an embedding model by default. "
|
||||
@@ -191,17 +200,24 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
obj.normalize_batch_and_arguments()
|
||||
is_single = obj.is_single
|
||||
if is_single:
|
||||
tokenized_obj = await self._tokenize_one_request(obj)
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
async for response in self._wait_one_response(obj, request, created_time):
|
||||
yield response
|
||||
else:
|
||||
async for response in self._handle_batch_request(
|
||||
obj, request, created_time
|
||||
):
|
||||
yield response
|
||||
|
||||
if self.server_args.log_requests:
|
||||
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
|
||||
|
||||
async with self.model_update_lock.reader_lock:
|
||||
is_single = obj.is_single
|
||||
if is_single:
|
||||
tokenized_obj = await self._tokenize_one_request(obj)
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
async for response in self._wait_one_response(
|
||||
obj, request, created_time
|
||||
):
|
||||
yield response
|
||||
else:
|
||||
async for response in self._handle_batch_request(
|
||||
obj, request, created_time
|
||||
):
|
||||
yield response
|
||||
|
||||
async def _tokenize_one_request(
|
||||
self,
|
||||
@@ -215,7 +231,7 @@ class TokenizerManager:
|
||||
if not self.server_args.disable_radix_cache:
|
||||
raise ValueError(
|
||||
"input_embeds is provided while disable_radix_cache is False. "
|
||||
"Please add `--disable-radix-cach` when you launch the server "
|
||||
"Please add `--disable-radix-cache` when you launch the server "
|
||||
"if you want to use input_embeds as inputs."
|
||||
)
|
||||
input_embeds = obj.input_embeds
|
||||
@@ -301,8 +317,8 @@ class TokenizerManager:
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
if self.server_args.log_requests:
|
||||
# Log requests
|
||||
logger.info(f"in={obj}, out={out}")
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
||||
logger.info(msg)
|
||||
del self.rid_to_state[obj.rid]
|
||||
yield out
|
||||
break
|
||||
@@ -423,55 +439,52 @@ class TokenizerManager:
|
||||
self,
|
||||
obj: UpdateWeightFromDiskReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
) -> Tuple[bool, str]:
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
# default the load format to the server_args
|
||||
if obj.load_format is None:
|
||||
obj.load_format = self.server_args.load_format
|
||||
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
||||
|
||||
if not self.model_update_lock.locked():
|
||||
if True:
|
||||
# Hold the lock if it is not async. This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
return await self._wait_for_model_update_from_disk(obj)
|
||||
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
for i in range(3):
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
# FIXME: We add some sleep here to avoid some race conditions.
|
||||
# We can use a read-write lock as a better fix.
|
||||
await asyncio.sleep(0.01)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
) -> Tuple[bool, str, int]:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
if self.server_args.dp_size == 1:
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
self.served_model_name = obj.model_path
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
|
||||
if self.server_args.dp_size == 1:
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
|
||||
all_success = all([r.success for r in result])
|
||||
if all_success is True:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
all_message = [r.message for r in result]
|
||||
all_message = " | ".join(all_message)
|
||||
return all_success, all_message
|
||||
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
all_success = all([r.success for r in result])
|
||||
if all_success is True:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
all_message = [r.message for r in result]
|
||||
all_message = " | ".join(all_message)
|
||||
return all_success, all_message
|
||||
|
||||
async def init_weights_update_group(
|
||||
self,
|
||||
obj: InitWeightsUpdateGroupReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> bool:
|
||||
) -> Tuple[bool, str]:
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
@@ -487,25 +500,22 @@ class TokenizerManager:
|
||||
self,
|
||||
obj: UpdateWeightsFromDistributedReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
) -> Tuple[bool, str]:
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
if not self.model_update_lock.locked():
|
||||
async with self.model_update_lock:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.parameter_update_result = asyncio.Future()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be for update weights from distributed"
|
||||
result = await self.parameter_update_result
|
||||
return result.success, result.message
|
||||
else:
|
||||
logger.error("Another parameter update is in progress in tokenizer manager")
|
||||
return (
|
||||
False,
|
||||
"Another parameter update is in progress. Please try again later.",
|
||||
)
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.parameter_update_result: Awaitable[
|
||||
UpdateWeightsFromDistributedReqOutput
|
||||
] = asyncio.Future()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be for update weights from distributed"
|
||||
result = await self.parameter_update_result
|
||||
return result.success, result.message
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
@@ -564,11 +574,11 @@ class TokenizerManager:
|
||||
|
||||
self.to_create_loop = False
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(self.handle_loop())
|
||||
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
|
||||
|
||||
signal_handler = SignalHandler(self)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
||||
loop.create_task(self.sigterm_watchdog())
|
||||
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
while not self.gracefully_exit:
|
||||
|
||||
@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
device: str,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.layer_num = layer_num
|
||||
self._create_buffers()
|
||||
|
||||
def _create_buffers(self):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, head_num, head_dim),
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, head_num, head_dim),
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def _clear_buffers(self):
|
||||
del self.k_buffer
|
||||
del self.v_buffer
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id].view(self.dtype)
|
||||
@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||
|
||||
|
||||
class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
|
||||
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
|
||||
@@ -311,6 +311,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import ipaddress
|
||||
import itertools
|
||||
import json
|
||||
@@ -1238,49 +1239,37 @@ def cuda_device_count_stateless() -> int:
|
||||
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
|
||||
|
||||
def should_use_tensor_core(
|
||||
kv_cache_dtype: torch.dtype,
|
||||
num_attention_heads: int,
|
||||
num_kv_heads: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine whether to use tensor cores for attention computation.
|
||||
|
||||
Args:
|
||||
kv_cache_dtype: Data type of the KV cache
|
||||
num_attention_heads: Number of attention heads
|
||||
num_kv_heads: Number of key/value heads
|
||||
|
||||
Returns:
|
||||
bool: Whether to use tensor cores
|
||||
"""
|
||||
# Try to use environment variable first
|
||||
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
|
||||
if env_override is not None:
|
||||
return env_override.lower() == "true"
|
||||
|
||||
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
||||
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
||||
try:
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
):
|
||||
return True
|
||||
def dataclass_to_string_truncated(data, max_length=2048):
|
||||
if isinstance(data, str):
|
||||
if len(data) > max_length:
|
||||
half_length = max_length // 2
|
||||
return f'"{data[:half_length]} ... {data[-half_length:]}"'
|
||||
else:
|
||||
return False
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# Calculate GQA group size
|
||||
gqa_group_size = num_attention_heads // num_kv_heads
|
||||
|
||||
# Determine based on dtype and GQA group size
|
||||
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||
return True
|
||||
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
|
||||
return gqa_group_size > 4
|
||||
return f'"{data}"'
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if len(data) > max_length:
|
||||
half_length = max_length // 2
|
||||
return str(data[:half_length]) + " ... " + str(data[-half_length:])
|
||||
else:
|
||||
return str(data)
|
||||
elif isinstance(data, dict):
|
||||
return (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f"{k}: {dataclass_to_string_truncated(v, max_length)}"
|
||||
for k, v in data.items()
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
elif dataclasses.is_dataclass(data):
|
||||
fields = dataclasses.fields(data)
|
||||
return (
|
||||
f"{data.__class__.__name__}("
|
||||
+ ", ".join(
|
||||
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
|
||||
for f in fields
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
else:
|
||||
return False
|
||||
return str(data)
|
||||
|
||||
Reference in New Issue
Block a user