Sync from v0.13
This commit is contained in:
10
tests/vllm_test_utils/setup.py
Normal file
10
tests/vllm_test_utils/setup.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="vllm_test_utils",
|
||||
version="0.1",
|
||||
packages=["vllm_test_utils"],
|
||||
)
|
||||
11
tests/vllm_test_utils/vllm_test_utils/__init__.py
Normal file
11
tests/vllm_test_utils/vllm_test_utils/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
vllm_utils is a package for vLLM testing utilities.
|
||||
It does not import any vLLM modules.
|
||||
"""
|
||||
|
||||
from .blame import BlameResult, blame
|
||||
from .monitor import MonitoredValues, monitor
|
||||
|
||||
__all__ = ["blame", "BlameResult", "monitor", "MonitoredValues"]
|
||||
56
tests/vllm_test_utils/vllm_test_utils/blame.py
Normal file
56
tests/vllm_test_utils/vllm_test_utils/blame.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import sys
|
||||
import traceback
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BlameResult:
|
||||
found: bool = False
|
||||
trace_stack: str = ""
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def blame(func: Callable) -> Generator[BlameResult, None, None]:
|
||||
"""
|
||||
Trace the function calls to find the first function that satisfies the
|
||||
condition. The trace stack will be stored in the result.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
with blame(lambda: some_condition()) as result:
|
||||
# do something
|
||||
|
||||
if result.found:
|
||||
print(result.trace_stack)
|
||||
"""
|
||||
result = BlameResult()
|
||||
|
||||
def _trace_calls(frame, event, arg=None):
|
||||
nonlocal result
|
||||
if event in ["call", "return"]:
|
||||
# for every function call or return
|
||||
try:
|
||||
# Temporarily disable the trace function
|
||||
sys.settrace(None)
|
||||
# check condition here
|
||||
if not result.found and func():
|
||||
result.found = True
|
||||
result.trace_stack = "".join(traceback.format_stack())
|
||||
# Re-enable the trace function
|
||||
sys.settrace(_trace_calls)
|
||||
except NameError:
|
||||
# modules are deleted during shutdown
|
||||
pass
|
||||
return _trace_calls
|
||||
|
||||
try:
|
||||
sys.settrace(_trace_calls)
|
||||
yield result
|
||||
finally:
|
||||
sys.settrace(None)
|
||||
75
tests/vllm_test_utils/vllm_test_utils/monitor.py
Normal file
75
tests/vllm_test_utils/vllm_test_utils/monitor.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import sys
|
||||
import traceback
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MonitoredValues(Generic[_T]):
|
||||
values: list[_T] = dataclasses.field(default_factory=list)
|
||||
trace_stacks: list[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def monitor(
|
||||
measure_func: Callable[[], _T],
|
||||
) -> Generator[MonitoredValues[_T], None, None]:
|
||||
"""
|
||||
Trace the function calls to continuously monitor the change of
|
||||
a value.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
def measure_func():
|
||||
... # measure the current value
|
||||
return current_value
|
||||
|
||||
|
||||
with monitor(measure_func) as monitored_values:
|
||||
# do something
|
||||
|
||||
monitored_values.values # all changes of the values
|
||||
monitored_values.trace_stacks # trace stacks of every change
|
||||
```
|
||||
"""
|
||||
monitored_values = MonitoredValues[_T]()
|
||||
|
||||
def _trace_calls(frame, event, arg=None):
|
||||
nonlocal monitored_values
|
||||
if event in ["line"]:
|
||||
# triggered by every line of Python code.
|
||||
# only Python functions will trigger it,
|
||||
# c/cpp functions will not trigger it.
|
||||
try:
|
||||
# Temporarily disable the trace function
|
||||
sys.settrace(None)
|
||||
# do a measurement
|
||||
current_value = measure_func()
|
||||
if (
|
||||
len(monitored_values.values) == 0
|
||||
or current_value != monitored_values.values[-1]
|
||||
):
|
||||
monitored_values.values.append(current_value)
|
||||
monitored_values.trace_stacks.append(
|
||||
"".join(traceback.format_stack())
|
||||
)
|
||||
# Re-enable the trace function
|
||||
sys.settrace(_trace_calls)
|
||||
except NameError:
|
||||
# modules are deleted during shutdown
|
||||
pass
|
||||
return _trace_calls
|
||||
|
||||
try:
|
||||
sys.settrace(_trace_calls)
|
||||
yield monitored_values
|
||||
finally:
|
||||
sys.settrace(None)
|
||||
Reference in New Issue
Block a user