Sync from v0.13
This commit is contained in:
@@ -1,50 +0,0 @@
|
||||
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import uvicorn
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import vllm.entrypoints.api_server
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
app = vllm.entrypoints.api_server.app
|
||||
|
||||
|
||||
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_aborts = 0
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
await super().abort(request_id)
|
||||
self._num_aborts += 1
|
||||
|
||||
def testing_stats(self) -> Dict[str, Any]:
|
||||
return {"num_aborted_requests": self._num_aborts}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
def stats() -> Response:
|
||||
"""Get the statistics of the engine."""
|
||||
return JSONResponse(engine.testing_stats())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
|
||||
vllm.entrypoints.api_server.engine = engine
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)
|
||||
@@ -1,108 +0,0 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
|
||||
response = requests.post("http://localhost:8000/generate",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0,
|
||||
"ignore_eos": True
|
||||
})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _query_server_long(prompt: str) -> dict:
|
||||
return _query_server(prompt, max_tokens=500)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
|
||||
worker_use_ray: bool):
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
commands = [
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m", "--host",
|
||||
"127.0.0.1", "--tokenizer-pool-size",
|
||||
str(tokenizer_pool_size)
|
||||
]
|
||||
if engine_use_ray:
|
||||
commands.append("--engine-use-ray")
|
||||
if worker_use_ray:
|
||||
commands.append("--worker-use-ray")
|
||||
uvicorn_process = subprocess.Popen(commands)
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
||||
@pytest.mark.parametrize("worker_use_ray", [False, True])
|
||||
@pytest.mark.parametrize("engine_use_ray", [False, True])
|
||||
def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
|
||||
engine_use_ray: bool):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
We run both the server and requests in separate processes.
|
||||
|
||||
We test that the server can handle incoming requests, including
|
||||
multiple requests at the same time, and that it can handle requests
|
||||
being cancelled without crashing.
|
||||
"""
|
||||
with Pool(32) as pool:
|
||||
# Wait until the server is ready
|
||||
prompts = ["warm up"] * 1
|
||||
result = None
|
||||
while not result:
|
||||
try:
|
||||
for r in pool.map(_query_server, prompts):
|
||||
result = r
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
time.sleep(1)
|
||||
|
||||
# Actual tests start here
|
||||
# Try with 1 prompt
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests == 0
|
||||
|
||||
# Try with 100 prompts
|
||||
prompts = ["test prompt"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
with Pool(32) as pool:
|
||||
# Cancel requests
|
||||
prompts = ["canceled requests"] * 100
|
||||
pool.map_async(_query_server_long, prompts)
|
||||
time.sleep(0.01)
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
# check cancellation stats
|
||||
# give it some times to update the stats
|
||||
time.sleep(1)
|
||||
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests > 0
|
||||
|
||||
# check that server still runs after cancellations
|
||||
with Pool(32) as pool:
|
||||
# Try with 100 prompts
|
||||
prompts = ["test prompt after canceled"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
@@ -1,96 +0,0 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestOutput:
|
||||
request_id: int
|
||||
finished: bool = False
|
||||
|
||||
|
||||
class MockEngine:
|
||||
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
self.add_request_calls = 0
|
||||
self.abort_request_calls = 0
|
||||
self.request_id = None
|
||||
|
||||
async def step_async(self):
|
||||
self.step_calls += 1
|
||||
return [RequestOutput(
|
||||
request_id=self.request_id)] if self.request_id else []
|
||||
|
||||
async def encode_request_async(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
|
||||
def stop_generating(self):
|
||||
self.request_id = None
|
||||
|
||||
def add_request(self, **kwargs):
|
||||
del kwargs # Unused
|
||||
self.add_request_calls += 1
|
||||
|
||||
async def add_request_async(self, **kwargs):
|
||||
self.add_request_calls += 1
|
||||
return
|
||||
|
||||
def abort_request(self, request_id):
|
||||
del request_id # Unused
|
||||
self.abort_request_calls += 1
|
||||
|
||||
def has_unfinished_requests(self):
|
||||
return self.request_id is not None
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
|
||||
await engine.add_request("1", "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 1
|
||||
assert engine.engine.step_calls == 1
|
||||
|
||||
await engine.add_request("2", "", None)
|
||||
engine.engine.generate("2")
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.add_request_calls == 2
|
||||
assert engine.engine.step_calls >= 2
|
||||
await asyncio.sleep(0.001)
|
||||
assert engine.engine.step_calls >= 3
|
||||
engine.engine.stop_generating()
|
||||
await asyncio.sleep(0.001)
|
||||
old_step_calls = engine.engine.step_calls
|
||||
await asyncio.sleep(0.001)
|
||||
assert engine.engine.step_calls == old_step_calls
|
||||
|
||||
await engine.add_request("3", "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
|
||||
assert engine.get_model_config() is not None
|
||||
assert engine.get_tokenizer() is not None
|
||||
assert engine.get_decoding_config() is not None
|
||||
@@ -1,134 +0,0 @@
|
||||
import os
|
||||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
|
||||
__file__))).parent.parent / "examples/template_chatml.jinja"
|
||||
assert chatml_jinja_path.exists()
|
||||
|
||||
# Define models, templates, and their corresponding expected outputs
|
||||
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
||||
("facebook/opt-125m", None, True,
|
||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||
("facebook/opt-125m", None, False,
|
||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of<|im_end|>
|
||||
<|im_start|>assistant
|
||||
"""),
|
||||
("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of""")
|
||||
]
|
||||
|
||||
TEST_MESSAGES = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'Hi there!'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'What is the capital of'
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockTokenizer:
|
||||
chat_template = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockServingChat:
|
||||
tokenizer: MockTokenizer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_chat_template():
|
||||
# Testing chatml template
|
||||
tokenizer = MockTokenizer()
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
await OpenAIServingChat._load_chat_template(
|
||||
mock_serving_chat, chat_template=chatml_jinja_path)
|
||||
|
||||
template_content = tokenizer.chat_template
|
||||
|
||||
# Test assertions
|
||||
assert template_content is not None
|
||||
# Hard coded value for template_chatml.jinja
|
||||
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_load_chat_template_filelike():
|
||||
# Testing chatml template
|
||||
template = "../../examples/does_not_exist"
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
|
||||
with pytest.raises(ValueError, match="looks like a file path"):
|
||||
await OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||
chat_template=template)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_load_chat_template_literallike():
|
||||
# Testing chatml template
|
||||
template = "{{ messages }}"
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
await OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||
chat_template=template)
|
||||
template_content = tokenizer.chat_template
|
||||
|
||||
assert template_content == template
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model,template,add_generation_prompt,expected_output",
|
||||
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
||||
async def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
expected_output):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
await OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||
chat_template=template)
|
||||
|
||||
# Create a mock request object using keyword arguments
|
||||
mock_request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=TEST_MESSAGES,
|
||||
add_generation_prompt=add_generation_prompt)
|
||||
|
||||
# Call the function and get the result
|
||||
result = tokenizer.apply_chat_template(
|
||||
conversation=mock_request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=mock_request.add_generation_prompt)
|
||||
|
||||
# Test assertion
|
||||
assert result == expected_output, (
|
||||
f"The generated prompt does not match the expected output for "
|
||||
f"model {model} and template {template}")
|
||||
@@ -1,157 +0,0 @@
|
||||
# imports for guided decoding tests
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
# using Ray for overall ease of process management, parallel requests,
|
||||
# and debugging.
|
||||
import ray
|
||||
import requests
|
||||
|
||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class ServerRunner:
|
||||
|
||||
def __init__(self, args):
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
self.proc = subprocess.Popen(
|
||||
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
self._wait_for_server()
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def _wait_for_server(self):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(
|
||||
"http://localhost:8000/health").status_code == 200:
|
||||
break
|
||||
except Exception as err:
|
||||
if self.proc.poll() is not None:
|
||||
raise RuntimeError("Server exited unexpectedly.") from err
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > MAX_SERVER_START_WAIT_S:
|
||||
raise RuntimeError(
|
||||
"Server failed to start in time.") from err
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "proc"):
|
||||
self.proc.terminate()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
ray.init()
|
||||
server_runner = ServerRunner.remote([
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--enforce-eager",
|
||||
"--engine-use-ray"
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client():
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="token-abc123",
|
||||
)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_models(server, client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert all(model.root == MODEL_NAME for model in models)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_completion(server, client: openai.AsyncOpenAI):
|
||||
completion = await client.completions.create(model=MODEL_NAME,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
assert completion.choices[0].text is not None and len(
|
||||
completion.choices[0].text) >= 5
|
||||
assert completion.choices[0].finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert completion.choices[0].text is not None and len(
|
||||
completion.choices[0].text) >= 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5)
|
||||
assert chat_completion.id is not None
|
||||
assert chat_completion.choices is not None and len(
|
||||
chat_completion.choices) == 1
|
||||
assert chat_completion.choices[0].message is not None
|
||||
assert chat_completion.choices[0].logprobs is not None
|
||||
assert chat_completion.choices[0].logprobs.top_logprobs is not None
|
||||
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 10
|
||||
assert message.role == "assistant"
|
||||
messages.append({"role": "assistant", "content": message.content})
|
||||
|
||||
# test multi-turn dialogue
|
||||
messages.append({"role": "user", "content": "express your result in json"})
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
@@ -1,67 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import RequestTracker
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
stream_1 = tracker.add_request("1")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
assert not finished
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "1" in finished
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], [], finished=True))
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "5"
|
||||
assert stream_2.finished
|
||||
assert not stream_5.finished
|
||||
@@ -1,50 +1,234 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
import weakref
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from ..models.utils import check_outputs_equal
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
ATTN_BACKEND = ["ROCM_ATTN"] if current_platform.is_rocm() else ["FLASH_ATTN"]
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"hmellor/tiny-random-Gemma2ForCausalLM",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
]
|
||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
|
||||
|
||||
|
||||
def test_vllm_gc_ed():
|
||||
"""Verify vllm instance is GC'ed when it is deleted"""
|
||||
llm = LLM("hmellor/tiny-random-LlamaForCausalLM")
|
||||
weak_llm = weakref.ref(llm)
|
||||
del llm
|
||||
# If there's any circular reference to vllm, this fails
|
||||
# because llm instance is not GC'ed.
|
||||
assert weak_llm() is None
|
||||
|
||||
|
||||
def _fix_prompt_embed_outputs(
|
||||
vllm_outputs: list[tuple[list[int], str]],
|
||||
hf_model: HfRunner,
|
||||
example_prompts: list[str],
|
||||
) -> list[tuple[list[int], str]]:
|
||||
fixed_vllm_outputs = []
|
||||
for vllm_output, hf_input, prompt in zip(
|
||||
vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts
|
||||
):
|
||||
hf_input_ids = hf_input["input_ids"].tolist()[0]
|
||||
fixed_vllm_outputs.append(
|
||||
(
|
||||
hf_input_ids + vllm_output[0][len(hf_input_ids) :],
|
||||
prompt + vllm_output[1],
|
||||
)
|
||||
)
|
||||
return fixed_vllm_outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("backend", ATTN_BACKEND)
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
hf_runner,
|
||||
model: str,
|
||||
backend: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
async_scheduling: bool,
|
||||
model_executor: str,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(1024))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [prompt]
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=model_executor,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model, distributed_executor_backend, attention_backend, test_suite, extra_env",
|
||||
[
|
||||
("facebook/opt-125m", "ray", "", "L4", {}),
|
||||
("facebook/opt-125m", "mp", "", "L4", {}),
|
||||
("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||
("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
||||
("facebook/opt-125m", "ray", "", "A100", {}),
|
||||
("facebook/opt-125m", "mp", "", "A100", {}),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models_distributed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
distributed_executor_backend: str,
|
||||
attention_backend: str,
|
||||
test_suite: str,
|
||||
extra_env: dict[str, str],
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
|
||||
pytest.skip("Skipping non-eager test for FlashInferBackend.")
|
||||
if test_suite != TARGET_TEST_SUITE:
|
||||
pytest.skip(f"Skip test for {test_suite}")
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
with monkeypatch.context() as monkeypatch_context:
|
||||
if (
|
||||
model == "meta-llama/Llama-3.2-1B-Instruct"
|
||||
and distributed_executor_backend == "ray"
|
||||
and attention_backend == ""
|
||||
and test_suite == "L4"
|
||||
and enable_prompt_embeds
|
||||
): # noqa
|
||||
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
|
||||
|
||||
vllm_model = vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
if attention_backend:
|
||||
monkeypatch_context.setenv(
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
attention_backend,
|
||||
)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
for k, v in extra_env.items():
|
||||
monkeypatch_context.setenv(k, v)
|
||||
|
||||
dtype = "half"
|
||||
max_tokens = 5
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method
|
||||
# (the default method).
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
) as vllm_model:
|
||||
if enable_prompt_embeds:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
|
||||
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts
|
||||
)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
|
||||
# Needed to mock an error in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model:
|
||||
if isinstance(vllm_model.llm.llm_engine, LLMEngine):
|
||||
v1_test_failed_model_execution(vllm_model)
|
||||
|
||||
|
||||
def v1_test_failed_model_execution(vllm_model):
|
||||
engine = vllm_model.llm.llm_engine
|
||||
mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error"))
|
||||
engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
vllm_model.generate_greedy(prompts, 200, use_tqdm=False)
|
||||
assert isinstance(exc_info.value, RuntimeError)
|
||||
assert "Mocked Critical Error" in str(exc_info.value)
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
It tests chunked prefill. Chunked prefill can be enabled by
|
||||
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
|
||||
prefill requests are chunked.
|
||||
|
||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
10
tests/basic_correctness/test_cpu_offload.py
Normal file
10
tests/basic_correctness/test_cpu_offload.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
|
||||
def test_cpu_offload():
|
||||
compare_two_settings(
|
||||
"hmellor/tiny-random-LlamaForCausalLM", [], ["--cpu-offload-gb", "1"]
|
||||
)
|
||||
281
tests/basic_correctness/test_cumem.py
Normal file
281
tests/basic_correctness/test_cumem.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
|
||||
from ..utils import create_new_process_for_each_test, requires_fp8
|
||||
|
||||
|
||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||
def test_python_error():
|
||||
"""
|
||||
Test if Python error occurs when there's low-level
|
||||
error happening from the C++ side.
|
||||
"""
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
total_bytes = torch.cuda.mem_get_info()[1]
|
||||
alloc_bytes = int(total_bytes * 0.7)
|
||||
tensors = []
|
||||
with allocator.use_memory_pool():
|
||||
# allocate 70% of the total memory
|
||||
x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
|
||||
tensors.append(x)
|
||||
# release the memory
|
||||
allocator.sleep()
|
||||
|
||||
# allocate more memory than the total memory
|
||||
y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
|
||||
tensors.append(y)
|
||||
with pytest.raises(RuntimeError):
|
||||
# when the allocator is woken up, it should raise an error
|
||||
# because we don't have enough memory
|
||||
allocator.wake_up()
|
||||
|
||||
|
||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||
def test_basic_cumem():
|
||||
# some tensors from default memory pool
|
||||
shape = (1024, 1024)
|
||||
x = torch.empty(shape, device="cuda")
|
||||
x.zero_()
|
||||
|
||||
# some tensors from custom memory pool
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
# custom memory pool
|
||||
y = torch.empty(shape, device="cuda")
|
||||
y.zero_()
|
||||
y += 1
|
||||
z = torch.empty(shape, device="cuda")
|
||||
z.zero_()
|
||||
z += 2
|
||||
|
||||
# they can be used together
|
||||
output = x + y + z
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
free_bytes = torch.cuda.mem_get_info()[0]
|
||||
allocator.sleep()
|
||||
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
|
||||
assert free_bytes_after_sleep > free_bytes
|
||||
allocator.wake_up()
|
||||
|
||||
# they can be used together
|
||||
output = x + y + z
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
|
||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||
def test_cumem_with_cudagraph():
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
weight = torch.eye(1024, device="cuda")
|
||||
with allocator.use_memory_pool(tag="discard"):
|
||||
cache = torch.empty(1024, 1024, device="cuda")
|
||||
|
||||
def model(x):
|
||||
out = x @ weight
|
||||
cache[: out.size(0)].copy_(out)
|
||||
return out + 1
|
||||
|
||||
x = torch.empty(128, 1024, device="cuda")
|
||||
|
||||
# warmup
|
||||
model(x)
|
||||
|
||||
# capture cudagraph
|
||||
model_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(model_graph):
|
||||
y = model(x)
|
||||
|
||||
free_bytes = torch.cuda.mem_get_info()[0]
|
||||
allocator.sleep()
|
||||
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
|
||||
assert free_bytes_after_sleep > free_bytes
|
||||
allocator.wake_up()
|
||||
|
||||
# after waking up, the content in the weight tensor
|
||||
# should be restored, but the content in the cache tensor
|
||||
# should be discarded
|
||||
|
||||
# this operation is also compatible with cudagraph
|
||||
|
||||
x.random_()
|
||||
model_graph.replay()
|
||||
|
||||
# cache content is as expected
|
||||
assert torch.allclose(x, cache[: x.size(0)])
|
||||
|
||||
# output content is as expected
|
||||
assert torch.allclose(y, x + 1)
|
||||
|
||||
|
||||
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# sleep mode with safetensors
|
||||
"hmellor/tiny-random-LlamaForCausalLM",
|
||||
# sleep mode with pytorch checkpoint
|
||||
"facebook/opt-125m",
|
||||
],
|
||||
)
|
||||
def test_end_to_end(model: str):
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
llm = LLM(model, enable_sleep_mode=True)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
|
||||
# which is difficult to measure in the test. therefore, we only
|
||||
# test sleep level 1 here.
|
||||
llm.sleep(level=1)
|
||||
|
||||
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||
# now the memory usage is mostly cudagraph memory pool,
|
||||
# and it should be less than the model weights (1B model, 2GiB weights)
|
||||
|
||||
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
|
||||
# is captured but cannot be releasesd from PyTorch due to a known bug,
|
||||
# therefore high memory usage after `llm.sleep` is called is expected.
|
||||
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
|
||||
# in V1.
|
||||
assert used_bytes < 7 * GiB_bytes
|
||||
|
||||
llm.wake_up()
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
|
||||
llm.sleep(level=1)
|
||||
llm.wake_up(tags=["weights"])
|
||||
|
||||
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
|
||||
|
||||
# should just reallocate memory for weights (1B model, ~2GiB weights)
|
||||
assert used_bytes < 10 * GiB_bytes
|
||||
|
||||
# now allocate kv cache memory
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
output3 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output3[0].outputs[0].text
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_deep_sleep():
|
||||
model = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
llm = LLM(model, enable_sleep_mode=True)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# Put the engine to deep sleep
|
||||
llm.sleep(level=2)
|
||||
|
||||
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||
assert used_bytes < 3 * GiB_bytes
|
||||
|
||||
llm.wake_up(tags=["weights"])
|
||||
llm.collective_rpc("reload_weights")
|
||||
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
|
||||
assert used_bytes < 4 * GiB_bytes
|
||||
|
||||
# now allocate kv cache and cuda graph memory
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_deep_sleep_async():
|
||||
async def test():
|
||||
model = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
enable_sleep_mode=True,
|
||||
)
|
||||
|
||||
llm = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
outputs = llm.generate(prompt, sampling_params, request_id="test_request_id1")
|
||||
async for output in outputs:
|
||||
pass
|
||||
|
||||
# Put the engine to deep sleep
|
||||
await llm.sleep(level=2)
|
||||
|
||||
await llm.wake_up(tags=["weights"])
|
||||
await llm.collective_rpc("reload_weights")
|
||||
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
|
||||
assert used_bytes < 4 * GiB_bytes
|
||||
|
||||
# now allocate kv cache and cuda graph memory
|
||||
await llm.wake_up(tags=["kv_cache"])
|
||||
outputs2 = llm.generate(prompt, sampling_params, request_id="test_request_id2")
|
||||
async for output2 in outputs2:
|
||||
pass
|
||||
|
||||
# cmp output
|
||||
assert output.outputs[0].text == output2.outputs[0].text
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
|
||||
@requires_fp8
|
||||
def test_deep_sleep_fp8_kvcache():
|
||||
GiB_bytes = 1 << 30
|
||||
model = "Qwen/Qwen2-0.5B"
|
||||
used_bytes_baseline = current_platform.get_current_memory_usage()
|
||||
|
||||
llm = LLM(model, enable_sleep_mode=True, kv_cache_dtype="fp8")
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# Put the engine to deep sleep
|
||||
llm.sleep(level=2)
|
||||
|
||||
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
|
||||
|
||||
# Rocm uses more memory for CudaGraphs, so we add 2 GiB more for the threshold
|
||||
rocm_extra_mem_bytes = 2 * GiB_bytes if current_platform.is_rocm() else 0
|
||||
mem_threshold_after_sleep = 3 * GiB_bytes + rocm_extra_mem_bytes
|
||||
assert used_bytes < mem_threshold_after_sleep
|
||||
|
||||
llm.wake_up(tags=["weights"])
|
||||
llm.collective_rpc("reload_weights")
|
||||
|
||||
used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline
|
||||
mem_threshold_after_wake_up = 4 * GiB_bytes + rocm_extra_mem_bytes
|
||||
assert used_bytes < mem_threshold_after_wake_up
|
||||
|
||||
# now allocate kv cache and cuda graph memory
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
@@ -1,223 +0,0 @@
|
||||
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test.
|
||||
|
||||
Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
|
||||
pytest tests/basic_correctness/test_preemption.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
|
||||
ENABLE_ARTIFICIAL_PREEMPT)
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
]
|
||||
|
||||
assert ENABLE_ARTIFICIAL_PREEMPT is True, (
|
||||
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
|
||||
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
|
||||
"tests/basic_correctness/test_preemption.py`")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
|
||||
def test_chunked_prefill_recompute(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
) -> None:
|
||||
"""Ensure that chunked prefill works with preemption."""
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_preemption(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
"""By default, recompute preemption is enabled"""
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@pytest.mark.parametrize("beam_width", [4])
|
||||
def test_swap(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
"""Use beam search enables swapping."""
|
||||
example_prompts = example_prompts[:1]
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, swap_space=10)
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, _ = hf_outputs[i]
|
||||
vllm_output_ids, _ = vllm_outputs[i]
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
for j in range(len(hf_output_ids)):
|
||||
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||
f"vLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@pytest.mark.parametrize("beam_width", [4])
|
||||
def test_swap_infeasible(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
"""Verify infeasible swap request will be ignored."""
|
||||
BLOCK_SIZE = 16
|
||||
prefill_blocks = 2
|
||||
decode_blocks = max_tokens // BLOCK_SIZE
|
||||
example_prompts = example_prompts[:1]
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
swap_space=10,
|
||||
block_size=BLOCK_SIZE,
|
||||
# Since beam search have more than 1 sequence, prefill + decode blocks
|
||||
# are not enough to finish.
|
||||
num_gpu_blocks_override=prefill_blocks + decode_blocks,
|
||||
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
|
||||
)
|
||||
sampling_params = SamplingParams(n=beam_width,
|
||||
use_beam_search=True,
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
req_outputs = vllm_model.model.generate(
|
||||
example_prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
# Verify the request is ignored and not hang.
|
||||
assert req_outputs[0].outputs[0].finish_reason == "length"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_preemption_infeasible(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
"""Verify infeasible preemption request will be ignored."""
|
||||
BLOCK_SIZE = 16
|
||||
prefill_blocks = 2
|
||||
decode_blocks = max_tokens // BLOCK_SIZE
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
block_size=BLOCK_SIZE,
|
||||
# Not enough gpu blocks to complete a single sequence.
|
||||
# preemption should happen, and the sequence should be
|
||||
# ignored instead of hanging forever.
|
||||
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
|
||||
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
|
||||
req_outputs = vllm_model.model.generate(
|
||||
example_prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
# Verify the request is ignored and not hang.
|
||||
for req_output in req_outputs:
|
||||
outputs = req_output.outputs
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0].finish_reason == "length"
|
||||
30
tests/benchmarks/test_latency_cli.py
Normal file
30
tests/benchmarks/test_latency_cli.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_latency():
|
||||
command = [
|
||||
"vllm",
|
||||
"bench",
|
||||
"latency",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--input-len",
|
||||
"32",
|
||||
"--output-len",
|
||||
"1",
|
||||
"--enforce-eager",
|
||||
"--load-format",
|
||||
"dummy",
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
print(result.stdout)
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||
249
tests/benchmarks/test_param_sweep.py
Normal file
249
tests/benchmarks/test_param_sweep.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.benchmarks.sweep.param_sweep import ParameterSweep, ParameterSweepItem
|
||||
|
||||
|
||||
class TestParameterSweepItem:
|
||||
"""Test ParameterSweepItem functionality."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_dict,expected",
|
||||
[
|
||||
(
|
||||
{"compilation_config.use_inductor_graph_partition": False},
|
||||
"--compilation-config.use_inductor_graph_partition=false",
|
||||
),
|
||||
(
|
||||
{"compilation_config.use_inductor_graph_partition": True},
|
||||
"--compilation-config.use_inductor_graph_partition=true",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nested_boolean_params(self, input_dict, expected):
|
||||
"""Test that nested boolean params use =true/false syntax."""
|
||||
item = ParameterSweepItem.from_record(input_dict)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
|
||||
assert expected in cmd
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_dict,expected",
|
||||
[
|
||||
({"enable_prefix_caching": False}, "--no-enable-prefix-caching"),
|
||||
({"enable_prefix_caching": True}, "--enable-prefix-caching"),
|
||||
({"disable_log_stats": False}, "--no-disable-log-stats"),
|
||||
({"disable_log_stats": True}, "--disable-log-stats"),
|
||||
],
|
||||
)
|
||||
def test_non_nested_boolean_params(self, input_dict, expected):
|
||||
"""Test that non-nested boolean params use --no- prefix."""
|
||||
item = ParameterSweepItem.from_record(input_dict)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
|
||||
assert expected in cmd
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_config",
|
||||
[
|
||||
{"cudagraph_mode": "full", "mode": 2, "use_inductor_graph_partition": True},
|
||||
{
|
||||
"cudagraph_mode": "piecewise",
|
||||
"mode": 3,
|
||||
"use_inductor_graph_partition": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_nested_dict_value(self, compilation_config):
|
||||
"""Test that nested dict values are serialized as JSON."""
|
||||
item = ParameterSweepItem.from_record(
|
||||
{"compilation_config": compilation_config}
|
||||
)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
|
||||
assert "--compilation-config" in cmd
|
||||
# The dict should be JSON serialized
|
||||
idx = cmd.index("--compilation-config")
|
||||
assert json.loads(cmd[idx + 1]) == compilation_config
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_dict,expected_key,expected_value",
|
||||
[
|
||||
({"model": "test-model"}, "--model", "test-model"),
|
||||
({"max_tokens": 100}, "--max-tokens", "100"),
|
||||
({"temperature": 0.7}, "--temperature", "0.7"),
|
||||
],
|
||||
)
|
||||
def test_string_and_numeric_values(self, input_dict, expected_key, expected_value):
|
||||
"""Test that string and numeric values are handled correctly."""
|
||||
item = ParameterSweepItem.from_record(input_dict)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve"])
|
||||
assert expected_key in cmd
|
||||
assert expected_value in cmd
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_dict,expected_key,key_idx_offset",
|
||||
[
|
||||
({"max_tokens": 200}, "--max-tokens", 1),
|
||||
({"enable_prefix_caching": False}, "--no-enable-prefix-caching", 0),
|
||||
],
|
||||
)
|
||||
def test_replace_existing_parameter(self, input_dict, expected_key, key_idx_offset):
|
||||
"""Test that existing parameters in cmd are replaced."""
|
||||
item = ParameterSweepItem.from_record(input_dict)
|
||||
|
||||
if key_idx_offset == 1:
|
||||
# Key-value pair
|
||||
cmd = item.apply_to_cmd(["vllm", "serve", "--max-tokens", "100", "model"])
|
||||
assert expected_key in cmd
|
||||
idx = cmd.index(expected_key)
|
||||
assert cmd[idx + 1] == "200"
|
||||
assert "100" not in cmd
|
||||
else:
|
||||
# Boolean flag
|
||||
cmd = item.apply_to_cmd(
|
||||
["vllm", "serve", "--enable-prefix-caching", "model"]
|
||||
)
|
||||
assert expected_key in cmd
|
||||
assert "--enable-prefix-caching" not in cmd
|
||||
|
||||
|
||||
class TestParameterSweep:
|
||||
"""Test ParameterSweep functionality."""
|
||||
|
||||
def test_from_records_list(self):
|
||||
"""Test creating ParameterSweep from a list of records."""
|
||||
records = [
|
||||
{"max_tokens": 100, "temperature": 0.7},
|
||||
{"max_tokens": 200, "temperature": 0.9},
|
||||
]
|
||||
sweep = ParameterSweep.from_records(records)
|
||||
assert len(sweep) == 2
|
||||
assert sweep[0]["max_tokens"] == 100
|
||||
assert sweep[1]["max_tokens"] == 200
|
||||
|
||||
def test_read_from_dict(self):
|
||||
"""Test creating ParameterSweep from a dict format."""
|
||||
data = {
|
||||
"experiment1": {"max_tokens": 100, "temperature": 0.7},
|
||||
"experiment2": {"max_tokens": 200, "temperature": 0.9},
|
||||
}
|
||||
sweep = ParameterSweep.read_from_dict(data)
|
||||
assert len(sweep) == 2
|
||||
|
||||
# Check that items have the _benchmark_name field
|
||||
names = {item["_benchmark_name"] for item in sweep}
|
||||
assert names == {"experiment1", "experiment2"}
|
||||
|
||||
# Check that parameters are preserved
|
||||
for item in sweep:
|
||||
if item["_benchmark_name"] == "experiment1":
|
||||
assert item["max_tokens"] == 100
|
||||
assert item["temperature"] == 0.7
|
||||
elif item["_benchmark_name"] == "experiment2":
|
||||
assert item["max_tokens"] == 200
|
||||
assert item["temperature"] == 0.9
|
||||
|
||||
def test_read_json_list_format(self):
|
||||
"""Test reading JSON file with list format."""
|
||||
records = [
|
||||
{"max_tokens": 100, "temperature": 0.7},
|
||||
{"max_tokens": 200, "temperature": 0.9},
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(records, f)
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
sweep = ParameterSweep.read_json(temp_path)
|
||||
assert len(sweep) == 2
|
||||
assert sweep[0]["max_tokens"] == 100
|
||||
assert sweep[1]["max_tokens"] == 200
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_read_json_dict_format(self):
|
||||
"""Test reading JSON file with dict format."""
|
||||
data = {
|
||||
"experiment1": {"max_tokens": 100, "temperature": 0.7},
|
||||
"experiment2": {"max_tokens": 200, "temperature": 0.9},
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(data, f)
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
sweep = ParameterSweep.read_json(temp_path)
|
||||
assert len(sweep) == 2
|
||||
|
||||
# Check that items have the _benchmark_name field
|
||||
names = {item["_benchmark_name"] for item in sweep}
|
||||
assert names == {"experiment1", "experiment2"}
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_unique_benchmark_names_validation(self):
|
||||
"""Test that duplicate _benchmark_name values raise an error."""
|
||||
# Test with duplicate names in list format
|
||||
records = [
|
||||
{"_benchmark_name": "exp1", "max_tokens": 100},
|
||||
{"_benchmark_name": "exp1", "max_tokens": 200},
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Duplicate _benchmark_name values"):
|
||||
ParameterSweep.from_records(records)
|
||||
|
||||
def test_unique_benchmark_names_multiple_duplicates(self):
|
||||
"""Test validation with multiple duplicate names."""
|
||||
records = [
|
||||
{"_benchmark_name": "exp1", "max_tokens": 100},
|
||||
{"_benchmark_name": "exp1", "max_tokens": 200},
|
||||
{"_benchmark_name": "exp2", "max_tokens": 300},
|
||||
{"_benchmark_name": "exp2", "max_tokens": 400},
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Duplicate _benchmark_name values"):
|
||||
ParameterSweep.from_records(records)
|
||||
|
||||
def test_no_benchmark_names_allowed(self):
|
||||
"""Test that records without _benchmark_name are allowed."""
|
||||
records = [
|
||||
{"max_tokens": 100, "temperature": 0.7},
|
||||
{"max_tokens": 200, "temperature": 0.9},
|
||||
]
|
||||
sweep = ParameterSweep.from_records(records)
|
||||
assert len(sweep) == 2
|
||||
|
||||
def test_mixed_benchmark_names_allowed(self):
|
||||
"""Test that mixing records with and without _benchmark_name is allowed."""
|
||||
records = [
|
||||
{"_benchmark_name": "exp1", "max_tokens": 100},
|
||||
{"max_tokens": 200, "temperature": 0.9},
|
||||
]
|
||||
sweep = ParameterSweep.from_records(records)
|
||||
assert len(sweep) == 2
|
||||
|
||||
|
||||
class TestParameterSweepItemKeyNormalization:
|
||||
"""Test key normalization in ParameterSweepItem."""
|
||||
|
||||
def test_underscore_to_hyphen_conversion(self):
|
||||
"""Test that underscores are converted to hyphens in CLI."""
|
||||
item = ParameterSweepItem.from_record({"max_tokens": 100})
|
||||
cmd = item.apply_to_cmd(["vllm", "serve"])
|
||||
assert "--max-tokens" in cmd
|
||||
|
||||
def test_nested_key_preserves_suffix(self):
|
||||
"""Test that nested keys preserve the suffix format."""
|
||||
# The suffix after the dot should preserve underscores
|
||||
item = ParameterSweepItem.from_record(
|
||||
{"compilation_config.some_nested_param": "value"}
|
||||
)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve"])
|
||||
# The prefix (compilation_config) gets converted to hyphens,
|
||||
# but the suffix (some_nested_param) is preserved
|
||||
assert any("compilation-config.some_nested_param" in arg for arg in cmd)
|
||||
171
tests/benchmarks/test_plot_filters.py
Normal file
171
tests/benchmarks/test_plot_filters.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from vllm.benchmarks.sweep.plot import (
|
||||
PlotEqualTo,
|
||||
PlotFilterBase,
|
||||
PlotFilters,
|
||||
PlotGreaterThan,
|
||||
PlotGreaterThanOrEqualTo,
|
||||
PlotLessThan,
|
||||
PlotLessThanOrEqualTo,
|
||||
PlotNotEqualTo,
|
||||
)
|
||||
|
||||
|
||||
class TestPlotFilters:
|
||||
"""Test PlotFilter functionality including 'inf' edge case."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Create sample DataFrames for testing."""
|
||||
# DataFrame with numeric values
|
||||
self.df_numeric = pd.DataFrame(
|
||||
{
|
||||
"request_rate": [1.0, 5.0, 10.0, 50.0, 100.0],
|
||||
"value": [10, 20, 30, 40, 50],
|
||||
}
|
||||
)
|
||||
|
||||
# DataFrame with float('inf') - note: string "inf" values are coerced
|
||||
# to float when loading data, so we only test with float('inf')
|
||||
self.df_inf_float = pd.DataFrame(
|
||||
{
|
||||
"request_rate": [1.0, 5.0, 10.0, float("inf"), float("inf")],
|
||||
"value": [10, 20, 30, 40, 50],
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("5.0", 1),
|
||||
("10.0", 1),
|
||||
("1.0", 1),
|
||||
],
|
||||
)
|
||||
def test_equal_to_numeric(self, target, expected_count):
|
||||
"""Test PlotEqualTo with numeric values."""
|
||||
filter_obj = PlotEqualTo("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
def test_equal_to_inf_float(self):
|
||||
"""Test PlotEqualTo with float('inf')."""
|
||||
filter_obj = PlotEqualTo("request_rate", "inf")
|
||||
result = filter_obj.apply(self.df_inf_float)
|
||||
# Should match both float('inf') entries because float('inf') == float('inf')
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("5.0", 4), # All except 5.0
|
||||
("1.0", 4), # All except 1.0
|
||||
],
|
||||
)
|
||||
def test_not_equal_to_numeric(self, target, expected_count):
|
||||
"""Test PlotNotEqualTo with numeric values."""
|
||||
filter_obj = PlotNotEqualTo("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
def test_not_equal_to_inf_float(self):
|
||||
"""Test PlotNotEqualTo with float('inf')."""
|
||||
filter_obj = PlotNotEqualTo("request_rate", "inf")
|
||||
result = filter_obj.apply(self.df_inf_float)
|
||||
# Should exclude float('inf') entries
|
||||
assert len(result) == 3
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("10.0", 2), # 1.0, 5.0
|
||||
("50.0", 3), # 1.0, 5.0, 10.0
|
||||
("5.0", 1), # 1.0
|
||||
],
|
||||
)
|
||||
def test_less_than(self, target, expected_count):
|
||||
"""Test PlotLessThan with numeric values."""
|
||||
filter_obj = PlotLessThan("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("10.0", 3), # 1.0, 5.0, 10.0
|
||||
("5.0", 2), # 1.0, 5.0
|
||||
],
|
||||
)
|
||||
def test_less_than_or_equal_to(self, target, expected_count):
|
||||
"""Test PlotLessThanOrEqualTo with numeric values."""
|
||||
filter_obj = PlotLessThanOrEqualTo("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("10.0", 2), # 50.0, 100.0
|
||||
("5.0", 3), # 10.0, 50.0, 100.0
|
||||
],
|
||||
)
|
||||
def test_greater_than(self, target, expected_count):
|
||||
"""Test PlotGreaterThan with numeric values."""
|
||||
filter_obj = PlotGreaterThan("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("10.0", 3), # 10.0, 50.0, 100.0
|
||||
("5.0", 4), # 5.0, 10.0, 50.0, 100.0
|
||||
],
|
||||
)
|
||||
def test_greater_than_or_equal_to(self, target, expected_count):
|
||||
"""Test PlotGreaterThanOrEqualTo with numeric values."""
|
||||
filter_obj = PlotGreaterThanOrEqualTo("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_str,expected_var,expected_target,expected_type",
|
||||
[
|
||||
("request_rate==5.0", "request_rate", "5.0", PlotEqualTo),
|
||||
("request_rate!=10.0", "request_rate", "10.0", PlotNotEqualTo),
|
||||
("request_rate<50.0", "request_rate", "50.0", PlotLessThan),
|
||||
("request_rate<=50.0", "request_rate", "50.0", PlotLessThanOrEqualTo),
|
||||
("request_rate>10.0", "request_rate", "10.0", PlotGreaterThan),
|
||||
("request_rate>=10.0", "request_rate", "10.0", PlotGreaterThanOrEqualTo),
|
||||
("request_rate==inf", "request_rate", "inf", PlotEqualTo),
|
||||
("request_rate!='inf'", "request_rate", "inf", PlotNotEqualTo),
|
||||
],
|
||||
)
|
||||
def test_parse_str(self, filter_str, expected_var, expected_target, expected_type):
|
||||
"""Test parsing filter strings."""
|
||||
filter_obj = PlotFilterBase.parse_str(filter_str)
|
||||
assert isinstance(filter_obj, expected_type)
|
||||
assert filter_obj.var == expected_var
|
||||
assert filter_obj.target == expected_target
|
||||
|
||||
def test_parse_str_inf_edge_case(self):
|
||||
"""Test parsing 'inf' string in filter."""
|
||||
filter_obj = PlotFilterBase.parse_str("request_rate==inf")
|
||||
assert isinstance(filter_obj, PlotEqualTo)
|
||||
assert filter_obj.var == "request_rate"
|
||||
assert filter_obj.target == "inf"
|
||||
|
||||
def test_parse_multiple_filters(self):
|
||||
"""Test parsing multiple filters."""
|
||||
filters = PlotFilters.parse_str("request_rate>5.0,value<=40")
|
||||
assert len(filters) == 2
|
||||
assert isinstance(filters[0], PlotGreaterThan)
|
||||
assert isinstance(filters[1], PlotLessThanOrEqualTo)
|
||||
|
||||
def test_parse_empty_filter(self):
|
||||
"""Test parsing empty filter string."""
|
||||
filters = PlotFilters.parse_str("")
|
||||
assert len(filters) == 0
|
||||
484
tests/benchmarks/test_random_dataset.py
Normal file
484
tests/benchmarks/test_random_dataset.py
Normal file
@@ -0,0 +1,484 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from typing import Any, NamedTuple, cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
RandomDataset,
|
||||
RandomMultiModalDataset,
|
||||
SampleRequest,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_tokenizer() -> PreTrainedTokenizerBase:
|
||||
# Use a small, commonly available tokenizer
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
|
||||
class Params(NamedTuple):
|
||||
num_requests: int
|
||||
prefix_len: int
|
||||
range_ratio: float
|
||||
input_len: int
|
||||
output_len: int
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def random_dataset_params() -> Params:
|
||||
return Params(
|
||||
num_requests=16, prefix_len=7, range_ratio=0.3, input_len=50, output_len=20
|
||||
)
|
||||
|
||||
|
||||
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
||||
"""Project a SampleRequest into a comparable tuple."""
|
||||
return (req.prompt, req.prompt_len, req.expected_output_len)
|
||||
|
||||
|
||||
def _collect_samples(
|
||||
dataset: RandomDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int = 16,
|
||||
prefix_len: int = 7,
|
||||
range_ratio: float = 0.3,
|
||||
input_len: int = 50,
|
||||
output_len: int = 20,
|
||||
) -> list[tuple[str, int, int]]:
|
||||
samples = dataset.sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=num_requests,
|
||||
prefix_len=prefix_len,
|
||||
range_ratio=range_ratio,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
)
|
||||
return [_fingerprint_sample(s) for s in samples]
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_same_seed(
|
||||
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
|
||||
) -> None:
|
||||
"""Same seed should yield identical outputs, even if global RNGs change.
|
||||
|
||||
This guards against accidental reliance on Python's random or np.random
|
||||
in RandomDataset after moving to numpy.default_rng.
|
||||
"""
|
||||
p = random_dataset_params
|
||||
common_seed = 123
|
||||
dataset_a = RandomDataset(random_seed=common_seed)
|
||||
dataset_b = RandomDataset(random_seed=common_seed)
|
||||
a = _collect_samples(
|
||||
dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len,
|
||||
)
|
||||
|
||||
# Perturb global RNG state to ensure isolation
|
||||
random.seed(999)
|
||||
_ = [random.random() for _ in range(100)]
|
||||
np.random.seed(888)
|
||||
_ = [np.random.random() for _ in range(100)]
|
||||
|
||||
b = _collect_samples(
|
||||
dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len,
|
||||
)
|
||||
assert a == b
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_different_seeds(
|
||||
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
|
||||
) -> None:
|
||||
"""Different seeds should change outputs with overwhelming likelihood."""
|
||||
p = random_dataset_params
|
||||
seed_a = 0
|
||||
dataset_a = RandomDataset(random_seed=seed_a)
|
||||
a = _collect_samples(
|
||||
dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len,
|
||||
)
|
||||
|
||||
seed_b = 999
|
||||
dataset_b = RandomDataset(random_seed=seed_b)
|
||||
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
||||
random.seed(seed_a)
|
||||
np.random.seed(seed_a)
|
||||
b = _collect_samples(
|
||||
dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len,
|
||||
)
|
||||
assert a != b
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# RandomMultiModalDataset tests
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _mm_fingerprint_sample(
|
||||
req: SampleRequest,
|
||||
) -> tuple[str, int, int, int, list[str]]:
|
||||
"""Create a compact fingerprint for multimodal samples.
|
||||
|
||||
Includes:
|
||||
- prompt string
|
||||
- prompt_len
|
||||
- expected_output_len
|
||||
- count of multimodal items
|
||||
- per-item type and URL prefix (e.g., 'data:image/jpeg;base64,')
|
||||
"""
|
||||
items = req.multi_modal_data or []
|
||||
item_prefixes: list[str] = []
|
||||
for it in items:
|
||||
if isinstance(it, dict) and it.get("type") == "image_url":
|
||||
url = it.get("image_url", {}).get("url", "")
|
||||
# Only keep a short identifying prefix to avoid huge strings
|
||||
item_prefixes.append(f"image:{url[:22]}")
|
||||
elif isinstance(it, dict) and it.get("type") == "video_url":
|
||||
url = it.get("video_url", {}).get("url", "")
|
||||
item_prefixes.append(f"video:{url[:22]}")
|
||||
else:
|
||||
item_prefixes.append("unknown:")
|
||||
return (
|
||||
req.prompt,
|
||||
req.prompt_len,
|
||||
req.expected_output_len,
|
||||
len(items),
|
||||
item_prefixes,
|
||||
)
|
||||
|
||||
|
||||
def _collect_mm_samples(
|
||||
dataset: RandomMultiModalDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
*,
|
||||
num_requests: int = 8,
|
||||
prefix_len: int = 3,
|
||||
range_ratio: float = 0.0,
|
||||
input_len: int = 20,
|
||||
output_len: int = 5,
|
||||
base_items_per_request: int = 2,
|
||||
num_mm_items_range_ratio: float = 0.0,
|
||||
limit_mm_per_prompt: dict[str, int] | None = None,
|
||||
bucket_config: dict[tuple[int, int, int], float] | None = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
) -> list[SampleRequest]:
|
||||
if limit_mm_per_prompt is None:
|
||||
limit_mm_per_prompt = {"image": 5, "video": 0}
|
||||
if bucket_config is None:
|
||||
bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5}
|
||||
return dataset.sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=num_requests,
|
||||
prefix_len=prefix_len,
|
||||
range_ratio=range_ratio,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
base_items_per_request=base_items_per_request,
|
||||
num_mm_items_range_ratio=num_mm_items_range_ratio,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
enable_multimodal_chat=enable_multimodal_chat,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
seed = 42
|
||||
ds_a = RandomMultiModalDataset(random_seed=seed)
|
||||
ds_b = RandomMultiModalDataset(random_seed=seed)
|
||||
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa == fb
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_different_seeds(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds_a = RandomMultiModalDataset(random_seed=0)
|
||||
ds_b = RandomMultiModalDataset(random_seed=999)
|
||||
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa != fb
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_respects_limits(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Requesting 3 items with a per-prompt limit of 1 should error per current
|
||||
# design (dataset refuses to silently clamp below the requested baseline).
|
||||
with pytest.raises(ValueError):
|
||||
_collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=12,
|
||||
base_items_per_request=3,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_zero_prob_entries_are_removed(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Second bucket has zero probability and should be ignored after
|
||||
# normalization
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=6,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 10, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0},
|
||||
)
|
||||
for s in samples:
|
||||
assert isinstance(s.multi_modal_data, list)
|
||||
typed_mm = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
for it in typed_mm:
|
||||
assert it.get("type") == "image_url"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=0,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 5, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
for s in samples:
|
||||
assert s.multi_modal_data == []
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_num_items_per_prompt(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Fixed number of images per prompt
|
||||
# set num_mm_items_range_ratio to 0.0
|
||||
# TODO: modify video values when video sampling is implemented
|
||||
samples_fixed_items = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=3,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 3, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
# Must have 5 requests each with 3 mm items per prompt
|
||||
assert len(samples_fixed_items) == 5
|
||||
for s in samples_fixed_items:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) == 3
|
||||
for it in mm_data:
|
||||
assert it.get("type") == "image_url"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_bucket_config_not_mutated(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# This bucket config is not normalized to sum to 1
|
||||
# and has more buckets than requested images
|
||||
original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3}
|
||||
# Keep a snapshot to compare after sampling
|
||||
snapshot = dict(original)
|
||||
|
||||
_ = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=4,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||
bucket_config=original,
|
||||
)
|
||||
|
||||
# Ensure the original dict content is unchanged
|
||||
assert original == snapshot
|
||||
|
||||
# Vary number of mm items per prompt
|
||||
# set num_mm_items_range_ratio to 0.5
|
||||
samples_varying_items = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.5,
|
||||
limit_mm_per_prompt={"image": 4, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
# Must have 5 requests each with less than 4 mm items per prompt
|
||||
# but at least 1 mm item per prompt
|
||||
assert len(samples_varying_items) == 5
|
||||
for s in samples_varying_items:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) <= 4
|
||||
assert len(mm_data) >= 1
|
||||
for it in mm_data:
|
||||
assert it.get("type") == "image_url"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_video_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
"""Test video sampling functionality in RandomMultiModalDataset."""
|
||||
ds = RandomMultiModalDataset(random_seed=42)
|
||||
|
||||
# Test with video bucket configuration
|
||||
bucket_config = {
|
||||
(64, 64, 1): 0.3, # Images
|
||||
(64, 64, 8): 0.7, # Videos
|
||||
}
|
||||
|
||||
limit_mm_per_prompt = {"image": 2, "video": 2}
|
||||
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
)
|
||||
|
||||
assert len(samples) == 5
|
||||
|
||||
# Check that we have both images and videos
|
||||
video_count = 0
|
||||
image_count = 0
|
||||
|
||||
for s in samples:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) == 1
|
||||
|
||||
item = mm_data[0]
|
||||
if item.get("type") == "video_url":
|
||||
video_count += 1
|
||||
# Verify video URL format
|
||||
url = item.get("video_url", {}).get("url", "")
|
||||
assert url.startswith("data:video/mp4;base64,")
|
||||
elif item.get("type") == "image_url":
|
||||
image_count += 1
|
||||
# Verify image URL format
|
||||
url = item.get("image_url", {}).get("url", "")
|
||||
assert url.startswith("data:image/jpeg;base64,")
|
||||
|
||||
# Should have some videos due to 0.7 probability
|
||||
assert video_count > 0
|
||||
assert image_count > 0
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_video_only_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
"""Test sampling with only video buckets."""
|
||||
ds = RandomMultiModalDataset(random_seed=42)
|
||||
|
||||
bucket_config = {
|
||||
(64, 64, 8): 1.0, # Only videos
|
||||
}
|
||||
|
||||
limit_mm_per_prompt = {"image": 0, "video": 1}
|
||||
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=3,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
)
|
||||
|
||||
assert len(samples) == 3
|
||||
|
||||
for s in samples:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) == 1
|
||||
|
||||
item = mm_data[0]
|
||||
assert item.get("type") == "video_url"
|
||||
url = item.get("video_url", {}).get("url", "")
|
||||
assert url.startswith("data:video/mp4;base64,")
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_video_deterministic_sampling(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
"""Test that video sampling is deterministic with same seed."""
|
||||
seed = 123
|
||||
ds_a = RandomMultiModalDataset(random_seed=seed)
|
||||
ds_b = RandomMultiModalDataset(random_seed=seed)
|
||||
|
||||
bucket_config = {
|
||||
(64, 64, 8): 1.0, # Only videos
|
||||
}
|
||||
|
||||
limit_mm_per_prompt = {"image": 0, "video": 1}
|
||||
|
||||
a = _collect_mm_samples(
|
||||
ds_a,
|
||||
hf_tokenizer,
|
||||
num_requests=3,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
)
|
||||
|
||||
b = _collect_mm_samples(
|
||||
ds_b,
|
||||
hf_tokenizer,
|
||||
num_requests=3,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
)
|
||||
|
||||
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa == fb
|
||||
398
tests/benchmarks/test_random_multimodal_dataset_video.py
Normal file
398
tests/benchmarks/test_random_multimodal_dataset_video.py
Normal file
@@ -0,0 +1,398 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, cast
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import RandomMultiModalDataset, SampleRequest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_tokenizer() -> PreTrainedTokenizerBase:
|
||||
"""Use a small, commonly available tokenizer."""
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_dataset() -> RandomMultiModalDataset:
|
||||
"""Create a RandomMultiModalDataset instance for testing."""
|
||||
return RandomMultiModalDataset(random_seed=42)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_generate_synthetic_video_different_seeds():
|
||||
"""Test that different seeds produce different videos."""
|
||||
dataset1 = RandomMultiModalDataset(random_seed=123)
|
||||
dataset2 = RandomMultiModalDataset(random_seed=456)
|
||||
|
||||
width, height, num_frames = 64, 48, 8
|
||||
|
||||
video1 = dataset1.generate_synthetic_video(width, height, num_frames)
|
||||
video2 = dataset2.generate_synthetic_video(width, height, num_frames)
|
||||
|
||||
# Videos should be different due to different seeds
|
||||
assert video1["bytes"] != video2["bytes"]
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_map_config_to_modality(video_dataset: RandomMultiModalDataset):
|
||||
"""Test modality mapping for different configurations."""
|
||||
# Test image configuration (num_frames = 1)
|
||||
assert video_dataset.map_config_to_modality((256, 256, 1)) == "image"
|
||||
assert video_dataset.map_config_to_modality((720, 1280, 1)) == "image"
|
||||
|
||||
# Test video configurations (num_frames > 1)
|
||||
assert video_dataset.map_config_to_modality((256, 256, 8)) == "video"
|
||||
assert video_dataset.map_config_to_modality((720, 1280, 16)) == "video"
|
||||
assert video_dataset.map_config_to_modality((64, 64, 32)) == "video"
|
||||
|
||||
# Test invalid configurations
|
||||
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
|
||||
video_dataset.map_config_to_modality((256, 256, 0))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
|
||||
video_dataset.map_config_to_modality((256, 256, -1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_generate_mm_item_video(video_dataset: RandomMultiModalDataset):
|
||||
"""Test generating multimodal items for video configurations."""
|
||||
# Test video item generation
|
||||
video_config = (64, 48, 8) # height, width, num_frames
|
||||
result = video_dataset.generate_mm_item(video_config)
|
||||
|
||||
# Check the result structure matches OpenAI API format
|
||||
assert isinstance(result, dict)
|
||||
assert result["type"] == "video_url"
|
||||
assert "video_url" in result
|
||||
assert "url" in result["video_url"]
|
||||
|
||||
# Check that the URL is a data URL with base64 encoded video
|
||||
url = result["video_url"]["url"]
|
||||
assert url.startswith("data:video/mp4;base64,")
|
||||
|
||||
# Decode and verify the video content
|
||||
base64_data = url.split(",")[1]
|
||||
video_bytes = base64.b64decode(base64_data)
|
||||
assert len(video_bytes) > 0
|
||||
|
||||
# Verify the video can be decoded
|
||||
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
temp_file.write(video_bytes)
|
||||
|
||||
try:
|
||||
cap = cv2.VideoCapture(temp_path)
|
||||
assert cap.isOpened()
|
||||
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
assert frame_count == 8
|
||||
assert frame_width == 48
|
||||
assert frame_height == 64
|
||||
|
||||
cap.release()
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_generate_mm_item_image(video_dataset: RandomMultiModalDataset):
|
||||
"""Test generating multimodal items for image configurations."""
|
||||
# Test image item generation
|
||||
image_config = (64, 48, 1) # height, width, num_frames=1
|
||||
result = video_dataset.generate_mm_item(image_config)
|
||||
|
||||
# Check the result structure matches OpenAI API format
|
||||
assert isinstance(result, dict)
|
||||
assert result["type"] == "image_url"
|
||||
assert "image_url" in result
|
||||
assert "url" in result["image_url"]
|
||||
|
||||
# Check that the URL is a data URL with base64 encoded image
|
||||
url = result["image_url"]["url"]
|
||||
assert url.startswith("data:image/jpeg;base64,")
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_generate_mm_item_invalid_config(video_dataset: RandomMultiModalDataset):
|
||||
"""Test error handling for invalid configurations."""
|
||||
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
|
||||
video_dataset.generate_mm_item((256, 256, 0))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_sample_with_video_buckets(
|
||||
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
|
||||
):
|
||||
"""Test sampling with video bucket configurations."""
|
||||
# Configure bucket with video probability > 0
|
||||
bucket_config = {
|
||||
(64, 64, 1): 0.3, # Images
|
||||
(64, 64, 8): 0.7, # Videos
|
||||
}
|
||||
|
||||
limit_mm_per_prompt = {"image": 5, "video": 3}
|
||||
|
||||
samples = video_dataset.sample(
|
||||
tokenizer=hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
input_len=20,
|
||||
output_len=5,
|
||||
)
|
||||
|
||||
assert len(samples) == 5
|
||||
|
||||
# Check that samples contain both images and videos
|
||||
video_count = 0
|
||||
image_count = 0
|
||||
|
||||
for sample in samples:
|
||||
assert isinstance(sample, SampleRequest)
|
||||
assert sample.multi_modal_data is not None
|
||||
assert isinstance(sample.multi_modal_data, list)
|
||||
|
||||
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
|
||||
assert len(mm_data) == 2 # base_items_per_request
|
||||
|
||||
for item in mm_data:
|
||||
if item["type"] == "video_url":
|
||||
video_count += 1
|
||||
# Verify video URL format
|
||||
url = item["video_url"]["url"]
|
||||
assert url.startswith("data:video/mp4;base64,")
|
||||
elif item["type"] == "image_url":
|
||||
image_count += 1
|
||||
# Verify image URL format
|
||||
url = item["image_url"]["url"]
|
||||
assert url.startswith("data:image/jpeg;base64,")
|
||||
|
||||
# Should have some videos due to 0.7 probability
|
||||
assert video_count > 0
|
||||
assert image_count > 0
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_sample_video_only_buckets(
|
||||
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
|
||||
):
|
||||
"""Test sampling with only video buckets."""
|
||||
bucket_config = {
|
||||
(64, 64, 8): 1.0, # Only videos
|
||||
}
|
||||
|
||||
limit_mm_per_prompt = {"image": 0, "video": 2}
|
||||
|
||||
samples = video_dataset.sample(
|
||||
tokenizer=hf_tokenizer,
|
||||
num_requests=3,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
input_len=20,
|
||||
output_len=5,
|
||||
)
|
||||
|
||||
assert len(samples) == 3
|
||||
|
||||
for sample in samples:
|
||||
assert isinstance(sample, SampleRequest)
|
||||
assert sample.multi_modal_data is not None
|
||||
assert isinstance(sample.multi_modal_data, list)
|
||||
|
||||
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
|
||||
assert len(mm_data) == 1
|
||||
|
||||
item = mm_data[0]
|
||||
assert item["type"] == "video_url"
|
||||
url = item["video_url"]["url"]
|
||||
assert url.startswith("data:video/mp4;base64,")
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_sample_respects_video_limits(
|
||||
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
|
||||
):
|
||||
"""Test that sampling respects video limits per prompt."""
|
||||
bucket_config = {
|
||||
(64, 64, 8): 1.0, # Only videos
|
||||
}
|
||||
|
||||
# Set very low video limit
|
||||
limit_mm_per_prompt = {"image": 0, "video": 1}
|
||||
|
||||
samples = video_dataset.sample(
|
||||
tokenizer=hf_tokenizer,
|
||||
num_requests=3,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
input_len=20,
|
||||
output_len=5,
|
||||
)
|
||||
|
||||
assert len(samples) == 3
|
||||
|
||||
for sample in samples:
|
||||
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
|
||||
assert len(mm_data) <= 1 # Should respect video limit
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_sample_mixed_buckets_with_zero_probability(
|
||||
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
|
||||
):
|
||||
"""Test sampling with mixed buckets including zero probability entries."""
|
||||
bucket_config = {
|
||||
(64, 64, 1): 0.5, # Images
|
||||
(64, 64, 8): 0.5, # Videos
|
||||
(128, 128, 16): 0.0, # Zero probability videos (should be ignored)
|
||||
}
|
||||
|
||||
limit_mm_per_prompt = {"image": 2, "video": 2}
|
||||
|
||||
samples = video_dataset.sample(
|
||||
tokenizer=hf_tokenizer,
|
||||
num_requests=4,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
input_len=20,
|
||||
output_len=5,
|
||||
)
|
||||
|
||||
assert len(samples) == 4
|
||||
|
||||
# Should only see 64x64 videos, not 128x128 videos
|
||||
for sample in samples:
|
||||
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
|
||||
for item in mm_data:
|
||||
if item["type"] == "video_url":
|
||||
# Decode video to verify dimensions
|
||||
url = item["video_url"]["url"]
|
||||
base64_data = url.split(",")[1]
|
||||
video_bytes = base64.b64decode(base64_data)
|
||||
|
||||
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: # noqa
|
||||
temp_path = temp_file.name
|
||||
temp_file.write(video_bytes)
|
||||
|
||||
try:
|
||||
cap = cv2.VideoCapture(temp_path)
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
# Should be 64x64, not 128x128
|
||||
assert frame_width == 64
|
||||
assert frame_height == 64
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_sample_deterministic_with_videos(hf_tokenizer: PreTrainedTokenizerBase):
|
||||
"""Test that sampling with videos is deterministic with same seed."""
|
||||
dataset1 = RandomMultiModalDataset(random_seed=123)
|
||||
dataset2 = RandomMultiModalDataset(random_seed=123)
|
||||
|
||||
bucket_config = {
|
||||
(64, 64, 1): 0.3, # Images
|
||||
(64, 64, 8): 0.7, # Videos
|
||||
}
|
||||
|
||||
limit_mm_per_prompt = {"image": 2, "video": 2}
|
||||
|
||||
samples1 = dataset1.sample(
|
||||
tokenizer=hf_tokenizer,
|
||||
num_requests=3,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
input_len=20,
|
||||
output_len=5,
|
||||
)
|
||||
|
||||
samples2 = dataset2.sample(
|
||||
tokenizer=hf_tokenizer,
|
||||
num_requests=3,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
input_len=20,
|
||||
output_len=5,
|
||||
)
|
||||
|
||||
assert len(samples1) == len(samples2)
|
||||
|
||||
# Compare multimodal data
|
||||
for s1, s2 in zip(samples1, samples2):
|
||||
assert s1.multi_modal_data == s2.multi_modal_data
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_sample_different_seeds_produce_different_videos(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
):
|
||||
"""Test that different seeds produce different video content."""
|
||||
dataset1 = RandomMultiModalDataset(random_seed=123)
|
||||
dataset2 = RandomMultiModalDataset(random_seed=456)
|
||||
|
||||
bucket_config = {
|
||||
(64, 64, 8): 1.0, # Only videos
|
||||
}
|
||||
|
||||
limit_mm_per_prompt = {"image": 0, "video": 1}
|
||||
|
||||
samples1 = dataset1.sample(
|
||||
tokenizer=hf_tokenizer,
|
||||
num_requests=2,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
input_len=20,
|
||||
output_len=5,
|
||||
)
|
||||
|
||||
samples2 = dataset2.sample(
|
||||
tokenizer=hf_tokenizer,
|
||||
num_requests=2,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
input_len=20,
|
||||
output_len=5,
|
||||
)
|
||||
|
||||
# Video content should be different
|
||||
for s1, s2 in zip(samples1, samples2):
|
||||
mm_data1 = cast(list[dict[str, Any]], s1.multi_modal_data)
|
||||
mm_data2 = cast(list[dict[str, Any]], s2.multi_modal_data)
|
||||
|
||||
assert len(mm_data1) == len(mm_data2) == 1
|
||||
|
||||
url1 = mm_data1[0]["video_url"]["url"]
|
||||
url2 = mm_data2[0]["video_url"]["url"]
|
||||
|
||||
assert url1 != url2 # Different video content
|
||||
77
tests/benchmarks/test_serve_cli.py
Normal file
77
tests/benchmarks/test_serve_cli.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve(server):
|
||||
command = [
|
||||
"vllm",
|
||||
"bench",
|
||||
"serve",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--host",
|
||||
server.host,
|
||||
"--port",
|
||||
str(server.port),
|
||||
"--dataset-name",
|
||||
"random",
|
||||
"--random-input-len",
|
||||
"32",
|
||||
"--random-output-len",
|
||||
"4",
|
||||
"--num-prompts",
|
||||
"5",
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
print(result.stdout)
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve_chat(server):
|
||||
command = [
|
||||
"vllm",
|
||||
"bench",
|
||||
"serve",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--host",
|
||||
server.host,
|
||||
"--port",
|
||||
str(server.port),
|
||||
"--dataset-name",
|
||||
"random",
|
||||
"--random-input-len",
|
||||
"32",
|
||||
"--random-output-len",
|
||||
"4",
|
||||
"--num-prompts",
|
||||
"5",
|
||||
"--endpoint",
|
||||
"/v1/chat/completions",
|
||||
"--backend",
|
||||
"openai-chat",
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
print(result.stdout)
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||
30
tests/benchmarks/test_throughput_cli.py
Normal file
30
tests/benchmarks/test_throughput_cli.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_throughput():
|
||||
command = [
|
||||
"vllm",
|
||||
"bench",
|
||||
"throughput",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--input-len",
|
||||
"32",
|
||||
"--output-len",
|
||||
"1",
|
||||
"--enforce-eager",
|
||||
"--load-format",
|
||||
"dummy",
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
print(result.stdout)
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||
52
tests/ci_envs.py
Normal file
52
tests/ci_envs.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
These envs only work for a small part of the tests, fix what you need!
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.envs import maybe_convert_bool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
VLLM_CI_NO_SKIP: bool = False
|
||||
VLLM_CI_DTYPE: str | None = None
|
||||
VLLM_CI_HEAD_DTYPE: str | None = None
|
||||
VLLM_CI_HF_DTYPE: str | None = None
|
||||
|
||||
environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# A model family has many models with the same architecture.
|
||||
# By default, a model family tests only one model.
|
||||
# Through this flag, all models can be tested.
|
||||
"VLLM_CI_NO_SKIP": lambda: bool(int(os.getenv("VLLM_CI_NO_SKIP", "0"))),
|
||||
# Allow changing the dtype used by vllm in tests
|
||||
"VLLM_CI_DTYPE": lambda: os.getenv("VLLM_CI_DTYPE", None),
|
||||
# Allow changing the head dtype used by vllm in tests
|
||||
"VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None),
|
||||
# Allow changing the head dtype used by transformers in tests
|
||||
"VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None),
|
||||
# Allow control over whether tests use enforce_eager
|
||||
"VLLM_CI_ENFORCE_EAGER": lambda: maybe_convert_bool(
|
||||
os.getenv("VLLM_CI_ENFORCE_EAGER", None)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
# lazy evaluation of environment variables
|
||||
if name in environment_variables:
|
||||
return environment_variables[name]()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(environment_variables.keys())
|
||||
|
||||
|
||||
def is_set(name: str):
|
||||
"""Check if an environment variable is explicitly set."""
|
||||
if name in environment_variables:
|
||||
return name in os.environ
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
5
tests/compile/README.md
Normal file
5
tests/compile/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# compile test folder structure
|
||||
|
||||
- `compile/test_*.py` : various unit tests meant for testing particular code path/features. Future tests are most likely added here. New test files added here will be included in CI automatically
|
||||
- `compile/fullgraph/` : full model tests, including all tests previously in compile/piecewise. These tests do not target particular features. New test files added here will be included in CI automatically
|
||||
- `compile/distributed/` : tests that require multiple GPUs. New test files added here will **NOT** be included in CI automatically as these tests generally need to be manually configured to run in runners with particular number/type of GPUs.
|
||||
111
tests/compile/backend.py
Normal file
111
tests/compile/backend.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import weakref
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
|
||||
import depyf
|
||||
from torch import fx
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx._utils import lazy_format_graph_code
|
||||
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.compilation.pass_manager import with_pattern_match_debug
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger("vllm.tests.compile.backend")
|
||||
|
||||
|
||||
class LazyInitPass(InductorPass):
|
||||
"""
|
||||
If there's a pass that we want to initialize lazily in a test,
|
||||
we can wrap it in LazyInitPass, which will initialize the pass when invoked
|
||||
and then immediately invoke it.
|
||||
"""
|
||||
|
||||
def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
|
||||
self.pass_cls = pass_cls
|
||||
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
|
||||
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.pass_ = self.pass_cls(self.vllm_config)
|
||||
self.pass_(graph)
|
||||
|
||||
|
||||
class TestBackend:
|
||||
"""
|
||||
This class provides a simple Inductor backend that can be used for testing.
|
||||
It takes a list of custom passes and runs them after Inductor's passes.
|
||||
It also saves the graph before and after the custom passes for inspection.
|
||||
|
||||
Inductor config can be modified directly by editing the inductor_config
|
||||
property. This can be helpful for adding passes like the
|
||||
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
|
||||
Inductor config is default-initialized from VllmConfig.CompilationConfig.
|
||||
"""
|
||||
|
||||
def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
|
||||
self.custom_passes = list(passes)
|
||||
vllm_config = get_current_vllm_config()
|
||||
compile_config = vllm_config.compilation_config
|
||||
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
|
||||
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
|
||||
self.inductor_config["force_disable_caches"] = True
|
||||
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
|
||||
|
||||
if debug_dump_path := vllm_config.compile_debug_dump_path():
|
||||
logger.debug("Dumping depyf output to %s", debug_dump_path)
|
||||
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
|
||||
else:
|
||||
self.debug_ctx = nullcontext()
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||
self.graph_pre_compile = deepcopy(graph)
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
with self.debug_ctx:
|
||||
return compile_fx(
|
||||
graph, example_inputs, config_patches=self.inductor_config
|
||||
)
|
||||
|
||||
@with_pattern_match_debug
|
||||
def post_pass(self, graph: fx.Graph):
|
||||
self.graph_pre_pass = deepcopy(graph)
|
||||
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
|
||||
|
||||
VllmInductorPass.dump_prefix = 0
|
||||
for pass_ in self.custom_passes:
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
VllmInductorPass.dump_prefix = None
|
||||
|
||||
self.graph_post_pass = deepcopy(graph)
|
||||
lazy_format_graph_code("graph_post_pass", graph.owning_module)
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
self.final_graph = graph
|
||||
|
||||
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
|
||||
for op in ops:
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
||||
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
||||
if fully_replaced:
|
||||
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
|
||||
|
||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
||||
for op in ops:
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
|
||||
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
|
||||
|
||||
def op_count(self, op: OpOverload, before=False) -> int:
|
||||
graph = self.graph_pre_pass if before else self.graph_post_pass
|
||||
return len(list(find_op_nodes(op, graph)))
|
||||
437
tests/compile/distributed/test_async_tp.py
Normal file
437
tests/compile/distributed/test_async_tp.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.collective_fusion import AsyncTPPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
from ...models.registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import (
|
||||
compare_two_settings,
|
||||
create_new_process_for_each_test,
|
||||
multi_gpu_test,
|
||||
)
|
||||
from ..backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
class TestMMRSModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
|
||||
)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
Forward pass implementing the mm + reduce scatter in the FX graph
|
||||
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
|
||||
# matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(mm, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGMMModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty((hidden_size, hidden_size)), requires_grad=False
|
||||
)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.weight, std=0.02)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
Forward pass implementing the mm + all gather in the FX graph
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_gather = tensor_model_parallel_all_gather(view, dim=0)
|
||||
permute = self.weight.permute(1, 0)
|
||||
mm = torch.mm(all_gather, permute)
|
||||
return mm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_gather.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
|
||||
|
||||
|
||||
class _BaseScaledMMModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.weight = (
|
||||
torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
# Initialize scale_b for _scaled_mm.
|
||||
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
|
||||
|
||||
|
||||
class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
|
||||
|
||||
"""
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||
scaled_mm = torch._scaled_mm(
|
||||
fp8_input,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the all gather + scaled_mm in the FX graph
|
||||
"""
|
||||
# Reshape input
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||
|
||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||
scaled_mm = torch._scaled_mm(
|
||||
all_gather,
|
||||
self.weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=self.scale_b,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
return scaled_mm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_gather.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
|
||||
|
||||
|
||||
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the cutlass_scaled_mm + reduce scatter
|
||||
in the FX graph
|
||||
|
||||
"""
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
|
||||
mm_out = torch.empty(
|
||||
(fp8_input.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
torch.ops._C.cutlass_scaled_mm(
|
||||
mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
|
||||
)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
def forward(self, input: torch.Tensor):
|
||||
"""
|
||||
Forward pass implementing the all gather + cutlass_scaled_mm
|
||||
in the FX graph
|
||||
"""
|
||||
# Reshape input
|
||||
fp8_input = input.to(FP8_DTYPE)
|
||||
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
|
||||
|
||||
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
|
||||
|
||||
mm_out = torch.empty(
|
||||
(all_gather.shape[0], self.weight.shape[1]),
|
||||
dtype=self.dtype,
|
||||
device=all_gather.device,
|
||||
)
|
||||
torch.ops._C.cutlass_scaled_mm(
|
||||
mm_out, all_gather, self.weight, scale_a, self.scale_b, None
|
||||
)
|
||||
return mm_out
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_gather.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"test_model",
|
||||
[
|
||||
TestMMRSModel,
|
||||
TestAGMMModel,
|
||||
TestScaledMMRSModel,
|
||||
TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel,
|
||||
TestAGCutlassScaledMMModel,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dynamic", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_async_tp_pass_replace(
|
||||
test_model: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
dynamic: bool,
|
||||
):
|
||||
if (
|
||||
test_model
|
||||
in (
|
||||
TestScaledMMRSModel,
|
||||
TestAGScaledMMModel,
|
||||
TestCutlassScaledMMRSModel,
|
||||
TestAGCutlassScaledMMModel,
|
||||
)
|
||||
and dtype == torch.float16
|
||||
):
|
||||
pytest.skip(
|
||||
"Only bf16 high precision output types are supported for "
|
||||
"per-token (row-wise) scaling"
|
||||
)
|
||||
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(
|
||||
num_processes,
|
||||
test_model,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
dynamic,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def async_tp_pass_on_test_model(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
test_model_cls: torch.nn.Module,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
dynamic: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(
|
||||
fuse_gemm_comms=True,
|
||||
),
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
assert (
|
||||
async_tp_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
async_tp_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
|
||||
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||
|
||||
hidden_states = torch.randn(
|
||||
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||
)
|
||||
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
assert async_tp_pass.matched_count == 1
|
||||
|
||||
# In pre-nodes, all gather or reduce scatter should exist,
|
||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
|
||||
# In post-nodes, fused_matmul_reduce_scatter or \
|
||||
# fused_all_gather_matmul should exist
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
def test_async_tp_pass_correctness(
|
||||
model_id: str,
|
||||
tp_size: int,
|
||||
async_tp_enabled: bool,
|
||||
distributed_backend: str,
|
||||
eager_mode: bool,
|
||||
num_gpus_available: int,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
pp_size = 1
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
|
||||
common_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
]
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
|
||||
compilation_config = {
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_sizes": [2, 4, 8],
|
||||
"splitting_ops": [],
|
||||
"pass_config": {"fuse_gemm_comms": async_tp_enabled},
|
||||
}
|
||||
|
||||
async_tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
"--compilation_config",
|
||||
json.dumps(compilation_config),
|
||||
]
|
||||
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
|
||||
compare_two_settings(model_id, async_tp_args, tp_args, method="generate")
|
||||
332
tests/compile/distributed/test_fusion_all_reduce.py
Normal file
332
tests/compile/distributed/test_fusion_all_reduce.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from importlib.util import find_spec
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
from ...utils import has_module_attribute, multi_gpu_test
|
||||
from ..backend import TestBackend
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(x)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = torch.mm(y, self.w[0])
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = torch.mm(y2, self.w[1])
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
|
||||
y3, resid = self.norm[2](x3, resid)
|
||||
|
||||
z4 = torch.mm(y3, self.w[2])
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size)
|
||||
.to(dtype=current_platform.fp8_dtype())
|
||||
.t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default
|
||||
if self.fp8_linear.quant_fp8.enabled()
|
||||
else torch.ops.aten.reciprocal.default,
|
||||
]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]
|
||||
|
||||
wq_gen, wscale_gen = zip(
|
||||
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
|
||||
)
|
||||
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
|
||||
print(f"{self.wq=}, {self.wscale=}")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
|
||||
z2 = cutlass_scaled_fp4_mm(
|
||||
yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
|
||||
)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
|
||||
z3 = cutlass_scaled_fp4_mm(
|
||||
yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
|
||||
)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
|
||||
z4 = cutlass_scaled_fp4_mm(
|
||||
yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
|
||||
)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.scaled_fp4_quant.default,
|
||||
]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"test_model, enable_quant_fp8_custom_op",
|
||||
[
|
||||
(TestAllReduceRMSNormModel, False),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, True),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, False),
|
||||
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
not find_spec("flashinfer")
|
||||
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
||||
reason="flashinfer is not found or flashinfer "
|
||||
"is not compiled with trtllm_allreduce_fusion",
|
||||
)
|
||||
def test_all_reduce_fusion_pass_replace(
|
||||
test_model: torch.nn.Module,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
):
|
||||
num_processes = 2
|
||||
if (
|
||||
test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
|
||||
and not current_platform.has_device_capability(100)
|
||||
):
|
||||
pytest.skip(
|
||||
"Skip as nvfp4 is only supported on "
|
||||
"devices with compute capability 10.0 (Blackwell)"
|
||||
)
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(
|
||||
num_processes,
|
||||
test_model,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def all_reduce_fusion_pass_on_test_model(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
test_model_cls: torch.nn.Module,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
custom_ops = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_quant_fp8_custom_op:
|
||||
custom_ops.append("+quant_fp8")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
|
||||
)
|
||||
)
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
fuse_allreduce_rms=True, eliminate_noops=True
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(
|
||||
noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
|
||||
)
|
||||
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
|
||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
assert all_reduce_fusion_pass.matched_count == 4, (
|
||||
f"{all_reduce_fusion_pass.matched_count=}"
|
||||
)
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
del all_reduce_fusion_pass
|
||||
580
tests/compile/distributed/test_fusions_e2e.py
Normal file
580
tests/compile/distributed/test_fusions_e2e.py
Normal file
@@ -0,0 +1,580 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
|
||||
from tests.v1.attention.utils import AttentionBackendEnum
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import flat_product, multi_gpu_test
|
||||
|
||||
is_blackwell = lambda: current_platform.is_device_capability_family(100)
|
||||
"""Are we running on Blackwell, a lot of tests depend on it"""
|
||||
|
||||
|
||||
class Matches(NamedTuple):
|
||||
attention_fusion: int = 0
|
||||
allreduce_fusion: int = 0
|
||||
rms_quant_norm_fusion: int = 0
|
||||
sequence_parallel: int = 0
|
||||
async_tp: int = 0
|
||||
|
||||
|
||||
class ModelBackendTestCase(NamedTuple):
|
||||
model_name: str
|
||||
model_kwargs: dict[str, Any]
|
||||
backend: AttentionBackendEnum
|
||||
matches: Matches
|
||||
|
||||
|
||||
MODELS_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS_FP4: list[ModelBackendTestCase] = []
|
||||
MODELS_GROUP_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS: list[ModelBackendTestCase] = [] # tp-only
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
# Use smaller model for L40s in CI
|
||||
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=32,
|
||||
allreduce_fusion=65,
|
||||
sequence_parallel=65,
|
||||
async_tp=128,
|
||||
),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
# TODO FlashInfer attn broken on Hopper with kvcache=fp8:
|
||||
# https://github.com/vllm-project/vllm/issues/28568
|
||||
backend=AttentionBackendEnum.FLASHINFER
|
||||
if is_blackwell()
|
||||
else AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=48,
|
||||
allreduce_fusion=96,
|
||||
sequence_parallel=96,
|
||||
async_tp=95, # mlp is moe, no fusion there
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
MODELS_FP4 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
matches=Matches(
|
||||
attention_fusion=32,
|
||||
allreduce_fusion=65,
|
||||
sequence_parallel=65,
|
||||
async_tp=128,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
# TP only
|
||||
MODELS = [
|
||||
ModelBackendTestCase(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=0,
|
||||
allreduce_fusion=65,
|
||||
sequence_parallel=65,
|
||||
async_tp=128,
|
||||
),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="Qwen/Qwen3-30B-A3B",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=0,
|
||||
allreduce_fusion=97,
|
||||
sequence_parallel=97,
|
||||
async_tp=96, # MLP is MoE, half the fusions of dense
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
MODELS_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(attention_fusion=32),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.ROCM_ATTN,
|
||||
matches=Matches(attention_fusion=32),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
matches=Matches(attention_fusion=32),
|
||||
),
|
||||
]
|
||||
|
||||
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
|
||||
|
||||
def has_cuda_graph_wrapper_metadata() -> bool:
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
module = import_module("torch._inductor.utils")
|
||||
module.CUDAGraphWrapperMetadata # noqa B018
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"inductor_graph_partition",
|
||||
[
|
||||
pytest.param(
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not has_cuda_graph_wrapper_metadata(),
|
||||
reason="This test requires"
|
||||
"torch._inductor.utils.CUDAGraphWrapperMetadata to run",
|
||||
),
|
||||
),
|
||||
False,
|
||||
],
|
||||
)
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||
not is_blackwell() or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
# FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
|
||||
# CUDAGraphMode.NONE here because it derives an attention backend that
|
||||
# does not support full cudagraphs
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
custom_ops=custom_ops_list,
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(compilation_config, model_name, **model_kwargs)
|
||||
|
||||
log_matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 1, log_holder.text
|
||||
assert int(log_matches[0]) == matches.attention_fusion
|
||||
|
||||
|
||||
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
|
||||
|
||||
|
||||
def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
||||
for op_list in itertools.product(*custom_ops_lists):
|
||||
yield ",".join(op_list)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||
list(
|
||||
flat_product(
|
||||
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
|
||||
)
|
||||
)
|
||||
# Toggle RMSNorm for FP4 models and unquant models
|
||||
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda()
|
||||
or not has_flashinfer()
|
||||
or not current_platform.has_device_capability(90),
|
||||
reason="allreduce+rmsnorm fusion requires flashinfer",
|
||||
)
|
||||
def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
if "fp4" in model_name.lower() and not is_blackwell():
|
||||
pytest.skip("NVFP4 quant requires Blackwell")
|
||||
|
||||
if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
|
||||
# FlashInfer attn fusion requires Blackwell
|
||||
matches = matches._replace(attention_fusion=0)
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
custom_ops=custom_ops_list,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
fuse_attn_quant=True,
|
||||
eliminate_noops=True,
|
||||
fuse_allreduce_rms=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(
|
||||
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||
)
|
||||
log_matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
# 2 for each compile range
|
||||
# (global compile range can be split due to fuse_allreduce_rmsnorm)
|
||||
num_compile_ranges = len(compilation_config.get_compile_ranges())
|
||||
assert num_compile_ranges in [1, 2]
|
||||
|
||||
assert len(log_matches) == 2 * num_compile_ranges, log_holder.text
|
||||
|
||||
assert all(int(log_match) == matches.attention_fusion for log_match in log_matches)
|
||||
|
||||
log_matches = re.findall(
|
||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.allreduce_fusion
|
||||
assert int(log_matches[1]) == matches.allreduce_fusion
|
||||
|
||||
log_matches = re.findall(
|
||||
r"pass_manager.py:\d+] Skipping .*AllReduceFusionPass.* with compile range",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2 * (num_compile_ranges - 1), log_holder.text
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||
list(
|
||||
flat_product(
|
||||
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
|
||||
)
|
||||
)
|
||||
# Toggle RMSNorm for FP4 models and unquant models
|
||||
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="sequence parallel only tested on CUDA",
|
||||
)
|
||||
def test_tp2_attn_quant_async_tp(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if is_blackwell():
|
||||
# TODO: https://github.com/vllm-project/vllm/issues/27893
|
||||
pytest.skip("Blackwell is not supported for AsyncTP pass")
|
||||
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
if "fp4" in model_name.lower() and not is_blackwell():
|
||||
pytest.skip("NVFP4 quant requires Blackwell")
|
||||
|
||||
if backend == AttentionBackendEnum.FLASHINFER:
|
||||
if not has_flashinfer():
|
||||
pytest.skip("FlashInfer backend requires flashinfer installed")
|
||||
if not is_blackwell():
|
||||
# FlashInfer attn fusion requires Blackwell
|
||||
matches = matches._replace(attention_fusion=0)
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
custom_ops=custom_ops_list,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
fuse_attn_quant=True,
|
||||
eliminate_noops=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(
|
||||
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||
)
|
||||
log_matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.attention_fusion
|
||||
assert int(log_matches[1]) == matches.attention_fusion
|
||||
|
||||
log_matches = re.findall(
|
||||
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.sequence_parallel
|
||||
assert int(log_matches[1]) == matches.sequence_parallel
|
||||
|
||||
log_matches = re.findall(
|
||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.async_tp
|
||||
assert int(log_matches[1]) == matches.async_tp
|
||||
|
||||
|
||||
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
||||
compilation_config = (
|
||||
compile_config
|
||||
if isinstance(compile_config, CompilationConfig)
|
||||
else CompilationConfig(mode=compile_config)
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
# Allow override from model_kwargs
|
||||
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
|
||||
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
|
||||
|
||||
# No cudagraphs by default
|
||||
if compilation_config.cudagraph_mode is None:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
llm = LLM(
|
||||
model=model,
|
||||
compilation_config=compilation_config,
|
||||
**model_kwargs,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Get the compile ranges split points after vllm config post init
|
||||
# in order to compute compile ranges correctly
|
||||
compilation_config.compile_ranges_split_points = (
|
||||
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
|
||||
)
|
||||
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS_GROUP_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="Qwen/Qwen3-30B-A3B-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
rms_quant_norm_fusion=48,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test rms norm+group quant_fp8 fusion
|
||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
# TODO: remove skip after we fix the fusion thoroughly
|
||||
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
|
||||
def test_rms_group_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
custom_ops=custom_ops_list,
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True, fuse_norm_quant=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(compilation_config, model_name, **model_kwargs)
|
||||
|
||||
log_matches = re.findall(
|
||||
r"\[fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 1, log_holder.text
|
||||
assert int(log_matches[0]) == matches.rms_quant_norm_fusion
|
||||
331
tests/compile/distributed/test_sequence_parallelism.py
Normal file
331
tests/compile/distributed/test_sequence_parallelism.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CUDAGraphMode,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from ..backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def forward(self, x):
|
||||
z = torch.relu(x)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = torch.mm(y, self.w[0])
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = torch.mm(y2, self.w[1])
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
|
||||
y3, resid = self.norm[2](x3, resid)
|
||||
|
||||
z4 = torch.mm(y3, self.w[2])
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.all_gather.default,
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
if RMSNorm.enabled():
|
||||
return [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size)
|
||||
.to(dtype=current_platform.fp8_dtype())
|
||||
.t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.all_gather.default,
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||
elif RMSNorm.enabled():
|
||||
return [
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
]
|
||||
elif self.fp8_linear.quant_fp8.enabled():
|
||||
return [
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"test_model_cls, custom_ops",
|
||||
[
|
||||
(TestAllReduceRMSNormModel, "+rms_norm"),
|
||||
(TestAllReduceRMSNormModel, "-rms_norm"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
|
||||
@pytest.mark.parametrize("dynamic", [False, True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(
|
||||
test_model_cls: type[torch.nn.Module],
|
||||
custom_ops: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
fuse_norm_quant: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(
|
||||
num_processes,
|
||||
test_model_cls,
|
||||
custom_ops,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
fuse_norm_quant,
|
||||
dynamic,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def sequence_parallelism_pass_on_test_model(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
test_model_cls: type[torch.nn.Module],
|
||||
custom_ops: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
fuse_norm_quant: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
compilation_config = CompilationConfig(
|
||||
splitting_ops=[], # avoid automatic rms_norm enablement
|
||||
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
|
||||
custom_ops=custom_ops_list,
|
||||
pass_config=PassConfig(
|
||||
enable_sp=True,
|
||||
fuse_norm_quant=fuse_norm_quant,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
) # NoOp needed for fusion
|
||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
|
||||
model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
passes_for_backend: list[VllmInductorPass] = [
|
||||
noop_pass,
|
||||
sequence_parallelism_pass,
|
||||
]
|
||||
|
||||
if fuse_norm_quant:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
|
||||
backend = TestBackend(*passes_for_backend)
|
||||
|
||||
model = test_model_cls(hidden_size)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
assert sequence_parallelism_pass.matched_count == 4
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
for op in model.ops_in_model_before():
|
||||
assert backend.op_count(op, before=True) == 4
|
||||
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
for op in model.ops_in_model_after():
|
||||
assert backend.op_count(op, before=False) == 4
|
||||
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend.graph_post_pass.nodes, op)
|
||||
155
tests/compile/fullgraph/test_basic_correctness.py
Normal file
155
tests/compile/fullgraph/test_basic_correctness.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from ...utils import compare_all_settings
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestSetting:
|
||||
model: str
|
||||
model_args: list[str]
|
||||
pp_size: int
|
||||
tp_size: int
|
||||
attn_backend: str
|
||||
method: str
|
||||
|
||||
|
||||
# we cannot afford testing the full Cartesian product
|
||||
# of all models and all modes
|
||||
@pytest.mark.parametrize(
|
||||
"test_setting",
|
||||
[
|
||||
# basic llama model
|
||||
TestSetting(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
model_args=["--max-model-len", "2048"],
|
||||
pp_size=2,
|
||||
tp_size=2,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="generate",
|
||||
),
|
||||
# llama model with quantization
|
||||
TestSetting(
|
||||
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
model_args=["--quantization", "gptq", "--max-model-len", "2048"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="generate",
|
||||
),
|
||||
# MoE model
|
||||
TestSetting(
|
||||
model="ibm/PowerMoE-3b",
|
||||
model_args=["--max-model-len", "2048"],
|
||||
pp_size=1,
|
||||
tp_size=2,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="generate",
|
||||
),
|
||||
# embedding model
|
||||
TestSetting(
|
||||
model="BAAI/bge-multilingual-gemma2",
|
||||
model_args=[
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="encode",
|
||||
),
|
||||
TestSetting(
|
||||
model="BAAI/bge-base-en-v1.5",
|
||||
model_args=["--runner", "pooling"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="encode",
|
||||
),
|
||||
# vision language model
|
||||
# See https://github.com/vllm-project/vllm/issues/26716.
|
||||
# TestSetting(
|
||||
# model="microsoft/Phi-3.5-vision-instruct",
|
||||
# model_args=["--trust-remote-code", "--max-model-len", "2048"],
|
||||
# pp_size=2,
|
||||
# tp_size=1,
|
||||
# attn_backend="FLASH_ATTN",
|
||||
# method="generate_with_image",
|
||||
# ),
|
||||
],
|
||||
)
|
||||
def test_compile_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_setting: TestSetting,
|
||||
):
|
||||
# this test is run under multiple suits, with different GPUs.
|
||||
# make sure we only run the test with correct CUDA devices.
|
||||
# don't use "<", as it will duplicate the tests.
|
||||
model = test_setting.model
|
||||
model_args = test_setting.model_args
|
||||
pp_size = test_setting.pp_size
|
||||
tp_size = test_setting.tp_size
|
||||
attn_backend = test_setting.attn_backend
|
||||
method = test_setting.method
|
||||
if cuda_device_count_stateless() < pp_size * tp_size:
|
||||
pytest.skip(
|
||||
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
||||
f"{cuda_device_count_stateless()}"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
final_args = [
|
||||
*model_args,
|
||||
"-pp",
|
||||
str(pp_size),
|
||||
"-tp",
|
||||
str(tp_size),
|
||||
"-cc.cudagraph_mode=none",
|
||||
]
|
||||
|
||||
all_args: list[list[str]] = []
|
||||
all_envs: list[dict[str, str] | None] = []
|
||||
|
||||
for comp_mode in [
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
for mode in [CompilationMode.NONE, comp_mode]:
|
||||
all_args.append(
|
||||
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
|
||||
)
|
||||
|
||||
# inductor will change the output, so we only compare if the output
|
||||
# is close, not exactly the same.
|
||||
compare_all_settings(
|
||||
model,
|
||||
all_args,
|
||||
all_envs,
|
||||
method=method if method != "generate" else "generate_close",
|
||||
)
|
||||
all_envs.clear()
|
||||
all_args.clear()
|
||||
|
||||
for mode in [
|
||||
CompilationMode.NONE,
|
||||
CompilationMode.STOCK_TORCH_COMPILE,
|
||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||
CompilationMode.VLLM_COMPILE,
|
||||
]:
|
||||
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
|
||||
all_envs.append({})
|
||||
all_envs.append({})
|
||||
|
||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||
185
tests/compile/fullgraph/test_full_cudagraph.py
Normal file
185
tests/compile/fullgraph/test_full_cudagraph.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporary_environ(env_vars):
|
||||
"""
|
||||
Temporarily set environment variables and restore them afterward.
|
||||
We have to do this vs monkeypatch because monkeypatch doesn't work
|
||||
with "module" scoped fixtures.
|
||||
"""
|
||||
original_env = {k: os.environ.get(k) for k in env_vars}
|
||||
try:
|
||||
os.environ.update(env_vars)
|
||||
yield
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
model_backends_full_cudagraph = []
|
||||
|
||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
||||
for mla_backend in MLA_backends:
|
||||
model_backends_full_cudagraph.append(
|
||||
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])
|
||||
)
|
||||
|
||||
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||
other_backend_configs = [
|
||||
backend_configs[c] for c in backend_configs if c not in MLA_backends
|
||||
]
|
||||
for backend_config in other_backend_configs:
|
||||
model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config))
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def llm_pair(request):
|
||||
model, backend_config, use_inductor_graph_partition = request.param
|
||||
backend_config.comp_config["use_inductor_graph_partition"] = (
|
||||
use_inductor_graph_partition
|
||||
)
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition only supported in torch>=2.9")
|
||||
|
||||
# Dynamically skip test if GPU capability is not met
|
||||
if (
|
||||
backend_config.specific_gpu_arch
|
||||
and backend_config.specific_gpu_arch != current_platform.get_device_capability()
|
||||
):
|
||||
if backend_config.specific_gpu_arch == (9, 0):
|
||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||
elif backend_config.specific_gpu_arch == (10, 0):
|
||||
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
|
||||
|
||||
env_vars = {
|
||||
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
||||
# when per-request generators are not used in V1.
|
||||
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
||||
**backend_config.env_vars,
|
||||
}
|
||||
with temporary_environ(env_vars):
|
||||
full = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.43,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
max_num_seqs=128,
|
||||
compilation_config=CompilationConfig(**backend_config.comp_config),
|
||||
generation_config="vllm",
|
||||
seed=42,
|
||||
)
|
||||
piecewise = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.43,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
max_num_seqs=128,
|
||||
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
|
||||
generation_config="vllm",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# PyTest caches the fixture values so we use weakref.proxy to enable GC
|
||||
yield weakref.proxy(full), weakref.proxy(piecewise)
|
||||
del full
|
||||
del piecewise
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=[0],
|
||||
threshold_ratio=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_pair",
|
||||
[
|
||||
pytest.param((model, backend_config, use_inductor_graph_partition))
|
||||
for model, backend_config in model_backends_full_cudagraph
|
||||
for use_inductor_graph_partition in [True, False]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
class TestFullCUDAGraph:
|
||||
"""
|
||||
Use a class such that an llm pair is constructed once for all
|
||||
batch_size/max_tokens combinations and released immediately after.
|
||||
|
||||
Module-scope fixtures would stick around the whole time,
|
||||
meaning there would be multiple LLM instances hogging memory simultaneously.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("batch_size", "max_tokens"),
|
||||
[
|
||||
(1, 10),
|
||||
(7, 10),
|
||||
(16, 10),
|
||||
(25, 10),
|
||||
(32, 10),
|
||||
(45, 10),
|
||||
(64, 10),
|
||||
(123, 10),
|
||||
(8, 5),
|
||||
(8, 30),
|
||||
],
|
||||
)
|
||||
def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]):
|
||||
"""
|
||||
Test various batch sizes and max_tokens to ensure that the
|
||||
full cudagraph compilation works for padded cases too.
|
||||
"""
|
||||
|
||||
full_cudagraph_llm, piecewise_llm = llm_pair
|
||||
|
||||
prompts = ["the quick brown fox"] * batch_size
|
||||
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
||||
# that can amplify tiny numeric differences across runtimes.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0, max_tokens=max_tokens, top_p=1.0
|
||||
)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
|
||||
# Check that all responses are the same
|
||||
for piecewise_res, full_res in zip(piecewise_responses, full_responses):
|
||||
assert (
|
||||
piecewise_res.outputs[0].text.lower()
|
||||
== full_res.outputs[0].text.lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
with (
|
||||
temporary_environ(
|
||||
{
|
||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
}
|
||||
),
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
|
||||
)
|
||||
250
tests/compile/fullgraph/test_full_graph.py
Normal file
250
tests/compile/fullgraph/test_full_graph.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
def models_list(*, all: bool = True, keywords: list[str] | None = None):
|
||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||
("facebook/opt-125m", {}),
|
||||
(
|
||||
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
|
||||
{"dtype": torch.float16},
|
||||
),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||
]
|
||||
|
||||
if all:
|
||||
TEST_MODELS.extend(
|
||||
[
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||
{"dtype": torch.float16},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# TODO: figure out why this fails.
|
||||
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
||||
TEST_MODELS.append(
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
|
||||
)
|
||||
|
||||
if is_quant_method_supported("gptq"):
|
||||
TEST_MODELS.append(
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
|
||||
)
|
||||
|
||||
if is_quant_method_supported("gptq_marlin"):
|
||||
TEST_MODELS.append(
|
||||
(
|
||||
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
|
||||
{"quantization": "gptq_marlin"},
|
||||
)
|
||||
)
|
||||
|
||||
if is_quant_method_supported("gptq_marlin_24"):
|
||||
TEST_MODELS.append(
|
||||
(
|
||||
"alexm-nm/tinyllama-24-marlin24-4bit-g128",
|
||||
{"quantization": "gptq_marlin_24"},
|
||||
)
|
||||
)
|
||||
|
||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||
TEST_MODELS.append(
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
|
||||
)
|
||||
|
||||
if keywords is None:
|
||||
return TEST_MODELS
|
||||
|
||||
# filter by keywords
|
||||
pred = lambda model: any(keyword in model[0] for keyword in keywords)
|
||||
return list(filter(pred, TEST_MODELS))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_mode",
|
||||
[CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
|
||||
)
|
||||
@pytest.mark.parametrize("model, model_kwargs", models_list(all=True))
|
||||
@create_new_process_for_each_test()
|
||||
def test_full_graph(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
compilation_mode: int,
|
||||
):
|
||||
if (
|
||||
"w8a8" in model
|
||||
or "w8w8" in model
|
||||
and current_platform.has_device_capability((10, 0))
|
||||
):
|
||||
# int8 removed on Blackwell:
|
||||
pytest.skip("int8 support removed on Blackwell")
|
||||
|
||||
with monkeypatch.context():
|
||||
print(f"MODEL={model}")
|
||||
|
||||
run_model(compilation_mode, model, **model_kwargs)
|
||||
|
||||
|
||||
# TODO(luka) add other supported compilation config scenarios here
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_config, model, model_kwargs",
|
||||
[
|
||||
# additional compile sizes, only some of the models
|
||||
(
|
||||
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
|
||||
*model_info,
|
||||
)
|
||||
for model_info in models_list(all=False)
|
||||
]
|
||||
+ [
|
||||
# RMSNorm + quant fusion, only 8-bit quant models
|
||||
(
|
||||
CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||
),
|
||||
),
|
||||
*model_info,
|
||||
)
|
||||
for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||
]
|
||||
+ [
|
||||
# Test depyf integration works
|
||||
(
|
||||
CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
debug_dump_path=Path(tempfile.gettempdir()),
|
||||
),
|
||||
"facebook/opt-125m",
|
||||
{},
|
||||
),
|
||||
]
|
||||
+ [
|
||||
# graph inductor partition
|
||||
(
|
||||
CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
# inductor graph partition uses
|
||||
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
compile_sizes=[1, 2],
|
||||
),
|
||||
*model_info,
|
||||
)
|
||||
for model_info in models_list(all=False)
|
||||
if is_torch_equal_or_newer("2.9.0.dev")
|
||||
],
|
||||
)
|
||||
# only test some of the models
|
||||
@create_new_process_for_each_test()
|
||||
def test_custom_compile_config(
|
||||
compilation_config: CompilationConfig,
|
||||
model: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
):
|
||||
if (
|
||||
"w8a8" in model
|
||||
or "w8w8" in model
|
||||
and current_platform.has_device_capability((10, 0))
|
||||
):
|
||||
# int8 removed on Blackwell:
|
||||
pytest.skip("int8 support removed on Blackwell")
|
||||
|
||||
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||
"2.9.0.dev"
|
||||
):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
print(f"MODEL={model}")
|
||||
run_model(compilation_config, model, **model_kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_mode",
|
||||
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model, backend",
|
||||
[
|
||||
("Qwen/Qwen2-0.5B", None), # Standard attention model
|
||||
(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
), # MLA (Multi-head Latent Attention) model
|
||||
],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
compilation_mode: int,
|
||||
model: str,
|
||||
backend: AttentionBackendEnum | None,
|
||||
):
|
||||
if backend:
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
model_kwargs = {
|
||||
"quantization": "fp8",
|
||||
"kv_cache_dtype": "fp8_e4m3",
|
||||
"calculate_kv_scales": True,
|
||||
"max_model_len": 512,
|
||||
}
|
||||
run_model(compilation_mode, model, **model_kwargs)
|
||||
|
||||
|
||||
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
||||
compilation_config = (
|
||||
compile_config
|
||||
if isinstance(compile_config, CompilationConfig)
|
||||
else CompilationConfig(mode=compile_config)
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
# Allow override from model_kwargs
|
||||
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
|
||||
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
|
||||
|
||||
# No cudagraphs by default
|
||||
if compilation_config.cudagraph_mode is None:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
compilation_config=compilation_config,
|
||||
**model_kwargs,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
73
tests/compile/fullgraph/test_multimodal_compile.py
Normal file
73
tests/compile/fullgraph/test_multimodal_compile.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def test_compile():
|
||||
vllm_config = VllmConfig()
|
||||
# Default configuration does not compile mm encoder
|
||||
assert not vllm_config.compilation_config.compile_mm_encoder
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
|
||||
"""Test that Qwen2.5-VL vision submodules are compiled.
|
||||
|
||||
This test verifies that the 3 vision submodules (Qwen2_5_VisionPatchEmbed,
|
||||
Qwen2_5_VisionBlock, and Qwen2_5_VisionPatchMerger) are properly tagged
|
||||
for compilation by checking that num_models_seen increases by at least 3.
|
||||
"""
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
# NOTE: Qwen2.5-VL has 35 models in total - the LLM backend
|
||||
# Vision Patch Embed, Vision Patch Merger, and then 32 Vision Blocks
|
||||
# (one for each layer) - in the future, we should fix vLLM compilation
|
||||
# logic to handle this case and only compile the Vision submodules once
|
||||
# and reuse the compiled code for all layers
|
||||
# See https://github.com/vllm-project/vllm/issues/27590
|
||||
compilation_counter.expect(num_models_seen=35),
|
||||
vllm_runner(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
max_model_len=2048,
|
||||
gpu_memory_utilization=0.8,
|
||||
compilation_config={
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_mm_encoder": True,
|
||||
},
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_qwen2_5_vl_no_vit_compilation(vllm_runner, monkeypatch):
|
||||
"""Test that Qwen2.5-VL vision submodules are not compiled when the
|
||||
config is passed off
|
||||
"""
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
compilation_counter.expect(num_models_seen=1),
|
||||
vllm_runner(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
max_model_len=2048,
|
||||
gpu_memory_utilization=0.8,
|
||||
compilation_config={
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_mm_encoder": False,
|
||||
},
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
326
tests/compile/fullgraph/test_multiple_graphs.py
Normal file
326
tests/compile/fullgraph/test_multiple_graphs.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test (piecewise) compilation with a simple model where multiple submodules
|
||||
are compiled and graph captured separately.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from .. import silly_attention # noqa: F401
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
HIDDEN_SIZE = 1024
|
||||
RANDOM_SEED = 0
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class ParentModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, mlp_size: int, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
|
||||
self.post_attn = nn.Linear(hidden_size, mlp_size, bias=False)
|
||||
self.rms_norm_weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
# Initialize to same weights for testing
|
||||
nn.init.xavier_normal_(
|
||||
self.pre_attn.weight.data,
|
||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||
gain=0.001,
|
||||
)
|
||||
nn.init.xavier_normal_(
|
||||
self.post_attn.weight.data,
|
||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||
gain=0.001,
|
||||
)
|
||||
|
||||
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_f32 = x.float()
|
||||
return (
|
||||
x_f32
|
||||
* torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
|
||||
* self.rms_norm_weight
|
||||
).to(x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.pre_attn(x)
|
||||
x = self.rms_norm_ref(x)
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = self.rms_norm_ref(x)
|
||||
x = self.post_attn(x)
|
||||
return x
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CompiledAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.attn = Attention(mlp_size, hidden_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.attn(x)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CompiledAttentionTwo(CompiledAttention):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.attn(x) + x
|
||||
|
||||
|
||||
@ignore_torch_compile
|
||||
class SimpleModelWithTwoGraphs(ParentModel):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
# Test will fail without set_model_tag here with error:
|
||||
# "ValueError: too many values to unpack (expected 3)"
|
||||
# This is because CompiledAttention and CompiledAttentionTwo
|
||||
# have different implementations but the same torch.compile
|
||||
# cache dir will be used as default prefix is 'model_tag'
|
||||
with set_model_tag("attn_one"):
|
||||
self.attn_one = CompiledAttention(
|
||||
mlp_size=mlp_size,
|
||||
hidden_size=hidden_size,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.attn_one",
|
||||
)
|
||||
with set_model_tag("attn_two"):
|
||||
self.attn_two = CompiledAttentionTwo(
|
||||
mlp_size=mlp_size,
|
||||
hidden_size=hidden_size,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.attn_two",
|
||||
)
|
||||
|
||||
self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bsz = x.shape[0]
|
||||
# CUDAGraph expects same tensor addresses for each run
|
||||
self.hidden_states[:bsz].copy_(x)
|
||||
x = self.attn_one(self.hidden_states[:bsz])
|
||||
self.hidden_states[:bsz].copy_(x)
|
||||
x = self.attn_two(self.hidden_states[:bsz])
|
||||
return x
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
inputs: torch.Tensor,
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(inputs)
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
model(inputs[:2])
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1,
|
||||
),
|
||||
):
|
||||
model(inputs[:1])
|
||||
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
output = model(inputs[:2])
|
||||
|
||||
output = output.cpu()
|
||||
return output.cpu()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_multi_graph_piecewise_compile(
|
||||
use_inductor_graph_partition: bool, use_bytecode_hook: bool, monkeypatch
|
||||
):
|
||||
# Set the environment variable for this test
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
outputs = []
|
||||
|
||||
# vllmcompile compile
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
# Pre-allocate memory for CUDAGraph which expects
|
||||
# static tensor addresses
|
||||
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
||||
|
||||
if use_inductor_graph_partition:
|
||||
# Splitting happens at Inductor lowering level,
|
||||
# total piecewise fx graphs is equal to total graphs
|
||||
num_piecewise_fx = 2
|
||||
num_piecewise_capturable_fx = 2
|
||||
else:
|
||||
# attn_one, attn_two each has 3 piecewise graphs
|
||||
# (pre attn, post attn, silly_attention) each
|
||||
num_piecewise_fx = 6
|
||||
# attn_one, attn_two has pre attn and post attn each, total=4
|
||||
num_piecewise_capturable_fx = 4
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2, # two graphs for the model
|
||||
num_piecewise_graphs_seen=num_piecewise_fx,
|
||||
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
|
||||
num_backend_compilations=num_piecewise_capturable_fx,
|
||||
num_cudagraph_captured=8, # num_cudagraph_sizes * num_partitions
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# no compile or cudagraph
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.NONE,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# piecewise compile without CUDA graph
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=num_piecewise_fx,
|
||||
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
|
||||
num_backend_compilations=num_piecewise_capturable_fx,
|
||||
num_cudagraph_captured=0, # no cudagraph captured
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# Generally don't expect outputs with and without inductor
|
||||
# to be bitwise equivalent
|
||||
assert torch.allclose(outputs[0], outputs[1])
|
||||
|
||||
# Expect bitwise equivalence using inductor w/ and w/o cudagraph
|
||||
assert torch.equal(outputs[0], outputs[2])
|
||||
167
tests/compile/fullgraph/test_simple.py
Normal file
167
tests/compile/fullgraph/test_simple.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test the piecewise compilation with a simple model so that we
|
||||
can exactly calculate the expected output and side effects.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from ..silly_attention import get_global_counter, reset_global_counter
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Overall effect:
|
||||
x = 3 * x + 19
|
||||
global_counter += 2
|
||||
"""
|
||||
x = x + 1
|
||||
x = x + 2
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x - 2
|
||||
x = x - 1
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
|
||||
def _run_simple_model(
|
||||
splitting_ops,
|
||||
use_inductor_graph_partition,
|
||||
backend,
|
||||
expected_num_piecewise_graphs_seen,
|
||||
expected_num_piecewise_capturable_graphs_seen,
|
||||
expected_num_backend_compilations,
|
||||
expected_num_cudagraph_captured,
|
||||
):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
backend=backend,
|
||||
splitting_ops=splitting_ops,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
cudagraph_copy_inputs=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix="")
|
||||
|
||||
inputs = torch.randn(100).cuda()
|
||||
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||
),
|
||||
set_forward_context(None, vllm_config=vllm_config),
|
||||
): # background context
|
||||
# warm up with background context
|
||||
model(inputs)
|
||||
|
||||
# capturing/replaying should under context of cudagraph dispatching
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
model(torch.randn(2).cuda())
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1,
|
||||
),
|
||||
):
|
||||
model(torch.randn(1).cuda())
|
||||
|
||||
input = torch.zeros(2).cuda()
|
||||
reset_global_counter()
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
output = model(input)
|
||||
assert get_global_counter() == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["inductor", "eager"])
|
||||
@torch.inference_mode()
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_simple_piecewise_compile(backend):
|
||||
_run_simple_model(
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=False,
|
||||
backend=backend,
|
||||
# 2 * num_layers + 1
|
||||
expected_num_piecewise_graphs_seen=5,
|
||||
# 1 + num_layers
|
||||
expected_num_piecewise_capturable_graphs_seen=3,
|
||||
# num_piecewise_capturable_graphs_seen
|
||||
expected_num_backend_compilations=3,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
expected_num_cudagraph_captured=6,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_simple_inductor_graph_partition(monkeypatch):
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
# disable compile cache so that we run separately for different splitting_ops
|
||||
# and get the expected number of cudagraphs captured.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
_run_simple_model(
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=True,
|
||||
backend="inductor",
|
||||
# Since not splitting at fx graph level
|
||||
expected_num_piecewise_graphs_seen=1,
|
||||
# Since not splitting at fx graph level
|
||||
expected_num_piecewise_capturable_graphs_seen=1,
|
||||
# Since not splitting at fx graph level
|
||||
expected_num_backend_compilations=1,
|
||||
# Inductor graph partition still captures 6 graph, same as fx graph partition
|
||||
expected_num_cudagraph_captured=6,
|
||||
)
|
||||
523
tests/compile/fullgraph/test_toy_llama.py
Normal file
523
tests/compile/fullgraph/test_toy_llama.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test the piecewise compilation with a simple model, comparing the output
|
||||
with and without the piecewise compilation.
|
||||
|
||||
This is a tractable model, the weights and computation are specially designed
|
||||
if the config `tractable_init` is set to True. Otherwise, the weights are
|
||||
initialized randomly with a fixed seed.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from .. import silly_attention # noqa: F401
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlamaConfig:
|
||||
hidden_size: int = 128
|
||||
mlp_size: int = 256
|
||||
vocab_size: int = 128
|
||||
num_layers: int = 2
|
||||
init_value: float = 1.0
|
||||
tractable_init: bool = False
|
||||
random_seed: int = 0
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
factors: list[Any] = []
|
||||
for k, v in self.__dict__.items():
|
||||
if k == "random_seed":
|
||||
continue
|
||||
factors.append((k, v))
|
||||
factors.sort()
|
||||
import hashlib
|
||||
|
||||
return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.mlp_size >= self.hidden_size
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, config: LlamaConfig) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_projection = nn.Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=config.mlp_size * 2,
|
||||
bias=False,
|
||||
)
|
||||
self.down_projection = nn.Linear(
|
||||
in_features=config.mlp_size,
|
||||
out_features=config.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
if config.tractable_init:
|
||||
nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
|
||||
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
|
||||
nn.init.eye_(self.down_projection.weight.data)
|
||||
else:
|
||||
nn.init.xavier_normal_(
|
||||
self.gate_up_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(config.random_seed),
|
||||
gain=0.001,
|
||||
)
|
||||
nn.init.xavier_normal_(
|
||||
self.down_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(config.random_seed),
|
||||
gain=0.001,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# for tractable_init and positive input, this is
|
||||
# essentially an elementwise-square
|
||||
x = self.gate_up_projection(x)
|
||||
x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
|
||||
x = self.down_projection(x)
|
||||
return x
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, config: LlamaConfig) -> None:
|
||||
super().__init__()
|
||||
self.qkv_projection = nn.Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=config.hidden_size * 3,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.output_projection = nn.Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=config.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
if config.tractable_init:
|
||||
nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
|
||||
nn.init.eye_(
|
||||
self.qkv_projection.weight.data[
|
||||
config.hidden_size : 2 * config.hidden_size
|
||||
]
|
||||
)
|
||||
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
|
||||
nn.init.eye_(self.output_projection.weight.data)
|
||||
else:
|
||||
nn.init.xavier_normal_(
|
||||
self.qkv_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(config.random_seed),
|
||||
gain=0.001,
|
||||
)
|
||||
nn.init.xavier_normal_(
|
||||
self.output_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(config.random_seed),
|
||||
gain=0.001,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# for tractable_init, this is:
|
||||
# output = (hidden_states * 3 + positions * 2)
|
||||
qkv = self.qkv_projection(hidden_states)
|
||||
hidden_size = qkv.size(-1) // 3
|
||||
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
|
||||
|
||||
q = q + positions.unsqueeze(1)
|
||||
k = k + positions.unsqueeze(1)
|
||||
|
||||
attn_output = torch.empty_like(q)
|
||||
torch.ops.silly.attention(q, k, v, attn_output)
|
||||
|
||||
output = self.output_projection(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig) -> None:
|
||||
super().__init__()
|
||||
self.self_attention = LlamaAttention(config)
|
||||
self.mlp = LlamaMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
For tractable computation:
|
||||
- if residual is None, the outputs are:
|
||||
- residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
|
||||
- hidden_states = (residual + 1) ** 2
|
||||
- if residual is not None, the outputs are:
|
||||
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
|
||||
- hidden_states = (residual + 1) ** 2
|
||||
""" # noqa
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states + 1
|
||||
else:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states + 1
|
||||
|
||||
hidden_states = self.self_attention(
|
||||
positions=positions, hidden_states=hidden_states
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states + 1
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
config: LlamaConfig,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_tokens = nn.Embedding(
|
||||
num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]
|
||||
)
|
||||
|
||||
# this is the initial value of the hidden states
|
||||
self.embedding_tokens.weight.data.fill_(config.init_value)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embedding_tokens(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def tractable_computation(
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
config: LlamaConfig,
|
||||
init_value: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = (
|
||||
torch.ones(
|
||||
input_ids.size(0),
|
||||
config.hidden_size,
|
||||
device=input_ids.device,
|
||||
dtype=input_ids.dtype,
|
||||
)
|
||||
* init_value
|
||||
)
|
||||
|
||||
# first layer
|
||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||
hidden_states = (residual + 1) ** 2
|
||||
|
||||
# following layers
|
||||
for _ in range(config.num_layers - 1):
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||
hidden_states = (residual + 1) ** 2
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
|
||||
# Start with a fresh copy to make sure there's no cache dir sharing
|
||||
compile_config = deepcopy(compile_config)
|
||||
cudagraph_runtime_mode = compile_config.cudagraph_mode
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=compile_config, additional_config=llama_config
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = (
|
||||
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
with set_forward_context({}, vllm_config=vllm_config): # background context
|
||||
B = 16 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
|
||||
positions = torch.arange(B).cuda()
|
||||
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(input_ids, positions)
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
model(input_ids[:2], positions[:2])
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1,
|
||||
),
|
||||
):
|
||||
model(input_ids[:1], positions[:1])
|
||||
|
||||
input_ids[:2].zero_()
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
|
||||
output = output.cpu()
|
||||
|
||||
if llama_config.tractable_init:
|
||||
expected_output = tractable_computation(
|
||||
input_ids[:2], positions[:2], llama_config
|
||||
).cpu()
|
||||
|
||||
assert torch.allclose(output, expected_output)
|
||||
else:
|
||||
return output.cpu()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend, use_inductor_graph_partition",
|
||||
[
|
||||
("eager", False), # No inductor
|
||||
("inductor", False), # Inductor, Dynamo partition
|
||||
("inductor", True), # Inductor, Inductor partition
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_toy_llama(
|
||||
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
|
||||
):
|
||||
# We disable the vLLM compile cache into a new tmp dir for 1 reason:
|
||||
# 1. To make sure we can properly track the number of Inductor compilations.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition only supported in torch>=2.9")
|
||||
|
||||
# compare output with and without piecewise compilation
|
||||
|
||||
llama_config = LlamaConfig(
|
||||
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
|
||||
)
|
||||
|
||||
tractable_config = LlamaConfig(
|
||||
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
|
||||
)
|
||||
|
||||
compile_config_no_compile = CompilationConfig(
|
||||
mode=CompilationMode.NONE,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
backend="eager",
|
||||
)
|
||||
|
||||
compile_config_no_split = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
backend=backend,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
|
||||
compile_config_split = deepcopy(compile_config_no_split)
|
||||
compile_config_split.splitting_ops = ["silly::attention"]
|
||||
|
||||
outputs = []
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(run_model(llama_config, compile_config_no_compile))
|
||||
|
||||
run_model(tractable_config, compile_config_no_compile)
|
||||
|
||||
if backend == "inductor":
|
||||
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
||||
else:
|
||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=1,
|
||||
num_piecewise_capturable_graphs_seen=1,
|
||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=2,
|
||||
**kwargs,
|
||||
):
|
||||
outputs.append(run_model(llama_config, compile_config_no_split))
|
||||
|
||||
run_model(tractable_config, compile_config_no_split)
|
||||
|
||||
if use_inductor_graph_partition:
|
||||
num_piecewise_fx = 1
|
||||
num_piecewise_capturable_fx = 1
|
||||
else:
|
||||
num_piecewise_fx = 2 * llama_config.num_layers + 1
|
||||
num_piecewise_capturable_fx = 1 + llama_config.num_layers
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=num_piecewise_fx,
|
||||
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
|
||||
num_backend_compilations=num_piecewise_capturable_fx,
|
||||
# num_cudagraph_sizes * num_partitions
|
||||
num_cudagraph_captured=2 * (1 + llama_config.num_layers),
|
||||
):
|
||||
outputs.append(run_model(llama_config, compile_config_split))
|
||||
run_model(tractable_config, compile_config_split)
|
||||
|
||||
for i in range(1, len(outputs)):
|
||||
assert torch.allclose(outputs[0], outputs[i])
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def benchmark():
|
||||
from triton.testing import do_bench
|
||||
|
||||
# similar to llama 3.1-8B
|
||||
llama_config = LlamaConfig(
|
||||
hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
|
||||
)
|
||||
|
||||
# a tiny model to measure the overhead
|
||||
# of piecewise cudagraph
|
||||
llama_config = LlamaConfig(
|
||||
hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
|
||||
)
|
||||
|
||||
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
|
||||
|
||||
eager_time = {}
|
||||
full_cudagraph_time = {}
|
||||
piecewise_cudagraph_time = {}
|
||||
|
||||
pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
for piecewise in [False, True]:
|
||||
if piecewise:
|
||||
compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = (
|
||||
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
|
||||
.eval()
|
||||
.cuda()
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
|
||||
B = 256 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
|
||||
positions = torch.arange(B).cuda().to(torch.bfloat16)
|
||||
|
||||
graphs = {}
|
||||
|
||||
model(input_ids, positions)
|
||||
for b in cudagraph_sizes[::-1]:
|
||||
if not piecewise:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, pool=pool):
|
||||
output = model(input_ids[:b], positions[:b])
|
||||
graphs[b] = (graph, output)
|
||||
else:
|
||||
output = model(input_ids[:b], positions[:b])
|
||||
graphs[b] = (model, output)
|
||||
for b in cudagraph_sizes:
|
||||
if piecewise:
|
||||
# noqa is for `Function definition does not bind loop variable`
|
||||
# it will be problematic if we save the created lambda function
|
||||
# and use it later, because it will look up the name `b` in the
|
||||
# enclosing scope, and the value of `b` will always be 256.
|
||||
# it is fine here, because we only use the lambda function once.
|
||||
runtime = do_bench(
|
||||
lambda: graphs[b][0]( # noqa
|
||||
input_ids[:b], # noqa
|
||||
positions[:b], # noqa
|
||||
)
|
||||
)
|
||||
piecewise_cudagraph_time[b] = runtime
|
||||
else:
|
||||
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
|
||||
eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa
|
||||
full_cudagraph_time[b] = runtime
|
||||
eager_time[b] = eager_runtime
|
||||
|
||||
# print in tabular format
|
||||
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
|
||||
for b in cudagraph_sizes:
|
||||
print(
|
||||
f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
|
||||
f"\t{piecewise_cudagraph_time[b]:.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Protect against subprocess reimport when using spawn_new_process_for_each_test
|
||||
import os
|
||||
|
||||
if os.environ.get("RUNNING_IN_SUBPROCESS") != "1":
|
||||
benchmark()
|
||||
65
tests/compile/silly_attention.py
Normal file
65
tests/compile/silly_attention.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Shared PyTorch custom silly attention for compilation tests.
|
||||
Centralizes custom operation definitions to avoid duplicate registrations.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
# Shared library for all compilation test operations
|
||||
# Using "silly" namespace to match existing test expectations
|
||||
# import this file will automatically register
|
||||
# torch ops for testing (like silly.attention)
|
||||
silly_lib = Library("silly", "FRAGMENT")
|
||||
|
||||
# Global counter that counts the number of times attention is invoked
|
||||
_global_counter = 0
|
||||
|
||||
|
||||
def get_global_counter():
|
||||
"""Get the current global counter value"""
|
||||
return _global_counter
|
||||
|
||||
|
||||
def reset_global_counter():
|
||||
"""Reset the global counter to 0"""
|
||||
global _global_counter
|
||||
_global_counter = 0
|
||||
|
||||
|
||||
def silly_attention(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
"""
|
||||
Unified attention implementation that depends on
|
||||
all inputs and affects the output.
|
||||
Always increments a global counter that tests can use or ignore.
|
||||
"""
|
||||
global _global_counter
|
||||
|
||||
# Always increment the global counter
|
||||
_global_counter += 1
|
||||
|
||||
# Unified implementation that depends on all inputs
|
||||
out.copy_(q + k + v)
|
||||
|
||||
|
||||
def silly_attention_fake(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
"""Fake implementation for testing"""
|
||||
return
|
||||
|
||||
|
||||
# Register the unified attention operation
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
205
tests/compile/test_aot_compile.py
Normal file
205
tests/compile/test_aot_compile.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
def reference_fn(x: torch.Tensor):
|
||||
assert x.shape[0] <= 42
|
||||
assert x.shape[0] % 2 == 0
|
||||
for _ in range(3000):
|
||||
x = x + x.shape[0]
|
||||
return x
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CompiledMod(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return reference_fn(x)
|
||||
|
||||
|
||||
def make_vllm_config() -> VllmConfig:
|
||||
return VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_vllm_config(vllm_config: VllmConfig):
|
||||
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
vllm_config = make_vllm_config()
|
||||
args = (torch.randn(10, 10),)
|
||||
expected = reference_fn(*args)
|
||||
with use_vllm_config(vllm_config):
|
||||
m.setenv("VLLM_USE_AOT_COMPILE", "0")
|
||||
with (
|
||||
pytest.raises(RuntimeError, match="Detected recompile"),
|
||||
torch.compiler.set_stance("fail_on_recompile"),
|
||||
):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
torch._dynamo.reset()
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
actual = CompiledMod(vllm_config=vllm_config)(*args)
|
||||
assert torch.allclose(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
|
||||
args = (torch.randn(10, 10),)
|
||||
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||
vllm_config = make_vllm_config()
|
||||
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
args = (torch.randn(10, 10),)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
vllm_config = make_vllm_config()
|
||||
with use_vllm_config(vllm_config):
|
||||
expected = CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
vllm_config = make_vllm_config()
|
||||
with use_vllm_config(vllm_config):
|
||||
ret = CompiledMod(vllm_config=vllm_config)(*args)
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that the shape environment is correctly serialized and preserved
|
||||
when loading from cache.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
args = (torch.randn(10, 10),)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
vllm_config = make_vllm_config()
|
||||
with use_vllm_config(vllm_config):
|
||||
compiled_mod = CompiledMod(vllm_config=vllm_config)
|
||||
compiled_mod(*args)
|
||||
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
||||
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
||||
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
vllm_config = make_vllm_config()
|
||||
with use_vllm_config(vllm_config):
|
||||
compiled_mod = CompiledMod(vllm_config=vllm_config)
|
||||
compiled_mod(*args)
|
||||
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
||||
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
||||
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
@use_vllm_config(make_vllm_config())
|
||||
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that compiling gpt2 twice results in a cache hit and
|
||||
capture torch dynamic symbol creations to ensure make_symbol
|
||||
not called on cache hit.
|
||||
"""
|
||||
|
||||
import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module
|
||||
from torch.utils._sympy.symbol import make_symbol
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
create_symbol_counter = multiprocessing.Value("i", 0)
|
||||
original_make_symbol = make_symbol
|
||||
|
||||
@functools.wraps(original_make_symbol)
|
||||
def counting_make_symbol(prefix, idx, **kwargs):
|
||||
with create_symbol_counter.get_lock():
|
||||
create_symbol_counter.value += 1
|
||||
return original_make_symbol(prefix, idx, **kwargs)
|
||||
|
||||
symbolic_shapes_module.make_symbol = counting_make_symbol
|
||||
try:
|
||||
with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname:
|
||||
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
# First compilation - initialize model and generate
|
||||
llm_model = LLM(
|
||||
model="gpt2",
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
),
|
||||
max_model_len=256,
|
||||
)
|
||||
|
||||
llm_model.generate("Hello, my name is")
|
||||
assert create_symbol_counter.value == 2
|
||||
create_symbol_counter.value = 0
|
||||
|
||||
# Clean up first model
|
||||
del llm_model
|
||||
|
||||
# Second compilation - should hit cache
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
llm_model = LLM(
|
||||
model="gpt2",
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
),
|
||||
max_model_len=256,
|
||||
)
|
||||
llm_model.generate("Hello, my name is")
|
||||
|
||||
assert create_symbol_counter.value == 0
|
||||
|
||||
finally:
|
||||
# Restore original method
|
||||
symbolic_shapes_module.make_symbol = original_make_symbol
|
||||
174
tests/compile/test_compile_ranges.py
Normal file
174
tests/compile/test_compile_ranges.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import fx as fx
|
||||
from torch import nn
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
import tests.compile.silly_attention # noqa
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.compilation.inductor_pass import (
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
from vllm.config import (
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.compilation import CompilationConfig, CompilationMode
|
||||
from vllm.config.scheduler import SchedulerConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.forward_context import set_forward_context
|
||||
|
||||
BATCH_SIZE = 64
|
||||
MLP_SIZE = 128
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TestModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
model(torch.randn(BATCH_SIZE, MLP_SIZE))
|
||||
for batch_size in batch_sizes:
|
||||
model(torch.randn(batch_size, MLP_SIZE))
|
||||
|
||||
|
||||
class PostGradRangeChecker(InductorPass):
|
||||
def __init__(self, ranges: list[Range]):
|
||||
self.ranges = ranges
|
||||
self.num_calls = 0
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
compile_range = get_pass_context().compile_range
|
||||
assert compile_range in self.ranges, (
|
||||
f"Compile range {compile_range} not in {self.ranges}"
|
||||
)
|
||||
self.num_calls += 1
|
||||
|
||||
def uuid(self) -> str:
|
||||
state: dict[str, Any] = {}
|
||||
return InductorPass.hash_dict(state)
|
||||
|
||||
|
||||
def test_compile_ranges(use_fresh_inductor_cache):
|
||||
post_grad_range_checker = PostGradRangeChecker(
|
||||
[
|
||||
Range(start=1, end=8),
|
||||
Range(start=16, end=16),
|
||||
Range(start=9, end=32),
|
||||
Range(start=64, end=64),
|
||||
Range(start=33, end=8192),
|
||||
]
|
||||
)
|
||||
torch.set_default_device("cuda")
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_ranges_split_points=[8, 32],
|
||||
compile_sizes=[16, 64, 128],
|
||||
inductor_compile_config={
|
||||
"post_grad_custom_post_pass": post_grad_range_checker,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = TestModel(vllm_config=vllm_config, prefix="").eval()
|
||||
# Number of compilations: 3 for each compile range + 2 compile sizes
|
||||
batch_sizes = [1, 4, 16, 24, 48, 64, 8192]
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=1,
|
||||
num_backend_compilations=5,
|
||||
):
|
||||
run_model(vllm_config, model, batch_sizes)
|
||||
assert post_grad_range_checker.num_calls == 5
|
||||
|
||||
|
||||
def test_compile_config_get_compile_ranges():
|
||||
compilation_config = CompilationConfig(
|
||||
compile_ranges_split_points=[8, 32],
|
||||
)
|
||||
VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
),
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
assert compilation_config.get_compile_ranges() == [
|
||||
Range(start=1, end=8),
|
||||
Range(start=9, end=32),
|
||||
Range(start=33, end=8192),
|
||||
]
|
||||
|
||||
|
||||
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
|
||||
# To force multiple compilations, we disable the compile cache
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
post_grad_range_checker = PostGradRangeChecker(
|
||||
ranges=[
|
||||
Range(start=1, end=8),
|
||||
Range(start=9, end=8192),
|
||||
]
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def create_vllm_config():
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_ranges_split_points=[8],
|
||||
inductor_compile_config={
|
||||
"post_grad_custom_post_pass": post_grad_range_checker,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
vllm_config_1 = create_vllm_config()
|
||||
with set_current_vllm_config(vllm_config_1):
|
||||
model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval()
|
||||
batch_sizes = [1, 16]
|
||||
run_model(vllm_config_1, model1, batch_sizes)
|
||||
assert post_grad_range_checker.num_calls == 2
|
||||
|
||||
post_grad_range_checker.num_calls = 0
|
||||
# Create a new vllm config with the new pass context
|
||||
vllm_config_2 = create_vllm_config()
|
||||
with set_current_vllm_config(vllm_config_2):
|
||||
model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval()
|
||||
batch_sizes = [4, 32]
|
||||
run_model(vllm_config_2, model2, batch_sizes)
|
||||
# Check that cache is used, so the number of calls
|
||||
# should be 0
|
||||
assert post_grad_range_checker.num_calls == 0
|
||||
404
tests/compile/test_config.py
Normal file
404
tests/compile/test_config.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode, PassConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
|
||||
|
||||
def test_version():
|
||||
# Test the version comparison logic using the private function
|
||||
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
|
||||
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
|
||||
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
|
||||
assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
|
||||
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||
|
||||
|
||||
def test_copy_pass():
|
||||
vllm_config = VllmConfig()
|
||||
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||
copied_inductor_pass = copy.deepcopy(inductor_pass)
|
||||
assert (
|
||||
copied_inductor_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
assert (
|
||||
copied_inductor_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
|
||||
|
||||
def test_custom_op():
|
||||
# proper syntax
|
||||
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid syntax '"):
|
||||
_ = CompilationConfig(custom_ops=["quant_fp8"])
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
|
||||
# on the state of the cache directory on the current machine, which
|
||||
# may be influenced by other tests.
|
||||
@pytest.mark.parametrize("val", ["1"])
|
||||
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
|
||||
|
||||
compilation_config = {
|
||||
"cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
|
||||
}
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
num_cache_entries_updated=0, num_compiled_artifacts_saved=0
|
||||
),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config=compilation_config,
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize(
|
||||
"cudagraph_mode,num_cudagraph_captured",
|
||||
[
|
||||
(CUDAGraphMode.NONE, 0),
|
||||
(CUDAGraphMode.FULL_DECODE_ONLY, 1),
|
||||
(CUDAGraphMode.PIECEWISE, 13),
|
||||
(CUDAGraphMode.FULL_AND_PIECEWISE, 14),
|
||||
],
|
||||
)
|
||||
def test_use_cudagraphs(
|
||||
vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
|
||||
):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [100],
|
||||
"cudagraph_mode": cudagraph_mode,
|
||||
}
|
||||
num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
|
||||
num_cudagraph_captured=num_cudagraph_captured,
|
||||
),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config=compilation_config,
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
def test_stock_torch_compile(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
compilation_counter.expect(stock_torch_compile_count=1),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
def test_no_compilation(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
with (
|
||||
compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config={"mode": CompilationMode.NONE},
|
||||
gpu_memory_utilization=0.4,
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
def test_enforce_eager(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with (
|
||||
compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
"facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
|
||||
) as _,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def test_splitting_ops_dynamic():
|
||||
# Default config
|
||||
config = VllmConfig()
|
||||
# Default V1 config leaves cudagraph mode unset; splitting ops are only
|
||||
# populated when the engine decides to use piecewise compilation.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
# When use_inductor_graph_partition=True
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
)
|
||||
)
|
||||
# with inductor partition we use splitting_ops directly for
|
||||
# partition rules
|
||||
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||
|
||||
# When attn_fusion pass enabled.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# cudagraph mode also fall back to FULL
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
# splitting_ops can not contain attention ops when attn_fusion
|
||||
# pass enabled.
|
||||
with pytest.raises(ValidationError):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
splitting_ops=CompilationConfig()._attention_ops,
|
||||
)
|
||||
)
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# fuse_attn_quant is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
|
||||
def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
||||
# Inductor partition case: user-provided splitting_ops should be
|
||||
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
|
||||
config = VllmConfig(
|
||||
parallel_config=ParallelConfig(
|
||||
all2all_backend="deepep_high_throughput",
|
||||
data_parallel_size=8,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=[
|
||||
"vllm::unified_attention",
|
||||
"vllm::moe_forward",
|
||||
"vllm::moe_forward_shared",
|
||||
],
|
||||
),
|
||||
)
|
||||
splitting_ops = config.compilation_config.splitting_ops
|
||||
assert splitting_ops == [
|
||||
"vllm::unified_attention",
|
||||
"vllm::moe_forward",
|
||||
"vllm::moe_forward_shared",
|
||||
]
|
||||
|
||||
|
||||
def test_should_split():
|
||||
import torch
|
||||
|
||||
from vllm.compilation.partition_rules import should_split
|
||||
|
||||
graph = torch.fx.Graph()
|
||||
node = torch.fx.Node(
|
||||
graph=graph,
|
||||
name="dummy_node",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.default,
|
||||
args=(),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
# supports OpOverloadPacket
|
||||
splitting_ops = ["aten::add"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
# supports OpOverload
|
||||
splitting_ops = ["aten::add.default"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
# supports OpOverload
|
||||
splitting_ops = ["aten::add.Tensor"]
|
||||
assert not should_split(node, splitting_ops)
|
||||
|
||||
q, k, v, out = [torch.randn(1)] * 4
|
||||
|
||||
# supports custom ops as OpOverloadPacket
|
||||
node = torch.fx.Node(
|
||||
graph=graph,
|
||||
name="dummy_node",
|
||||
op="call_function",
|
||||
target=torch.ops.silly.attention,
|
||||
args=(q, k, v, out),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
splitting_ops = ["silly::attention"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
# supports custom ops as OpOverload
|
||||
node = torch.fx.Node(
|
||||
graph=graph,
|
||||
name="dummy_node",
|
||||
op="call_function",
|
||||
target=torch.ops.silly.attention.default,
|
||||
args=(q, k, v, out),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
splitting_ops = ["silly::attention"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
splitting_ops = ["silly::attention.default"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.support_static_graph_mode(),
|
||||
reason="Skip if not cudagraph mode supported",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"cudagraph_capture_sizes",
|
||||
"max_cudagraph_capture_size",
|
||||
"tp_size",
|
||||
"enable_sp",
|
||||
"max_num_batched_tokens",
|
||||
"cudagraph_mode",
|
||||
"expected_max_size",
|
||||
),
|
||||
[
|
||||
(None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
(
|
||||
[1, 2, 4],
|
||||
8,
|
||||
1,
|
||||
False,
|
||||
2048,
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
ValidationError,
|
||||
),
|
||||
([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
|
||||
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
|
||||
# truncated to nearest multiple of 8 or 16
|
||||
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
# max from list
|
||||
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
|
||||
# filtered out 15 due to SP
|
||||
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
# limited by the max_tokens
|
||||
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
# the list should contain at least 1 element when use cudagraph
|
||||
([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
|
||||
# the max capturing size should be >= 1 when use cudagraph
|
||||
(None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
|
||||
],
|
||||
)
|
||||
def test_cudagraph_sizes_post_init(
|
||||
cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size,
|
||||
tp_size,
|
||||
enable_sp,
|
||||
max_num_batched_tokens,
|
||||
cudagraph_mode,
|
||||
expected_max_size,
|
||||
):
|
||||
ctx = nullcontext()
|
||||
if expected_max_size == ValidationError:
|
||||
ctx = pytest.raises(expected_max_size)
|
||||
|
||||
with (
|
||||
ctx,
|
||||
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
|
||||
):
|
||||
compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size=max_cudagraph_capture_size,
|
||||
pass_config=PassConfig(
|
||||
enable_sp=enable_sp,
|
||||
fuse_norm_quant=True,
|
||||
fuse_act_quant=True,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
)
|
||||
engine_args = EngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=tp_size,
|
||||
max_num_seqs=min(max_num_batched_tokens, 128),
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
assert (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
286
tests/compile/test_decorator.py
Normal file
286
tests/compile/test_decorator.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(
|
||||
vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode
|
||||
):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
model(torch.randn(2, MLP_SIZE).cuda())
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1,
|
||||
),
|
||||
):
|
||||
model(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
output = model(torch.randn(2, MLP_SIZE).cuda())
|
||||
|
||||
output = output.cpu()
|
||||
return output.cpu()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||
def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatch):
|
||||
# disable compile cache so that we can count the number of compilations
|
||||
# appropriately
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
# piecewise
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
expected_num_graphs_seen = 1
|
||||
expected_num_cudagraph_captured = (
|
||||
4 # num_cudagraph_sizes * num cudagraphs to capture
|
||||
)
|
||||
if use_inductor_graph_partition:
|
||||
expected_num_piecewise_graphs_seen = 1
|
||||
expected_num_piecewise_capturable_graphs_seen = 1
|
||||
expected_num_backend_compilations = 1
|
||||
else:
|
||||
expected_num_piecewise_graphs_seen = 3
|
||||
expected_num_piecewise_capturable_graphs_seen = 2
|
||||
expected_num_backend_compilations = 2
|
||||
|
||||
@support_torch_compile
|
||||
class A(nn.Module):
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
|
||||
@ignore_torch_compile
|
||||
class B(A): ...
|
||||
|
||||
@support_torch_compile
|
||||
class C(B): ...
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
# A has support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=expected_num_graphs_seen,
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
# B's ignore_torch_compile should override A's support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
# C's support_torch_compile should override B's ignore_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=expected_num_graphs_seen,
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||
):
|
||||
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
||||
|
||||
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
||||
@support_torch_compile(
|
||||
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
)
|
||||
class B(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x + x
|
||||
return x
|
||||
|
||||
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
||||
@support_torch_compile(
|
||||
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
|
||||
)
|
||||
class A(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.mod1(x)
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = self.mod2(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||
def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch):
|
||||
# disable compile cache so that we can count the number of compilations
|
||||
# appropriately
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=True,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
),
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
if use_inductor_graph_partition:
|
||||
expected_num_piecewise_graphs_seen = 2
|
||||
expected_num_piecewise_capturable_graphs_seen = 2
|
||||
expected_num_backend_compilations = 2
|
||||
else:
|
||||
expected_num_piecewise_graphs_seen = 6
|
||||
expected_num_piecewise_capturable_graphs_seen = 4
|
||||
expected_num_backend_compilations = 4
|
||||
|
||||
# A has support_torch_compile but enable_if fn returns False
|
||||
# enalbe_if will be True for B, so we expect mod1 and mod2
|
||||
# to be compiled
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
# 3 piecewise graphs per instance of B()
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num cudagraphable graphs to capture
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
# Set kv_sharing_fast_prefill=False
|
||||
# which will cause A to be compiled and B to not be compiled
|
||||
vllm_config = VllmConfig(
|
||||
cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=False,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
|
||||
|
||||
if use_inductor_graph_partition:
|
||||
expected_num_piecewise_graphs_seen = 1
|
||||
expected_num_piecewise_capturable_graphs_seen = 1
|
||||
expected_num_backend_compilations = 1
|
||||
else:
|
||||
# 3 attn ops and 4 non-attn ops
|
||||
expected_num_piecewise_graphs_seen = 7
|
||||
expected_num_piecewise_capturable_graphs_seen = 4
|
||||
expected_num_backend_compilations = 4
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||
# 3 attn ops and 4 non-attn ops
|
||||
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
|
||||
num_backend_compilations=expected_num_backend_compilations,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num cudagraphable graphs to capture
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
219
tests/compile/test_dynamic_shapes_compilation.py
Normal file
219
tests/compile/test_dynamic_shapes_compilation.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.config.compilation import (
|
||||
CompilationMode,
|
||||
DynamicShapesConfig,
|
||||
DynamicShapesType,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
def get_test_models():
|
||||
"""Get list of models to test based on PyTorch version"""
|
||||
# TODO "Qwen/Qwen3-4B-Instruct-2507" fails Fix issue and support it.
|
||||
return ["gpt2", "Qwen/Qwen2-7B-Instruct", "meta-llama/Llama-3.1-8B"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", get_test_models())
|
||||
@pytest.mark.parametrize(
|
||||
"shapes_type",
|
||||
[
|
||||
DynamicShapesType.BACKED,
|
||||
DynamicShapesType.UNBACKED,
|
||||
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
@pytest.mark.parametrize("evaluate_guards", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_dynamic_shapes_compilation(
|
||||
monkeypatch,
|
||||
model_name,
|
||||
shapes_type,
|
||||
use_aot_compile,
|
||||
use_bytecode_hook,
|
||||
evaluate_guards,
|
||||
):
|
||||
"""Test that all dynamic shapes types compile successfully"""
|
||||
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
|
||||
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
|
||||
|
||||
if evaluate_guards and shapes_type == DynamicShapesType.UNBACKED:
|
||||
pytest.skip("unbacked dynamic shapes do not add guards")
|
||||
|
||||
if evaluate_guards and use_aot_compile:
|
||||
pytest.skip("evaluate_guards requires use_aot_compile=0")
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
|
||||
print(f"Testing {shapes_type.name} dynamic shapes...")
|
||||
|
||||
# Initialize the model with specific dynamic shapes configuration
|
||||
model = LLM(
|
||||
model=model_name,
|
||||
compilation_config={
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"dynamic_shapes_config": {
|
||||
"type": shapes_type.value,
|
||||
"evaluate_guards": evaluate_guards,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
output = model.generate(prompt)
|
||||
result = output[0].outputs[0].text
|
||||
# Example of setting the sampling parameters
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
yes_tokens = tokenizer.encode("yes", add_special_tokens=False)
|
||||
no_tokens = tokenizer.encode("no", add_special_tokens=False)
|
||||
allowed_ids = list(set(yes_tokens + no_tokens))
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=1, temperature=0, allowed_token_ids=allowed_ids
|
||||
)
|
||||
|
||||
output = model.generate(
|
||||
"answer with yes or no is " + result + " rubbish for prompt " + prompt + "?",
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
result = output[0].outputs[0].text
|
||||
assert result == "yes"
|
||||
|
||||
# Clean up GPU memory
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
print("GPU memory cleared")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
|
||||
@pytest.mark.parametrize(
|
||||
"dynamic_shapes_type",
|
||||
[
|
||||
DynamicShapesType.BACKED,
|
||||
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("evaluate_guards", [False, True])
|
||||
def test_model_specialization_with_evaluate_guards(
|
||||
monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards
|
||||
):
|
||||
"""Test that evaluate_guards correctly detects shape specialization
|
||||
violations.
|
||||
"""
|
||||
|
||||
if (
|
||||
use_aot_compile == "1"
|
||||
and dynamic_shapes_type == DynamicShapesType.BACKED
|
||||
and evaluate_guards
|
||||
):
|
||||
pytest.skip("evaluate_guards for backed does not work with aot_compile=1")
|
||||
|
||||
@support_torch_compile
|
||||
class ModelWithSizeCheck(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# This will cause specialization - torch.compile will guard on
|
||||
# sx.shape[0]
|
||||
if x.shape[0] >= 10:
|
||||
return x * 10
|
||||
else:
|
||||
return x * 10
|
||||
|
||||
@support_torch_compile
|
||||
class ModelWithOneSizeCheck(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# This will cause 0/1 specializations.
|
||||
if x.shape[0] == 0:
|
||||
return x * 10
|
||||
if x.shape[0] == 1:
|
||||
return x * 10
|
||||
else:
|
||||
return x * 10
|
||||
|
||||
@contextmanager
|
||||
def use_vllm_config(vllm_config: VllmConfig):
|
||||
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
|
||||
yield
|
||||
|
||||
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0")
|
||||
|
||||
# Create vllm config with the desired settings
|
||||
from vllm.config import CompilationMode
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
dynamic_shapes_config=DynamicShapesConfig(
|
||||
type=dynamic_shapes_type,
|
||||
evaluate_guards=evaluate_guards,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def test(model_class, input1, input2, is_01_specialization=False):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
use_vllm_config(vllm_config),
|
||||
tempfile.TemporaryDirectory() as tmpdirname,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||
|
||||
model = model_class(vllm_config=vllm_config).cuda()
|
||||
|
||||
model(input1)
|
||||
|
||||
if evaluate_guards and (
|
||||
not (
|
||||
is_01_specialization
|
||||
and dynamic_shapes_type == DynamicShapesType.BACKED
|
||||
)
|
||||
):
|
||||
# This should fail because guards were added.
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
model(input2)
|
||||
|
||||
# Expected failure - guard was violated
|
||||
error_msg = str(excinfo.value)
|
||||
assert (
|
||||
"GuardManager check failed" in error_msg
|
||||
or "Detected recompile when torch.compile stance" in error_msg
|
||||
), error_msg
|
||||
|
||||
else:
|
||||
model(input2)
|
||||
|
||||
test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda())
|
||||
test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda())
|
||||
test(
|
||||
ModelWithOneSizeCheck,
|
||||
torch.randn(20, 10).cuda(),
|
||||
torch.randn(1, 10).cuda(),
|
||||
is_01_specialization=True,
|
||||
)
|
||||
267
tests/compile/test_functionalization.py
Normal file
267
tests/compile/test_functionalization.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
TEST_FP8 = current_platform.supports_fp8()
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class TestSiluMul(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int = 128):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
if TEST_FP8:
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
if TEST_FP8:
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
return x2
|
||||
else:
|
||||
return y
|
||||
|
||||
def example_inputs(self, num_tokens=32, hidden_size=128):
|
||||
return (torch.rand(num_tokens, hidden_size * 2),)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
if TEST_FP8 and do_fusion:
|
||||
return [torch.ops._C.silu_and_mul_quant.default]
|
||||
else:
|
||||
return [torch.ops._C.silu_and_mul.default]
|
||||
|
||||
def ops_not_in_model(self):
|
||||
return []
|
||||
|
||||
|
||||
class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size))
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
|
||||
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
if TEST_FP8:
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
|
||||
# matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
|
||||
# layer normalization
|
||||
norm_output, residual_output = self.norm(mm, residual)
|
||||
|
||||
if TEST_FP8:
|
||||
# scaled_mm with static input quantization
|
||||
fp8_linear_result = self.fp8_linear.apply(
|
||||
norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(norm_output.device),
|
||||
)
|
||||
|
||||
return fp8_linear_result, residual_output
|
||||
|
||||
else:
|
||||
return norm_output, residual_output
|
||||
|
||||
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size))
|
||||
return (hidden_states, residual)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
if TEST_FP8 and do_fusion:
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||
else:
|
||||
return [torch.ops._C.fused_add_rms_norm.default]
|
||||
|
||||
def ops_not_in_model(self):
|
||||
return []
|
||||
|
||||
|
||||
class TestRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, head_dim=64, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
|
||||
def forward(self, positions, q, k):
|
||||
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
|
||||
return q_rotated, k_rotated
|
||||
|
||||
def example_inputs(self, num_tokens=32, head_dim=64):
|
||||
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||
q = torch.randn(num_tokens, head_dim)
|
||||
k = torch.randn(num_tokens, head_dim)
|
||||
return (positions, q, k)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
return [torch.ops._C.rotary_embedding.default]
|
||||
|
||||
def ops_not_in_model(self):
|
||||
return []
|
||||
|
||||
|
||||
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = head_dim * num_heads
|
||||
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
self.hidden_size, self.hidden_size * 3, bias=False
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
|
||||
def forward(self, positions, hidden_states):
|
||||
# Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
|
||||
# -> slice_scatter -> split_with_sizes
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
|
||||
q, k, v = torch.split(qkv, split_sizes, dim=-1)
|
||||
|
||||
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
|
||||
|
||||
qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
|
||||
return qkv_updated
|
||||
|
||||
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
|
||||
hidden_size = head_dim * num_heads
|
||||
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
return (positions, hidden_states)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
return [torch.ops._C.rotary_embedding.default]
|
||||
|
||||
def ops_not_in_model(self):
|
||||
return [torch.ops.aten.slice_scatter.default]
|
||||
|
||||
|
||||
MODELS = [
|
||||
TestSiluMul,
|
||||
TestFusedAddRMSNorm,
|
||||
TestRotaryEmbedding,
|
||||
TestRotaryEmbeddingSliceScatter,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("model_class", MODELS)
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||
def test_fix_functionalization(
|
||||
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
custom_ops=["all"],
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=do_fusion,
|
||||
fuse_act_quant=do_fusion,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
assert RMSNorm.enabled()
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
passes = (
|
||||
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion
|
||||
else [noop_pass, cleanup_pass]
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
model = model_class()
|
||||
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
for op in model.ops_not_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
338
tests/compile/test_fusion.py
Normal file
338
tests/compile/test_fusion.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.plugins
|
||||
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float,
|
||||
group_shape: GroupShape,
|
||||
cuda_force_torch: bool,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cuda_force_torch = cuda_force_torch
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
if group_shape.is_per_group():
|
||||
self.wscale = [
|
||||
torch.rand(
|
||||
(hidden_size // group_shape[1], hidden_size // group_shape[1]),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
else:
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
static = group_shape == GroupShape.PER_TENSOR
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
else:
|
||||
self.scale = [None for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
|
||||
]
|
||||
if not group_shape.is_per_group():
|
||||
self.w = [self.w[0].t() for _ in range(3)]
|
||||
|
||||
if group_shape.is_per_group():
|
||||
self.fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
|
||||
else:
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
self.group_shape = group_shape
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
x3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
x4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
|
||||
]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return (
|
||||
[QUANT_OPS[self.quant_key]]
|
||||
if self.enable_quant_fp8_custom_op
|
||||
else [torch.ops.aten.reciprocal]
|
||||
)
|
||||
|
||||
def ops_in_model_before_partial(self):
|
||||
return (
|
||||
[RMS_OP, RMS_ADD_OP]
|
||||
if self.enable_rms_norm_custom_op
|
||||
else [torch.ops.aten.rsqrt]
|
||||
)
|
||||
|
||||
|
||||
GROUP_SHAPES = [
|
||||
GroupShape.PER_TOKEN,
|
||||
GroupShape.PER_TENSOR,
|
||||
GroupShape(1, 128),
|
||||
GroupShape(1, 64),
|
||||
]
|
||||
|
||||
|
||||
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float, **kwargs):
|
||||
super().__init__()
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(128, 128),
|
||||
act_quant_group_shape=GroupShape(1, 128),
|
||||
cutlass_block_fp8_supported=False,
|
||||
use_aiter_and_is_supported=True,
|
||||
)
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
||||
self.wscale = [
|
||||
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
|
||||
|
||||
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||
x2, resid, self.norm_weight[1], self.eps
|
||||
)
|
||||
|
||||
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
|
||||
|
||||
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||
x3, resid, self.norm_weight[2], self.eps
|
||||
)
|
||||
|
||||
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
|
||||
|
||||
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||
x4, resid, self.norm_weight[3], self.eps
|
||||
)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.rocm_aiter_rms_norm,
|
||||
torch.ops.vllm.rocm_aiter_group_fp8_quant,
|
||||
]
|
||||
|
||||
def ops_in_model_before_partial(self):
|
||||
return []
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
|
||||
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
||||
@pytest.mark.parametrize(
|
||||
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
|
||||
list(itertools.product([TestModel], [True, False], [True, False]))
|
||||
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
|
||||
)
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize(
|
||||
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
|
||||
)
|
||||
def test_fusion_rmsnorm_quant(
|
||||
dtype,
|
||||
hidden_size,
|
||||
num_tokens,
|
||||
eps,
|
||||
group_shape,
|
||||
model_class,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
cuda_force_torch,
|
||||
):
|
||||
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||
pytest.skip("AITER is not supported on this GPU.")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
|
||||
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
|
||||
|
||||
# Skip test for 64-bit group shape when running with cutlass or deepgemm
|
||||
if group_shape == GroupShape(1, 64) and (
|
||||
cutlass_block_fp8_supported() or is_deep_gemm_supported()
|
||||
):
|
||||
pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")
|
||||
|
||||
custom_ops = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_quant_fp8_custom_op:
|
||||
custom_ops.append("+quant_fp8")
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||
),
|
||||
),
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
if model_class is TestRmsnormGroupFp8QuantModel:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
|
||||
else:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||
model = model_class(
|
||||
hidden_size=hidden_size,
|
||||
eps=eps,
|
||||
group_shape=group_shape,
|
||||
cuda_force_torch=cuda_force_torch,
|
||||
)
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
result_fused = model_fused(x)
|
||||
|
||||
model_unfused = torch.compile(model, backend=backend2)
|
||||
result_unfused = model_unfused(x)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert fusion_pass.matched_count == 3
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_before_ops(
|
||||
model.ops_in_model_before_partial(), fully_replaced=False
|
||||
)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
# If RMSNorm custom op is disabled (native/torch impl used),
|
||||
# there's a risk that the fused add doesn't get included in the
|
||||
# replacement and only the rms part gets fused with quant.
|
||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||
if (
|
||||
not enable_rms_norm_custom_op
|
||||
and model_class is not TestRmsnormGroupFp8QuantModel
|
||||
):
|
||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
||||
assert n_add_nodes(backend.graph_post_pass) == 2
|
||||
477
tests/compile/test_fusion_attn.py
Normal file
477
tests/compile/test_fusion_attn.py
Normal file
@@ -0,0 +1,477 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import LazyInitPass, TestBackend
|
||||
from tests.utils import flat_product
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
class AttentionQuantPatternModel(torch.nn.Module):
|
||||
"""Base model for AttentionQuantPattern fusion."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
vllm_config: VllmConfig,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_qo_heads = num_qo_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_size = head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_qo_heads,
|
||||
head_size=self.head_size,
|
||||
scale=1.0 / (self.head_size**0.5),
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=vllm_config.cache_config,
|
||||
prefix="model.layers.0.self_attn.attn",
|
||||
)
|
||||
self.attn._k_scale = self.attn._k_scale.to(device)
|
||||
self.attn._v_scale = self.attn._v_scale.to(device)
|
||||
|
||||
self.block_size = 16
|
||||
|
||||
# Initialize attn MetadataBuilder
|
||||
self.builder = self.attn.attn_backend.get_builder_cls()(
|
||||
kv_cache_spec=AttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
),
|
||||
layer_names=[self.attn.layer_name],
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
|
||||
"""Initialize attention metadata."""
|
||||
|
||||
# Create common attn metadata
|
||||
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, self.block_size, self.device, arange_block_indices=True
|
||||
)
|
||||
|
||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
||||
num_blocks = batch_size * max_blocks
|
||||
backend = self.attn.backend
|
||||
|
||||
# TODO(luka) use get_kv_cache_stride_order
|
||||
# Create dummy KV cache for the selected backend
|
||||
if backend == AttentionBackendEnum.ROCM_ATTN:
|
||||
# k/v as 1st dimention
|
||||
# HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
2,
|
||||
num_blocks,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
# k/v as 1st dimention
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
2,
|
||||
num_blocks,
|
||||
self.block_size,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
# k/v as 2nd dimention
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == AttentionBackendEnum.FLASHINFER:
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
).permute(0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {backend}")
|
||||
self.attn.kv_cache = [kv_cache]
|
||||
|
||||
# Build attn metadata
|
||||
self.attn_metadata = self.builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
return self.attn_metadata
|
||||
|
||||
|
||||
class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
"""Test model for AttentionFp8StaticQuantPattern fusion."""
|
||||
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.quant_key.scale.static,
|
||||
act_quant_group_shape=self.quant_key.scale.group_shape,
|
||||
)
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w",
|
||||
{
|
||||
"weight": torch.randn(hidden_size, hidden_size)
|
||||
.to(dtype=FP8_DTYPE, device=self.device)
|
||||
.t(),
|
||||
"wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
return self.fp8_linear.apply(
|
||||
input=attn_output,
|
||||
weight=self.w["weight"],
|
||||
weight_scale=self.w["wscale"],
|
||||
input_scale=self.w["scale"],
|
||||
)
|
||||
|
||||
|
||||
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
"""Test model for AttentionNvfp4QuantPattern fusion."""
|
||||
|
||||
quant_key = kNvfp4Quant
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w",
|
||||
{
|
||||
"weight": torch.randint(
|
||||
256,
|
||||
(hidden_size, hidden_size // 2),
|
||||
dtype=FP4_DTYPE,
|
||||
device=self.device,
|
||||
),
|
||||
"wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
|
||||
dtype=FP8_DTYPE, device=self.device
|
||||
),
|
||||
"wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
|
||||
"scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
quant_output, output_block_scale = scaled_fp4_quant(
|
||||
attn_output, 1 / self.w["scale"]
|
||||
)
|
||||
return cutlass_scaled_fp4_mm(
|
||||
a=quant_output,
|
||||
b=self.w["weight"],
|
||||
block_scale_a=output_block_scale,
|
||||
block_scale_b=self.w["wscale_swizzled"],
|
||||
alpha=self.w["scale"] * self.w["wscale"],
|
||||
out_dtype=attn_output.dtype,
|
||||
)
|
||||
|
||||
|
||||
MODELS_FP8: list[tuple[str, type]] = []
|
||||
MODELS_FP4: list[tuple[str, type]] = []
|
||||
HEADS: list[tuple[int, int]] = []
|
||||
SPLIT_ATTENTION: list[bool] = []
|
||||
BACKENDS_FP8: list[AttentionBackendEnum] = []
|
||||
BACKENDS_FP4: list[AttentionBackendEnum] = []
|
||||
|
||||
if current_platform.is_cuda():
|
||||
HEADS = [(64, 8), (40, 8)]
|
||||
MODELS_FP8 = [
|
||||
(
|
||||
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
TestAttentionFp8StaticQuantPatternModel,
|
||||
)
|
||||
]
|
||||
MODELS_FP4 = [
|
||||
(
|
||||
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
TestAttentionNvfp4QuantPatternModel,
|
||||
)
|
||||
]
|
||||
BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
|
||||
BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
HEADS = [(32, 8), (40, 8)]
|
||||
MODELS_FP8 = [
|
||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||
]
|
||||
BACKENDS = [
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"backend, model_name, model_class, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"]))
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||
)
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
def test_attention_quant_pattern(
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
custom_ops: str,
|
||||
model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: AttentionBackendEnum,
|
||||
dist_init,
|
||||
):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(42)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
dtype=dtype,
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_seqs=1024,
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops_list,
|
||||
),
|
||||
cache_config=CacheConfig(cache_dtype="fp8"),
|
||||
attention_config=AttentionConfig(backend=backend),
|
||||
)
|
||||
|
||||
# Create test inputs
|
||||
q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
|
||||
k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
|
||||
# Mark first dimension as dynamic for realistic testing
|
||||
torch._dynamo.mark_dynamic(q, 0)
|
||||
torch._dynamo.mark_dynamic(k, 0)
|
||||
torch._dynamo.mark_dynamic(v, 0)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
vllm_config_unfused = copy.deepcopy(vllm_config)
|
||||
with (
|
||||
set_current_vllm_config(vllm_config_unfused),
|
||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
|
||||
):
|
||||
model_unfused = model_class(
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config_unfused,
|
||||
)
|
||||
model_unfused = model_unfused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
|
||||
|
||||
# Run model directly without fusion
|
||||
# Still compile so query QuantFP8 has closer numerics
|
||||
result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v)
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
fuse_attn_quant=True, eliminate_noops=True
|
||||
)
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
|
||||
):
|
||||
model_fused = model_class(
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config,
|
||||
w=model_unfused.w,
|
||||
)
|
||||
model_fused = model_fused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
|
||||
|
||||
# Create test backend with fusion passes enabled
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
||||
|
||||
# Compile model with fusion enabled
|
||||
model_compiled = torch.compile(
|
||||
model_fused, backend=test_backend, fullgraph=True
|
||||
)
|
||||
assert model_compiled.attn._o_scale_float is None
|
||||
|
||||
result_fused_1 = model_compiled(q, k, v)
|
||||
|
||||
if backend == AttentionBackendEnum.FLASHINFER:
|
||||
# With the Flashinfer backend after the 1st round of the forward
|
||||
# pass, output quant scale should be loaded into the attn layer's
|
||||
# _o_scale_float, the 2nd round should reuse the loaded
|
||||
# _o_scale_float
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
result_fused_2 = model_compiled(q, k, v)
|
||||
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
|
||||
torch.testing.assert_close(
|
||||
result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
# Check attn fusion support
|
||||
quant_key: QuantKey = model_class.quant_key
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key)
|
||||
for key, layer in vllm_config.compilation_config.static_forward_context.items()
|
||||
]
|
||||
assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
|
||||
"All layers should support attention fusion"
|
||||
)
|
||||
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
quant_op = (
|
||||
torch.ops.aten.reciprocal
|
||||
if "-quant_fp8" in custom_ops_list
|
||||
else QUANT_OPS[quant_key]
|
||||
)
|
||||
|
||||
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
|
||||
# Only output quant ops are fused into attention.
|
||||
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
|
||||
|
||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||
|
||||
# Check attention ops in the graph before and after fusion
|
||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
||||
attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
|
||||
|
||||
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
|
||||
assert len(attn_nodes_pre) == len(attn_nodes_post), (
|
||||
"Should have same number of attention nodes before and after fusion"
|
||||
)
|
||||
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
|
||||
"Attention should not have output_scale before fusion"
|
||||
)
|
||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
|
||||
"Attention should have output_scale after fusion"
|
||||
)
|
||||
|
||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
||||
"Attention should not have output_block_scale before fusion"
|
||||
)
|
||||
if quant_key.dtype == FP8_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
||||
"Attention should not have output_block_scale after FP8 fusion"
|
||||
)
|
||||
elif quant_key.dtype == FP4_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
|
||||
"Attention should have output_block_scale after FP4 fusion"
|
||||
)
|
||||
|
||||
# Check that results are close
|
||||
torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)
|
||||
124
tests/compile/test_graph_partition.py
Normal file
124
tests/compile/test_graph_partition.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from vllm.compilation.backends import split_graph
|
||||
|
||||
|
||||
def test_getitem_moved_to_producer_subgraph():
|
||||
"""
|
||||
Test that getitem operations are moved to the same subgraph as their input,
|
||||
preventing tuple inputs to submodules.
|
||||
"""
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
# torch.split returns a tuple, creating real getitem operations
|
||||
# Should become first submodule that produces tuple
|
||||
chunks = torch.split(x, x.shape[0] // 2, dim=0)
|
||||
|
||||
# Following ops should become second submodule that consumes tuple
|
||||
result_0 = torch.relu(chunks[0])
|
||||
result_1 = torch.relu(chunks[1])
|
||||
return torch.cat([result_0, result_1], dim=0)
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
gm = make_fx(model_fn)(x)
|
||||
|
||||
has_getitem = any(
|
||||
node.op == "call_function" and node.target == operator.getitem
|
||||
for node in gm.graph.nodes
|
||||
)
|
||||
assert has_getitem, "Test setup failed: graph should contain getitem operations"
|
||||
|
||||
# Split on tuple producer aten::split
|
||||
split_ops = ["aten::split.Tensor"]
|
||||
split_gm, split_items = split_graph(gm, split_ops)
|
||||
assert len(split_items) == 2, "Graph should be split into 2 submodules"
|
||||
|
||||
for split_item in split_items:
|
||||
submodule = split_item.graph
|
||||
|
||||
getitem_on_placeholder = []
|
||||
for node in submodule.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == operator.getitem
|
||||
and node.args[0].op == "placeholder"
|
||||
):
|
||||
getitem_on_placeholder.append(node)
|
||||
|
||||
assert len(getitem_on_placeholder) == 0, (
|
||||
f"Submodule {split_item.submod_name} has getitem operations on "
|
||||
f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. "
|
||||
"This means tuple inputs were not properly eliminated."
|
||||
)
|
||||
|
||||
new_x = torch.randn(4, 3)
|
||||
output_original = gm(new_x)
|
||||
output_split = split_gm(new_x)
|
||||
|
||||
assert torch.allclose(output_original, output_split), "Output mismatch"
|
||||
|
||||
|
||||
def test_no_tuple_inputs_with_multiple_consumers():
|
||||
"""
|
||||
Test that when a tuple is consumed by multiple split operations,
|
||||
getitem operations are properly moved to avoid tuple inputs.
|
||||
"""
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
# torch.split returns a tuple, creating real getitem operations
|
||||
# Should become first submodule that produces tuple
|
||||
chunks = torch.split(x, x.shape[0] // 2, dim=0)
|
||||
|
||||
# These should become second submodule consuming tuple
|
||||
result_1 = torch.relu(chunks[0])
|
||||
result_2 = torch.relu(chunks[1])
|
||||
|
||||
# Artificial graph splitting point to create another
|
||||
# independent submodule that consumes tuple later
|
||||
# This would become the third submodule
|
||||
result_1 = torch.sigmoid(result_1)
|
||||
|
||||
# Fourth submodule that consumes tuple
|
||||
result = torch.cat([chunks[0], chunks[1], result_1, result_2])
|
||||
return result
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
gm = make_fx(model_fn)(x)
|
||||
|
||||
has_getitem = any(
|
||||
node.op == "call_function" and node.target == operator.getitem
|
||||
for node in gm.graph.nodes
|
||||
)
|
||||
assert has_getitem, "Test setup failed: graph should contain getitem operations"
|
||||
|
||||
split_ops = ["aten::split.Tensor", "aten::sigmoid"]
|
||||
split_gm, split_items = split_graph(gm, split_ops)
|
||||
assert len(split_items) == 4, "Graph should be split into 4 submodules"
|
||||
|
||||
for split_item in split_items:
|
||||
submodule = split_item.graph
|
||||
|
||||
for node in submodule.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == operator.getitem
|
||||
and node.args[0].op == "placeholder"
|
||||
):
|
||||
pytest.fail(
|
||||
f"Submodule {split_item.submod_name} has getitem on "
|
||||
f"placeholder {node.args[0].name}, indicating it receives "
|
||||
"a tuple input"
|
||||
)
|
||||
|
||||
new_x = torch.randn(4, 3)
|
||||
output_original = gm(new_x)
|
||||
output_split = split_gm(new_x)
|
||||
|
||||
assert torch.allclose(output_original, output_split), "Output mismatch after split"
|
||||
115
tests/compile/test_noop_elimination.py
Normal file
115
tests/compile/test_noop_elimination.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
# Important edge case is when `num_tokens == buffer_size`
|
||||
@pytest.mark.parametrize(
|
||||
("num_tokens", "buffer_size"), [(256, 256), (256, 512), (1024, 1024), (1024, 1025)]
|
||||
)
|
||||
@pytest.mark.parametrize("hidden_size", [64, 4096])
|
||||
def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pos_embed = torch.empty(buffer_size, hidden_size, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x += self.pos_embed[: x.shape[0]]
|
||||
# Chain of reshapes
|
||||
y = x.reshape(-1, 128, 32)
|
||||
z = y.reshape(-1, 4096)
|
||||
# No-op reshape
|
||||
a = z.reshape(-1, 4096)
|
||||
# Final reshape that should remain
|
||||
b = a.reshape(-1, 128, 32)
|
||||
# No-op slice
|
||||
c = b[0 : b.shape[0]]
|
||||
# The pass should replace the result of this op with `c`.
|
||||
d = torch.slice_scatter(
|
||||
torch.ones_like(c), # Dummy tensor to be scattered into
|
||||
c, # Source tensor
|
||||
0, # dim
|
||||
0, # start
|
||||
c.shape[0], # end
|
||||
)
|
||||
return d
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass)
|
||||
|
||||
model = Model()
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
|
||||
model2 = torch.compile(model, backend=backend)
|
||||
result2 = model2(x)
|
||||
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
|
||||
# The no-op reshape and slice should be eliminated.
|
||||
# The initial slice on the positional embedding should remain.
|
||||
# The chain of reshapes should be fused into a single reshape.
|
||||
assert backend.op_count(torch.ops.aten.reshape.default) == 1
|
||||
assert backend.op_count(torch.ops.aten.slice.Tensor) == 1
|
||||
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0
|
||||
|
||||
|
||||
def test_non_noop_slice_preserved():
|
||||
"""Ensure that a slice with end=-1 (dropping last row) is NOT eliminated.
|
||||
|
||||
Regression test for a bug where end=-1 was treated like an inferred
|
||||
dimension (reshape semantics) leading to incorrect elimination.
|
||||
"""
|
||||
torch.set_default_device("cuda")
|
||||
x = torch.randn(16, 16)
|
||||
|
||||
class SliceModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
base = x.clone()
|
||||
src = torch.ones(15, 16)
|
||||
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
|
||||
return x[0:-1, :], y
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
backend = TestBackend(noop_pass)
|
||||
model = SliceModel()
|
||||
ref = model(x)
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
out = compiled(x)
|
||||
torch.testing.assert_close(ref, out)
|
||||
# The slice should remain (not a no-op).
|
||||
assert backend.op_count(torch.ops.aten.slice.Tensor) == 1
|
||||
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 1
|
||||
83
tests/compile/test_pass_manager.py
Normal file
83
tests/compile/test_pass_manager.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.compilation.inductor_pass import (
|
||||
CallableInductorPass,
|
||||
InductorPass,
|
||||
pass_context,
|
||||
)
|
||||
from vllm.compilation.pass_manager import PostGradPassManager
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
|
||||
|
||||
# dummy custom pass that doesn't inherit
|
||||
def simple_callable(graph: torch.fx.Graph):
|
||||
pass
|
||||
|
||||
|
||||
# Should fail to add directly to the pass manager
|
||||
def test_bad_callable():
|
||||
config = VllmConfig()
|
||||
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
pass_manager.add(simple_callable)
|
||||
|
||||
|
||||
# Pass that inherits from InductorPass
|
||||
class ProperPass(InductorPass):
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callable",
|
||||
[
|
||||
ProperPass(),
|
||||
# Can also wrap callables in CallableInductorPass for compliance
|
||||
CallableInductorPass(simple_callable),
|
||||
CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)),
|
||||
],
|
||||
)
|
||||
def test_pass_manager_uuid(callable):
|
||||
# Set the pass context as PassManager uuid uses it
|
||||
with pass_context(Range(start=1, end=8)):
|
||||
# Some passes need dtype to be set
|
||||
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
|
||||
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
|
||||
# Check that UUID is different if the same pass is added 2x
|
||||
pass_manager.add(callable)
|
||||
uuid1 = pass_manager.uuid()
|
||||
pass_manager.add(callable)
|
||||
uuid2 = pass_manager.uuid()
|
||||
assert uuid1 != uuid2
|
||||
|
||||
# UUID should be the same as the original one,
|
||||
# as we constructed in the same way.
|
||||
pass_manager2 = PostGradPassManager()
|
||||
pass_manager2.configure(config)
|
||||
pass_manager2.add(callable)
|
||||
assert uuid1 == pass_manager2.uuid()
|
||||
|
||||
# UUID should be different due to config change
|
||||
config2 = copy.deepcopy(config)
|
||||
config2.compilation_config.pass_config.fuse_norm_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_norm_quant
|
||||
)
|
||||
config2.compilation_config.pass_config.fuse_act_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_act_quant
|
||||
)
|
||||
pass_manager3 = PostGradPassManager()
|
||||
pass_manager3.configure(config2)
|
||||
pass_manager3.add(callable)
|
||||
assert uuid1 != pass_manager3.uuid()
|
||||
196
tests/compile/test_qk_norm_rope_fusion.py
Normal file
196
tests/compile/test_qk_norm_rope_fusion.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.qk_norm_rope_fusion import (
|
||||
FUSED_QK_ROPE_OP,
|
||||
QKNormRoPEFusionPass,
|
||||
)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RSQRT_OP = torch.ops.aten.rsqrt.default
|
||||
INDEX_SELECT_OP = torch.ops.aten.index.Tensor
|
||||
|
||||
|
||||
class QKNormRoPETestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
eps: float,
|
||||
is_neox: bool,
|
||||
vllm_config: VllmConfig,
|
||||
dtype: torch.dtype,
|
||||
prefix: str = "model.layers.0.self_attn.attn",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_size = num_heads * head_dim
|
||||
self.kv_size = num_kv_heads * head_dim
|
||||
self.rotary_dim = head_dim
|
||||
self.eps = eps
|
||||
self.dtype = dtype
|
||||
|
||||
# Register layer metadata for the fusion pass via Attention.
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=1.0 / self.head_dim**0.5,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=vllm_config.cache_config,
|
||||
prefix=prefix,
|
||||
attn_type=AttentionType.DECODER,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=self.eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=self.eps)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000,
|
||||
is_neox_style=is_neox,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.enable_rms_norm_custom_op = self.q_norm.enabled()
|
||||
self.enable_rope_custom_op = self.rotary_emb.enabled()
|
||||
|
||||
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
return q, k, v
|
||||
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
ops = []
|
||||
if self.enable_rms_norm_custom_op:
|
||||
ops.append(RMS_OP)
|
||||
else:
|
||||
ops.append(RSQRT_OP)
|
||||
|
||||
if self.enable_rope_custom_op:
|
||||
if self.rotary_emb.use_flashinfer:
|
||||
ops.append(FLASHINFER_ROTARY_OP)
|
||||
else:
|
||||
ops.append(ROTARY_OP)
|
||||
else:
|
||||
ops.append(INDEX_SELECT_OP)
|
||||
return ops
|
||||
|
||||
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
||||
return [FUSED_QK_ROPE_OP]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("is_neox", [True, False])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Only test on cuda and rocm platform",
|
||||
)
|
||||
def test_qk_norm_rope_fusion(
|
||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
||||
):
|
||||
if not hasattr(torch.ops._C, "fused_qk_norm_rope"):
|
||||
pytest.skip("fused_qk_norm_rope custom op not available")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
custom_ops: list[str] = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_rope_custom_op:
|
||||
custom_ops.append("+rotary_embedding")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||
T = 5
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = QKNormRoPETestModel(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
eps=eps,
|
||||
is_neox=is_neox,
|
||||
vllm_config=vllm_config,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = QKNormRoPEFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend_baseline = TestBackend(noop_pass, cleanup_pass)
|
||||
|
||||
qkv = torch.randn(T, model.q_size + 2 * model.kv_size)
|
||||
pos = torch.arange(T, dtype=torch.long, device=qkv.device)
|
||||
qkv_unfused = qkv.clone()
|
||||
pos_unfused = pos.clone()
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv, 0)
|
||||
torch._dynamo.mark_dynamic(pos, 0)
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
q_fused, k_fused, v_fused = model_fused(qkv, pos)
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv_unfused, 0)
|
||||
torch._dynamo.mark_dynamic(pos_unfused, 0)
|
||||
model_unfused = torch.compile(model, backend=backend_baseline)
|
||||
q_unfused, k_unfused, v_unfused = model_unfused(qkv_unfused, pos_unfused)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
260
tests/compile/test_silu_mul_quant_fusion.py
Normal file
260
tests/compile/test_silu_mul_quant_fusion.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||
from vllm._aiter_ops import IS_AITER_FOUND
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
FUSED_OPS,
|
||||
SILU_MUL_OP,
|
||||
ActivationQuantFusionPass,
|
||||
)
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def is_nvfp4_supported():
|
||||
return current_platform.has_device_capability(100)
|
||||
|
||||
|
||||
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
|
||||
(
|
||||
QUANT_OPS[kFp8StaticTensorSym]
|
||||
if self.enable_quant_fp8_custom_op
|
||||
else torch.ops.aten.reciprocal
|
||||
),
|
||||
]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [FUSED_OPS[kFp8StaticTensorSym]]
|
||||
|
||||
|
||||
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
||||
super().__init__()
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
silu_and_mul_nvfp4_quant_supported,
|
||||
)
|
||||
|
||||
assert silu_and_mul_nvfp4_quant_supported
|
||||
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
|
||||
# create nvfp4 weight
|
||||
w = torch.rand((hidden_size, hidden_size))
|
||||
self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w)
|
||||
|
||||
# get global scale offline
|
||||
_, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x))
|
||||
|
||||
self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
|
||||
out = cutlass_scaled_fp4_mm(
|
||||
a=y_quant,
|
||||
b=self.w,
|
||||
block_scale_a=y_block_scale,
|
||||
block_scale_b=self.w_block_scale,
|
||||
alpha=self.alpha,
|
||||
out_dtype=y.dtype,
|
||||
)
|
||||
return out
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
|
||||
QUANT_OPS[kNvfp4Quant],
|
||||
]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [FUSED_OPS[kNvfp4Quant]]
|
||||
|
||||
|
||||
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(128, 128),
|
||||
act_quant_group_shape=GroupShape(1, 128),
|
||||
cutlass_block_fp8_supported=False,
|
||||
use_aiter_and_is_supported=True,
|
||||
)
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
|
||||
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
||||
self.wscale = torch.rand(
|
||||
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
|
||||
)
|
||||
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
|
||||
]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [32, 64])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
|
||||
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
|
||||
+ [
|
||||
(TestSiluMulNvfp4QuantModel, False, False),
|
||||
(TestSiluMulGroupFp8QuantModel, False, False),
|
||||
],
|
||||
)
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.skipif(
|
||||
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
|
||||
)
|
||||
def test_fusion_silu_and_mul_quant(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
model_class: type[
|
||||
TestSiluMulFp8QuantModel
|
||||
| TestSiluMulNvfp4QuantModel
|
||||
| TestSiluMulGroupFp8QuantModel
|
||||
],
|
||||
enable_silu_mul_custom_op: bool,
|
||||
enable_quant_fp8_custom_op: bool,
|
||||
cuda_force_torch: bool,
|
||||
):
|
||||
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
|
||||
pytest.skip("NVFP4 is not supported on this GPU.")
|
||||
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||
pytest.skip("AITER is not supported on this GPU.")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
maybe_create_device_identity()
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
custom_ops = []
|
||||
if enable_silu_mul_custom_op:
|
||||
custom_ops.append("+silu_and_mul")
|
||||
if enable_quant_fp8_custom_op:
|
||||
custom_ops.append("+quant_fp8")
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(config):
|
||||
fusion_passes = [ActivationQuantFusionPass(config)]
|
||||
if IS_AITER_FOUND:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
||||
backend = TestBackend(*passes)
|
||||
model = model_class(
|
||||
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
|
||||
)
|
||||
|
||||
# First dimension dynamic
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
|
||||
model2 = torch.compile(model, backend=backend)
|
||||
result2 = model2(x)
|
||||
|
||||
# Check that it gives the same answer
|
||||
if model_class == TestSiluMulFp8QuantModel:
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||
atol, rtol = 1e-1, 1e-1
|
||||
elif model_class == TestSiluMulGroupFp8QuantModel:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(
|
||||
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
assert sum([p.matched_count for p in fusion_passes]) == 1
|
||||
|
||||
# In pre-nodes, quant op should be present and fused kernels should not
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, fused kernels should be present and quant op should not
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
135
tests/compile/test_wrapper.py
Normal file
135
tests/compile/test_wrapper.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
class MyMod(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
|
||||
if x.size()[0] >= 4:
|
||||
return x * 2
|
||||
else:
|
||||
return x * 100
|
||||
|
||||
|
||||
class MyWrapper(TorchCompileWithNoGuardsWrapper):
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor): # type: ignore[override]
|
||||
# this is the function to be compiled
|
||||
return self.model(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch):
|
||||
"""Test basic functionality of TorchCompileWithNoGuardsWrapper."""
|
||||
# Set the environment variable for this test
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
# Create a proper vLLM config instead of mocking
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig()
|
||||
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
||||
vllm_config.compilation_config.backend = "inductor"
|
||||
|
||||
# Test DYNAMO_TRACE_ONCE
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch._dynamo.reset()
|
||||
mod = MyMod()
|
||||
wrapper = MyWrapper(mod)
|
||||
|
||||
# First call should trigger compilation
|
||||
x = torch.tensor([1, 2, 3, 4])
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result1 = wrapper(x)
|
||||
expected1 = torch.tensor([2, 4, 6, 8])
|
||||
assert torch.allclose(result1, expected1), (
|
||||
f"Expected {expected1}, got {result1}"
|
||||
)
|
||||
|
||||
# Second call should use compiled code
|
||||
x2 = torch.tensor([1, 2, 3])
|
||||
result2 = wrapper(x2)
|
||||
expected2 = torch.tensor([2, 4, 6])
|
||||
assert torch.allclose(result2, expected2), (
|
||||
f"Expected {expected2}, got {result2}"
|
||||
)
|
||||
|
||||
# without the wrapper result would be different.
|
||||
result3 = mod(x2)
|
||||
expected3 = torch.tensor([100, 200, 300])
|
||||
|
||||
assert torch.allclose(result3, expected3), (
|
||||
f"Expected {result3}, got {expected3}"
|
||||
)
|
||||
|
||||
# with STOCK_TORCH_COMPILE we do not remove guards.
|
||||
vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
|
||||
torch._dynamo.reset()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod = MyMod()
|
||||
wrapper = MyWrapper(mod)
|
||||
|
||||
# First call should trigger compilation
|
||||
x = torch.tensor([1, 2, 3, 4])
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result1 = wrapper(x)
|
||||
expected1 = torch.tensor([2, 4, 6, 8])
|
||||
assert torch.allclose(result1, expected1), (
|
||||
f"Expected {expected1}, got {result1}"
|
||||
)
|
||||
|
||||
# Second call should triger another compilation
|
||||
x2 = torch.tensor([1, 2, 3])
|
||||
result2 = wrapper(x2)
|
||||
expected2 = torch.tensor([100, 200, 300])
|
||||
assert torch.allclose(result2, expected2), (
|
||||
f"Expected {expected2}, got {result2}"
|
||||
)
|
||||
|
||||
# NO_COMPILATION level not supported.
|
||||
vllm_config.compilation_config.mode = None
|
||||
torch._dynamo.reset()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch._dynamo.reset()
|
||||
mod = MyMod()
|
||||
|
||||
try:
|
||||
wrapper = MyWrapper(mod)
|
||||
except Exception:
|
||||
return
|
||||
raise AssertionError("expected an exception to be raised")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run with both parameter values
|
||||
|
||||
class MockMonkeypatch:
|
||||
def setenv(self, name, value):
|
||||
os.environ[name] = value
|
||||
|
||||
mp = MockMonkeypatch()
|
||||
|
||||
print("Testing with VLLM_USE_BYTECODE_HOOK=False")
|
||||
test_torch_compile_wrapper(False, mp)
|
||||
|
||||
print("Testing with VLLM_USE_BYTECODE_HOOK=True")
|
||||
test_torch_compile_wrapper(True, mp)
|
||||
|
||||
print("All tests passed!")
|
||||
4
tests/config/test_config.yaml
Normal file
4
tests/config/test_config.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
port: 12312
|
||||
served_model_name: mymodel
|
||||
tensor_parallel_size: 2
|
||||
trust_remote_code: true
|
||||
75
tests/config/test_config_generation.py
Normal file
75
tests/config/test_config_generation.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.quantization.quark.utils import deep_compare
|
||||
|
||||
|
||||
def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that configs created with normal (untouched) CUDA_VISIBLE_DEVICES
|
||||
and CUDA_VISIBLE_DEVICES="" are equivalent. This ensures consistent
|
||||
behavior regardless of whether GPU visibility is disabled via empty string
|
||||
or left in its normal state.
|
||||
"""
|
||||
|
||||
def create_config():
|
||||
engine_args = EngineArgs(
|
||||
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||
)
|
||||
return engine_args.create_engine_config()
|
||||
|
||||
# Create config with CUDA_VISIBLE_DEVICES set normally
|
||||
normal_config = create_config()
|
||||
|
||||
# Create config with CUDA_VISIBLE_DEVICES=""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("CUDA_VISIBLE_DEVICES", "")
|
||||
empty_config = create_config()
|
||||
|
||||
normal_config_dict = vars(normal_config)
|
||||
empty_config_dict = vars(empty_config)
|
||||
|
||||
# Remove instance_id before comparison as it's expected to be different
|
||||
normal_config_dict.pop("instance_id", None)
|
||||
empty_config_dict.pop("instance_id", None)
|
||||
|
||||
assert deep_compare(normal_config_dict, empty_config_dict), (
|
||||
'Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=""'
|
||||
" should be equivalent"
|
||||
)
|
||||
|
||||
|
||||
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
||||
# In testing, this method needs to be nested inside as ray does not
|
||||
# see the test module.
|
||||
def create_config():
|
||||
engine_args = EngineArgs(
|
||||
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||
)
|
||||
return engine_args.create_engine_config()
|
||||
|
||||
config = create_config()
|
||||
parallel_config = config.parallel_config
|
||||
assert parallel_config.ray_runtime_env is None
|
||||
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
runtime_env = {
|
||||
"env_vars": {
|
||||
"TEST_ENV_VAR": "test_value",
|
||||
},
|
||||
}
|
||||
|
||||
config_ref = ray.remote(create_config).options(runtime_env=runtime_env).remote()
|
||||
|
||||
config = ray.get(config_ref)
|
||||
parallel_config = config.parallel_config
|
||||
assert parallel_config.ray_runtime_env is not None
|
||||
assert (
|
||||
parallel_config.ray_runtime_env.env_vars().get("TEST_ENV_VAR") == "test_value"
|
||||
)
|
||||
|
||||
ray.shutdown()
|
||||
166
tests/config/test_config_utils.py
Normal file
166
tests/config/test_config_utils.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors, normalize_value
|
||||
|
||||
# Helpers
|
||||
|
||||
|
||||
def endswith_fqname(obj, suffix: str) -> bool:
|
||||
# normalize_value(type) returns fully-qualified name
|
||||
# Compare suffix to avoid brittle import paths.
|
||||
out = normalize_value(obj)
|
||||
return isinstance(out, str) and out.endswith(suffix)
|
||||
|
||||
|
||||
def expected_path(p_str: str = ".") -> str:
|
||||
import pathlib
|
||||
|
||||
p = pathlib.Path(p_str)
|
||||
return p.expanduser().resolve().as_posix()
|
||||
|
||||
|
||||
# Minimal dataclass to test get_hash_factors.
|
||||
# Avoid importing heavy vLLM configs.
|
||||
@dataclass
|
||||
class SimpleConfig:
|
||||
a: object
|
||||
b: object | None = None
|
||||
|
||||
|
||||
class DummyLogprobsMode(Enum):
|
||||
RAW_LOGITS = "raw_logits"
|
||||
|
||||
|
||||
def test_hash_factors_deterministic():
|
||||
"""Test that hash_factors produces consistent SHA-256 hashes"""
|
||||
factors = {"a": 1, "b": "test"}
|
||||
hash1 = hash_factors(factors)
|
||||
hash2 = hash_factors(factors)
|
||||
|
||||
assert hash1 == hash2
|
||||
# Dict key insertion order should not affect the hash.
|
||||
factors_reordered = {"b": "test", "a": 1}
|
||||
assert hash_factors(factors_reordered) == hash1
|
||||
assert len(hash1) == 64
|
||||
assert all(c in "0123456789abcdef" for c in hash1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inp, expected",
|
||||
[
|
||||
(None, None),
|
||||
(True, True),
|
||||
(1, 1),
|
||||
(1.0, 1.0),
|
||||
("x", "x"),
|
||||
(b"ab", "6162"),
|
||||
(bytearray(b"ab"), "6162"),
|
||||
([1, 2], (1, 2)),
|
||||
({"b": 2, "a": 1}, (("a", 1), ("b", 2))),
|
||||
],
|
||||
)
|
||||
def test_normalize_value_matrix(inp, expected):
|
||||
"""Parametric input→expected normalization table."""
|
||||
assert normalize_value(inp) == expected
|
||||
|
||||
|
||||
def test_normalize_value_enum():
|
||||
# Enums normalize to (module.QualName, value).
|
||||
# DummyLogprobsMode uses a string payload.
|
||||
out = normalize_value(DummyLogprobsMode.RAW_LOGITS)
|
||||
assert isinstance(out, tuple)
|
||||
assert out[0].endswith("DummyLogprobsMode")
|
||||
# Expect string payload 'raw_logits'.
|
||||
assert out[1] == "raw_logits"
|
||||
|
||||
|
||||
def test_normalize_value_set_order_insensitive():
|
||||
# Sets are unordered; normalize_value sorts elements for determinism.
|
||||
assert normalize_value({3, 1, 2}) == normalize_value({1, 2, 3})
|
||||
|
||||
|
||||
def test_normalize_value_path_normalization():
|
||||
from pathlib import Path # local import to avoid global dependency
|
||||
|
||||
# Paths expand/resolve to absolute strings.
|
||||
# Stabilizes hashing across working dirs.
|
||||
assert normalize_value(Path(".")) == expected_path(".")
|
||||
|
||||
|
||||
def test_normalize_value_uuid_and_to_json():
|
||||
# Objects may normalize via uuid() or to_json_string().
|
||||
class HasUUID:
|
||||
def uuid(self):
|
||||
return "test-uuid"
|
||||
|
||||
class ToJson:
|
||||
def to_json_string(self):
|
||||
return '{"x":1}'
|
||||
|
||||
assert normalize_value(HasUUID()) == "test-uuid"
|
||||
assert normalize_value(ToJson()) == '{"x":1}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad",
|
||||
[
|
||||
(lambda x: x),
|
||||
(type("CallableInstance", (), {"__call__": lambda self: 0}))(),
|
||||
(lambda: (lambda: 0))(), # nested function instance
|
||||
],
|
||||
)
|
||||
def test_error_cases(bad):
|
||||
"""Inputs expected to raise TypeError."""
|
||||
# Reject functions/lambdas/callable instances
|
||||
# to avoid under-hashing.
|
||||
with pytest.raises(TypeError):
|
||||
normalize_value(bad)
|
||||
|
||||
|
||||
def test_enum_vs_int_disambiguation():
|
||||
# int stays primitive
|
||||
nf_int = normalize_value(1)
|
||||
assert nf_int == 1
|
||||
|
||||
# enum becomes ("module.QualName", value)
|
||||
nf_enum = normalize_value(DummyLogprobsMode.RAW_LOGITS)
|
||||
assert isinstance(nf_enum, tuple) and len(nf_enum) == 2
|
||||
enum_type, enum_val = nf_enum
|
||||
assert enum_type.endswith(".DummyLogprobsMode")
|
||||
assert enum_val == "raw_logits"
|
||||
|
||||
# Build factor dicts from configs with int vs enum
|
||||
f_int = get_hash_factors(SimpleConfig(1), set())
|
||||
f_enum = get_hash_factors(SimpleConfig(DummyLogprobsMode.RAW_LOGITS), set())
|
||||
# The int case remains a primitive value
|
||||
assert f_int["a"] == 1
|
||||
# The enum case becomes a tagged tuple ("module.QualName", "raw_logits")
|
||||
assert isinstance(f_enum["a"], tuple) and f_enum["a"][1] == "raw_logits"
|
||||
# Factor dicts must differ so we don't collide primitives with Enums.
|
||||
assert f_int != f_enum
|
||||
# Hash digests must differ correspondingly
|
||||
assert hash_factors(f_int) != hash_factors(f_enum)
|
||||
|
||||
# Hash functions produce stable hex strings
|
||||
h_int = hash_factors(f_int)
|
||||
h_enum = hash_factors(f_enum)
|
||||
assert isinstance(h_int, str) and len(h_int) == 64
|
||||
assert isinstance(h_enum, str) and len(h_enum) == 64
|
||||
|
||||
|
||||
def test_classes_are_types():
|
||||
"""Types normalize to FQNs; include real vLLM types."""
|
||||
# Only classes allowed; functions/lambdas are rejected.
|
||||
# Canonical form is the fully-qualified name.
|
||||
assert isinstance(normalize_value(str), str)
|
||||
|
||||
class LocalDummy:
|
||||
pass
|
||||
|
||||
assert endswith_fqname(LocalDummy, ".LocalDummy")
|
||||
6
tests/config/test_config_with_model.yaml
Normal file
6
tests/config/test_config_with_model.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
# Same as test_config.yaml but with model specified
|
||||
model: config-model
|
||||
port: 12312
|
||||
served_model_name: mymodel
|
||||
tensor_parallel_size: 2
|
||||
trust_remote_code: true
|
||||
53
tests/config/test_mp_reducer.py
Normal file
53
tests/config/test_mp_reducer.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
|
||||
def test_mp_reducer():
|
||||
"""
|
||||
Test that _reduce_config reducer is registered when AsyncLLM is instantiated
|
||||
without transformers_modules. This is a regression test for
|
||||
https://github.com/vllm-project/vllm/pull/18640.
|
||||
"""
|
||||
|
||||
# Ensure transformers_modules is not in sys.modules
|
||||
if "transformers_modules" in sys.modules:
|
||||
del sys.modules["transformers_modules"]
|
||||
|
||||
with patch("multiprocessing.reducer.register") as mock_register:
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
max_model_len=32,
|
||||
gpu_memory_utilization=0.1,
|
||||
disable_log_stats=True,
|
||||
)
|
||||
|
||||
async_llm = AsyncLLM.from_engine_args(
|
||||
engine_args,
|
||||
start_engine_loop=False,
|
||||
)
|
||||
|
||||
assert mock_register.called, (
|
||||
"multiprocessing.reducer.register should have been called"
|
||||
)
|
||||
|
||||
vllm_config_registered = False
|
||||
for call_args in mock_register.call_args_list:
|
||||
# Verify that a reducer for VllmConfig was registered
|
||||
if len(call_args[0]) >= 2 and call_args[0][0] == VllmConfig:
|
||||
vllm_config_registered = True
|
||||
|
||||
reducer_func = call_args[0][1]
|
||||
assert callable(reducer_func), "Reducer function should be callable"
|
||||
break
|
||||
|
||||
assert vllm_config_registered, (
|
||||
"VllmConfig should have been registered to multiprocessing.reducer"
|
||||
)
|
||||
|
||||
async_llm.shutdown()
|
||||
25
tests/config/test_multimodal_config.py
Normal file
25
tests/config/test_multimodal_config.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_str_conversion():
|
||||
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
|
||||
assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
MultiModalConfig(mm_encoder_attn_backend="not_a_backend")
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_hash_updates():
|
||||
base_hash = MultiModalConfig().compute_hash()
|
||||
overridden_hash = MultiModalConfig(
|
||||
mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN
|
||||
).compute_hash()
|
||||
assert base_hash != overridden_hash
|
||||
1640
tests/conftest.py
1640
tests/conftest.py
File diff suppressed because it is too large
Load Diff
@@ -1,12 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def should_do_global_cleanup_after_test() -> bool:
|
||||
"""Disable the global cleanup fixture for tests in this directory. This
|
||||
provides a ~10x speedup for unit tests that don't load a model to GPU.
|
||||
|
||||
This requires that tests in this directory clean up after themselves if they
|
||||
use the GPU.
|
||||
"""
|
||||
return False
|
||||
@@ -1,41 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed)
|
||||
|
||||
|
||||
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
distinct_llm_kwargs, seed):
|
||||
kwargs = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**distinct_llm_kwargs,
|
||||
}
|
||||
|
||||
def generator_inner():
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
del llm
|
||||
@@ -1,455 +0,0 @@
|
||||
from itertools import cycle
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Allow only 5 sequences of ~1024 tokens in worst case.
|
||||
"block_size": 16,
|
||||
"num_gpu_blocks_override": 5 * (64 + 1),
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
|
||||
@pytest.mark.parametrize("batch_size", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator, batch_size):
|
||||
"""Verify block manager v2 produces same outputs as block manager v1, even
|
||||
when there is preemption.
|
||||
|
||||
This constructs two LLM, each with limited number of GPU blocks. The limit
|
||||
is decided such that as the sequences in the batch grow, sequences must be
|
||||
preempted and removed from cache.
|
||||
|
||||
If the output token ids are equivalent, then we have confidence that the KV
|
||||
cache is not corrupted in the v2 block manager.
|
||||
|
||||
NOTE: We want a significant number of generated tokens so that any incorrect
|
||||
KV mapping has time to build up error.
|
||||
"""
|
||||
output_len = 1024
|
||||
temperature = 0.0
|
||||
|
||||
# We want to ensure equality even with preemption.
|
||||
# We force the total block size to be 1 + cdiv(output_len, block_size)
|
||||
# so that only one sequence can fit at a time (once the sequences grow).
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids from block manager v2')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Use a large block size to trigger more copy-on-writes.
|
||||
"block_size": 32,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
|
||||
@pytest.mark.parametrize("batch_size", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
|
||||
test_llm_generator, batch_size):
|
||||
"""Verify beam search equality with block manager v1 and v2.
|
||||
|
||||
This requires copy-on-writes; if the v1 and v2 output is the same, then
|
||||
we have some confidence cow is working.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
use_beam_search=True,
|
||||
best_of=2,
|
||||
)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids from block manager v2')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# Our prompts will generate 128 tokens; since the prompts themselves are
|
||||
# small, we don't need much KV space beyond 128.
|
||||
"max_model_len": 160,
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Lookahead scheduling only supported in v2 block manager.
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"block_size": 16,
|
||||
|
||||
# Allow only 2 sequences of ~128 tokens in worst case.
|
||||
# Note 8 = 128/block_size
|
||||
"num_gpu_blocks_override": 2 * (8 + 1),
|
||||
},
|
||||
{
|
||||
"block_size": 8,
|
||||
|
||||
# Allow only 2 sequences of ~128 tokens in worst case.
|
||||
# Note 16 = 128/block_size
|
||||
"num_gpu_blocks_override": 2 * (16 + 1),
|
||||
}
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"num_lookahead_slots": 0,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[{
|
||||
# We run one test with block_size < lookahead_slots, one test with
|
||||
# block_size > lookahead_slots
|
||||
"num_lookahead_slots": 10,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size):
|
||||
"""Verify vLLM produces the same output with greedy sampling, when lookahead
|
||||
scheduling is used vs. not.
|
||||
|
||||
Lookahead scheduling is not expected to modify the output, as it simply
|
||||
allocates empty slots ahead of the known token ids in a sliding fashion.
|
||||
|
||||
This test constrains the total number of blocks to force preemption. It also
|
||||
varies the block size so that the lookahead size is less than and greater
|
||||
than the block size.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids without lookahead scheduling')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids with lookahead scheduling')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 2,
|
||||
"max_num_seqs": 2,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [
|
||||
{
|
||||
"use_v2_block_manager": False,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"use_v2_block_manager": True,
|
||||
"num_lookahead_slots": 0,
|
||||
},
|
||||
{
|
||||
"use_v2_block_manager": True,
|
||||
"num_lookahead_slots": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
|
||||
test_llm_generator, batch_size):
|
||||
"""Verify that chunked prefill works with BlockManagerV2, with and without
|
||||
lookahead scheduling.
|
||||
"""
|
||||
output_len = 32
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids with BlockManagerV1')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids with BlockManagerV2')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Allow only 5 sequences of ~1024 tokens in worst case.
|
||||
"block_size": 16,
|
||||
"num_gpu_blocks_override": 5 * (64 + 1),
|
||||
|
||||
# Enable prefill cache
|
||||
"enable_prefix_caching": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
|
||||
@pytest.mark.parametrize("batch_size", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
|
||||
baseline_llm_generator, test_llm_generator, batch_size):
|
||||
"""Verify block manager v2 produces same outputs as block manager v1, even
|
||||
when there is preemption.
|
||||
|
||||
This constructs two LLM, each with limited number of GPU blocks. The limit
|
||||
is decided such that as the sequences in the batch grow, sequences must be
|
||||
preempted and removed from cache.
|
||||
|
||||
If the output token ids are equivalent, then we have confidence that the KV
|
||||
cache is not corrupted in the v2 block manager.
|
||||
|
||||
NOTE: We want a significant number of generated tokens so that any incorrect
|
||||
KV mapping has time to build up error.
|
||||
"""
|
||||
output_len = 1024
|
||||
temperature = 0.0
|
||||
|
||||
# We want to ensure equality even with preemption.
|
||||
# We force the total block size to be 1 + cdiv(output_len, block_size)
|
||||
# so that only one sequence can fit at a time (once the sequences grow).
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids from block manager v2')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Allow only 5 sequences of ~1024 tokens in worst case.
|
||||
"block_size": 16,
|
||||
"num_gpu_blocks_override": 5 * (64 + 1),
|
||||
|
||||
# Test APC in v2 block
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"enable_prefix_caching": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}])
|
||||
@pytest.mark.parametrize("batch_size", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator, batch_size):
|
||||
"""Verify block manager v2 with auto prefix caching enabled produces same
|
||||
outputs as auto prefix caching disabled, even when there is preemption.
|
||||
|
||||
This constructs two LLM, each with limited number of GPU blocks. The limit
|
||||
is decided such that as the sequences in the batch grow, sequences must be
|
||||
preempted and removed from cache.
|
||||
|
||||
If the output token ids are equivalent, then we have confidence that auto
|
||||
prefix caching itself at least don't cause result error.
|
||||
"""
|
||||
output_len = 1024
|
||||
temperature = 0.0
|
||||
|
||||
# We want to ensure equality even with preemption.
|
||||
# We force the total block size to be 1 + cdiv(output_len, block_size)
|
||||
# so that only one sequence can fit at a time (once the sequences grow).
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids with APC disabled')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids with APC enabled')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
del llm
|
||||
|
||||
return token_ids
|
||||
@@ -1,103 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.sequence import Logprob, SequenceStatus
|
||||
from vllm.utils import chunk_list
|
||||
|
||||
from ..utils import create_seq_group
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
|
||||
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
|
||||
@pytest.mark.parametrize("watermark", [0.0, 0.5])
|
||||
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
|
||||
num_gpu_blocks: int, watermark: float):
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
watermark=watermark,
|
||||
)
|
||||
num_watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
num_output_blocks_per_seq = 1
|
||||
|
||||
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
|
||||
# the current implementation assumes all seqs are new prompts / don't have
|
||||
# different output lens.
|
||||
num_output_blocks = num_output_blocks_per_seq
|
||||
|
||||
for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=block_size * num_prompt_blocks,
|
||||
seq_output_lens=[
|
||||
block_size * num_output_blocks_per_seq
|
||||
for _ in range(num_seqs_per_group)
|
||||
],
|
||||
)
|
||||
|
||||
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
|
||||
|
||||
can_allocate_result = block_manager.can_allocate(seq_group)
|
||||
|
||||
num_required_blocks = num_prompt_blocks + num_output_blocks
|
||||
|
||||
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
|
||||
assert can_allocate_result == AllocStatus.NEVER
|
||||
elif num_gpu_blocks >= num_required_blocks:
|
||||
assert can_allocate_result == AllocStatus.OK
|
||||
else:
|
||||
assert can_allocate_result == AllocStatus.LATER
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("prompt_len", [1, 7, 8])
|
||||
@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129])
|
||||
@pytest.mark.parametrize("num_lookahead_slots", [0, 10])
|
||||
def test_append_slots(block_size, prompt_len, num_slots_to_append,
|
||||
num_lookahead_slots):
|
||||
"""Verify append_slots consumes the correct number of blocks from the block
|
||||
table.
|
||||
"""
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
watermark = 0.1
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
watermark=watermark,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=prompt_len,
|
||||
seq_output_lens=[0],
|
||||
)
|
||||
|
||||
# Allocate seq
|
||||
assert block_manager.can_allocate(seq_group)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Seq seq to RUNNING
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
# Append tokens to the sequeqnce
|
||||
for token_id in range(num_slots_to_append):
|
||||
seq.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
|
||||
# Append slots for new tokens and lookahead slots.
|
||||
free_blocks_before_append = block_manager.get_num_free_gpu_blocks()
|
||||
block_manager.append_slots(seq, num_lookahead_slots)
|
||||
num_consumed_blocks = (free_blocks_before_append -
|
||||
block_manager.get_num_free_gpu_blocks())
|
||||
|
||||
# Expect consumed blocks to be new blocks required to support the new slots.
|
||||
expected_consumed_blocks = len(
|
||||
chunk_list(
|
||||
list(
|
||||
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
|
||||
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
|
||||
assert num_consumed_blocks == expected_consumed_blocks
|
||||
@@ -1,575 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.block_table import BlockTable
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.utils import Device, cdiv, chunk_list
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
def test_allocate_naive(block_size: int, sequence_len: int):
|
||||
"""Test the allocation of blocks using the naive allocator.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size and
|
||||
number of blocks. It then allocates multiple BlockTables with varying
|
||||
sequence lengths and verifies that the number of free blocks decreases as
|
||||
expected after each allocation.
|
||||
"""
|
||||
assert block_size > 1
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="naive",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
|
||||
|
||||
block_tables = []
|
||||
for i in range(5):
|
||||
assert allocator.get_num_free_blocks(
|
||||
device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc
|
||||
|
||||
block_tables.append(
|
||||
BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
))
|
||||
block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
def test_allocate_prefix_caching(block_size: int, sequence_len: int):
|
||||
"""Test the allocation of blocks using the prefix caching allocator.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size and
|
||||
number of blocks, using the prefix caching allocator. It then allocates
|
||||
multiple BlockTables with varying sequence lengths and verifies that the
|
||||
number of free blocks decreases as expected after each allocation.
|
||||
|
||||
The test expects all sequences to share allocations, except for their last
|
||||
block, which may be mutable. It calculates the expected number of immutable
|
||||
and mutable blocks per allocation based on the sequence length and block
|
||||
size.
|
||||
"""
|
||||
assert block_size > 1
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="prefix_caching",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
chunked_tokens = list(chunk_list(token_ids, block_size))
|
||||
num_mutable_blocks_per_alloc = 0 if len(
|
||||
chunked_tokens[-1]) == block_size else 1
|
||||
num_immutable_blocks_per_alloc = len(
|
||||
chunked_tokens) - num_mutable_blocks_per_alloc
|
||||
|
||||
block_tables = []
|
||||
for alloc_i in range(1, 6):
|
||||
|
||||
block_tables.append(
|
||||
BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
))
|
||||
block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
# Expect all sequences to share allocations, except for their last block
|
||||
# (which may be mutable).
|
||||
assert allocator.get_num_free_blocks(
|
||||
device=Device.GPU) == num_gpu_blocks - (
|
||||
num_immutable_blocks_per_alloc + num_mutable_blocks_per_alloc *
|
||||
(alloc_i))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
@pytest.mark.parametrize("device", ["cpu", "gpu"])
|
||||
def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str,
|
||||
device: str):
|
||||
"""Test the allocation and freeing of blocks using different allocators and
|
||||
devices.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size,
|
||||
number of blocks, allocator type, and device. It then allocates a BlockTable
|
||||
multiple times with the same sequence and verifies that the number of free
|
||||
blocks remains consistent after each allocation and freeing.
|
||||
"""
|
||||
device = Device[device.upper()]
|
||||
|
||||
num_device_blocks = 1024
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_device_blocks,
|
||||
num_cpu_blocks=num_device_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
block_table.allocate(token_ids=token_ids, device=device)
|
||||
assert allocator.get_num_free_blocks(
|
||||
device) == num_device_blocks - num_blocks_per_alloc
|
||||
assert all(block_id is not None
|
||||
for block_id in block_table.physical_block_ids)
|
||||
|
||||
block_table.free()
|
||||
assert allocator.get_num_free_blocks(device) == num_device_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("append_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_append_token_ids_allocation(block_size: int, sequence_len: int,
|
||||
append_len: int, allocator_type: str):
|
||||
"""Test the allocation behavior when appending token IDs to a BlockTable.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size,
|
||||
number of blocks, and allocator type. It then allocates a BlockTable with an
|
||||
initial sequence and appends additional token IDs to it. The test verifies
|
||||
that the number of allocated blocks before and after appending matches the
|
||||
expected values.
|
||||
"""
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(append_len))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
num_expected_blocks_before_append = len(
|
||||
list(chunk_list(token_ids, block_size)))
|
||||
num_expected_appended_blocks = len(
|
||||
list(chunk_list(token_ids + token_ids_to_append,
|
||||
block_size))) - num_expected_blocks_before_append
|
||||
|
||||
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
assert len(
|
||||
block_table.physical_block_ids) == num_expected_blocks_before_append
|
||||
block_table.append_token_ids(token_ids_to_append)
|
||||
assert len(
|
||||
block_table.physical_block_ids
|
||||
) == num_expected_blocks_before_append + num_expected_appended_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("num_empty_slots", [1, 16, 129])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int,
|
||||
num_empty_slots: int,
|
||||
allocator_type: str):
|
||||
"""Test the allocation behavior when ensuring a certain number of empty
|
||||
slots in a BlockTable.
|
||||
|
||||
This test creates a CpuGpuBlockAllocator with the specified block size,
|
||||
number of blocks, and allocator type. It then allocates a BlockTable with an
|
||||
initial sequence and ensures a certain number of empty slots. The test
|
||||
verifies that the number of allocated blocks before and after ensuring empty
|
||||
slots matches the expected values. It also checks that filling up the empty
|
||||
slots does not consume additional blocks.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
num_expected_blocks_before_append = len(
|
||||
list(chunk_list(token_ids, block_size)))
|
||||
num_expected_appended_blocks = len(
|
||||
list(chunk_list(token_ids + [-1] * num_empty_slots,
|
||||
block_size))) - num_expected_blocks_before_append
|
||||
|
||||
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
# Assert that the empty slots consume the expected number of additional
|
||||
# blocks.
|
||||
assert len(
|
||||
block_table.physical_block_ids) == num_expected_blocks_before_append
|
||||
block_table.ensure_num_empty_slots(num_empty_slots)
|
||||
assert len(
|
||||
block_table.physical_block_ids
|
||||
) == num_expected_blocks_before_append + num_expected_appended_blocks
|
||||
|
||||
# Now, ensure no additional blocks consumed as we fill up the empty slots.
|
||||
num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU)
|
||||
block_table.append_token_ids(token_ids=list(range(num_empty_slots)))
|
||||
assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 9])
|
||||
@pytest.mark.parametrize("append_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("append_size", [1, 4, 129])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_append_token_ids_correct_content(block_size: int, sequence_len: int,
|
||||
append_len: int, allocator_type: str,
|
||||
append_size: int):
|
||||
"""Verify token ids are correctly appended. Appends various amounts of
|
||||
token ids in various append sizes, and verifies the final sequence is
|
||||
correct.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(append_len))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
appended_so_far = []
|
||||
for append in chunk_list(token_ids_to_append, append_size):
|
||||
block_table.append_token_ids(append)
|
||||
appended_so_far.extend(append)
|
||||
|
||||
assert block_table._get_all_token_ids() == token_ids + appended_so_far
|
||||
|
||||
assert block_table._get_all_token_ids() == token_ids + token_ids_to_append
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_len", [1, 9, 129])
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_fork(seq_len: int, block_size: int, allocator_type: str):
|
||||
"""Create a sequence using the specified allocator.
|
||||
1. Assert that after forking the sequence, the free block count is the
|
||||
same.
|
||||
2. Assert that the forked sequence has the same physical mappings.
|
||||
3. Then free the original sequence; verify that the free block count is
|
||||
the same.
|
||||
4. Finally, free the forked sequence and verify that the free block
|
||||
count drops to zero.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(seq_len))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
block_table.allocate(token_ids)
|
||||
|
||||
num_free_blocks_before_fork = allocator.get_num_free_blocks(
|
||||
device=Device.GPU)
|
||||
|
||||
forked_block_table = block_table.fork()
|
||||
|
||||
# Expect physical_block_ids and token_ids to match.
|
||||
assert (block_table.physical_block_ids ==
|
||||
forked_block_table.physical_block_ids)
|
||||
assert block_table._get_all_token_ids(
|
||||
) == forked_block_table._get_all_token_ids()
|
||||
|
||||
# Do not expect any additional allocations.
|
||||
assert allocator.get_num_free_blocks(
|
||||
device=Device.GPU) == num_free_blocks_before_fork
|
||||
|
||||
# Free the original blocks. Assert num free blocks does not change, since
|
||||
# refcount is nonzero.
|
||||
block_table.free()
|
||||
assert allocator.get_num_free_blocks(
|
||||
device=Device.GPU) == num_free_blocks_before_fork
|
||||
|
||||
# Expect the forked block table to be unaffected by the free.
|
||||
assert all(block_id is not None
|
||||
for block_id in forked_block_table.physical_block_ids)
|
||||
|
||||
# Free the forked blocks. Assert num free blocks does change, since
|
||||
# refcount is now zero.
|
||||
forked_block_table.free()
|
||||
assert allocator.get_num_free_blocks(device=Device.GPU) == num_gpu_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("append_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("appender", ["forked", "original"])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_cow(block_size: int, sequence_len: int, append_len: int,
|
||||
allocator_type: str, appender: str):
|
||||
"""Fork a sequence; append to the forked sequence; verify there's a CoW.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(append_len))
|
||||
|
||||
original_block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
num_expected_non_cow_blocks = cdiv(sequence_len, block_size)
|
||||
num_expected_cow_blocks = cdiv(sequence_len + append_len,
|
||||
block_size) - (sequence_len // block_size)
|
||||
|
||||
original_block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
original_block_ids = original_block_table.physical_block_ids
|
||||
|
||||
forked_block_table = original_block_table.fork()
|
||||
|
||||
# Expect no additional allocation (copy on _write_).
|
||||
assert allocator.get_num_free_blocks(
|
||||
Device.GPU) == (num_gpu_blocks - num_expected_non_cow_blocks)
|
||||
|
||||
if appender == "forked":
|
||||
appender_block_table = forked_block_table
|
||||
static_block_table = original_block_table
|
||||
elif appender == "original":
|
||||
appender_block_table = original_block_table
|
||||
static_block_table = forked_block_table
|
||||
else:
|
||||
raise ValueError(f"unknown test config {appender=}")
|
||||
|
||||
# Write tokens.
|
||||
appender_block_table.append_token_ids(token_ids_to_append)
|
||||
|
||||
# Expect the non-appending block table to have no change.
|
||||
assert static_block_table.physical_block_ids == original_block_ids
|
||||
assert appender_block_table.physical_block_ids != original_block_ids
|
||||
|
||||
# Expect the blocks changed during append to have a CoW.
|
||||
assert allocator.get_num_free_blocks(
|
||||
Device.GPU) == num_gpu_blocks - (num_expected_non_cow_blocks +
|
||||
num_expected_cow_blocks)
|
||||
|
||||
cows = allocator.clear_copy_on_writes()
|
||||
if sequence_len % block_size > 0:
|
||||
# If the last block in the sequence is not full, then when appending we
|
||||
# expect a CoW.
|
||||
assert cows
|
||||
|
||||
cow_block_id = sequence_len // block_size
|
||||
expected_src = static_block_table.physical_block_ids[cow_block_id]
|
||||
expected_dst = appender_block_table.physical_block_ids[cow_block_id]
|
||||
|
||||
assert expected_src in cows
|
||||
assert expected_dst in cows[expected_src]
|
||||
else:
|
||||
# Otherwise, there should be no copy-on-write.
|
||||
assert not cows
|
||||
|
||||
static_block_table.free()
|
||||
appender_block_table.free()
|
||||
|
||||
# After free, expect all blocks to be freed.
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("append_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("lookahead_slots", [1, 16, 129])
|
||||
@pytest.mark.parametrize("appender", ["forked", "original"])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_cow_lookahead_simple(block_size: int, sequence_len: int,
|
||||
append_len: int, lookahead_slots: int,
|
||||
allocator_type: str, appender: str):
|
||||
"""Similar to test_cow, except with lookahead allocation. The assertions are
|
||||
less rigorous due to the complexity of the property under test.
|
||||
"""
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(append_len))
|
||||
|
||||
original_block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
original_block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
# Allocate lookahead slots.
|
||||
original_block_table.ensure_num_empty_slots(lookahead_slots)
|
||||
original_block_ids = original_block_table.physical_block_ids
|
||||
|
||||
forked_block_table = original_block_table.fork()
|
||||
|
||||
if appender == "forked":
|
||||
appender_block_table = forked_block_table
|
||||
static_block_table = original_block_table
|
||||
elif appender == "original":
|
||||
appender_block_table = original_block_table
|
||||
static_block_table = forked_block_table
|
||||
else:
|
||||
raise ValueError(f"unknown test config {appender=}")
|
||||
|
||||
# Write tokens.
|
||||
appender_block_table.append_token_ids(token_ids_to_append)
|
||||
|
||||
# Expect the non-appending block table to have no change.
|
||||
assert static_block_table.physical_block_ids == original_block_ids
|
||||
assert appender_block_table.physical_block_ids != original_block_ids
|
||||
|
||||
cows = allocator.clear_copy_on_writes()
|
||||
|
||||
# Always expect copy-on-write
|
||||
assert cows
|
||||
|
||||
if sequence_len % block_size > 0:
|
||||
# If the last block in the sequence is not full, then when appending we
|
||||
# expect a CoW.
|
||||
assert cows
|
||||
|
||||
cow_block_id = sequence_len // block_size
|
||||
expected_src = static_block_table.physical_block_ids[cow_block_id]
|
||||
expected_dst = appender_block_table.physical_block_ids[cow_block_id]
|
||||
|
||||
assert expected_src in cows
|
||||
assert expected_dst in cows[expected_src]
|
||||
|
||||
static_block_table.free()
|
||||
appender_block_table.free()
|
||||
|
||||
# After free, expect all blocks to be freed.
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("num_new_tokens", [1, 16, 129])
|
||||
@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int,
|
||||
num_new_tokens: int,
|
||||
num_lookahead_slots: int,
|
||||
allocator_type: str):
|
||||
"""Verify correct calculation of get_num_blocks_touched_by_append_slots.
|
||||
|
||||
This is done by using copy-on-write, which requires any modified block to
|
||||
be copied before write if the refcount > 1. We set the refcount>1 by forking
|
||||
a sequence, then measure the free blocks before and after an append. If the
|
||||
number of consumed blocks equals what `get_num_blocks_touched_by_append_
|
||||
slots` returns, then the calculation is correct.
|
||||
"""
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(num_new_tokens))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
# Add lookahead before fork so both sequences have the same lookahead
|
||||
# blocks.
|
||||
block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots)
|
||||
|
||||
# Fork sequence so that every block has refcount > 1.
|
||||
_ = block_table.fork()
|
||||
|
||||
# Determine how many blocks should be touched.
|
||||
expected_num_touched_blocks = (
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
token_ids=token_ids_to_append,
|
||||
num_lookahead_slots=num_lookahead_slots))
|
||||
|
||||
# Measure how many blocks are touched by measuring num_free_blocks before
|
||||
# and after the append.
|
||||
#
|
||||
# We expect append_token_ids to CoW all mutated blocks that have refcount>1.
|
||||
num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU)
|
||||
block_table.append_token_ids(token_ids_to_append, num_lookahead_slots)
|
||||
num_consumed_blocks = (num_free_blocks_before_append -
|
||||
allocator.get_num_free_blocks(Device.GPU))
|
||||
|
||||
# TODO(cade) ensure equality when num_lookahead_slots > 0.
|
||||
# The reason we have < is because lookahead blocks are not copied eagerly;
|
||||
# they are copied on first write. This will cause issues for beam search +
|
||||
# speculative decoding. This is acceptable for now as it is a large effort
|
||||
# to combine the two. To fix this, we can ensure single sequence ownership
|
||||
# of lookahead blocks by appending empty slots to each block, which will
|
||||
# trigger the CoW.
|
||||
#
|
||||
# Until then, we can accept that the consumed tokens are <= the expected
|
||||
# tokens when appending with lookahead.
|
||||
if num_lookahead_slots > 0:
|
||||
assert num_consumed_blocks <= expected_num_touched_blocks
|
||||
else:
|
||||
assert num_consumed_blocks == expected_num_touched_blocks
|
||||
@@ -1,42 +0,0 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.common import RefCounter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
@pytest.mark.parametrize("num_incrs", [1, 100])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
def test_incr(seed: int, num_incrs: int, num_blocks: int):
|
||||
random.seed(seed)
|
||||
|
||||
all_block_indices = list(range(num_blocks))
|
||||
counter = RefCounter(all_block_indices=all_block_indices)
|
||||
|
||||
block_id = random.randint(0, num_blocks - 1)
|
||||
for i in range(num_incrs):
|
||||
value = counter.incr(block_id)
|
||||
assert value == i + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
@pytest.mark.parametrize("num_incrs", [1, 100])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
def test_incr_decr(seed: int, num_incrs: int, num_blocks: int):
|
||||
random.seed(seed)
|
||||
|
||||
all_block_indices = list(range(num_blocks))
|
||||
counter = RefCounter(all_block_indices=all_block_indices)
|
||||
|
||||
block_id = random.randint(0, num_blocks - 1)
|
||||
for i in range(num_incrs):
|
||||
value = counter.incr(block_id)
|
||||
assert value == i + 1
|
||||
|
||||
for i in range(num_incrs):
|
||||
value = counter.decr(block_id)
|
||||
assert value == num_incrs - (i + 1)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
counter.decr(block_id)
|
||||
@@ -1,93 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.utils import Device, chunk_list
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_cpu_blocks", [0, 512])
|
||||
@pytest.mark.parametrize("num_gpu_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int,
|
||||
block_size: int, allocator_type: str):
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
cpu_blocks = [
|
||||
allocator.allocate_mutable(prev_block=None, device=Device.CPU)
|
||||
for _ in range(num_cpu_blocks)
|
||||
]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == 0
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
gpu_blocks = [
|
||||
allocator.allocate_mutable(prev_block=None, device=Device.GPU)
|
||||
for _ in range(num_gpu_blocks)
|
||||
]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == 0
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == 0
|
||||
|
||||
_ = [allocator.free(block) for block in cpu_blocks]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == 0
|
||||
|
||||
_ = [allocator.free(block) for block in gpu_blocks]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_cpu_blocks", [0, 512])
|
||||
@pytest.mark.parametrize("num_gpu_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [2])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int,
|
||||
block_size: int, allocator_type: str):
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
unique_token_ids = list(
|
||||
range((num_cpu_blocks + num_gpu_blocks) * block_size))
|
||||
gpu_token_ids = chunk_list(unique_token_ids[:num_gpu_blocks * block_size],
|
||||
block_size)
|
||||
cpu_token_ids = chunk_list(unique_token_ids[num_gpu_blocks * block_size:],
|
||||
block_size)
|
||||
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
cpu_blocks = [
|
||||
allocator.allocate_immutable(prev_block=None,
|
||||
token_ids=token_ids,
|
||||
device=Device.CPU)
|
||||
for token_ids in cpu_token_ids
|
||||
]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == 0
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
gpu_blocks = [
|
||||
allocator.allocate_immutable(prev_block=None,
|
||||
token_ids=token_ids,
|
||||
device=Device.GPU)
|
||||
for token_ids in gpu_token_ids
|
||||
]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == 0
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == 0
|
||||
|
||||
_ = [allocator.free(block) for block in cpu_blocks]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == 0
|
||||
|
||||
_ = [allocator.free(block) for block in gpu_blocks]
|
||||
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
@@ -1,102 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
|
||||
|
||||
class TestNaiveBlockAllocator:
|
||||
|
||||
@staticmethod
|
||||
def create_allocate_lambda(allocate_type: str,
|
||||
allocator: NaiveBlockAllocator,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int]):
|
||||
if allocate_type == "immutable":
|
||||
allocate_block = lambda: allocator.allocate_immutable(
|
||||
prev_block=prev_block, token_ids=token_ids)
|
||||
elif allocate_type == "mutable":
|
||||
allocate_block = lambda: allocator.allocate_mutable(prev_block=
|
||||
prev_block)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
return allocate_block
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_allocate_ooms(allocate_type: str, num_blocks: int,
|
||||
block_size: int):
|
||||
allocator = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
allocate_type,
|
||||
allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)))
|
||||
|
||||
[allocate_block() for _ in range(num_blocks)]
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocate_block()
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_free_prevents_oom(allocate_type: str, num_blocks: int,
|
||||
block_size: int):
|
||||
allocator = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
allocate_type,
|
||||
allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)))
|
||||
|
||||
blocks = [allocate_block() for _ in range(num_blocks)]
|
||||
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocate_block()
|
||||
|
||||
block_to_free = blocks.pop()
|
||||
|
||||
for _ in range(100):
|
||||
block_id = block_to_free.block_id
|
||||
allocator.free(block_to_free)
|
||||
assert block_to_free.block_id is None
|
||||
|
||||
new_block = allocate_block()
|
||||
assert new_block.block_id == block_id
|
||||
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocate_block()
|
||||
|
||||
block_to_free = new_block
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_get_num_free_blocks(allocate_type: str, num_blocks: int,
|
||||
block_size: int):
|
||||
allocator = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
allocate_type,
|
||||
allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)))
|
||||
|
||||
assert allocator.get_num_free_blocks() == num_blocks
|
||||
|
||||
blocks = [allocate_block() for _ in range(num_blocks)]
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
assert allocator.get_num_free_blocks() == i
|
||||
allocator.free(block)
|
||||
@@ -1,509 +0,0 @@
|
||||
import math
|
||||
import random
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
from vllm.core.block.prefix_caching_block import (PrefixCachingBlock,
|
||||
PrefixCachingBlockAllocator)
|
||||
|
||||
|
||||
class TestPrefixCachingBlock:
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
@pytest.mark.parametrize("is_curr_block_full", [True, False])
|
||||
def test_first_block_has_correct_content_hash(seed: int, block_size: int,
|
||||
is_curr_block_full: bool):
|
||||
"""Verify a block which is first in the sequence has the correct hash.
|
||||
"""
|
||||
random.seed(seed)
|
||||
num_to_fill = block_size if is_curr_block_full else random.randint(
|
||||
0, block_size - 1)
|
||||
token_ids = list(range(num_to_fill))
|
||||
mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator)
|
||||
|
||||
block_with_prev = PrefixCachingBlock(
|
||||
prev_block=None,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
prefix_caching_allocator=mock_allocator)
|
||||
|
||||
if is_curr_block_full:
|
||||
# Expect hash since block is full.
|
||||
assert block_with_prev.content_hash == (
|
||||
PrefixCachingBlock.hash_block_tokens(
|
||||
is_first_block=True,
|
||||
prev_block_hash=None,
|
||||
cur_block_token_ids=token_ids))
|
||||
else:
|
||||
# Do not expect hash since block is not full.
|
||||
assert block_with_prev.content_hash is None
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
@pytest.mark.parametrize("is_curr_block_full", [True, False])
|
||||
@pytest.mark.parametrize("prev_block_has_hash", [True, False])
|
||||
def test_nth_block_has_correct_content_hash(seed: int, block_size: int,
|
||||
is_curr_block_full: bool,
|
||||
prev_block_has_hash: bool):
|
||||
"""Verify a block which is not first in the sequence has the correct
|
||||
hash.
|
||||
"""
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
previous_block = MagicMock(spec=PrefixCachingBlock)
|
||||
prev_block_hash = random.randint(0, 1000)
|
||||
previous_block.content_hash = (prev_block_hash
|
||||
if prev_block_has_hash else None)
|
||||
|
||||
num_to_fill = block_size if is_curr_block_full else random.randint(
|
||||
0, block_size - 1)
|
||||
token_ids = list(range(num_to_fill))
|
||||
mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator)
|
||||
|
||||
block_with_prev = PrefixCachingBlock(
|
||||
prev_block=previous_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
prefix_caching_allocator=mock_allocator,
|
||||
)
|
||||
|
||||
if is_curr_block_full and prev_block_has_hash:
|
||||
# Expect hash since block is full and previous block has hash.
|
||||
assert (block_with_prev.content_hash ==
|
||||
PrefixCachingBlock.hash_block_tokens(
|
||||
is_first_block=False,
|
||||
prev_block_hash=prev_block_hash,
|
||||
cur_block_token_ids=token_ids))
|
||||
else:
|
||||
# Do not expect hash since block is not full or the previous block
|
||||
# does not have a hash.
|
||||
assert block_with_prev.content_hash is None
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("block_size", [1, 2, 16])
|
||||
@pytest.mark.parametrize("num_tokens", list(range(3)))
|
||||
@pytest.mark.parametrize("num_empty_trailing_blocks", [0, 1, 10])
|
||||
def test_blocks_have_correct_hash_in_chain(block_size: int,
|
||||
num_tokens: int,
|
||||
num_empty_trailing_blocks: int):
|
||||
"""Create two chains of logical blocks with the same contents.
|
||||
Assert the hashes are equal.
|
||||
"""
|
||||
random.seed(0)
|
||||
|
||||
token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)]
|
||||
|
||||
first_chain, second_chain = [
|
||||
TestPrefixCachingBlock.create_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
num_empty_trailing_blocks=num_empty_trailing_blocks)
|
||||
for _ in range(2)
|
||||
]
|
||||
|
||||
for first_chain_block, second_chain_block in zip(
|
||||
first_chain, second_chain):
|
||||
assert (first_chain_block.content_hash ==
|
||||
second_chain_block.content_hash)
|
||||
|
||||
if not first_chain or not second_chain:
|
||||
assert first_chain == second_chain
|
||||
assert num_tokens == 0
|
||||
|
||||
@staticmethod
|
||||
def create_chain(block_size: int,
|
||||
token_ids: List[int],
|
||||
num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]:
|
||||
"""Helper method which creates a chain of blocks.
|
||||
"""
|
||||
blocks = []
|
||||
num_blocks = math.ceil(
|
||||
len(token_ids) / block_size) + num_empty_trailing_blocks
|
||||
|
||||
if num_blocks == 0:
|
||||
return []
|
||||
|
||||
allocator = MagicMock(spec=PrefixCachingBlockAllocator)
|
||||
|
||||
prev_block = None
|
||||
for block_number in range(0, num_blocks):
|
||||
prev_block = PrefixCachingBlock(
|
||||
prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_size=block_size,
|
||||
prefix_caching_allocator=allocator,
|
||||
)
|
||||
|
||||
tokens_to_append = token_ids[block_number *
|
||||
block_size:(block_number + 1) *
|
||||
block_size]
|
||||
if tokens_to_append:
|
||||
prev_block.append_token_ids(tokens_to_append)
|
||||
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
class TestPrefixCachingBlockAllocator:
|
||||
|
||||
@staticmethod
|
||||
def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int]):
|
||||
if allocate_type == "immutable":
|
||||
allocate_block = lambda: allocator.allocate_immutable(
|
||||
prev_block=prev_block, token_ids=token_ids)
|
||||
elif allocate_type == "mutable":
|
||||
allocate_block = lambda: allocator.allocate_mutable(prev_block=
|
||||
prev_block)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
return allocate_block
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_allocate_mutable_ooms(num_blocks: int, block_size: int):
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda(
|
||||
allocate_type="mutable",
|
||||
allocator=allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)),
|
||||
)
|
||||
|
||||
[allocate_block() for _ in range(num_blocks)]
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocate_block()
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_allocate_immutable_does_not_oom_single_hash(
|
||||
num_blocks: int, block_size: int):
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda(
|
||||
allocate_type="immutable",
|
||||
allocator=allocator,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)),
|
||||
)
|
||||
|
||||
blocks = [allocate_block() for _ in range(num_blocks)]
|
||||
|
||||
# Expect no OOM. If these were mutable blocks, this would OOM.
|
||||
non_oom_block = allocate_block()
|
||||
|
||||
# Expect all blocks to have same physical block index.
|
||||
for block in blocks:
|
||||
assert (block.block_id == non_oom_block.block_id)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_allocate_immutable_ooms_many_hash(num_blocks: int,
|
||||
block_size: int):
|
||||
"""Consume all blocks using many different hashes/block content.
|
||||
|
||||
Do this by creating a sequence that is very long.
|
||||
Expect next block to OOM.
|
||||
"""
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
|
||||
# Create token ids that will exhaust all blocks.
|
||||
token_ids = list(range(num_blocks * block_size))
|
||||
|
||||
chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Expect allocation with unseen hash to fail.
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocator.allocate_immutable(prev_block=chain[-1],
|
||||
token_ids=list(range(block_size)))
|
||||
|
||||
# Expect mutable allocation to fail.
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocator.allocate_mutable(prev_block=chain[-1])
|
||||
|
||||
# Expect allocation of exact same chain to pass.
|
||||
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Expect physical block indices to be the same in both chains.
|
||||
assert chain and second_chain
|
||||
for first_chain_block, second_chain_block in zip(chain, second_chain):
|
||||
assert (first_chain_block.block_id == second_chain_block.block_id)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1, 1024])
|
||||
@pytest.mark.parametrize("block_size", [1, 16])
|
||||
def test_free_prevents_oom(num_blocks: int, block_size: int):
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
|
||||
# Create token ids that will exhaust all blocks.
|
||||
token_ids = list(range(num_blocks * block_size))
|
||||
|
||||
chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Expect mutable allocation to fail.
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocator.allocate_mutable(prev_block=None)
|
||||
|
||||
block_to_free = chain[-1]
|
||||
|
||||
# Expect free/allocate loop to succeed many times.
|
||||
for i in range(100):
|
||||
block_id = block_to_free.block_id
|
||||
allocator.free(block_to_free)
|
||||
assert block_to_free.block_id is None, i
|
||||
|
||||
new_block = allocator.allocate_mutable(prev_block=None)
|
||||
assert new_block.block_id == block_id, i
|
||||
|
||||
with pytest.raises(BlockAllocator.NoFreeBlocksError):
|
||||
allocator.allocate_mutable(prev_block=None)
|
||||
|
||||
block_to_free = new_block
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int):
|
||||
random.seed(seed)
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
num_blocks_to_consume = random.randint(1, num_blocks - 1)
|
||||
|
||||
# Create token ids that will exhaust all blocks.
|
||||
token_ids = list(range(num_blocks_to_consume * block_size))
|
||||
|
||||
chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Free each block in chain, assert num free blocks includes new free
|
||||
# block.
|
||||
for i, block in enumerate(chain):
|
||||
assert allocator.get_num_free_blocks() == (num_blocks -
|
||||
num_blocks_to_consume +
|
||||
i)
|
||||
allocator.free(block)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
def test_get_num_free_blocks_shared(num_blocks: int, block_size: int,
|
||||
seed: int):
|
||||
"""Verify sharing occurs by allocating two sequences that share prefixes
|
||||
and incrementally freeing blocks.
|
||||
"""
|
||||
random.seed(seed)
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
num_blocks_to_consume = random.randint(1, num_blocks - 1)
|
||||
|
||||
# Create token ids that will exhaust all blocks.
|
||||
token_ids = list(range(num_blocks_to_consume * block_size))
|
||||
|
||||
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Free each block in the first chain. Since all blocks are shared, the
|
||||
# free count should stay constant.
|
||||
for i, block in enumerate(first_chain):
|
||||
assert allocator.get_num_free_blocks() == (num_blocks -
|
||||
num_blocks_to_consume)
|
||||
allocator.free(block)
|
||||
|
||||
# Free each block in the second chain. Since the refcount is now zero,
|
||||
# the free count should increment with each free.
|
||||
for i, block in enumerate(second_chain):
|
||||
assert allocator.get_num_free_blocks() == (num_blocks -
|
||||
num_blocks_to_consume +
|
||||
i)
|
||||
allocator.free(block)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
def test_get_common_computed_block_ids(num_blocks: int, block_size: int,
|
||||
seed: int):
|
||||
"""Verify get_common_computed_block_ids could get correct result
|
||||
by create two immutable chain sharing prefix at specified pos,
|
||||
and compare whether we also could get right result
|
||||
from get_common_computed_block_ids.
|
||||
"""
|
||||
random.seed(seed)
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2,
|
||||
block_size=block_size)
|
||||
num_blocks_to_consume = random.randint(1, num_blocks - 1)
|
||||
|
||||
# Create token ids that will exhaust all blocks.
|
||||
token_ids = list(range(num_blocks_to_consume * block_size))
|
||||
blocks = list(range(num_blocks_to_consume))
|
||||
|
||||
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# mark all blocks in first chain as computed
|
||||
allocator.mark_blocks_as_computed(blocks)
|
||||
|
||||
# After zero_point, second_chain's token_ids would be set -1, which
|
||||
# make it different from here comparing with first_chain
|
||||
zero_point = random.randint(1, len(token_ids) - 1)
|
||||
zero_point_blocks = zero_point // block_size
|
||||
token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point)
|
||||
|
||||
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
first_computed_ids = [
|
||||
first_chain[i].block_id for i in range(num_blocks_to_consume)
|
||||
]
|
||||
second_computed_ids = [
|
||||
second_chain[i].block_id for i in range(num_blocks_to_consume)
|
||||
]
|
||||
res = allocator.get_common_computed_block_ids(
|
||||
[first_computed_ids, second_computed_ids])
|
||||
|
||||
assert (len(res) == zero_point_blocks)
|
||||
|
||||
# Test case where two last accessed times are equal
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("seed", list(range(20)))
|
||||
def test_eviction_order(num_blocks: int, block_size: int, seed: int):
|
||||
"""This test case simulate the two chain created and free in order,
|
||||
and together they would exhaust the initial freed blocks.
|
||||
|
||||
So the next block created after those two chain shall use the block
|
||||
from the first chain as that block has long access time.
|
||||
While first chain has two blocks, it shall pick up the last one, as
|
||||
it has larger token number.
|
||||
"""
|
||||
|
||||
random.seed(seed)
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
num_blocks_to_consume = num_blocks + 1
|
||||
|
||||
token_ids = list(range(num_blocks_to_consume * block_size))
|
||||
|
||||
num_blocks_in_first_chain = 2
|
||||
num_tokens_in_first_chain = block_size * num_blocks_in_first_chain
|
||||
# First chain takes the first block
|
||||
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids[:num_tokens_in_first_chain],
|
||||
allocator=allocator,
|
||||
)
|
||||
# There should only be one block allocated at this point
|
||||
assert allocator.get_num_free_blocks() == (num_blocks -
|
||||
num_blocks_in_first_chain)
|
||||
|
||||
# Set the last accessed time of the first block to 1
|
||||
blocks_ids = [block.block_id for block in first_chain]
|
||||
allocator.mark_blocks_as_accessed(blocks_ids, 1)
|
||||
|
||||
# Second chain takes the rest of the blocks
|
||||
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids[num_tokens_in_first_chain:-block_size],
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# There shouldn't be any blocks left at this point
|
||||
assert allocator.get_num_free_blocks() == (0)
|
||||
|
||||
assert len(first_chain) == num_blocks_in_first_chain
|
||||
last_block_id = first_chain[-1].block_id
|
||||
# Free each block in the first chain.
|
||||
for i, block in enumerate(first_chain):
|
||||
allocator.free(block)
|
||||
|
||||
# Set the last accessed time on all of the blocks in the second chain
|
||||
# to 2
|
||||
blocks_ids = [block.block_id for block in second_chain]
|
||||
allocator.mark_blocks_as_accessed(blocks_ids, 2)
|
||||
|
||||
# Free each block in the second chain.
|
||||
for i, block in enumerate(second_chain):
|
||||
allocator.free(block)
|
||||
|
||||
# Allocate a new block and check that it's the least recently used block
|
||||
# from the first chain.
|
||||
new_block = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids[-block_size:],
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
assert new_block[0].block_id == last_block_id
|
||||
|
||||
@staticmethod
|
||||
def create_immutable_chain(
|
||||
block_size: int,
|
||||
token_ids: List[int],
|
||||
allocator: PrefixCachingBlockAllocator,
|
||||
) -> List[PrefixCachingBlock]:
|
||||
"""Helper method which creates a chain of blocks.
|
||||
"""
|
||||
blocks = []
|
||||
num_blocks = math.ceil(len(token_ids) / block_size)
|
||||
|
||||
if num_blocks == 0:
|
||||
return []
|
||||
|
||||
prev_block = None
|
||||
for block_number in range(0, num_blocks):
|
||||
block_token_ids = token_ids[block_number *
|
||||
block_size:(block_number + 1) *
|
||||
block_size]
|
||||
prev_block = allocator.allocate_immutable(
|
||||
prev_block=prev_block, token_ids=block_token_ids)
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
||||
@@ -1,367 +0,0 @@
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
from vllm.core.block_manager_v1 import (BlockSpaceManagerV1,
|
||||
UncachedBlockAllocator)
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
from .utils import create_dummy_prompt
|
||||
|
||||
|
||||
def test_block_allocator_allocate():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
|
||||
# Allocate all available cpu blocks.
|
||||
num_free = num_cpu_blocks
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
for _ in range(num_cpu_blocks):
|
||||
block = cpu_allocator.allocate()
|
||||
num_free -= 1
|
||||
|
||||
assert block not in cpu_allocator.free_blocks
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cpu_allocator.allocate()
|
||||
|
||||
|
||||
def test_block_allocator_free():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
|
||||
# Allocate all available cpu blocks.
|
||||
blocks: List[PhysicalTokenBlock] = []
|
||||
for _ in range(num_cpu_blocks):
|
||||
block = cpu_allocator.allocate()
|
||||
blocks.append(block)
|
||||
assert block not in cpu_allocator.free_blocks
|
||||
|
||||
# Free all allocated cpu blocks.
|
||||
num_free = 0
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
for block in blocks:
|
||||
cpu_allocator.free(block)
|
||||
num_free += 1
|
||||
assert block in cpu_allocator.free_blocks
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cpu_allocator.free(block)
|
||||
|
||||
|
||||
def test_allocate():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
for i in range(num_gpu_blocks):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
assert block_manager.can_allocate(seq_group)
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
# Use watermark to reserve one gpu block.
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=1 / num_gpu_blocks)
|
||||
for i in range(num_gpu_blocks - 1):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
assert block_manager.can_allocate(seq_group)
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
|
||||
def test_append_slot_single_seq():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate single seq to gpu block.
|
||||
prompt, seq_group = create_dummy_prompt("1", block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Nothing to append. Sequence has no new logical blocks.
|
||||
assert block_manager.can_append_slots(seq_group)
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert not block_manager.append_slots(prompt)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_blocks == after_blocks
|
||||
|
||||
# Add block_size number of new tokens and append slot.
|
||||
for i in range(block_size):
|
||||
token_id = i + 5
|
||||
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
|
||||
assert block_manager.can_append_slots(seq_group)
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert not block_manager.append_slots(prompt)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_blocks - after_blocks == 1
|
||||
|
||||
|
||||
def test_append_slot_cow():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size=block_size,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate prompt to gpu block. There is one slot left in the block.
|
||||
prompt = Sequence(seq_id=1,
|
||||
prompt="one two three",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
block_size=block_size)
|
||||
|
||||
# Fork the sequence, such that a COW will be required when we append a new
|
||||
# token id.
|
||||
child = prompt.fork(new_seq_id=2)
|
||||
|
||||
# Allocate space for the sequence group.
|
||||
seq_group = SequenceGroup("1", [prompt, child], SamplingParams(),
|
||||
time.time(), time.perf_counter)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Fork and append a new token id. We expect a COW to be scheduled.
|
||||
token_id = 4
|
||||
child.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.fork(prompt, child)
|
||||
|
||||
assert block_manager.can_append_slots(seq_group)
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
|
||||
cows = block_manager.append_slots(child)
|
||||
assert cows
|
||||
for src_block, dst_blocks in cows.items():
|
||||
assert src_block not in dst_blocks
|
||||
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_blocks - after_blocks == 1
|
||||
|
||||
|
||||
def test_fork():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
prompt, seq_group = create_dummy_prompt("1",
|
||||
block_size - 1,
|
||||
block_size=block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Fork prompt and copy block tables.
|
||||
child = prompt.fork(2)
|
||||
block_manager.fork(prompt, child)
|
||||
assert block_manager.get_block_table(
|
||||
prompt) == block_manager.get_block_table(child)
|
||||
token_id = 4
|
||||
# Append token to child. Block is shared so copy on write occurs.
|
||||
child.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.append_slots(child)
|
||||
assert block_manager.get_block_table(
|
||||
prompt) != block_manager.get_block_table(child)
|
||||
|
||||
|
||||
def test_swap():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
|
||||
prompt.status = SequenceStatus.WAITING
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Emulate a forward pass by appending a single token.
|
||||
# The block manager then knows how many unprocessed
|
||||
# tokens will be written in the next forward pass.
|
||||
token_id = 0
|
||||
prompt.status = SequenceStatus.RUNNING
|
||||
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
|
||||
# Swap seq group from GPU -> CPU.
|
||||
gpu_blocks = block_manager.get_block_table(prompt)
|
||||
assert block_manager.can_swap_out(seq_group)
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_out(seq_group)
|
||||
assert list(mapping.keys()) == gpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
||||
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
|
||||
prompt.status = SequenceStatus.SWAPPED
|
||||
|
||||
# Swap seq group from CPU -> GPU.
|
||||
cpu_blocks = block_manager.get_block_table(prompt)
|
||||
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_in(seq_group)
|
||||
assert list(mapping.keys()) == cpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
|
||||
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
|
||||
|
||||
|
||||
def test_free():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
prompt, seq_group = create_dummy_prompt("1", block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Free allocated seq.
|
||||
prompt_blocks = len(block_manager.get_block_table(prompt))
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
block_manager.free(prompt)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert after_blocks == before_blocks + prompt_blocks
|
||||
|
||||
# Block table for freed seq is deleted.
|
||||
with pytest.raises(KeyError):
|
||||
block_manager.get_block_table(prompt)
|
||||
|
||||
|
||||
def test_reset():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same seq group on all available gpu blocks.
|
||||
original_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
for i in range(num_gpu_blocks):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.get_num_free_gpu_blocks() == 0
|
||||
|
||||
# Resetting block manager frees all allocated blocks.
|
||||
block_manager.reset()
|
||||
assert block_manager.get_num_free_gpu_blocks() == original_blocks
|
||||
|
||||
|
||||
def test_sliding_window_multi_seq():
|
||||
"""
|
||||
Tests that memory allocation and deallocation is handled
|
||||
correctly with multiple sequences that exceed the sliding
|
||||
window's capacity.
|
||||
"""
|
||||
block_size = 1
|
||||
num_cpu_blocks = 8
|
||||
num_gpu_blocks = 8
|
||||
sliding_window = 2
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
sliding_window=sliding_window,
|
||||
watermark=0)
|
||||
|
||||
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
|
||||
|
||||
parent = Sequence(1, "one two three", [0, 1, 2], block_size)
|
||||
seq_group = SequenceGroup("1", [parent], SamplingParams(), time.time(),
|
||||
None)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# the parent seq has len 3, but since sliding_window is 2,
|
||||
# we will use at most 2 blocks
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window
|
||||
|
||||
# Fork prompt and copy block tables.
|
||||
child = parent.fork(2)
|
||||
block_manager.fork(parent, child)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# forking does not increase memory consumption
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window
|
||||
|
||||
# assert both parent and child share all blocks
|
||||
assert block_manager.get_block_table(
|
||||
parent) == block_manager.get_block_table(child)
|
||||
|
||||
token_id = 4
|
||||
# Append token to child. Block is shared so copy on write occurs.
|
||||
child.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.append_slots(child)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# we will use now one block more. Each seq will use 2 blocks,
|
||||
# but only one can be shared
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window - 1
|
||||
|
||||
token_id = 5
|
||||
parent.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.append_slots(parent)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# no change, because both sequences are still just sharing one block
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window - 1
|
||||
|
||||
block_table_parent = block_manager.get_block_table(parent)
|
||||
block_table_child = block_manager.get_block_table(child)
|
||||
|
||||
assert block_table_parent != block_table_child
|
||||
|
||||
# assert both blocks are sharing the second-last block
|
||||
assert block_table_parent[-2] == block_table_child[-2]
|
||||
|
||||
# now let's clean up...
|
||||
block_manager.free(parent)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# We have freed one seq, reducing the ref count of two blocks by one.
|
||||
# One of the two was only used by the parent seq, so this is now free.
|
||||
# The child seq still consumes sliding_window blocks
|
||||
assert block_manager.get_num_free_gpu_blocks(
|
||||
) == num_gpu_blocks - sliding_window
|
||||
|
||||
# free all blocks
|
||||
block_manager.free(child)
|
||||
|
||||
# assert all blocks are free now
|
||||
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
|
||||
@@ -1,564 +0,0 @@
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.sequence import Logprob, SequenceGroup
|
||||
|
||||
from .utils import create_dummy_prompt
|
||||
|
||||
|
||||
def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def append_new_token(seq_group, token_id: int):
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
|
||||
|
||||
def test_simple():
|
||||
"""Verify basic scheduling works."""
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_model_len = 16
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
num_seq_group,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(num_seq_group):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
num_tokens = block_size * num_seq_group
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_tokens
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
for s in running:
|
||||
append_new_token(s, 1)
|
||||
|
||||
# Schedule seq groups generation.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_seq_group
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
|
||||
|
||||
def test_chunk():
|
||||
"""Verify prefills are chunked properly."""
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 80
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Verify the second request is chunked.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 60
|
||||
# Verify it is chunked.
|
||||
assert seq_group_meta[1].token_chunk_size == 4
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
# Only the first seq group has a new token appended.
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# One chunked prefill, and one decoding.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# The first one is prefill. Scheduler guarantees ordering.
|
||||
assert seq_group_meta[0].token_chunk_size == 56
|
||||
# The second one is a chunked prefill.
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 57
|
||||
|
||||
|
||||
def test_complex():
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 80
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# Verify the second request is chunked.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 60
|
||||
# Verify it is chunked.
|
||||
assert seq_group_meta[1].token_chunk_size == 4
|
||||
assert not running[0].is_prefill()
|
||||
assert running[1].is_prefill()
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
# Only the first seq group has a new token appended.
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# Add 2 more requsets.
|
||||
for i in range(2, 4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 3
|
||||
# The first one is the first chunked prefill.
|
||||
assert seq_group_meta[0].token_chunk_size == 7
|
||||
# The second one is the second new chunked prefill.
|
||||
assert seq_group_meta[1].token_chunk_size == 56
|
||||
# The last one is decode.
|
||||
assert seq_group_meta[2].token_chunk_size == 1
|
||||
# Two of them are in chunked prefill.
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
# The first 2 requests are now in decodine phase.
|
||||
append_new_token(running[0], 1)
|
||||
assert not running[0].is_prefill()
|
||||
append_new_token(running[1], 1)
|
||||
assert not running[1].is_prefill()
|
||||
# The third request is still in prefill stage.
|
||||
assert running[2].is_prefill()
|
||||
|
||||
|
||||
def test_maximal_decoding():
|
||||
"""Verify decoding requests are prioritized."""
|
||||
block_size = 4
|
||||
max_seqs = 2
|
||||
max_model_len = 2
|
||||
max_num_batched_tokens = 2
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# The first prefill is scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 1
|
||||
assert seq_group_meta[0].token_chunk_size == 2
|
||||
assert not running[0].is_prefill()
|
||||
assert running[1].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 2
|
||||
# Only the first seq group has a new token appended.
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# Create one more seq_group.
|
||||
_, seq_group = create_dummy_prompt("3", prompt_length=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
# The first decoding + second chunk is scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert not running[0].is_prefill()
|
||||
assert running[1].is_prefill()
|
||||
assert running[2].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 2
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# Decoding + running prefill is prioritized.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert not running[0].is_prefill()
|
||||
assert not running[1].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 2
|
||||
append_new_token(running[0], 1)
|
||||
append_new_token(running[1], 1)
|
||||
|
||||
# Only decoding is prioritized.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert not running[0].is_prefill()
|
||||
assert not running[1].is_prefill()
|
||||
assert out.num_prefill_groups == 0
|
||||
assert out.num_batched_tokens == 2
|
||||
append_new_token(running[0], 1)
|
||||
append_new_token(running[1], 1)
|
||||
|
||||
# After aborting the decoding request, the fcfs new prefill is prioritized.
|
||||
scheduler.abort_seq_group(running[0].request_id)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert not running[1].is_prefill()
|
||||
assert running[2].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 2
|
||||
|
||||
|
||||
def test_prompt_limit():
|
||||
"""Verify max_num_batched_tokens < max_model_len is possible."""
|
||||
block_size = 4
|
||||
max_seqs = 32
|
||||
max_model_len = 64
|
||||
max_num_batched_tokens = 32
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=48)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# The prompt length > max_num_batched_tokens should be still scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 1
|
||||
assert seq_group_meta[0].token_chunk_size == 32
|
||||
assert running[0].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 32
|
||||
|
||||
|
||||
def test_prompt_limit_exceed():
|
||||
block_size = 4
|
||||
max_seqs = 64
|
||||
max_model_len = 32
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
_, seq_group = create_dummy_prompt("2", prompt_length=48)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.ignored_seq_groups) == 1
|
||||
assert out.ignored_seq_groups[0] == seq_group
|
||||
|
||||
|
||||
def test_swap():
|
||||
"""Verify swapping works with chunked prefill requests"""
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
max_num_batched_tokens = 30
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# The request is chunked.
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# The running prefill is now swapped.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over new prefill.
|
||||
_, seq_group = create_dummy_prompt("2", prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
|
||||
|
||||
def test_running_prefill_prioritized_over_swap():
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
max_num_batched_tokens = 30
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# The request is chunked.
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
# The request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# The running prefill is now swapped.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
|
||||
# Add 1 more task. Swap is not possible, so prefill is running.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
|
||||
|
||||
_, seq_group2 = create_dummy_prompt("2", prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group2)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
|
||||
# Now although swap is possible, running prefill is prioritized.
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
|
||||
# Decoding is prioritized.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 1
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
|
||||
# Since we abort the sequence group, we can finally swap.
|
||||
scheduler.abort_seq_group(seq_group2.request_id)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
|
||||
|
||||
def test_chunked_prefill_preempt():
|
||||
"""Verify preempt works with chunked prefill requests"""
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
max_num_batched_tokens = 30
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# The request is chunked.
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
# The request should be preempted.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# The running prefill is now preempted.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
|
||||
# Make sure we can reschedule preempted request.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
assert seq_group.get_num_uncomputed_tokens() == 30
|
||||
|
||||
# We should be able to run prefill twice as it is chunked.
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return True
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert not seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
|
||||
def test_chunked_prefill_max_seqs():
|
||||
block_size = 4
|
||||
max_seqs = 2
|
||||
max_model_len = 80
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running = []
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=65)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
# The first prefill is chunked.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens
|
||||
assert len(get_sequence_groups(out)) == 1
|
||||
|
||||
# Add new requests.
|
||||
for i in range(4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=65)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Make sure only 2 requests are scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert not running[0].is_prefill()
|
||||
assert running[1].is_prefill()
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# Although we have enough token budget, we can only schedule max_seqs.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert seq_group_meta[0].token_chunk_size == 2
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert out.num_batched_tokens == 3
|
||||
assert len(get_sequence_groups(out)) == max_seqs
|
||||
assert not running[0].is_prefill()
|
||||
assert not running[1].is_prefill()
|
||||
@@ -1,900 +0,0 @@
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
|
||||
|
||||
from .utils import create_dummy_prompt
|
||||
|
||||
|
||||
def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def append_new_token(out, token_id: int):
|
||||
seq_groups = get_sequence_groups(out)
|
||||
for seq_group in seq_groups:
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
|
||||
|
||||
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def test_scheduler_add_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
|
||||
cache_config.num_cpu_blocks = 4
|
||||
cache_config.num_gpu_blocks = 4
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
# Add seq group to scheduler.
|
||||
num_seq_group = 4
|
||||
for i in range(num_seq_group):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
assert scheduler.get_num_unfinished_seq_groups() == i + 1
|
||||
|
||||
|
||||
def test_scheduler_abort_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 4
|
||||
cache_config.num_gpu_blocks = 4
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
# Add multiple seq groups to scheduler.
|
||||
num_seq_group = 4
|
||||
request_ids = set()
|
||||
for i in range(num_seq_group):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
request_ids.add(str(i))
|
||||
|
||||
# Abort all added seq groups.
|
||||
assert scheduler.get_num_unfinished_seq_groups() == num_seq_group
|
||||
scheduler.abort_seq_group(request_ids)
|
||||
assert scheduler.get_num_unfinished_seq_groups() == 0
|
||||
|
||||
|
||||
def test_scheduler_schedule_simple():
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(num_seq_group):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
num_tokens = block_size * num_seq_group
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_tokens
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Schedule seq groups generation.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_seq_group
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
append_new_token(out, 1)
|
||||
|
||||
|
||||
def test_scheduler_prefill_prioritized():
|
||||
"""Verify running batched tokens are not applied to prefill requests."""
|
||||
block_size = 4
|
||||
max_model_len = 30
|
||||
max_batched_num_tokens = 30
|
||||
scheduler_config = SchedulerConfig(max_batched_num_tokens, 2,
|
||||
max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 2
|
||||
cache_config.num_gpu_blocks = 2
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
_, seq_group_a = create_dummy_prompt("1", 1)
|
||||
scheduler.add_seq_group(seq_group_a)
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_a]
|
||||
|
||||
# Add a new prefill request B.
|
||||
_, seq_group_b = create_dummy_prompt("2", 30)
|
||||
scheduler.add_seq_group(seq_group_b)
|
||||
|
||||
# Verify prefill requests are prioritized. Since max_batched_num_tokens
|
||||
# is 1, new prefill request has to be scheduled first.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_b]
|
||||
|
||||
|
||||
def test_scheduler_schedule_preempt_abort():
|
||||
block_size = 4
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, 2, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 2
|
||||
cache_config.num_gpu_blocks = 2
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
seq_a, seq_group_a = create_dummy_prompt("1", block_size)
|
||||
seq_b, seq_group_b = create_dummy_prompt("2", block_size)
|
||||
scheduler.add_seq_group(seq_group_a)
|
||||
scheduler.add_seq_group(seq_group_b)
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
|
||||
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == 2
|
||||
assert scheduler.get_num_unfinished_seq_groups() == 2
|
||||
|
||||
# Append "generated" tokens, allowing the sequence to mark prompt tokens as
|
||||
# processed.
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Schedule seq groups generation and preempt seq group b.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_a]
|
||||
assert out.num_batched_tokens == 1
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == 1
|
||||
assert scheduler.get_num_unfinished_seq_groups() == 2
|
||||
|
||||
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
|
||||
scheduler.abort_seq_group("1")
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_b]
|
||||
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == 1
|
||||
assert scheduler.get_num_unfinished_seq_groups() == 1
|
||||
|
||||
|
||||
def test_scheduler_max_seqs():
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_seq_group = 2
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
all_seq_groups: List[SequenceGroup] = []
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(num_seq_group):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
|
||||
all_seq_groups.append(seq_group)
|
||||
|
||||
# Append 1 seq group
|
||||
scheduler.add_seq_group(all_seq_groups[0])
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Schedule seq groups generation.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Append 2 more seq group
|
||||
scheduler.add_seq_group(all_seq_groups[1])
|
||||
scheduler.add_seq_group(all_seq_groups[2])
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
# Only 1 seq group should be scheduled since max_seq_group is 2
|
||||
# and one is prompting.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
|
||||
|
||||
|
||||
def test_scheduler_delay_factor():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
# schedule first prompt
|
||||
seq_group_meta, seq_group = create_dummy_prompt("0",
|
||||
prompt_length=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert out.num_prefill_groups > 0
|
||||
assert seq_group_meta[0].request_id == '0'
|
||||
append_new_token(out, 1)
|
||||
|
||||
# wait for a second before scheduling next prompt
|
||||
time.sleep(1)
|
||||
seq_group_meta, seq_group = create_dummy_prompt("1",
|
||||
prompt_length=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
# second prompt should *not* be scheduled
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert out.num_prefill_groups == 0
|
||||
assert seq_group_meta[0].request_id == '0'
|
||||
append_new_token(out, 1)
|
||||
|
||||
# wait for more than 0.5 second and try again
|
||||
time.sleep(0.6)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert out.num_prefill_groups > 0
|
||||
assert seq_group_meta[0].request_id == '1'
|
||||
append_new_token(out, 1)
|
||||
|
||||
|
||||
def test_swapped_out_prioritized():
|
||||
scheduler = initialize_scheduler(max_num_seqs=6)
|
||||
# best_of=2 * 3 == 6 sequences.
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
append_new_token(out, 1)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "2"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 2
|
||||
assert out.num_batched_tokens == 2
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over prefill.
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
append_new_token(out, 1)
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 3
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
|
||||
|
||||
def initialize_scheduler(*,
|
||||
max_num_seqs=1000,
|
||||
max_token_budget=1000,
|
||||
max_model_len=1000,
|
||||
lora_config=None):
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs,
|
||||
max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||
return scheduler
|
||||
|
||||
|
||||
def create_token_budget(token_budget: int = 10000,
|
||||
max_num_seqs: int = 10000) -> SchedulingBudget:
|
||||
return SchedulingBudget(
|
||||
token_budget=token_budget,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
|
||||
|
||||
def add_token_budget(budget: SchedulingBudget,
|
||||
num_batched_tokens: int = 0,
|
||||
num_curr_seqs: int = 0):
|
||||
mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
|
||||
budget.add_num_batched_tokens(mock_seq_group.request_id,
|
||||
num_batched_tokens)
|
||||
budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)
|
||||
|
||||
|
||||
def test_prefill_schedule_max_prompt_len():
|
||||
"""
|
||||
Test prompt longer than max_prompt_len is aborted.
|
||||
"""
|
||||
scheduler = initialize_scheduler(max_model_len=30)
|
||||
_, seq_group = create_dummy_prompt(0, prompt_length=60)
|
||||
waiting = deque([seq_group])
|
||||
budget = create_token_budget()
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 1
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remaining_waiting) == 0
|
||||
|
||||
|
||||
def test_prefill_schedule_token_budget():
|
||||
"""
|
||||
Test token budget respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget(token_budget=0)
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
|
||||
# 0 token budget == nothing is scheduled.
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remaining_waiting) == 2
|
||||
|
||||
# 60 token budget == 1 request scheduled.
|
||||
budget = create_token_budget(token_budget=60)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 1
|
||||
assert budget.num_batched_tokens == 60
|
||||
assert budget.num_curr_seqs == 1
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
# Test when current_batched_tokens respected.
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget(token_budget=60)
|
||||
add_token_budget(budget, 30, 0)
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
# Cannot schedule a prompt that doesn't fit the budget.
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 30
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remaining_waiting) == 1
|
||||
budget = create_token_budget(token_budget=90)
|
||||
add_token_budget(budget, 30, 0)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.seq_groups) == 1
|
||||
assert budget.num_batched_tokens == 90
|
||||
assert budget.num_curr_seqs == 1
|
||||
assert len(remaining_waiting) == 0
|
||||
|
||||
|
||||
def test_prefill_schedule_max_seqs():
|
||||
"""
|
||||
Test max seq respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 120
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
# Verify curr_num_seqs respected.
|
||||
waiting = deque()
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
add_token_budget(budget, 0, 2)
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
|
||||
def test_prefill_schedule_max_lora():
|
||||
"""
|
||||
Test max lora is respected and prioritized.
|
||||
"""
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
waiting = deque()
|
||||
budget = create_token_budget(token_budget=120)
|
||||
curr_loras = set()
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
lora_request=LoRARequest(
|
||||
lora_name=str(i),
|
||||
lora_int_id=i + 1,
|
||||
lora_local_path="abc"))
|
||||
waiting.append(seq_group)
|
||||
# Add two more requests to verify lora is prioritized.
|
||||
# 0: Lora, 1: Lora, 2: regular, 3: regular
|
||||
# In the first iteration, index 0, 2 is scheduled.
|
||||
# If a request is not scheduled because it hits max lora, it is
|
||||
# prioritized. Verify that.
|
||||
for i in range(2, 4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
# Schedule 2 requests (0 and 2)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, curr_loras)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 120
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(remaining_waiting) == 2
|
||||
assert len(curr_loras) == 1
|
||||
# The second lora request is scheduled next as FCFS policy.
|
||||
# Reset curr_loras so that it can be scheduled.
|
||||
curr_loras = set()
|
||||
budget = create_token_budget(token_budget=60)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
remaining_waiting, budget, curr_loras)
|
||||
assert len(output.seq_groups) == 1
|
||||
assert output.seq_groups[0].seq_group.request_id == "1"
|
||||
assert len(remaining_waiting) == 1
|
||||
assert len(curr_loras) == 1
|
||||
assert budget.num_batched_tokens == 60
|
||||
|
||||
|
||||
def test_prefill_schedule_no_block_manager_capacity():
|
||||
"""
|
||||
Test sequence cannot be scheduled due to block manager has no capacity.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.block_manager.can_allocate = MagicMock()
|
||||
scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
|
||||
remainig_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remainig_waiting) == 3
|
||||
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.block_manager.can_allocate = MagicMock()
|
||||
scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 3
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remaining_waiting) == 0
|
||||
|
||||
|
||||
def test_decode_schedule_preempted():
|
||||
"""
|
||||
Test decodes cannot be scheduled and preempted.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
running.append(seq_group)
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# 1 cannot be scheduled, and the lowest priority (request 2)
|
||||
# should be preempted. 1 will also be preempted.
|
||||
budget = create_token_budget()
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
assert len(remainig_running) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert output.decode_seq_groups[0].seq_group.request_id == "0"
|
||||
assert len(output.preempted) == 2
|
||||
# Verify budgets are updated.
|
||||
assert budget.num_batched_tokens == 1
|
||||
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
|
||||
# assert budget.num_curr_seqs == 1
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == {}
|
||||
|
||||
|
||||
def test_decode_swap_beam_search():
|
||||
"""
|
||||
Test best_of > 1 swap out blocks
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
running.append(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
budget.add_num_seqs(seq_group.request_id,
|
||||
seq_group.get_max_num_running_seqs())
|
||||
budget.add_num_batched_tokens(
|
||||
seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING))
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "2"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
scheduler.block_manager.swap_out = MagicMock()
|
||||
expected_swap_mapping = {"5": "7"}
|
||||
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
||||
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
assert len(remainig_running) == 0
|
||||
assert len(output.decode_seq_groups) == 2
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert output.decode_seq_groups[0].seq_group.request_id == "0"
|
||||
assert output.decode_seq_groups[1].seq_group.request_id == "1"
|
||||
assert len(output.preempted) == 0
|
||||
assert len(output.swapped_out) == 1
|
||||
# Budget should refledct preempted requests.
|
||||
assert budget.num_batched_tokens == 2
|
||||
# since there are 2 sequences, 2 should be subtracted.
|
||||
assert budget.num_curr_seqs == 4
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == expected_swap_mapping
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == {}
|
||||
|
||||
|
||||
def test_schedule_decode_blocks_to_copy_update():
|
||||
"""
|
||||
Verify blocks_to_copy is updated.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
running = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
running.append(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.append_slots = MagicMock()
|
||||
scheduler.block_manager.append_slots.return_value = {2: [3]}
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
assert len(remaining_running) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert len(output.preempted) == 0
|
||||
assert len(output.swapped_out) == 0
|
||||
# Nothing is preempted.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
# Since append_slot returns the source -> dist mapping, it should
|
||||
# applied.
|
||||
assert output.blocks_to_copy == {2: [3]}
|
||||
|
||||
|
||||
def test_schedule_swapped_simple():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 0
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
# swap in is the reverse of swap out
|
||||
blocks_to_swap_in_reverse = {}
|
||||
for swapin, swapout in output.blocks_to_swap_in.items():
|
||||
blocks_to_swap_in_reverse[swapout] = swapin
|
||||
assert blocks_to_swap_out == blocks_to_swap_in_reverse
|
||||
|
||||
|
||||
def test_schedule_swapped_max_token_budget():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
budget = create_token_budget(token_budget=1)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
# Verify num_batched_tokens are respected.
|
||||
budget = create_token_budget(token_budget=1)
|
||||
add_token_budget(budget, 1, 0)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
remaining_swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_schedule_swapped_max_seqs():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
for i in range(4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 2
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
# Verify num_curr_seqs are respected.
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
remaining_swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_schedule_swapped_max_loras():
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = set()
|
||||
blocks_to_swap_out = {}
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
lora_request=LoRARequest(
|
||||
lora_name=str(i),
|
||||
lora_int_id=i + 1,
|
||||
lora_local_path="abc"))
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 1
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert len(curr_loras) == 1
|
||||
|
||||
|
||||
def test_schedule_swapped_cannot_swap_in():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
|
||||
# Since we cannot swap in, none of the requests are swapped in.
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_infeasible_swap():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
|
||||
# Since we cannot swap in, none of the requests are swapped in.
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 0
|
||||
assert len(output.infeasible_seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_schedule_swapped_blocks_to_copy():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
blocks_to_swap_out = {}
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.append_slots = MagicMock()
|
||||
scheduler.block_manager.append_slots.return_value = {2: [3]}
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert output.blocks_to_copy == {2: [3]}
|
||||
|
||||
|
||||
def test_scheduling_budget():
|
||||
TOKEN_BUDGET = 4
|
||||
MAX_SEQS = 4
|
||||
budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
|
||||
assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
|
||||
assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
|
||||
assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
|
||||
assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
|
||||
assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
|
||||
assert budget.remaining_token_budget() == TOKEN_BUDGET
|
||||
|
||||
# Verify add/subtract num batched tokens.
|
||||
_, seq_group = create_dummy_prompt("1", 3)
|
||||
budget.add_num_batched_tokens(seq_group.request_id, 2)
|
||||
assert budget.remaining_token_budget() == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
|
||||
assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
|
||||
# Verify adding another seq group is no-op.
|
||||
budget.add_num_batched_tokens(seq_group.request_id, 2)
|
||||
assert budget.remaining_token_budget() == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
budget.subtract_num_batched_tokens(seq_group.request_id, 2)
|
||||
assert budget.remaining_token_budget() == 4
|
||||
assert budget.num_batched_tokens == 0
|
||||
budget.subtract_num_batched_tokens(seq_group.request_id, 2)
|
||||
assert budget.remaining_token_budget() == 4
|
||||
assert budget.num_batched_tokens == 0
|
||||
|
||||
# Verify add/subtract max seqs.
|
||||
_, seq_group = create_dummy_prompt("1", 3)
|
||||
budget.add_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
|
||||
assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
|
||||
assert budget.num_curr_seqs == 2
|
||||
# Verify adding another seq group is no-op.
|
||||
budget.add_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.num_curr_seqs == 2
|
||||
budget.subtract_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.num_curr_seqs == 0
|
||||
budget.subtract_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.num_curr_seqs == 0
|
||||
@@ -1,74 +0,0 @@
|
||||
import time
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup
|
||||
|
||||
|
||||
def create_dummy_prompt(
|
||||
request_id: str,
|
||||
prompt_length: int,
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
use_beam_search: bool = False,
|
||||
best_of: int = 1,
|
||||
) -> Tuple[Sequence, SequenceGroup]:
|
||||
if not block_size:
|
||||
block_size = prompt_length
|
||||
|
||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||
# and prompt "0 ... block_size".
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
|
||||
seq_group = SequenceGroup(
|
||||
request_id, [prompt],
|
||||
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
|
||||
time.time(), lora_request)
|
||||
|
||||
return prompt, seq_group
|
||||
|
||||
|
||||
def create_seq_group(
|
||||
seq_prompt_len: int = 1024,
|
||||
seq_output_lens: Iterable[int] = (128, ),
|
||||
request_id: str = '0',
|
||||
seq_id_start: int = 0,
|
||||
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
|
||||
|
||||
assert len(seq_output_lens) > 0
|
||||
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
prompt_token_ids = [0] * seq_prompt_len
|
||||
|
||||
seqs = []
|
||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||
seq = Sequence(
|
||||
seq_id=seq_id_start + seq_id_offset,
|
||||
prompt="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
for i in range(output_len):
|
||||
seq.append_token_id(
|
||||
token_id=i,
|
||||
logprobs={i: Logprob(0.0)},
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
seq_group = SequenceGroup(
|
||||
request_id=request_id,
|
||||
seqs=seqs,
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=time.time(),
|
||||
)
|
||||
|
||||
return seq_group
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
81
tests/cuda/test_cuda_context.py
Normal file
81
tests/cuda/test_cuda_context.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ctypes
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def check_cuda_context():
|
||||
"""Check CUDA driver context status"""
|
||||
try:
|
||||
cuda = ctypes.CDLL("libcuda.so")
|
||||
device = ctypes.c_int()
|
||||
result = cuda.cuCtxGetDevice(ctypes.byref(device))
|
||||
return (True, device.value) if result == 0 else (False, None)
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
|
||||
def run_cuda_test_in_thread(device_input, expected_device_id):
|
||||
"""Run CUDA context test in separate thread for isolation"""
|
||||
try:
|
||||
# New thread should have no CUDA context initially
|
||||
valid_before, device_before = check_cuda_context()
|
||||
if valid_before:
|
||||
return (
|
||||
False,
|
||||
"CUDA context should not exist in new thread, "
|
||||
f"got device {device_before}",
|
||||
)
|
||||
|
||||
# Test setting CUDA context
|
||||
current_platform.set_device(device_input)
|
||||
|
||||
# Verify context is created correctly
|
||||
valid_after, device_id = check_cuda_context()
|
||||
if not valid_after:
|
||||
return False, "CUDA context should be valid after set_cuda_context"
|
||||
if device_id != expected_device_id:
|
||||
return False, f"Expected device {expected_device_id}, got {device_id}"
|
||||
|
||||
return True, "Success"
|
||||
except Exception as e:
|
||||
return False, f"Exception in thread: {str(e)}"
|
||||
|
||||
|
||||
class TestSetCudaContext:
|
||||
"""Test suite for the set_cuda_context function."""
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
@pytest.mark.parametrize(
|
||||
argnames="device_input,expected_device_id",
|
||||
argvalues=[
|
||||
(0, 0),
|
||||
(torch.device("cuda:0"), 0),
|
||||
("cuda:0", 0),
|
||||
],
|
||||
ids=["int", "torch_device", "string"],
|
||||
)
|
||||
def test_set_cuda_context_parametrized(self, device_input, expected_device_id):
|
||||
"""Test setting CUDA context in isolated threads."""
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(
|
||||
run_cuda_test_in_thread, device_input, expected_device_id
|
||||
)
|
||||
success, message = future.result(timeout=30)
|
||||
assert success, message
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
def test_set_cuda_context_invalid_device_type(self):
|
||||
"""Test error handling for invalid device type."""
|
||||
with pytest.raises(ValueError, match="Expected a cuda device"):
|
||||
current_platform.set_device(torch.device("cpu"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
32
tests/detokenizer/test_disable_detokenization.py
Normal file
32
tests/detokenizer/test_disable_detokenization.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.skip_v1
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_computed_prefix_blocks(model: str):
|
||||
# This test checks if the engine generates completions both with and
|
||||
# without optional detokenization, that detokenization includes text
|
||||
# and no-detokenization doesn't, and that both completions have the same
|
||||
# token_ids.
|
||||
prompt = (
|
||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||
"paper clips? Is there an easy to follow video tutorial available "
|
||||
"online for free?"
|
||||
)
|
||||
|
||||
llm = LLM(model=model)
|
||||
sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False)
|
||||
|
||||
outputs_no_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
|
||||
sampling_params.detokenize = True
|
||||
outputs_with_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0]
|
||||
|
||||
assert outputs_no_detokenization.text == ""
|
||||
assert outputs_with_detokenization.text != ""
|
||||
assert outputs_no_detokenization.token_ids == outputs_with_detokenization.token_ids
|
||||
52
tests/detokenizer/test_min_tokens.py
Normal file
52
tests/detokenizer/test_min_tokens.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer
|
||||
|
||||
PROMPT = "Hello, my name is Lee, and I'm a student in the " + "college of engineering"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"min_tokens,stop,truth",
|
||||
[
|
||||
(0, None, " is Lee, and I'm a student in the college of engineering"),
|
||||
(0, "e", " is L"),
|
||||
(5, "e", " is Lee, and I'm a stud"),
|
||||
],
|
||||
)
|
||||
def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
||||
"""Test for a specific min_tokens and stop.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/22014
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
||||
all_prompt_ids = tokenizer(PROMPT, add_special_tokens=False).input_ids
|
||||
|
||||
# The prompt is "Hello, my name is"
|
||||
prompt_token_ids = all_prompt_ids[:4]
|
||||
params = SamplingParams(
|
||||
stop=stop,
|
||||
min_tokens=min_tokens,
|
||||
)
|
||||
request = EngineCoreRequest(
|
||||
request_id="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
detokenizer.update(all_prompt_ids[4:], False)
|
||||
assert detokenizer.output_text == truth
|
||||
@@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test the different finish_reason="stop" situations during generation:
|
||||
1. One of the provided stop strings
|
||||
2. One of the provided stop tokens
|
||||
@@ -11,7 +13,7 @@ import transformers
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODEL = "facebook/opt-350m"
|
||||
MODEL = "distilbert/distilgpt2"
|
||||
STOP_STR = "."
|
||||
SEED = 42
|
||||
MAX_TOKENS = 1024
|
||||
@@ -19,41 +21,49 @@ MAX_TOKENS = 1024
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_model(vllm_runner):
|
||||
vllm_model = vllm_runner(MODEL)
|
||||
yield vllm_model
|
||||
del vllm_model
|
||||
with vllm_runner(MODEL) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
def test_stop_reason(vllm_model, example_prompts):
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
|
||||
stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
|
||||
llm = vllm_model.model
|
||||
llm = vllm_model.llm
|
||||
|
||||
# test stop token
|
||||
outputs = llm.generate(example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop_token_ids=[stop_token_id]))
|
||||
outputs = llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
ignore_eos=True,
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop_token_ids=[stop_token_id],
|
||||
),
|
||||
)
|
||||
for output in outputs:
|
||||
output = output.outputs[0]
|
||||
assert output.finish_reason == "stop"
|
||||
assert output.stop_reason == stop_token_id
|
||||
|
||||
# test stop string
|
||||
outputs = llm.generate(example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
seed=SEED, max_tokens=MAX_TOKENS, stop="."))
|
||||
outputs = llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop="."
|
||||
),
|
||||
)
|
||||
for output in outputs:
|
||||
output = output.outputs[0]
|
||||
assert output.finish_reason == "stop"
|
||||
assert output.stop_reason == STOP_STR
|
||||
|
||||
# test EOS token
|
||||
outputs = llm.generate(example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
seed=SEED, max_tokens=MAX_TOKENS))
|
||||
outputs = llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(seed=SEED, max_tokens=MAX_TOKENS),
|
||||
)
|
||||
for output in outputs:
|
||||
output = output.outputs[0]
|
||||
assert output.finish_reason == "length" or (
|
||||
output.finish_reason == "stop" and output.stop_reason is None)
|
||||
output.finish_reason == "stop" and output.stop_reason is None
|
||||
)
|
||||
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import BaseIncrementalDetokenizer
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def include_stop_str_in_output(request):
|
||||
return request.param
|
||||
|
||||
|
||||
class _DummyDetokenizer(BaseIncrementalDetokenizer):
|
||||
def __init__(self, request: EngineCoreRequest):
|
||||
super().__init__(request)
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
# Map token id to single ASCII character for deterministic testing.
|
||||
return chr(next_token_id)
|
||||
|
||||
|
||||
def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0):
|
||||
params = SamplingParams(
|
||||
stop=stop,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
min_tokens=min_tokens,
|
||||
)
|
||||
# Keep other fields minimal for unit test purposes.
|
||||
req = EngineCoreRequest(
|
||||
request_id="test",
|
||||
prompt_token_ids=[],
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
return req
|
||||
|
||||
|
||||
def test_stop_string_while_stop_token_terminates(include_stop_str_in_output: bool):
|
||||
"""
|
||||
This test verifies that the detokenizer correctly handles the case where
|
||||
the generated token sequence contains both:
|
||||
- a stop token
|
||||
- an <eos> token
|
||||
|
||||
The detokenizer should respect the stop string and truncate the output
|
||||
accordingly.
|
||||
|
||||
Imagine the following sequence:
|
||||
- "abcdeZ" is generated, where "Z" is the <eos> token.
|
||||
- "cd" is the stop string.
|
||||
|
||||
If include_stop_str_in_output=False, the detokenizer should truncate the
|
||||
output to "ab" because the stop string "cd" is excluded.
|
||||
If include_stop_str_in_output=True, the detokenizer should include the stop
|
||||
string "cd" in the output, resulting in "abcd".
|
||||
|
||||
|
||||
This verifies the behavioral change introduced in BaseIncrementalDetokenizer
|
||||
where stop-string evaluation occurs before the early-return on
|
||||
stop_terminated.
|
||||
"""
|
||||
|
||||
# Generate text "abcdeZ" and tokenize it.
|
||||
generated_text = "abcde"
|
||||
eos_token = "Z"
|
||||
stop_string = "cd"
|
||||
generated_text = generated_text + eos_token
|
||||
token_ids = [ord(c) for c in generated_text]
|
||||
|
||||
# Create a request with the stop string and initialize the detokenizer.
|
||||
req = _make_request(
|
||||
stop=[stop_string], include_stop_str_in_output=include_stop_str_in_output
|
||||
)
|
||||
detok = _DummyDetokenizer(req)
|
||||
|
||||
# Simulate that the last token ('Z') is a stop token (stop_terminated=True).
|
||||
result = detok.update(new_token_ids=token_ids, stop_terminated=True)
|
||||
|
||||
# The update should not report a stop string
|
||||
assert result == stop_string
|
||||
|
||||
# Output text should reflect stop-string handling:
|
||||
# - include_stop_str_in_output=False => exclude "cd" => "ab"
|
||||
# - include_stop_str_in_output=True => include "cd" => "abcd"
|
||||
expected_text = "abcd" if include_stop_str_in_output else "ab"
|
||||
assert detok.output_text == expected_text
|
||||
|
||||
# The skipped final token should still be recorded in token_ids.
|
||||
assert detok.output_token_ids == token_ids
|
||||
|
||||
# get_next_output_text should return the full text when finished=True.
|
||||
# (Buffering only applies during streaming when finished=False.)
|
||||
assert detok.get_next_output_text(finished=True, delta=False) == expected_text
|
||||
120
tests/detokenizer/test_stop_strings.py
Normal file
120
tests/detokenizer/test_stop_strings.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
MODEL = "meta-llama/llama-2-7b-hf"
|
||||
MAX_TOKENS = 200
|
||||
|
||||
|
||||
def _test_stopping(
|
||||
llm: LLM,
|
||||
expected_output: str,
|
||||
expected_reason: Any,
|
||||
stop: list[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
include_in_output: bool = False,
|
||||
) -> None:
|
||||
output = llm.generate(
|
||||
"A story about vLLM:\n",
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
include_stop_str_in_output=include_in_output,
|
||||
),
|
||||
)[0].outputs[0]
|
||||
|
||||
assert output is not None
|
||||
assert output.text == expected_output
|
||||
assert output.stop_reason == expected_reason
|
||||
|
||||
|
||||
def _stop_basic(llm):
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".",
|
||||
)
|
||||
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".",
|
||||
)
|
||||
|
||||
|
||||
def _stop_multi_tokens(llm):
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo",
|
||||
)
|
||||
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo",
|
||||
)
|
||||
|
||||
|
||||
def _stop_partial_token(llm):
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani",
|
||||
)
|
||||
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani",
|
||||
)
|
||||
|
||||
|
||||
def _stop_token_id(llm):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013,
|
||||
)
|
||||
|
||||
_test_stopping(
|
||||
llm,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_strings():
|
||||
llm = LLM(MODEL, enforce_eager=True)
|
||||
|
||||
_stop_basic(llm)
|
||||
_stop_multi_tokens(llm)
|
||||
_stop_partial_token(llm)
|
||||
# FIXME: this does not respect include_in_output=False
|
||||
# _stop_token_id(llm)
|
||||
168
tests/distributed/conftest.py
Normal file
168
tests/distributed/conftest.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import msgspec
|
||||
import msgspec.msgpack
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory
|
||||
|
||||
from .test_events import SampleBatch
|
||||
|
||||
DP_RANK = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_port():
|
||||
"""Generate a random port number for testing"""
|
||||
return random.randint(10000, 59900)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def publisher_config(random_port, request):
|
||||
"""Create a publisher config with inproc transport"""
|
||||
how = request.param if hasattr(request, "param") else "inproc"
|
||||
|
||||
if how == "inproc":
|
||||
endpoint = f"inproc://test-{random_port}"
|
||||
replay_endpoint = endpoint + "-replay"
|
||||
else:
|
||||
endpoint = f"tcp://*:{random_port}"
|
||||
replay_endpoint = f"tcp://*:{random_port + 100}"
|
||||
|
||||
return KVEventsConfig(
|
||||
enable_kv_cache_events=True,
|
||||
publisher="zmq",
|
||||
endpoint=endpoint,
|
||||
replay_endpoint=replay_endpoint,
|
||||
buffer_steps=100,
|
||||
hwm=1000,
|
||||
topic="test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def publisher(publisher_config):
|
||||
"""Create and return a publisher instance"""
|
||||
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||
yield pub
|
||||
pub.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def subscriber(publisher_config):
|
||||
"""Create and return a subscriber for testing"""
|
||||
endpoint = publisher_config.endpoint
|
||||
replay_endpoint = publisher_config.replay_endpoint
|
||||
|
||||
if endpoint.startswith("tcp://*"):
|
||||
endpoint = endpoint.replace("*", "127.0.0.1")
|
||||
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
|
||||
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
|
||||
|
||||
sub = MockSubscriber(
|
||||
[endpoint],
|
||||
[replay_endpoint] if replay_endpoint else None,
|
||||
publisher_config.topic,
|
||||
)
|
||||
yield sub
|
||||
sub.close()
|
||||
|
||||
|
||||
class MockSubscriber:
|
||||
"""Helper class to receive and verify published events"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pub_endpoints: str | list[str],
|
||||
replay_endpoints: str | list[str] | None = None,
|
||||
topic: str = "",
|
||||
decode_type=SampleBatch,
|
||||
):
|
||||
self.ctx = zmq.Context.instance()
|
||||
|
||||
# Convert single endpoint to list for consistency
|
||||
if isinstance(pub_endpoints, str):
|
||||
pub_endpoints = [pub_endpoints]
|
||||
if isinstance(replay_endpoints, str):
|
||||
replay_endpoints = [replay_endpoints]
|
||||
|
||||
# Set up subscriber socket - connect to all endpoints
|
||||
self.sub = self.ctx.socket(zmq.SUB)
|
||||
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8"))
|
||||
for endpoint in pub_endpoints:
|
||||
self.sub.connect(endpoint)
|
||||
|
||||
# Set up replay sockets if provided
|
||||
self.replay_sockets = []
|
||||
if replay_endpoints:
|
||||
for replay_endpoint in replay_endpoints:
|
||||
replay = self.ctx.socket(zmq.REQ)
|
||||
replay.connect(replay_endpoint)
|
||||
self.replay_sockets.append(replay)
|
||||
|
||||
self.topic = topic
|
||||
self.topic_bytes = topic.encode("utf-8")
|
||||
self.received_msgs: list[tuple[int, SampleBatch]] = []
|
||||
self.last_seq = -1
|
||||
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||
|
||||
def receive_one(self, timeout=1000) -> tuple[int, SampleBatch] | None:
|
||||
"""Receive a single message with timeout"""
|
||||
if not self.sub.poll(timeout):
|
||||
return None
|
||||
|
||||
topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
|
||||
assert topic_bytes == self.topic_bytes
|
||||
|
||||
seq = int.from_bytes(seq_bytes, "big")
|
||||
data = self.decoder.decode(payload)
|
||||
self.last_seq = seq
|
||||
self.received_msgs.append((seq, data))
|
||||
return seq, data
|
||||
|
||||
def request_replay(self, start_seq: int, socket_idx: int = 0) -> None:
|
||||
"""Request replay of messages starting from start_seq"""
|
||||
if not self.replay_sockets:
|
||||
raise ValueError("Replay sockets not initialized")
|
||||
if socket_idx >= len(self.replay_sockets):
|
||||
raise ValueError(f"Invalid socket index {socket_idx}")
|
||||
|
||||
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
|
||||
|
||||
def receive_replay(self, socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
||||
"""Receive replayed messages from a specific replay socket"""
|
||||
if not self.replay_sockets:
|
||||
raise ValueError("Replay sockets not initialized")
|
||||
if socket_idx >= len(self.replay_sockets):
|
||||
raise ValueError(f"Invalid socket index {socket_idx}")
|
||||
|
||||
replay_socket = self.replay_sockets[socket_idx]
|
||||
replayed: list[tuple[int, SampleBatch]] = []
|
||||
while True:
|
||||
try:
|
||||
if not replay_socket.poll(1000):
|
||||
break
|
||||
|
||||
frames = replay_socket.recv_multipart()
|
||||
if not frames or not frames[-1]:
|
||||
# End of replay marker
|
||||
break
|
||||
|
||||
seq_bytes, payload = frames
|
||||
seq = int.from_bytes(seq_bytes, "big")
|
||||
data = self.decoder.decode(payload)
|
||||
replayed.append((seq, data))
|
||||
except zmq.ZMQError as _:
|
||||
break
|
||||
|
||||
return replayed
|
||||
|
||||
def close(self):
|
||||
"""Clean up resources"""
|
||||
self.sub.close()
|
||||
for replay in self.replay_sockets:
|
||||
replay.close()
|
||||
49
tests/distributed/eplb_utils.py
Normal file
49
tests/distributed/eplb_utils.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def distributed_run(fn, world_size, *args):
|
||||
number_of_processes = world_size
|
||||
processes: list[mp.Process] = []
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env["RANK"] = str(i)
|
||||
env["LOCAL_RANK"] = str(i)
|
||||
env["WORLD_SIZE"] = str(number_of_processes)
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = mp.Process(target=fn, args=(env, world_size, *args))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
|
||||
def set_env_vars_and_device(env: dict[str, str]) -> None:
|
||||
update_environment_variables(env)
|
||||
local_rank = os.environ["LOCAL_RANK"]
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_distributed_environment()
|
||||
|
||||
# Ensure each worker process has the same random seed
|
||||
random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
@@ -1,59 +0,0 @@
|
||||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||
vLLM will allocate all the available memory, so we need to run the tests one
|
||||
by one. The solution is to pass arguments (model name) by environment
|
||||
variables.
|
||||
Run:
|
||||
```sh
|
||||
TEST_DIST_MODEL=facebook/opt-125m pytest \
|
||||
test_basic_distributed_correctness.py
|
||||
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
|
||||
test_basic_distributed_correctness.py
|
||||
```
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
MODELS = [
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
enforce_eager = False
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
if backend_by_env_var == "FLASHINFER":
|
||||
enforce_eager = True
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=enforce_eager)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
64
tests/distributed/test_ca_buffer_sharing.py
Normal file
64
tests/distributed/test_ca_buffer_sharing.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# can only run on machines with p2p access across GPUs
|
||||
# can only run with torchrun:
|
||||
# torchrun --nproc_per_node=2 tests/distributed/test_ca_buffer_sharing.py
|
||||
|
||||
import ctypes
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
|
||||
CustomAllreduce,
|
||||
)
|
||||
|
||||
# create a cpu process group for communicating metadata (ipc handle)
|
||||
dist.init_process_group(backend="gloo")
|
||||
rank = local_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# every process sets its own device (differently)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaSetDevice(rank)
|
||||
|
||||
buffer_size_in_bytes = 1024
|
||||
byte_value = 2 # the value we write to the buffer for verification
|
||||
|
||||
pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes)
|
||||
|
||||
print(f"Rank {rank} has pointers {pointers}")
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if rank == 0:
|
||||
# the first rank tries to write to all buffers
|
||||
for p in pointers:
|
||||
pointer = ctypes.c_void_p(p)
|
||||
lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes)
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
host_data = (ctypes.c_char * buffer_size_in_bytes)()
|
||||
|
||||
# all ranks read from all buffers, and check if the data is correct
|
||||
for p in pointers:
|
||||
pointer = ctypes.c_void_p(p)
|
||||
lib.cudaMemcpy(host_data, pointer, buffer_size_in_bytes)
|
||||
for i in range(buffer_size_in_bytes):
|
||||
assert ord(host_data[i]) == byte_value, (
|
||||
f"Rank {rank} failed"
|
||||
f" to verify buffer {p}. Expected {byte_value}, "
|
||||
f"got {ord(host_data[i])}"
|
||||
)
|
||||
|
||||
print(f"Rank {rank} verified all buffers")
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
CustomAllreduce.free_shared_buffer(pointers)
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||
vLLM will allocate all the available memory, so we need to run the tests one
|
||||
by one. The solution is to pass arguments (model name) by environment
|
||||
variables.
|
||||
|
||||
Run:
|
||||
```sh
|
||||
TEST_DIST_MODEL=facebook/opt-125m pytest \
|
||||
test_chunked_prefill_distributed.py
|
||||
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
|
||||
test_chunked_prefill_distributed.py
|
||||
```
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
MODELS = [
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
) -> None:
|
||||
# Add a chunked prefill config.
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
assert chunked_prefill_token_size != -1
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=max_num_seqs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
@@ -1,53 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test the communication operators.
|
||||
|
||||
Run `pytest tests/distributed/test_comm_ops.py`.
|
||||
"""
|
||||
import os
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.test_utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
from vllm.distributed import (
|
||||
broadcast_tensor_dict,
|
||||
get_pp_group,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
|
||||
from ..utils import (
|
||||
init_test_distributed_environment,
|
||||
multi_gpu_test,
|
||||
multi_process_parallel,
|
||||
)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
def all_reduce_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
||||
(r + 1) for r in range(tensor_parallel_size)
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
|
||||
for r in range(tp_size)
|
||||
]
|
||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
t = all_tensors[rank]
|
||||
t = all_tensors[rank % tp_size]
|
||||
t = tensor_model_parallel_all_reduce(t)
|
||||
assert torch.allclose(t, expected)
|
||||
torch.testing.assert_close(t, expected)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
def reduce_scatter_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
|
||||
for r in range(tp_size)
|
||||
]
|
||||
|
||||
index = rank % tp_size
|
||||
partition_size = num_elements // tp_size
|
||||
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
expected = all_reduce[index * partition_size : (index + 1) * partition_size]
|
||||
t = all_tensors[index]
|
||||
t = tensor_model_parallel_reduce_scatter(t, 0)
|
||||
torch.testing.assert_close(t, expected)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def all_gather_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
num_dimensions = 3
|
||||
tensor_size = list(range(2, num_dimensions + 2))
|
||||
total_size = 1
|
||||
@@ -55,56 +107,169 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
total_size *= s
|
||||
for all_gather_dimension in range(num_dimensions):
|
||||
all_tensors = [
|
||||
torch.arange(total_size, dtype=torch.float32,
|
||||
device="cuda").reshape(tensor_size) * (r + 1)
|
||||
for r in range(tensor_parallel_size)
|
||||
torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(
|
||||
tensor_size
|
||||
)
|
||||
* (r + 1)
|
||||
for r in range(tp_size)
|
||||
]
|
||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||
t = all_tensors[rank]
|
||||
t = all_tensors[rank % tp_size]
|
||||
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
|
||||
assert torch.allclose(t, expected)
|
||||
torch.testing.assert_close(t, expected)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
def broadcast_tensor_dict_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
test_dict = {
|
||||
# device tensor
|
||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cuda"),
|
||||
# CPU tensor
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||
"c": "test",
|
||||
"d": [1, 2, 3],
|
||||
"e": {
|
||||
"a": 1,
|
||||
"b": 2
|
||||
},
|
||||
"e": {"a": 1, "b": 2},
|
||||
# empty tensor
|
||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||
}
|
||||
|
||||
if rank == 0:
|
||||
if (rank % tp_size) == 0:
|
||||
broadcast_tensor_dict(test_dict, src=0)
|
||||
else:
|
||||
recv_dict = broadcast_tensor_dict(src=0)
|
||||
assert len(recv_dict) == len(test_dict)
|
||||
assert torch.allclose(recv_dict["a"], test_dict["a"])
|
||||
assert torch.allclose(recv_dict["b"], test_dict["b"])
|
||||
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
|
||||
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
|
||||
assert recv_dict["c"] == test_dict["c"]
|
||||
assert recv_dict["d"] == test_dict["d"]
|
||||
assert recv_dict["e"] == test_dict["e"]
|
||||
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("test_target", [
|
||||
all_reduce_test_worker, all_gather_test_worker,
|
||||
broadcast_tensor_dict_test_worker
|
||||
])
|
||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||
multi_process_tensor_parallel(tensor_parallel_size, test_target)
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def send_recv_tensor_dict_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
):
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
test_dict = {
|
||||
# device tensor
|
||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||
# CPU tensor
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||
"c": "test",
|
||||
"d": [1, 2, 3],
|
||||
"e": {"a": 1, "b": 2},
|
||||
# empty tensor
|
||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||
}
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
recv_dict = get_pp_group().recv_tensor_dict()
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
get_pp_group().send_tensor_dict(test_dict)
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
assert len(recv_dict) == len(test_dict)
|
||||
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
|
||||
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
|
||||
assert recv_dict["c"] == test_dict["c"]
|
||||
assert recv_dict["d"] == test_dict["d"]
|
||||
assert recv_dict["e"] == test_dict["e"]
|
||||
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def send_recv_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
):
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
size = 64
|
||||
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
get_pp_group().send(test_tensor)
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
torch.testing.assert_close(test_tensor, recv_tensor)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"test_target",
|
||||
[all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker],
|
||||
)
|
||||
def test_multi_process_tensor_parallel(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size: int,
|
||||
test_target: Callable[..., Any],
|
||||
):
|
||||
multi_process_parallel(monkeypatch, tp_size, 1, test_target)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("pp_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]
|
||||
)
|
||||
def test_multi_process_pipeline_parallel(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
pp_size: int,
|
||||
test_target: Callable[..., Any],
|
||||
):
|
||||
multi_process_parallel(monkeypatch, 1, pp_size, test_target)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("pp_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"test_target",
|
||||
[
|
||||
send_recv_test_worker,
|
||||
send_recv_tensor_dict_test_worker,
|
||||
all_reduce_test_worker,
|
||||
all_gather_test_worker,
|
||||
broadcast_tensor_dict_test_worker,
|
||||
],
|
||||
)
|
||||
def test_multi_process_tensor_parallel_pipeline_parallel(
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
test_target: Callable[..., Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
multi_process_parallel(monkeypatch, tp_size, pp_size, test_target)
|
||||
|
||||
296
tests/distributed/test_context_parallel.py
Normal file
296
tests/distributed/test_context_parallel.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
|
||||
important to set the distributed backend to "mp" to avoid Ray scheduling
|
||||
all workers in a node other than the head node, which can cause the test
|
||||
to fail.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
|
||||
from tests.utils import RemoteOpenAIServer, create_new_process_for_each_test
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
|
||||
logger = init_logger("test_context_parallel")
|
||||
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
CP_TEST_MODELS = [
|
||||
# TODO support other models
|
||||
# [LANGUAGE GENERATION]
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
||||
]
|
||||
|
||||
# GSM8K eval configuration
|
||||
NUM_QUESTIONS = 256 # Fast eval for CI
|
||||
NUM_SHOTS = 5 # Few-shot examples
|
||||
# tp accuracy with 2% buffer
|
||||
MIN_ACCURACY = {
|
||||
# .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.64,
|
||||
# .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
|
||||
"Qwen/Qwen2.5-1.5B-Instruct": 0.52,
|
||||
}
|
||||
|
||||
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
dcp_size: int
|
||||
cp_kv_cache_interleave_size: int
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
|
||||
class CPTestOptions(NamedTuple):
|
||||
multi_node_only: bool
|
||||
attn_backend: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPTestSettings:
|
||||
parallel_setups: list[ParallelSetup]
|
||||
distributed_backends: list[str]
|
||||
runner: RunnerOption
|
||||
test_options: CPTestOptions
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
*,
|
||||
tp_base: int = 4,
|
||||
pp_base: int = 1,
|
||||
dcp_multipliers: list[float] | None = None,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
multi_node_only: bool = False,
|
||||
runner: RunnerOption = "auto",
|
||||
attn_backend: str | None = None,
|
||||
):
|
||||
parallel_setups = []
|
||||
if dcp_multipliers is None:
|
||||
dcp_multipliers = [
|
||||
0.5,
|
||||
]
|
||||
for eager_mode_val in [False]:
|
||||
for pp_multiplier in [1]:
|
||||
for dcp_multiplier in dcp_multipliers:
|
||||
for chunked_prefill_val in [True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
dcp_size=int(dcp_multiplier * tp_base),
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
)
|
||||
return CPTestSettings(
|
||||
parallel_setups=parallel_setups,
|
||||
distributed_backends=["mp"],
|
||||
runner=runner,
|
||||
test_options=CPTestOptions(
|
||||
multi_node_only=multi_node_only,
|
||||
attn_backend=attn_backend,
|
||||
),
|
||||
)
|
||||
|
||||
def iter_params(self, model_id: str):
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for backend in self.distributed_backends:
|
||||
yield (
|
||||
model_id,
|
||||
parallel_setup,
|
||||
backend,
|
||||
self.runner,
|
||||
opts,
|
||||
)
|
||||
|
||||
|
||||
CP_TEXT_GENERATION_MODELS = {
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||
CPTestSettings.detailed(dcp_multipliers=[1]),
|
||||
CPTestSettings.detailed(
|
||||
dcp_multipliers=[0.5],
|
||||
cp_kv_cache_interleave_size=64,
|
||||
attn_backend="FLASHMLA",
|
||||
),
|
||||
],
|
||||
"Qwen/Qwen2.5-1.5B-Instruct": [
|
||||
CPTestSettings.detailed(
|
||||
cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN"
|
||||
),
|
||||
CPTestSettings.detailed(
|
||||
cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER"
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _test_cp_gsm8k(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
runner: RunnerOption,
|
||||
test_options: CPTestOptions,
|
||||
num_gpus_available: int,
|
||||
*,
|
||||
method: Literal["generate"],
|
||||
is_multimodal: bool,
|
||||
):
|
||||
(
|
||||
tp_size,
|
||||
pp_size,
|
||||
dcp_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
|
||||
multi_node_only, attn_backend = test_options
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
trust_remote_code = model_info.trust_remote_code
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
pytest.skip(
|
||||
"Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend"
|
||||
)
|
||||
if multi_node_only and not VLLM_MULTI_NODE:
|
||||
pytest.skip("Not in multi-node setting")
|
||||
|
||||
server_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--max-num-seqs",
|
||||
"64",
|
||||
]
|
||||
if chunked_prefill:
|
||||
server_args.append("--enable-chunked-prefill")
|
||||
if eager_mode:
|
||||
server_args.append("--enforce-eager")
|
||||
if runner != "auto":
|
||||
server_args.extend(["--runner", runner])
|
||||
if trust_remote_code:
|
||||
server_args.append("--trust-remote-code")
|
||||
if tokenizer_mode:
|
||||
server_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||
if hf_overrides:
|
||||
server_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
|
||||
server_args.extend(
|
||||
[
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
"--decode-context-parallel-size",
|
||||
str(dcp_size),
|
||||
"--dcp-kv-cache-interleave-size",
|
||||
str(cp_kv_cache_interleave_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
)
|
||||
|
||||
server_env = {}
|
||||
if attn_backend:
|
||||
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_id,
|
||||
server_args,
|
||||
env_dict=server_env,
|
||||
max_wait_seconds=720,
|
||||
) as remote_server:
|
||||
host = f"http://{remote_server.host}"
|
||||
port = remote_server.port
|
||||
|
||||
# Run GSM8K evaluation
|
||||
results = evaluate_gsm8k(
|
||||
num_questions=NUM_QUESTIONS,
|
||||
num_shots=NUM_SHOTS,
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
# Validate accuracy is reasonable
|
||||
accuracy = results["accuracy"]
|
||||
min_accuracy = MIN_ACCURACY[model_id]
|
||||
assert accuracy >= min_accuracy, (
|
||||
f"TP+DCP accuracy too low: {accuracy:.3f} < {min_accuracy:.3f}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"model_id",
|
||||
"parallel_setup",
|
||||
"distributed_backend",
|
||||
"runner",
|
||||
"test_options",
|
||||
),
|
||||
[
|
||||
params
|
||||
for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
|
||||
for setting in settings
|
||||
for params in setting.iter_params(model_id)
|
||||
if model_id in CP_TEST_MODELS
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_cp_generation(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
runner: RunnerOption,
|
||||
test_options: CPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
if (
|
||||
model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
and torch.cuda.get_device_capability() < (9, 0)
|
||||
):
|
||||
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
|
||||
if (
|
||||
model_id == "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
and torch.cuda.get_device_capability() != (9, 0)
|
||||
):
|
||||
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
|
||||
|
||||
_test_cp_gsm8k(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=False,
|
||||
)
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
@@ -6,10 +8,14 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators import custom_all_reduce
|
||||
from vllm.test_utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||
|
||||
from ..utils import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_test_distributed_environment,
|
||||
multi_process_parallel,
|
||||
)
|
||||
|
||||
random.seed(42)
|
||||
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
||||
@@ -18,67 +24,109 @@ for i, v in enumerate(test_sizes):
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def graph_allreduce(world_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
distributed_init_port)
|
||||
def graph_allreduce(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size,
|
||||
pp_size,
|
||||
rank,
|
||||
distributed_init_port,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
group = get_tp_group().device_group
|
||||
|
||||
custom_all_reduce.init_custom_all_reduce()
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
with custom_all_reduce.capture():
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device())
|
||||
inp2 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device())
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
# the input buffer is immediately modified to test
|
||||
# synchronization
|
||||
dist.all_reduce(inp1)
|
||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||
dist.all_reduce(inp2)
|
||||
graph.replay()
|
||||
assert torch.allclose(out1, inp1)
|
||||
assert torch.allclose(out2, inp2)
|
||||
# A small all_reduce for warmup.
|
||||
# this is needed because device communicators might be created lazily
|
||||
# (e.g. NCCL). This will ensure that the communicator is initialized
|
||||
# before any communication happens, so that this group can be used for
|
||||
# graph capture immediately.
|
||||
data = torch.zeros(1)
|
||||
data = data.to(device=device)
|
||||
torch.distributed.all_reduce(data, group=group)
|
||||
torch.cuda.synchronize()
|
||||
del data
|
||||
|
||||
# we use the first group to communicate once
|
||||
# and the second group to communicate twice
|
||||
# and so on
|
||||
# this is used to demonstrate that each group can
|
||||
# communicate independently
|
||||
num_communication = rank // tp_size + 1
|
||||
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
with graph_capture(device=device) as graph_capture_context:
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(
|
||||
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
inp2 = torch.randint(
|
||||
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||
for i in range(num_communication):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
# the input buffer is immediately modified to test
|
||||
# synchronization
|
||||
dist.all_reduce(inp1, group=group)
|
||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||
dist.all_reduce(inp2, group=group)
|
||||
graph.replay()
|
||||
torch.testing.assert_close(out1, inp1)
|
||||
torch.testing.assert_close(out2, inp2)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def eager_allreduce(world_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
distributed_init_port)
|
||||
def eager_allreduce(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size,
|
||||
pp_size,
|
||||
rank,
|
||||
distributed_init_port,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
sz = 1024
|
||||
custom_all_reduce.init_custom_all_reduce()
|
||||
fa = custom_all_reduce.get_handle()
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = fa.all_reduce_unreg(inp)
|
||||
assert torch.allclose(out, inp * world_size)
|
||||
# we use the first group to communicate once
|
||||
# and the second group to communicate twice
|
||||
# and so on
|
||||
# this is used to demonstrate that each group can
|
||||
# communicate independently
|
||||
num_communication = rank // tp_size + 1
|
||||
sz = 1024
|
||||
fa = get_tp_group().device_communicator.ca_comm
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce(out, registered=False)
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||
out = fa.all_reduce_unreg(inp)
|
||||
assert torch.allclose(out, inp * world_size)
|
||||
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce(out, registered=False)
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
|
||||
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
|
||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||
multi_process_tensor_parallel(tensor_parallel_size, test_target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
multi_process_tensor_parallel(2, graph_allreduce)
|
||||
def test_custom_allreduce(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size,
|
||||
pipeline_parallel_size,
|
||||
test_target,
|
||||
):
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
||||
|
||||
8
tests/distributed/test_distributed_oot.py
Normal file
8
tests/distributed/test_distributed_oot.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server
|
||||
|
||||
|
||||
def test_distributed_oot(dummy_opt_path: str):
|
||||
run_and_test_dummy_opt_api_server(dummy_opt_path, tp=2)
|
||||
312
tests/distributed/test_eplb_algo.py
Normal file
312
tests/distributed/test_eplb_algo.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
|
||||
|
||||
|
||||
def test_basic_rebalance():
|
||||
"""Test basic rebalancing functionality"""
|
||||
# Example from https://github.com/deepseek-ai/eplb
|
||||
weight = torch.tensor(
|
||||
[
|
||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||
]
|
||||
)
|
||||
|
||||
num_layers = weight.shape[0]
|
||||
num_replicas = 16
|
||||
num_groups = 4
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify output shapes
|
||||
assert phy2log.shape == (
|
||||
2,
|
||||
16,
|
||||
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
|
||||
assert log2phy.shape[0] == 2, (
|
||||
f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
|
||||
)
|
||||
assert log2phy.shape[1] == 12, (
|
||||
f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
|
||||
)
|
||||
assert logcnt.shape == (
|
||||
2,
|
||||
12,
|
||||
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
|
||||
|
||||
# Verify physical to logical expert mapping range is correct
|
||||
assert torch.all(phy2log >= 0) and torch.all(phy2log < 12), (
|
||||
"Physical to logical mapping should be in range [0, 12)"
|
||||
)
|
||||
|
||||
# Verify expert count reasonableness
|
||||
assert torch.all(logcnt >= 1), "Each logical expert should have at least 1 replica"
|
||||
assert torch.sum(logcnt, dim=1).sum() == num_replicas * num_layers, (
|
||||
f"Total replicas should be {num_replicas * num_layers}"
|
||||
)
|
||||
|
||||
# Verify expected output
|
||||
expected_phy2log = torch.tensor(
|
||||
[
|
||||
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
||||
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
|
||||
]
|
||||
)
|
||||
assert torch.all(phy2log == expected_phy2log)
|
||||
|
||||
expected_logcnt = torch.tensor(
|
||||
[[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]
|
||||
)
|
||||
assert torch.all(logcnt == expected_logcnt)
|
||||
|
||||
|
||||
def test_single_gpu_case():
|
||||
"""Test single GPU case"""
|
||||
weight = torch.tensor([[10, 20, 30, 40]])
|
||||
num_replicas = 4
|
||||
num_groups = 1
|
||||
num_nodes = 1
|
||||
num_gpus = 1
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 4)
|
||||
assert log2phy.shape[0] == 1
|
||||
assert log2phy.shape[1] == 4
|
||||
assert logcnt.shape == (1, 4)
|
||||
|
||||
# Verify all logical experts are mapped
|
||||
assert set(phy2log[0].tolist()) == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_equal_weights():
|
||||
"""Test case with equal weights"""
|
||||
weight = torch.tensor([[50, 50, 50, 50, 50, 50, 50, 50]])
|
||||
num_replicas = 8
|
||||
num_groups = 2
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 8)
|
||||
|
||||
# With equal weights, each expert should have exactly one replica
|
||||
assert torch.all(logcnt == 1), (
|
||||
"With equal weights and no replication, "
|
||||
"each expert should have exactly 1 replica"
|
||||
)
|
||||
|
||||
|
||||
def test_extreme_weight_imbalance():
|
||||
"""Test extreme weight imbalance case"""
|
||||
weight = torch.tensor([[1000, 1, 1, 1, 1, 1, 1, 1]])
|
||||
num_replicas = 12
|
||||
num_groups = 2
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 12)
|
||||
assert logcnt.shape == (1, 8)
|
||||
|
||||
# Expert with highest weight (index 0) should have more replicas
|
||||
assert logcnt[0, 0] > logcnt[0, 1], (
|
||||
"Expert with highest weight should have more replicas"
|
||||
)
|
||||
|
||||
|
||||
def test_multiple_layers():
|
||||
"""Test multiple layers case"""
|
||||
weight = torch.tensor(
|
||||
[
|
||||
[10, 20, 30, 40, 50, 60], # First layer
|
||||
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
|
||||
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
|
||||
]
|
||||
)
|
||||
num_replicas = 8
|
||||
num_groups = 2
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (3, 8)
|
||||
assert logcnt.shape == (3, 6)
|
||||
|
||||
# Verify expert allocation is reasonable for each layer
|
||||
for layer in range(3):
|
||||
assert torch.all(phy2log[layer] >= 0) and torch.all(phy2log[layer] < 6), (
|
||||
f"Layer {layer} physical to logical mappingshould be in range [0, 6)"
|
||||
)
|
||||
assert torch.sum(logcnt[layer]) == num_replicas, (
|
||||
f"Layer {layer} total replicas should be {num_replicas}"
|
||||
)
|
||||
|
||||
|
||||
def test_parameter_validation():
|
||||
"""Test parameter validation"""
|
||||
weight = torch.tensor([[10, 20, 30, 40]])
|
||||
|
||||
# Test non-divisible case - this should handle normally without throwing
|
||||
# errors because the function will fall back to global load balancing
|
||||
# strategy
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 4)
|
||||
|
||||
# Test cases that will actually cause errors:
|
||||
# num_physical_experts not divisible by num_gpus
|
||||
with pytest.raises(AssertionError):
|
||||
DefaultEplbPolicy.rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
|
||||
|
||||
|
||||
def test_small_scale_hierarchical():
|
||||
"""Test small-scale hierarchical load balancing"""
|
||||
weight = torch.tensor(
|
||||
[
|
||||
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
|
||||
]
|
||||
)
|
||||
num_replicas = 12
|
||||
num_groups = 4 # 4 groups, 2 experts each
|
||||
num_nodes = 2 # 2 nodes
|
||||
num_gpus = 4 # 4 GPUs
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Verify basic constraints
|
||||
assert phy2log.shape == (1, 12)
|
||||
assert logcnt.shape == (1, 8)
|
||||
assert torch.sum(logcnt) == num_replicas
|
||||
assert torch.all(logcnt >= 1)
|
||||
|
||||
# Expert with highest weight should have more replicas
|
||||
max_weight_expert = torch.argmax(weight[0])
|
||||
assert logcnt[0, max_weight_expert] >= 2, (
|
||||
"Highest weight expert should have multiple replicas"
|
||||
)
|
||||
|
||||
|
||||
def test_global_load_balance_fallback():
|
||||
"""Test global load balancing fallback case"""
|
||||
# When num_groups % num_nodes != 0, should fall back to global load
|
||||
# balancing
|
||||
weight = torch.tensor([[10, 20, 30, 40, 50, 60]])
|
||||
num_replicas = 8
|
||||
num_groups = 3 # Cannot be divided evenly by num_nodes=2
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Should work normally, just using global load balancing strategy
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 6)
|
||||
assert torch.sum(logcnt) == num_replicas
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_device_compatibility(device):
|
||||
"""Test device compatibility"""
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
weight = torch.tensor([[10, 20, 30, 40]], device=device)
|
||||
num_replicas = 6
|
||||
num_groups = 2
|
||||
num_nodes = 1
|
||||
num_gpus = 2
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
# Function will convert to CPU internally, but should handle different
|
||||
# device inputs normally
|
||||
assert phy2log.shape == (1, 6)
|
||||
assert logcnt.shape == (1, 4)
|
||||
|
||||
|
||||
def test_additional_cases():
|
||||
"""Test more edge cases and different parameter combinations"""
|
||||
|
||||
# Test case 1: Large-scale distributed setup
|
||||
weight1 = torch.tensor(
|
||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
|
||||
)
|
||||
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
|
||||
weight1, 24, 8, 4, 8
|
||||
)
|
||||
|
||||
assert phy2log1.shape == (1, 24)
|
||||
assert logcnt1.shape == (1, 16)
|
||||
assert torch.sum(logcnt1) == 24
|
||||
|
||||
# Test case 2: Different weight distributions
|
||||
weight2 = torch.tensor(
|
||||
[
|
||||
[200, 150, 100, 50, 25, 12], # Decreasing weights
|
||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||
]
|
||||
)
|
||||
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
|
||||
weight2, 10, 3, 1, 2
|
||||
)
|
||||
|
||||
assert phy2log2.shape == (2, 10)
|
||||
assert logcnt2.shape == (2, 6)
|
||||
|
||||
# Verify high-weight experts have more replicas
|
||||
for layer in range(2):
|
||||
max_weight_idx = torch.argmax(weight2[layer])
|
||||
assert logcnt2[layer, max_weight_idx] >= 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weight = torch.tensor(
|
||||
[
|
||||
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
|
||||
]
|
||||
)
|
||||
|
||||
num_replicas = 16
|
||||
num_groups = 4
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
print(phy2log)
|
||||
|
||||
test_basic_rebalance()
|
||||
607
tests/distributed/test_eplb_execute.py
Normal file
607
tests/distributed/test_eplb_execute.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.distributed.eplb.rebalance_execute import (
|
||||
move_from_buffer,
|
||||
rearrange_expert_weights_inplace,
|
||||
transfer_layer,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_tp_group,
|
||||
)
|
||||
|
||||
from .eplb_utils import distributed_run, set_env_vars_and_device
|
||||
|
||||
|
||||
def create_expert_indices_with_redundancy(
|
||||
num_layers: int,
|
||||
num_logical_experts: int,
|
||||
total_physical_experts: int,
|
||||
redundancy_config: list[int], # redundancy for each logical expert
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create expert indices with redundancy.
|
||||
|
||||
Args:
|
||||
num_layers: number of layers
|
||||
num_logical_experts: number of logical experts
|
||||
total_physical_experts: total number of physical experts
|
||||
redundancy_config: redundancy for each logical expert
|
||||
|
||||
Returns:
|
||||
indices: Shape (num_layers, total_physical_experts)
|
||||
"""
|
||||
assert sum(redundancy_config) == total_physical_experts
|
||||
assert len(redundancy_config) == num_logical_experts
|
||||
|
||||
indices = torch.zeros(num_layers, total_physical_experts, dtype=torch.long)
|
||||
|
||||
for layer in range(num_layers):
|
||||
physical_pos = 0
|
||||
for logical_expert_id, redundancy in enumerate(redundancy_config):
|
||||
for _ in range(redundancy):
|
||||
indices[layer, physical_pos] = logical_expert_id
|
||||
physical_pos += 1
|
||||
|
||||
# Shuffle the indices at dim 1
|
||||
for layer in range(num_layers):
|
||||
indices[layer] = indices[layer][torch.randperm(indices.shape[1])]
|
||||
|
||||
return indices
|
||||
|
||||
|
||||
def create_expert_weights(
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
hidden_sizes: list[int],
|
||||
rank: int,
|
||||
device: torch.device,
|
||||
physical_to_logical_mapping: torch.Tensor,
|
||||
) -> list[list[torch.Tensor]]:
|
||||
"""
|
||||
Create fake expert weights tensor for testing.
|
||||
|
||||
Use `arange` to generate predictable weights values, based on logical
|
||||
expert ID.
|
||||
All replicas of the same logical expert should have the same weights.
|
||||
|
||||
Args:
|
||||
physical_to_logical_mapping: Shape (num_layers, num_local_experts)
|
||||
mapping[layer, physical_pos] = logical_expert_id
|
||||
"""
|
||||
expert_weights = []
|
||||
|
||||
for layer in range(num_layers):
|
||||
layer_weights = []
|
||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||
weight_tensor = torch.zeros(
|
||||
num_local_experts, hidden_size, device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
for local_expert in range(num_local_experts):
|
||||
# Get the logical expert ID for this physical expert
|
||||
global_pos = rank * num_local_experts + local_expert
|
||||
logical_expert_id = physical_to_logical_mapping[
|
||||
layer, global_pos
|
||||
].item()
|
||||
|
||||
# Generate weights based on logical expert ID
|
||||
# (so that all replicas of the same logical expert have the
|
||||
# same weights)
|
||||
base_value = logical_expert_id * 1000 + layer * 100 + weight_idx * 10
|
||||
weight_tensor[local_expert] = torch.arange(
|
||||
base_value,
|
||||
base_value + hidden_size,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
layer_weights.append(weight_tensor)
|
||||
expert_weights.append(layer_weights)
|
||||
|
||||
return expert_weights
|
||||
|
||||
|
||||
def create_redundancy_config(
|
||||
num_logical_experts: int,
|
||||
num_physical_experts: int,
|
||||
) -> list[int]:
|
||||
"""Create a redundancy configuration."""
|
||||
redundancy_config = [1] * num_logical_experts
|
||||
remaining = num_physical_experts - num_logical_experts
|
||||
# Randomly assign the remaining physical experts to the logical experts
|
||||
for _ in range(remaining):
|
||||
redundancy_config[random.choice(range(num_logical_experts))] += 1
|
||||
return redundancy_config
|
||||
|
||||
|
||||
def verify_expert_weights_after_shuffle(
|
||||
expert_weights: list[list[torch.Tensor]],
|
||||
new_indices: torch.Tensor,
|
||||
hidden_sizes: list[int],
|
||||
ep_rank: int,
|
||||
num_local_experts: int,
|
||||
):
|
||||
"""Verify the weights after shuffling are correct."""
|
||||
num_layers = len(expert_weights)
|
||||
|
||||
for layer in range(num_layers):
|
||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||
weight_tensor = expert_weights[layer][weight_idx]
|
||||
|
||||
for local_expert in range(num_local_experts):
|
||||
# Calculate the global expert ID for this local expert
|
||||
global_pos = ep_rank * num_local_experts + local_expert
|
||||
expected_logical_expert = new_indices[layer, global_pos].item()
|
||||
|
||||
# Check if the weights are correct
|
||||
actual_weights = weight_tensor[local_expert]
|
||||
expected_base = (
|
||||
expected_logical_expert * 1000 + layer * 100 + weight_idx * 10
|
||||
)
|
||||
expected_weights = torch.arange(
|
||||
expected_base,
|
||||
expected_base + hidden_size,
|
||||
device=actual_weights.device,
|
||||
dtype=actual_weights.dtype,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
actual_weights,
|
||||
expected_weights,
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"local expert {local_expert}: "
|
||||
f"weights do not match. "
|
||||
f"Expected logical expert {expected_logical_expert}",
|
||||
)
|
||||
|
||||
|
||||
def verify_redundant_experts_have_same_weights(
|
||||
expert_weights: list[list[torch.Tensor]],
|
||||
indices: torch.Tensor,
|
||||
hidden_sizes: list[int],
|
||||
world_size: int,
|
||||
num_local_experts: int,
|
||||
):
|
||||
"""
|
||||
Verify that all replicas of the same logical expert have the same weights.
|
||||
"""
|
||||
num_layers = len(expert_weights)
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
|
||||
for layer in range(num_layers):
|
||||
# Collect weights for all physical experts for each weight matrix
|
||||
all_weights: list[torch.Tensor] = []
|
||||
|
||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||
# Create tensor to store all expert weights
|
||||
# Shape: [total_physical_experts, hidden_size]
|
||||
gathered_weights = torch.zeros(
|
||||
total_physical_experts,
|
||||
hidden_size,
|
||||
device=expert_weights[layer][weight_idx].device,
|
||||
dtype=expert_weights[layer][weight_idx].dtype,
|
||||
)
|
||||
|
||||
# Use all_gather to collect expert weights from current node
|
||||
# expert_weights[layer][weight_idx] shape:
|
||||
# [num_local_experts, hidden_size]
|
||||
local_weights = expert_weights[layer][
|
||||
weight_idx
|
||||
] # [num_local_experts, hidden_size]
|
||||
|
||||
# Split tensor along dim 0 into a list for all_gather
|
||||
gathered_weights_list = torch.chunk(gathered_weights, world_size, dim=0)
|
||||
|
||||
torch.distributed.all_gather(
|
||||
# Output list: each element corresponds to one rank's weights
|
||||
list(gathered_weights_list),
|
||||
local_weights, # Input: current rank's local weights
|
||||
)
|
||||
|
||||
all_weights.append(gathered_weights)
|
||||
|
||||
# Verify that all replicas of the same logical expert have the same
|
||||
# weights
|
||||
logical_expert_weights: dict[int, dict[int, torch.Tensor]] = {}
|
||||
|
||||
for physical_pos in range(total_physical_experts):
|
||||
logical_expert_id = int(indices[layer, physical_pos].item())
|
||||
|
||||
if logical_expert_id not in logical_expert_weights:
|
||||
# First time encountering this logical expert, save its weights
|
||||
logical_expert_weights[logical_expert_id] = {
|
||||
weight_idx: all_weights[weight_idx][physical_pos]
|
||||
for weight_idx in range(len(hidden_sizes))
|
||||
}
|
||||
else:
|
||||
# Verify that current physical expert's weights match the
|
||||
# previously saved logical expert weights
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
all_weights[weight_idx][physical_pos],
|
||||
logical_expert_weights[logical_expert_id][weight_idx],
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"logical expert {logical_expert_id}: "
|
||||
f"Physical expert {physical_pos} has different weights"
|
||||
f"than expected",
|
||||
)
|
||||
|
||||
|
||||
def _test_async_transfer_layer_without_mtp_worker(
|
||||
env,
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
num_logical_experts: int,
|
||||
) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
tp_group = get_tp_group()
|
||||
ep_group = tp_group.device_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
hidden_sizes = [16, 32]
|
||||
|
||||
redundancy_config = create_redundancy_config(
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
)
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
redundancy_config,
|
||||
)
|
||||
|
||||
new_redundancy_config = create_redundancy_config(
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
new_redundancy_config,
|
||||
)
|
||||
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
device,
|
||||
old_indices,
|
||||
)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
cuda_stream = torch.cuda.Stream(device=device)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
|
||||
transfer_layer(
|
||||
old_global_expert_indices=old_indices,
|
||||
new_global_expert_indices=new_indices,
|
||||
expert_weights=expert_weights,
|
||||
expert_weights_buffer=expert_buffer,
|
||||
ep_group=ep_group,
|
||||
layer=layer_idx,
|
||||
cuda_stream=cuda_stream,
|
||||
)
|
||||
)
|
||||
|
||||
cuda_stream.synchronize()
|
||||
move_from_buffer(
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffer=expert_buffer,
|
||||
is_unchanged=is_unchanged,
|
||||
is_received_locally=is_received_locally,
|
||||
experts_recv_loc=experts_recv_loc,
|
||||
new_indices=new_indices[layer_idx].tolist(),
|
||||
ep_group=ep_group,
|
||||
)
|
||||
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
|
||||
def _test_rearrange_expert_weights_with_redundancy(
|
||||
env, world_size, num_layers, num_local_experts, num_logical_experts
|
||||
) -> None:
|
||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||
# to expert parallel)
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
# Test parameters
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
hidden_sizes = [32, 64] # Two different weight matrices
|
||||
|
||||
# Create old expert indices (with redundancy)
|
||||
redundancy_config = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
redundancy_config,
|
||||
)
|
||||
|
||||
# Create new expert indices (with redundancy)
|
||||
new_redundancy_config = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
new_redundancy_config,
|
||||
)
|
||||
|
||||
# Create expert weights
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
# Execute weight rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify the rearrangement result
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size,num_layers,num_local_experts,num_logical_experts",
|
||||
[
|
||||
# 2 GPU, 2 experts per GPU
|
||||
# 3 logical experts, 4 physical experts, 1 redundant experts
|
||||
(2, 1, 2, 3),
|
||||
# 2 GPU, 3 experts per GPU
|
||||
# 4 logical experts, 6 physical experts, 2 redundant experts
|
||||
(2, 2, 3, 4),
|
||||
# 2 GPU, 8 experts per GPU
|
||||
# 16 logical experts, 16 physical experts, 0 redundant experts
|
||||
(2, 4, 8, 16),
|
||||
# 4 GPU, 2 experts per GPU
|
||||
# 6 logical experts, 8 physical experts, 2 redundant experts
|
||||
(4, 1, 2, 6),
|
||||
# 4 GPU, 2 experts per GPU
|
||||
# 5 logical experts, 8 physical experts, 3 redundant experts
|
||||
(4, 2, 2, 5),
|
||||
# 4 GPU, 8 experts per GPU
|
||||
# 16 logical experts, 32 physical experts, 16 redundant experts
|
||||
(4, 8, 8, 16),
|
||||
],
|
||||
)
|
||||
def test_rearrange_expert_weights_with_redundancy(
|
||||
world_size, num_layers, num_local_experts, num_logical_experts
|
||||
):
|
||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(
|
||||
_test_rearrange_expert_weights_with_redundancy,
|
||||
world_size,
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
)
|
||||
|
||||
|
||||
def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
num_layers = 2
|
||||
num_local_experts = 2
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
num_logical_experts = total_physical_experts // 2 # Some redundancy
|
||||
hidden_sizes = [32, 64]
|
||||
|
||||
# Create redundancy configuration
|
||||
redundancy_config = [2] * num_logical_experts
|
||||
|
||||
# Same indices - no change
|
||||
indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts, redundancy_config
|
||||
)
|
||||
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
|
||||
)
|
||||
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
for layer_weights in expert_weights:
|
||||
layer_copy = []
|
||||
for weight in layer_weights:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
# Execute rearrangement (should be no change)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
indices, # Same indices
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify that the weights have not changed
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg=f"""Layer {layer}, weight {weight_idx}
|
||||
should remain unchanged""",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size,num_layers,num_local_experts,num_logical_experts",
|
||||
[
|
||||
(2, 2, 2, 3),
|
||||
],
|
||||
)
|
||||
def test_async_transfer_layer_without_mtp(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
num_logical_experts: int,
|
||||
):
|
||||
"""Exercise async EPLB transfer path without MTP/spec decode."""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
distributed_run(
|
||||
_test_async_transfer_layer_without_mtp_worker,
|
||||
world_size,
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_rearrange_expert_weights_no_change(world_size):
|
||||
"""
|
||||
Test that when the indices do not change, the weights should remain
|
||||
unchanged.
|
||||
"""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(_test_rearrange_expert_weights_no_change, world_size)
|
||||
|
||||
|
||||
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
num_layers = 1
|
||||
num_local_experts = 2
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
num_logical_experts = total_physical_experts // 2
|
||||
hidden_sizes = [32]
|
||||
|
||||
# Create different index distributions
|
||||
old_redundancy = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
new_redundancy = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts, old_redundancy
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts, new_redundancy
|
||||
)
|
||||
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
for layer_weights in expert_weights:
|
||||
layer_copy = []
|
||||
for weight in layer_weights:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
# Execute profile mode rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=True, # Profile mode
|
||||
)
|
||||
|
||||
# In profile mode, the weights should remain unchanged
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg="In profile mode, the weights should remain unchanged",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
"""Test profile mode (should not copy actual weights)"""
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)
|
||||
285
tests/distributed/test_eplb_fused_moe_layer.py
Normal file
285
tests/distributed/test_eplb_fused_moe_layer.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Test that the interaction between EPLB and FusedMoE Layer is okay
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
from .eplb_utils import distributed_run, set_env_vars_and_device
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
num_layers: int
|
||||
num_experts: int
|
||||
num_local_experts: int
|
||||
num_topk: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
weight_dtype: torch.dtype
|
||||
weight_scale_dtype: torch.dtype | None
|
||||
column_major_scales: bool
|
||||
|
||||
|
||||
def make_expert_weights(
|
||||
layer_idx: int,
|
||||
global_expert_idx: int,
|
||||
global_num_experts: int,
|
||||
tensor_shape: tuple[int, ...],
|
||||
tensor_dtype: torch.dtype,
|
||||
tensor_device: torch.device,
|
||||
is_column_major: bool,
|
||||
) -> torch.Tensor:
|
||||
assert len(tensor_shape) == 2
|
||||
|
||||
if is_column_major:
|
||||
tensor_shape = (tensor_shape[1], tensor_shape[0])
|
||||
|
||||
x = torch.empty(tensor_shape, dtype=tensor_dtype, device=tensor_device)
|
||||
value_offset = (layer_idx * global_num_experts + global_expert_idx) * x.numel()
|
||||
x.view(-1).copy_(
|
||||
torch.arange(
|
||||
value_offset,
|
||||
value_offset + x.numel(),
|
||||
dtype=tensor_dtype,
|
||||
device=tensor_device,
|
||||
)
|
||||
)
|
||||
|
||||
if is_column_major:
|
||||
x = torch.transpose(x, 1, 0)
|
||||
assert not x.is_contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def make_fused_moe_layer(
|
||||
rank: int,
|
||||
layer_idx: int,
|
||||
test_config: TestConfig,
|
||||
) -> FusedMoE:
|
||||
fml = FusedMoE(
|
||||
num_experts=test_config.num_experts,
|
||||
top_k=test_config.num_topk,
|
||||
hidden_size=test_config.hidden_size,
|
||||
intermediate_size=test_config.intermediate_size,
|
||||
prefix=f"dummy_layer_{layer_idx}",
|
||||
activation="silu",
|
||||
is_act_and_mul=True,
|
||||
params_dtype=test_config.weight_dtype,
|
||||
)
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
from functools import partial
|
||||
|
||||
_make_expert_weights = partial(
|
||||
make_expert_weights,
|
||||
layer_idx=layer_idx,
|
||||
global_num_experts=test_config.num_experts,
|
||||
tensor_device=device,
|
||||
)
|
||||
|
||||
assert isinstance(fml.w13_weight.data, torch.Tensor)
|
||||
assert isinstance(fml.w2_weight.data, torch.Tensor)
|
||||
fml.w13_weight.data = fml.w13_weight.data.to(device=device)
|
||||
fml.w2_weight.data = fml.w2_weight.data.to(device=device)
|
||||
w13_weight = fml.w13_weight.data
|
||||
w2_weight = fml.w2_weight.data
|
||||
assert w13_weight.size(0) == test_config.num_local_experts
|
||||
for i in range(test_config.num_local_experts):
|
||||
g_i = rank * test_config.num_local_experts + i
|
||||
w13_weight_e = w13_weight[i]
|
||||
w2_weight_e = w2_weight[i]
|
||||
w13_weight_e.copy_(
|
||||
_make_expert_weights(
|
||||
global_expert_idx=g_i,
|
||||
tensor_shape=w13_weight_e.shape,
|
||||
tensor_dtype=w13_weight_e.dtype,
|
||||
is_column_major=False,
|
||||
)
|
||||
)
|
||||
w2_weight_e.copy_(
|
||||
_make_expert_weights(
|
||||
global_expert_idx=g_i,
|
||||
tensor_shape=w2_weight_e.shape,
|
||||
tensor_dtype=w2_weight_e.dtype,
|
||||
is_column_major=False,
|
||||
)
|
||||
)
|
||||
|
||||
block_size = 16
|
||||
|
||||
def block_quant_scales_shape(
|
||||
shape: tuple[int, ...], is_column_major: bool
|
||||
) -> tuple[int, ...]:
|
||||
assert len(shape) == 3
|
||||
if not is_column_major:
|
||||
return (shape[0], shape[1] // block_size, shape[2] // block_size)
|
||||
else:
|
||||
return (shape[0], shape[2] // block_size, shape[1] // block_size)
|
||||
|
||||
is_column_major = test_config.column_major_scales
|
||||
w13_weight_scale_inv = torch.empty(
|
||||
block_quant_scales_shape(w13_weight.shape, is_column_major),
|
||||
dtype=test_config.weight_dtype,
|
||||
device=device,
|
||||
)
|
||||
w2_weight_scale_inv = torch.empty(
|
||||
block_quant_scales_shape(w2_weight.shape, is_column_major),
|
||||
dtype=test_config.weight_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for i in range(test_config.num_local_experts):
|
||||
g_i = rank * test_config.num_local_experts + i
|
||||
w13_s_e = w13_weight_scale_inv[i]
|
||||
w2_s_e = w2_weight_scale_inv[i]
|
||||
w13_s_e.copy_(
|
||||
_make_expert_weights(
|
||||
global_expert_idx=g_i,
|
||||
tensor_shape=w13_s_e.shape,
|
||||
tensor_dtype=w13_s_e.dtype,
|
||||
# Fill data in row-major and then
|
||||
# transpose if test_config requires col-major.
|
||||
is_column_major=False,
|
||||
)
|
||||
)
|
||||
w2_s_e.copy_(
|
||||
_make_expert_weights(
|
||||
global_expert_idx=g_i,
|
||||
tensor_shape=w2_s_e.shape,
|
||||
tensor_dtype=w2_s_e.dtype,
|
||||
is_column_major=False,
|
||||
)
|
||||
)
|
||||
if is_column_major:
|
||||
w13_weight_scale_inv = torch.transpose(w13_weight_scale_inv, 1, 2)
|
||||
w2_weight_scale_inv = torch.transpose(w2_weight_scale_inv, 1, 2)
|
||||
assert not w13_weight_scale_inv.is_contiguous()
|
||||
assert not w2_weight_scale_inv.is_contiguous()
|
||||
|
||||
# Add scales to the parameter list
|
||||
fml.w13_weight_scale_inv = torch.nn.Parameter(
|
||||
w13_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
fml.w2_weight_scale_inv = torch.nn.Parameter(
|
||||
w2_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
|
||||
return fml
|
||||
|
||||
|
||||
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||
# to expert parallel)
|
||||
set_env_vars_and_device(env)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.tensor_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
|
||||
fml_layers = [
|
||||
make_fused_moe_layer(ep_rank, layer_idx, test_config)
|
||||
for layer_idx in range(test_config.num_layers)
|
||||
]
|
||||
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
|
||||
|
||||
indices = torch.zeros(
|
||||
test_config.num_layers, test_config.num_experts, dtype=torch.long
|
||||
)
|
||||
for lidx in range(test_config.num_layers):
|
||||
indices[lidx] = torch.Tensor(range(test_config.num_experts))
|
||||
|
||||
shuffled_indices = torch.zeros_like(indices)
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
num_local_experts = test_config.num_local_experts
|
||||
num_global_experts = test_config.num_experts
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
for name, w in fml.named_parameters():
|
||||
for e in range(num_local_experts):
|
||||
g_e = shuffled_indices[lidx][ep_rank * num_local_experts + e]
|
||||
ref = make_expert_weights(
|
||||
layer_idx=lidx,
|
||||
global_expert_idx=int(g_e.item()),
|
||||
global_num_experts=num_global_experts,
|
||||
tensor_shape=w[e].shape,
|
||||
tensor_dtype=w[e].dtype,
|
||||
tensor_device=w[e].device,
|
||||
is_column_major=not w[e].is_contiguous(),
|
||||
)
|
||||
assert w[e].shape == ref.shape and w[e].stride() == ref.stride(), (
|
||||
f"w[{e}] {w[e].size()} {w[e].stride()} vs "
|
||||
f"ref {ref.size()} {ref.stride()}"
|
||||
)
|
||||
torch.testing.assert_close(w[e], ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("num_layers", [4])
|
||||
@pytest.mark.parametrize("num_experts", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("intermediate_size", [256])
|
||||
@pytest.mark.parametrize("column_major_scales", [True, False])
|
||||
def test_eplb_fml(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
column_major_scales: bool,
|
||||
):
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
num_local_experts = num_experts // world_size
|
||||
num_topk = 4
|
||||
# The dtypes are fine as we are essentially just checking data-copies
|
||||
weight_dtype = torch.bfloat16
|
||||
weight_scale_dtype = torch.bfloat16
|
||||
|
||||
test_config = TestConfig(
|
||||
num_layers=num_layers,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
num_topk=num_topk,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
weight_dtype=weight_dtype,
|
||||
weight_scale_dtype=weight_scale_dtype,
|
||||
column_major_scales=column_major_scales,
|
||||
)
|
||||
|
||||
distributed_run(
|
||||
_test_eplb_fml,
|
||||
world_size,
|
||||
test_config,
|
||||
)
|
||||
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_dp_group,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
)
|
||||
|
||||
from .eplb_utils import distributed_run, set_env_vars_and_device
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
num_layers: int
|
||||
num_experts: int
|
||||
num_local_experts: int
|
||||
num_topk: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
num_tokens: int
|
||||
|
||||
|
||||
def make_fused_moe_layer(
|
||||
rank: int,
|
||||
layer_idx: int,
|
||||
test_config: TestConfig,
|
||||
) -> FusedMoE:
|
||||
quant_config = None
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
quant_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
)
|
||||
|
||||
fml = FusedMoE(
|
||||
num_experts=test_config.num_experts,
|
||||
top_k=test_config.num_topk,
|
||||
hidden_size=test_config.hidden_size,
|
||||
intermediate_size=test_config.intermediate_size,
|
||||
prefix=f"dummy_layer_{layer_idx}",
|
||||
activation="silu",
|
||||
is_act_and_mul=True,
|
||||
params_dtype=torch.bfloat16,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
|
||||
nvfp4_fused_moe.create_weights(
|
||||
fml,
|
||||
test_config.num_local_experts,
|
||||
test_config.hidden_size,
|
||||
test_config.intermediate_size,
|
||||
params_dtype=torch.uint8,
|
||||
global_num_experts=test_config.num_experts,
|
||||
)
|
||||
|
||||
fml = fml.to(device)
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
test_config.num_local_experts,
|
||||
test_config.intermediate_size,
|
||||
test_config.hidden_size,
|
||||
in_dtype=torch.bfloat16,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
fml.w13_weight.data = w1_q
|
||||
fml.w2_weight.data = w2_q
|
||||
|
||||
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
|
||||
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
|
||||
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
|
||||
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
|
||||
fml.w2_weight_scale.data = (
|
||||
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w2_weight_scale.data.dtype)
|
||||
fml.w13_weight_scale.data = (
|
||||
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w13_weight_scale.data.dtype)
|
||||
|
||||
nvfp4_fused_moe.process_weights_after_loading(fml)
|
||||
|
||||
fml.maybe_init_modular_kernel()
|
||||
|
||||
return fml
|
||||
|
||||
|
||||
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
set_env_vars_and_device(env)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_dp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
fml_layers = [
|
||||
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
|
||||
for layer_idx in range(test_config.num_layers)
|
||||
]
|
||||
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
|
||||
|
||||
hidden_states = []
|
||||
router_logits = []
|
||||
for layer_idx in range(test_config.num_layers):
|
||||
hidden_states.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.hidden_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
router_logits.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.num_experts),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
out_before_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_before_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
indices = torch.zeros(
|
||||
test_config.num_layers, test_config.num_experts, dtype=torch.long
|
||||
)
|
||||
for lidx in range(test_config.num_layers):
|
||||
indices[lidx] = torch.Tensor(range(test_config.num_experts))
|
||||
|
||||
shuffled_indices = torch.zeros_like(indices)
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
num_global_experts = test_config.num_experts
|
||||
|
||||
logical_to_physical_map_list = []
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
physical_to_logical_map = shuffled_indices[lidx].to(device)
|
||||
logical_to_physical_map = torch.empty(
|
||||
(num_global_experts,), dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map[physical_to_logical_map] = torch.arange(
|
||||
0, num_global_experts, dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map_list.append(
|
||||
logical_to_physical_map.reshape(num_global_experts, 1)
|
||||
)
|
||||
|
||||
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
|
||||
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
logical_replica_count = torch.ones(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
fml.enable_eplb = True
|
||||
fml.set_eplb_state(
|
||||
lidx,
|
||||
torch.zeros(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
),
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
out_after_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_after_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
for lidx in range(test_config.num_layers):
|
||||
torch.testing.assert_close(
|
||||
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("num_layers", [8])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("intermediate_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("backend", ["latency", "throughput"])
|
||||
def test_eplb_fml(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_tokens: int,
|
||||
backend: str,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
num_local_experts = num_experts // world_size
|
||||
num_topk = 4
|
||||
|
||||
test_config = TestConfig(
|
||||
num_layers=num_layers,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
num_topk=num_topk,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
distributed_run(
|
||||
_test_eplb_fml,
|
||||
world_size,
|
||||
test_config,
|
||||
)
|
||||
142
tests/distributed/test_eplb_spec_decode.py
Normal file
142
tests/distributed/test_eplb_spec_decode.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import lm_eval
|
||||
import pytest
|
||||
|
||||
from tests.utils import large_gpu_mark
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def get_model_args(
|
||||
model_name: str,
|
||||
spec_model_name: str | None,
|
||||
spec_method: str,
|
||||
tp_size: int,
|
||||
model_max_len: int,
|
||||
use_async: bool = False,
|
||||
) -> dict:
|
||||
speculative_config = {
|
||||
"method": spec_method,
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 1,
|
||||
"max_model_len": model_max_len,
|
||||
}
|
||||
eplb_config = {
|
||||
"num_redundant_experts": tp_size,
|
||||
"window_size": 128,
|
||||
"step_interval": 1024,
|
||||
"log_balancedness": False,
|
||||
}
|
||||
if use_async:
|
||||
eplb_config["use_async"] = True
|
||||
model_args = {
|
||||
"pretrained": model_name,
|
||||
"dtype": "auto",
|
||||
"add_bos_token": True,
|
||||
"tensor_parallel_size": tp_size,
|
||||
"gpu_memory_utilization": 0.7,
|
||||
"speculative_config": speculative_config,
|
||||
"enable_expert_parallel": True,
|
||||
"eplb_config": eplb_config,
|
||||
"enable_eplb": True,
|
||||
"max_model_len": model_max_len,
|
||||
}
|
||||
return model_args
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="EPLB with Spec Decode is a work in progress on ROCm.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_setup",
|
||||
[
|
||||
pytest.param(
|
||||
("mtp", "Qwen/Qwen3-Next-80B-A3B-Instruct", None, 4, 0.86),
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
0.92,
|
||||
),
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues"),
|
||||
),
|
||||
],
|
||||
ids=["qwen3_next_mtp", "llama4_eagle"],
|
||||
)
|
||||
def test_eplb_spec_decode(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model_setup: tuple[str, str, str, int, float],
|
||||
):
|
||||
"""
|
||||
Test the correctness of EPLB speculative decoding with GSM8K dataset.
|
||||
Applicable to MoE models with mtp or eagle spec decode.
|
||||
"""
|
||||
method, model_name, spec_model_name, tp_size, expected_gsm8k_value = model_setup
|
||||
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
|
||||
model_args = get_model_args(
|
||||
model_name=model_name,
|
||||
spec_model_name=spec_model_name,
|
||||
spec_method=method,
|
||||
tp_size=tp_size,
|
||||
model_max_len=4096,
|
||||
)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
batch_size=64,
|
||||
num_fewshot=8,
|
||||
)
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (
|
||||
measured_value - RTOL < expected_gsm8k_value
|
||||
and measured_value + RTOL > expected_gsm8k_value
|
||||
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@large_gpu_mark(min_gb=80)
|
||||
def test_eplb_spec_decode_qwen3_next_mtp_async() -> None:
|
||||
"""
|
||||
Ensure async EPLB works with MTP speculative decoding for Qwen3-Next.
|
||||
"""
|
||||
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
expected_gsm8k_value = 0.86
|
||||
|
||||
model_args = get_model_args(
|
||||
model_name="Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
spec_model_name=None,
|
||||
spec_method="mtp",
|
||||
tp_size=4,
|
||||
model_max_len=4096,
|
||||
use_async=True,
|
||||
)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
batch_size=64,
|
||||
num_fewshot=8,
|
||||
)
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (
|
||||
measured_value - RTOL < expected_gsm8k_value
|
||||
and measured_value + RTOL > expected_gsm8k_value
|
||||
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
|
||||
314
tests/distributed/test_events.py
Normal file
314
tests/distributed/test_events.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
import time
|
||||
|
||||
import msgspec
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import (
|
||||
EventBatch,
|
||||
EventPublisherFactory,
|
||||
NullEventPublisher,
|
||||
)
|
||||
|
||||
DP_RANK = 0
|
||||
|
||||
|
||||
class EventSample(
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore
|
||||
array_like=True, # type: ignore
|
||||
):
|
||||
"""Test event for publisher testing"""
|
||||
|
||||
id: int
|
||||
value: str
|
||||
|
||||
|
||||
class SampleBatch(EventBatch):
|
||||
"""Test event batch for publisher testing"""
|
||||
|
||||
events: list[EventSample]
|
||||
|
||||
|
||||
def create_test_events(count: int) -> SampleBatch:
|
||||
"""Create a batch of test events"""
|
||||
events = [EventSample(id=i, value=f"test-{i}") for i in range(count)]
|
||||
return SampleBatch(ts=time.time(), events=events)
|
||||
|
||||
|
||||
def test_basic_publishing(publisher, subscriber):
|
||||
"""Test basic event publishing works"""
|
||||
|
||||
test_batch = create_test_events(5)
|
||||
publisher.publish(test_batch)
|
||||
|
||||
result = subscriber.receive_one(timeout=1000)
|
||||
assert result is not None, "No message received"
|
||||
|
||||
seq, received = result
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch"
|
||||
assert len(received.events) == len(test_batch.events), "Number of events mismatch"
|
||||
|
||||
for i, event in enumerate(received.events):
|
||||
assert event.id == i, "Event id mismatch"
|
||||
assert event.value == f"test-{i}", "Event value mismatch"
|
||||
|
||||
|
||||
def test_multiple_events(publisher, subscriber):
|
||||
"""Test publishing and receiving multiple event batches"""
|
||||
for _ in range(10):
|
||||
batch = create_test_events(2)
|
||||
publisher.publish(batch)
|
||||
|
||||
received = []
|
||||
for _ in range(10):
|
||||
data = subscriber.receive_one(timeout=100)
|
||||
if data:
|
||||
received.append(data)
|
||||
|
||||
assert len(received) == 10, "Number of messages mismatch"
|
||||
seqs = [seq for seq, _ in received]
|
||||
assert seqs == list(range(10)), "Sequence numbers mismatch"
|
||||
|
||||
|
||||
def test_replay_mechanism(publisher, subscriber):
|
||||
"""Test the replay mechanism works correctly"""
|
||||
for _ in range(19):
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch)
|
||||
|
||||
time.sleep(0.5) # Need publisher to process above requests
|
||||
subscriber.request_replay(10)
|
||||
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch) # 20th message
|
||||
|
||||
replayed = subscriber.receive_replay()
|
||||
|
||||
assert len(replayed) > 0, "No replayed messages received"
|
||||
seqs = [seq for seq, _ in replayed]
|
||||
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
||||
assert seqs == list(range(min(seqs), max(seqs) + 1)), (
|
||||
"Replayed messages not consecutive"
|
||||
)
|
||||
|
||||
|
||||
def test_buffer_limit(publisher, subscriber, publisher_config):
|
||||
"""Test buffer limit behavior"""
|
||||
buffer_size = publisher_config.buffer_steps
|
||||
|
||||
# Publish more events than the buffer can hold
|
||||
for i in range(buffer_size + 10):
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch)
|
||||
|
||||
time.sleep(0.5) # Need publisher to process above requests
|
||||
subscriber.request_replay(0)
|
||||
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch)
|
||||
|
||||
replayed = subscriber.receive_replay()
|
||||
|
||||
assert len(replayed) <= buffer_size, "Can't replay more than buffer size"
|
||||
|
||||
oldest_seq = min(seq for seq, _ in replayed)
|
||||
assert oldest_seq >= 10, "The oldest sequence should be at least 10"
|
||||
|
||||
|
||||
def test_topic_filtering(publisher_config):
|
||||
"""
|
||||
Test that a subscriber only receives messages matching its topic filter
|
||||
"""
|
||||
publisher_config.replay_endpoint = None
|
||||
|
||||
publisher_config.topic = "foo"
|
||||
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
|
||||
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
||||
sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
|
||||
|
||||
try:
|
||||
time.sleep(0.1)
|
||||
|
||||
for _ in range(3):
|
||||
pub.publish(create_test_events(1))
|
||||
|
||||
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is not None for msg in foo_received), (
|
||||
"Subscriber with matching topic should receive messages"
|
||||
)
|
||||
|
||||
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is None for msg in bar_received), (
|
||||
"Subscriber with non-matching topic should receive no messages"
|
||||
)
|
||||
finally:
|
||||
pub.shutdown()
|
||||
sub_foo.close()
|
||||
sub_bar.close()
|
||||
|
||||
|
||||
def test_high_volume(publisher, subscriber):
|
||||
"""Test publishing and receiving a high volume of events"""
|
||||
num_batches = 10_000
|
||||
events_per_batch = 100
|
||||
|
||||
# Publish events in a separate thread to not block
|
||||
def publish_events():
|
||||
for i in range(num_batches):
|
||||
batch = create_test_events(events_per_batch)
|
||||
publisher.publish(batch)
|
||||
# Small delay to avoid overwhelming
|
||||
if i % 100 == 0:
|
||||
time.sleep(0.01)
|
||||
|
||||
received: list[tuple[int, SampleBatch]] = []
|
||||
|
||||
publisher_thread = threading.Thread(target=publish_events)
|
||||
publisher_thread.start()
|
||||
|
||||
start_time = time.time()
|
||||
while len(received) < num_batches:
|
||||
if time.time() - start_time > 10: # Timeout after 10 seconds
|
||||
break
|
||||
|
||||
result = subscriber.receive_one(timeout=100)
|
||||
if result:
|
||||
received.append(result)
|
||||
|
||||
publisher_thread.join()
|
||||
|
||||
assert len(received) >= num_batches * 0.9, "We should have received most messages"
|
||||
|
||||
seqs = [seq for seq, _ in received]
|
||||
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
||||
|
||||
|
||||
def test_null_publisher():
|
||||
"""Test that NullEventPublisher can be used without errors"""
|
||||
publisher = NullEventPublisher(DP_RANK)
|
||||
|
||||
# This should not raise any errors
|
||||
batch = create_test_events(5)
|
||||
publisher.publish(batch)
|
||||
publisher.shutdown()
|
||||
|
||||
|
||||
def test_data_parallel_rank_tagging(publisher_config):
|
||||
"""Test that events are properly tagged with their data parallel rank"""
|
||||
|
||||
publisher_config.topic = "foo"
|
||||
pub_0 = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||
pub_1 = EventPublisherFactory.create(publisher_config, DP_RANK + 1)
|
||||
|
||||
# Hardcode the expected endpoints based on port offsetting behavior
|
||||
# Both ranks get offsets according to _offset_endpoint_port function
|
||||
base_endpoint = publisher_config.endpoint
|
||||
if "tcp://" in base_endpoint:
|
||||
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
|
||||
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
|
||||
expected_endpoint_1 = base_endpoint.replace(
|
||||
":5557", ":5558"
|
||||
) # rank 1 gets port + 1
|
||||
else:
|
||||
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
|
||||
expected_endpoint_0 = base_endpoint # rank 0 gets base
|
||||
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
|
||||
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
|
||||
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
|
||||
|
||||
try:
|
||||
time.sleep(0.1) # Let publishers start up
|
||||
|
||||
# Publish events from different ranks
|
||||
batch_0 = create_test_events(2)
|
||||
batch_1 = create_test_events(3)
|
||||
|
||||
pub_0.publish(batch_0)
|
||||
pub_1.publish(batch_1)
|
||||
|
||||
# Receive events from rank 0
|
||||
result_0 = sub_0.receive_one(timeout=200)
|
||||
assert result_0 is not None, "No message received from rank 0"
|
||||
seq_0, received_0 = result_0
|
||||
|
||||
# Receive events from rank 1
|
||||
result_1 = sub_1.receive_one(timeout=200)
|
||||
assert result_1 is not None, "No message received from rank 1"
|
||||
seq_1, received_1 = result_1
|
||||
|
||||
# Verify DP rank tagging
|
||||
assert received_0.data_parallel_rank == 0, (
|
||||
f"Expected DP rank 0, got {received_0.data_parallel_rank}"
|
||||
)
|
||||
assert received_1.data_parallel_rank == 1, (
|
||||
f"Expected DP rank 1, got {received_1.data_parallel_rank}"
|
||||
)
|
||||
|
||||
# Verify event content is correct
|
||||
assert len(received_0.events) == 2, "Wrong number of events from rank 0"
|
||||
assert len(received_1.events) == 3, "Wrong number of events from rank 1"
|
||||
|
||||
finally:
|
||||
pub_0.shutdown()
|
||||
pub_1.shutdown()
|
||||
sub_0.close()
|
||||
sub_1.close()
|
||||
|
||||
|
||||
def test_event_publisher_factory():
|
||||
"""Test event publisher factory creation behavior under different configurations"""
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.distributed.kv_events import ZmqEventPublisher
|
||||
|
||||
# test config is None
|
||||
publisher = EventPublisherFactory.create(None, DP_RANK)
|
||||
assert isinstance(publisher, NullEventPublisher)
|
||||
publisher.shutdown()
|
||||
|
||||
# test disable kv cache events
|
||||
config = KVEventsConfig(
|
||||
enable_kv_cache_events=False,
|
||||
publisher="zmq", # Even if zmq is specified, should return NullEventPublisher
|
||||
endpoint="tcp://localhost:5557",
|
||||
)
|
||||
publisher = EventPublisherFactory.create(config, DP_RANK)
|
||||
assert isinstance(publisher, NullEventPublisher)
|
||||
publisher.shutdown()
|
||||
|
||||
# test zmq publisher
|
||||
config = KVEventsConfig(
|
||||
enable_kv_cache_events=True,
|
||||
publisher="zmq",
|
||||
endpoint="inproc://test-factory-true",
|
||||
)
|
||||
publisher = EventPublisherFactory.create(config, DP_RANK)
|
||||
assert isinstance(publisher, ZmqEventPublisher)
|
||||
publisher.shutdown()
|
||||
|
||||
# test unknown publisher
|
||||
with pytest.raises(ValueError, match="Input should be"):
|
||||
KVEventsConfig(
|
||||
enable_kv_cache_events=True,
|
||||
publisher="unknown_publisher",
|
||||
endpoint="tcp://localhost:5557",
|
||||
)
|
||||
|
||||
# test publisher not specified
|
||||
config = KVEventsConfig(
|
||||
enable_kv_cache_events=True,
|
||||
# publisher not specified, should default to "zmq"
|
||||
endpoint="tcp://localhost:5557",
|
||||
)
|
||||
publisher = EventPublisherFactory.create(config, DP_RANK)
|
||||
assert isinstance(publisher, ZmqEventPublisher)
|
||||
publisher.shutdown()
|
||||
231
tests/distributed/test_expert_parallel.py
Normal file
231
tests/distributed/test_expert_parallel.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
|
||||
logger = init_logger("test_expert_parallel")
|
||||
|
||||
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
|
||||
class EPTestOptions(NamedTuple):
|
||||
trust_remote_code: bool
|
||||
tokenizer_mode: str | None
|
||||
load_format: str | None = None
|
||||
hf_overrides: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EPTestSettings:
|
||||
parallel_setups: list[ParallelSetup]
|
||||
distributed_backends: list[str]
|
||||
runner: RunnerOption
|
||||
test_options: EPTestOptions
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
*,
|
||||
tp_base: int = 2,
|
||||
runner: RunnerOption = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: str | None = None,
|
||||
load_format: str | None = None,
|
||||
hf_overrides: str | None = None,
|
||||
):
|
||||
return EPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
|
||||
ParallelSetup(
|
||||
tp_size=2 * tp_base, eager_mode=False, chunked_prefill=True
|
||||
),
|
||||
ParallelSetup(
|
||||
tp_size=2 * tp_base, eager_mode=True, chunked_prefill=False
|
||||
),
|
||||
],
|
||||
distributed_backends=["mp", "ray"],
|
||||
runner=runner,
|
||||
test_options=EPTestOptions(
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fast(
|
||||
*,
|
||||
tp_base: int = 2,
|
||||
runner: RunnerOption = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: str | None = None,
|
||||
load_format: str | None = None,
|
||||
hf_overrides: str | None = None,
|
||||
):
|
||||
return EPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp"],
|
||||
runner=runner,
|
||||
test_options=EPTestOptions(
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides,
|
||||
),
|
||||
)
|
||||
|
||||
def iter_params(self, model_name: str):
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for distributed_backend in self.distributed_backends:
|
||||
yield (
|
||||
model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
self.runner,
|
||||
opts,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: You can adjust tp_base locally to fit the model in GPU
|
||||
# The values displayed here are only a rough indicator of the size of the model
|
||||
|
||||
TEST_MODELS = {
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast(trust_remote_code=True),
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4),
|
||||
}
|
||||
|
||||
|
||||
def _compare_tp(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
runner: RunnerOption,
|
||||
test_options: EPTestOptions,
|
||||
num_gpus_available: int,
|
||||
*,
|
||||
method: Literal["generate"],
|
||||
):
|
||||
(
|
||||
tp_size,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
(
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
load_format,
|
||||
hf_overrides,
|
||||
) = test_options
|
||||
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Need at least {tp_size} GPUs")
|
||||
|
||||
common_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
"--load-format",
|
||||
"auto",
|
||||
]
|
||||
if chunked_prefill:
|
||||
common_args.append("--enable-chunked-prefill")
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
if runner != "auto":
|
||||
common_args.extend(["--runner", runner])
|
||||
if trust_remote_code:
|
||||
common_args.append("--trust-remote-code")
|
||||
if tokenizer_mode:
|
||||
common_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||
if load_format:
|
||||
common_args.extend(["--load-format", load_format])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", hf_overrides])
|
||||
|
||||
ep_env = {
|
||||
"VLLM_TEST_ENABLE_EP": "1",
|
||||
}
|
||||
|
||||
ep_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
|
||||
# compare without expert parallelism
|
||||
tp_env = {
|
||||
"VLLM_TEST_ENABLE_EP": "0",
|
||||
}
|
||||
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(
|
||||
model_name,
|
||||
ep_args,
|
||||
tp_args,
|
||||
ep_env,
|
||||
tp_env,
|
||||
method=method,
|
||||
max_wait_seconds=360,
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend", "runner", "test_options"),
|
||||
[
|
||||
params
|
||||
for model_name, settings in TEST_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_ep(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
runner: RunnerOption,
|
||||
test_options: EPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(
|
||||
model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
)
|
||||
244
tests/distributed/test_expert_placement.py
Normal file
244
tests/distributed/test_expert_placement.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
||||
|
||||
|
||||
def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts):
|
||||
"""Verify that the expert map follows the round_robin pattern."""
|
||||
# Calculate expected local experts (supporting non-divisible cases)
|
||||
base_experts = global_num_experts // ep_size
|
||||
remainder = global_num_experts % ep_size
|
||||
|
||||
local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts
|
||||
|
||||
# Expected expert IDs for this rank in round_robin pattern
|
||||
# For non-divisible cases, ranks with extra experts start earlier
|
||||
expected_expert_ids = []
|
||||
for expert_idx in range(local_num_experts):
|
||||
global_expert_id = ep_rank + expert_idx * ep_size
|
||||
expected_expert_ids.append(global_expert_id)
|
||||
|
||||
# Check that only expected experts are mapped to this rank
|
||||
for global_expert_id in range(global_num_experts):
|
||||
if global_expert_id in expected_expert_ids:
|
||||
local_expert_id = expert_map[global_expert_id]
|
||||
expected_local_id = expected_expert_ids.index(global_expert_id)
|
||||
assert local_expert_id == expected_local_id, (
|
||||
f"Global expert {global_expert_id} should map to local expert "
|
||||
f"{expected_local_id}, got {local_expert_id}"
|
||||
)
|
||||
else:
|
||||
assert expert_map[global_expert_id] == -1, (
|
||||
f"Global expert {global_expert_id} should not be mapped to this rank"
|
||||
)
|
||||
|
||||
# Verify that all local expert IDs are consecutive starting from 0
|
||||
local_expert_ids = [expert_map[global_id] for global_id in expected_expert_ids]
|
||||
expected_local_ids = list(range(local_num_experts))
|
||||
assert local_expert_ids == expected_local_ids, (
|
||||
f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
|
||||
"""Test round_robin expert placement with various expert counts."""
|
||||
|
||||
# Test with different global_num_experts values
|
||||
# Include both divisible and non-divisible cases
|
||||
if world_size == 2:
|
||||
test_cases = [
|
||||
(4, 2), # 4 experts (divisible)
|
||||
(8, 2), # 8 experts (divisible)
|
||||
(9, 2), # 9 experts (non-divisible)
|
||||
(16, 2), # 16 experts (divisible)
|
||||
(17, 2), # 17 experts (non-divisible)
|
||||
]
|
||||
elif world_size == 4:
|
||||
test_cases = [
|
||||
(8, 4), # 8 experts (divisible)
|
||||
(16, 4), # 16 experts (divisible)
|
||||
(18, 4), # 18 experts (non-divisible)
|
||||
(32, 4), # 32 experts (divisible)
|
||||
(33, 4), # 33 experts (non-divisible)
|
||||
]
|
||||
else:
|
||||
test_cases = []
|
||||
|
||||
for test_global_experts, test_ep_size in test_cases:
|
||||
# Ensure ep_size matches world_size
|
||||
assert test_ep_size == world_size, (
|
||||
f"ep_size {test_ep_size} must equal world_size {world_size}"
|
||||
)
|
||||
|
||||
# Test each rank
|
||||
for ep_rank in range(world_size):
|
||||
# Calculate expected local experts
|
||||
base_experts = test_global_experts // test_ep_size
|
||||
remainder = test_global_experts % test_ep_size
|
||||
if ep_rank < remainder:
|
||||
expected_test_local = base_experts + 1
|
||||
else:
|
||||
expected_test_local = base_experts
|
||||
|
||||
test_local_experts, test_expert_map, _ = determine_expert_map(
|
||||
ep_size=test_ep_size,
|
||||
ep_rank=ep_rank,
|
||||
global_num_experts=test_global_experts,
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
)
|
||||
|
||||
assert test_local_experts == expected_test_local, (
|
||||
f"For {test_global_experts} experts on {test_ep_size} ranks, "
|
||||
f"rank {ep_rank}: expected {expected_test_local} local"
|
||||
f"experts, got {test_local_experts}"
|
||||
)
|
||||
|
||||
if test_expert_map is not None:
|
||||
assert test_expert_map.shape == (test_global_experts,), (
|
||||
f"Expected expert map shape ({test_global_experts},), "
|
||||
f"got {test_expert_map.shape}"
|
||||
)
|
||||
|
||||
# Verify round_robin pattern for this test case
|
||||
verify_round_robin_pattern(
|
||||
test_expert_map, ep_rank, test_ep_size, test_global_experts
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_expert_placement_edge_cases(expert_placement_strategy, world_size):
|
||||
"""Test edge cases for round_robin expert placement."""
|
||||
|
||||
# Test case 1: ep_size = 1 (should return None for expert_map)
|
||||
local_num_experts, expert_map, _ = determine_expert_map(
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
global_num_experts=8,
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
)
|
||||
assert local_num_experts == 8, "For ep_size=1, should get all experts"
|
||||
assert expert_map is None, "For ep_size=1, expert_map should be None"
|
||||
|
||||
# Test case 2: ep_size = 0 (should raise assertion)
|
||||
with pytest.raises(AssertionError):
|
||||
determine_expert_map(
|
||||
ep_size=0,
|
||||
ep_rank=0,
|
||||
global_num_experts=8,
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
)
|
||||
|
||||
|
||||
def test_determine_expert_map_comprehensive():
|
||||
"""Test of determine_expert_map function with various configurations."""
|
||||
|
||||
# Test cases: (ep_size, ep_rank, global_num_experts,
|
||||
# expert_placement_strategy, expected_local, expected_map_pattern)
|
||||
test_cases = [
|
||||
# Round robin placement tests
|
||||
(
|
||||
2,
|
||||
0,
|
||||
8,
|
||||
"round_robin",
|
||||
4,
|
||||
[0, -1, 1, -1, 2, -1, 3, -1],
|
||||
), # rank 0 gets even experts
|
||||
(
|
||||
2,
|
||||
1,
|
||||
8,
|
||||
"round_robin",
|
||||
4,
|
||||
[-1, 0, -1, 1, -1, 2, -1, 3],
|
||||
), # rank 1 gets odd experts
|
||||
(
|
||||
2,
|
||||
0,
|
||||
9,
|
||||
"round_robin",
|
||||
5,
|
||||
[0, -1, 1, -1, 2, -1, 3, -1, 4],
|
||||
), # rank 0 gets 5 experts (even + last)
|
||||
(
|
||||
2,
|
||||
1,
|
||||
9,
|
||||
"round_robin",
|
||||
4,
|
||||
[-1, 0, -1, 1, -1, 2, -1, 3, -1],
|
||||
), # rank 1 gets 4 experts (odd)
|
||||
# 4-rank tests
|
||||
(
|
||||
4,
|
||||
0,
|
||||
8,
|
||||
"round_robin",
|
||||
2,
|
||||
[0, -1, -1, -1, 1, -1, -1, -1],
|
||||
), # rank 0 gets experts 0, 4
|
||||
(
|
||||
4,
|
||||
1,
|
||||
8,
|
||||
"round_robin",
|
||||
2,
|
||||
[-1, 0, -1, -1, -1, 1, -1, -1],
|
||||
), # rank 1 gets experts 1, 5
|
||||
(
|
||||
4,
|
||||
2,
|
||||
8,
|
||||
"round_robin",
|
||||
2,
|
||||
[-1, -1, 0, -1, -1, -1, 1, -1],
|
||||
), # rank 2 gets experts 2, 6
|
||||
(
|
||||
4,
|
||||
3,
|
||||
8,
|
||||
"round_robin",
|
||||
2,
|
||||
[-1, -1, -1, 0, -1, -1, -1, 1],
|
||||
), # rank 3 gets experts 3, 7
|
||||
]
|
||||
|
||||
for (
|
||||
ep_size,
|
||||
ep_rank,
|
||||
global_num_experts,
|
||||
expert_placement_strategy,
|
||||
expected_local,
|
||||
expected_map_pattern,
|
||||
) in test_cases:
|
||||
local_num_experts, expert_map, _ = determine_expert_map(
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
)
|
||||
|
||||
assert local_num_experts == expected_local, (
|
||||
f"ep_size={ep_size}, ep_rank={ep_rank}, "
|
||||
f"global_num_experts={global_num_experts}, "
|
||||
f"expert_placement_strategy={expert_placement_strategy}: "
|
||||
f"expected {expected_local} local experts, got {local_num_experts}"
|
||||
)
|
||||
|
||||
if expected_map_pattern is None:
|
||||
assert expert_map is None, "Expected expert_map to be None"
|
||||
else:
|
||||
assert expert_map is not None, "Expected expert_map to not be None"
|
||||
actual_map = expert_map.tolist()
|
||||
assert actual_map == expected_map_pattern, (
|
||||
f"ep_size={ep_size}, ep_rank={ep_rank}, "
|
||||
f"global_num_experts={global_num_experts}, "
|
||||
f"expert_placement_strategy={expert_placement_strategy}: "
|
||||
f"expected map {expected_map_pattern}, got {actual_map}"
|
||||
)
|
||||
78
tests/distributed/test_kvlayout.py
Normal file
78
tests/distributed/test_kvlayout.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config import (
|
||||
DeviceConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger("test_expert_parallel")
|
||||
|
||||
|
||||
def test_get_kv_connector_cache_layout_without_kv_connector():
|
||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
assert layout == "NHD"
|
||||
|
||||
|
||||
def test_get_kv_connector_cache_layout_with_lmcache_connector():
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="LMCacheConnectorV1",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
device_config=DeviceConfig("cpu"), kv_transfer_config=kv_transfer_config
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
assert layout == "NHD"
|
||||
|
||||
|
||||
def test_get_kv_connector_cache_layout_with_nixl_connector():
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
model_config = ModelConfig()
|
||||
vllm_config = VllmConfig(
|
||||
device_config=DeviceConfig("cpu"),
|
||||
model_config=model_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
assert layout == "HND"
|
||||
|
||||
|
||||
def test_get_kv_connector_cache_layout_with_multi_connector():
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="MultiConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"connectors": [
|
||||
{"kv_connector": "ExampleConnector", "kv_role": "kv_both"},
|
||||
{"kv_connector": "NixlConnector", "kv_role": "kv_both"},
|
||||
]
|
||||
},
|
||||
)
|
||||
model_config = ModelConfig()
|
||||
vllm_config = VllmConfig(
|
||||
device_config=DeviceConfig("cpu"),
|
||||
model_config=model_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
assert layout == "HND"
|
||||
64
tests/distributed/test_multi_node_assignment.py
Normal file
64
tests/distributed/test_multi_node_assignment.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Make sure ray assigns GPU workers to the correct node.
|
||||
|
||||
Run:
|
||||
```sh
|
||||
cd $VLLM_PATH/tests
|
||||
|
||||
pytest distributed/test_multi_node_assignment.py
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
from vllm import initialize_ray_cluster
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.v1.executor.ray_utils import _wait_until_pg_removed
|
||||
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not VLLM_MULTI_NODE, reason="Need at least 2 nodes to run the test."
|
||||
)
|
||||
def test_multi_node_assignment() -> None:
|
||||
# NOTE: important to keep this class definition here
|
||||
# to let ray use cloudpickle to serialize it.
|
||||
class Actor:
|
||||
def get_ip(self):
|
||||
return get_ip()
|
||||
|
||||
for _ in range(10):
|
||||
config = ParallelConfig(1, 2)
|
||||
initialize_ray_cluster(config)
|
||||
|
||||
current_ip = get_ip()
|
||||
workers = []
|
||||
for bundle_id, bundle in enumerate(config.placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=config.placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
)(Actor).remote()
|
||||
worker_ip = ray.get(worker.get_ip.remote())
|
||||
assert worker_ip == current_ip
|
||||
workers.append(worker)
|
||||
|
||||
for worker in workers:
|
||||
ray.kill(worker)
|
||||
|
||||
_wait_until_pg_removed(config.placement_group)
|
||||
437
tests/distributed/test_multiproc_executor.py
Normal file
437
tests/distributed/test_multiproc_executor.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Integration tests for MultiprocExecutor at the executor level.
|
||||
This test directly tests the executor without going through the LLM interface,
|
||||
focusing on executor initialization, RPC calls, and distributed execution.
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1,
|
||||
max_model_len: int = 256,
|
||||
gpu_memory_utilization: float = 0.3,
|
||||
distributed_executor_backend: str = "mp",
|
||||
nnodes: int = 1,
|
||||
node_rank: int = 0,
|
||||
master_port: int = 0,
|
||||
) -> VllmConfig:
|
||||
"""Create a VllmConfig for testing using EngineArgs."""
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
# Override distributed node settings if needed
|
||||
if nnodes > 1 or node_rank > 0:
|
||||
vllm_config.parallel_config.nnodes = nnodes
|
||||
vllm_config.parallel_config.node_rank = node_rank
|
||||
vllm_config.parallel_config.master_port = master_port
|
||||
if nnodes > 1:
|
||||
vllm_config.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
return vllm_config
|
||||
|
||||
|
||||
def create_test_scheduler_output(num_requests: int = 1) -> SchedulerOutput:
|
||||
"""Create a minimal SchedulerOutput for testing."""
|
||||
# This is a simplified version - in practice you'd need proper
|
||||
# SchedulerOutput construction based on the actual vLLM v1 API
|
||||
return SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_resumed_reqs=[],
|
||||
scheduled_running_reqs=[],
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
def test_multiproc_executor_initialization():
|
||||
"""Test that MultiprocExecutor can be initialized with proper config."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
# Create executor - this should initialize workers
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify executor properties
|
||||
assert executor.world_size == 1, "World size should be 1 for single GPU"
|
||||
assert executor.local_world_size == 1, "Local world size should be 1"
|
||||
assert hasattr(executor, "workers"), "Executor should have workers"
|
||||
assert len(executor.workers) == 1, "Should have 1 worker for single GPU"
|
||||
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_initialization_tensor_parallel():
|
||||
"""Test MultiprocExecutor initialization with tensor parallelism."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify executor properties
|
||||
assert executor.world_size == 2, "World size should be 2 for TP=2"
|
||||
assert executor.local_world_size == 2, "Local world size should be 2"
|
||||
assert len(executor.workers) == 2, "Should have 2 workers for TP=2"
|
||||
|
||||
# Verify output rank calculation
|
||||
output_rank = executor._get_output_rank()
|
||||
assert output_rank == 0, "Output rank should be 0 for TP=2, PP=1"
|
||||
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_collective_rpc():
|
||||
"""Test collective RPC calls to all workers."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Test check_health RPC - should work without errors
|
||||
executor.check_health()
|
||||
|
||||
# Test that RPC works correctly
|
||||
# Note: We're just testing that the RPC mechanism works,
|
||||
# not testing actual model execution here
|
||||
assert not executor.is_failed, "Executor should not be in failed state"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
def test_multiproc_executor_failure_callback():
|
||||
"""Test failure callback registration and invocation."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Test callback registration
|
||||
callback_invoked = []
|
||||
|
||||
def test_callback():
|
||||
callback_invoked.append(True)
|
||||
|
||||
# Register callback
|
||||
executor.register_failure_callback(test_callback)
|
||||
|
||||
# Callback should not be invoked yet
|
||||
assert len(callback_invoked) == 0, "Callback should not be invoked immediately"
|
||||
|
||||
# Simulate failure
|
||||
executor.is_failed = True
|
||||
|
||||
# Register another callback - should be invoked immediately
|
||||
executor.register_failure_callback(test_callback)
|
||||
assert len(callback_invoked) == 1, (
|
||||
"Callback should be invoked when executor is failed"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_worker_monitor():
|
||||
"""Test that worker monitor is set up correctly."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Verify all worker processes are alive
|
||||
for worker in executor.workers:
|
||||
assert worker.proc.is_alive(), f"Worker rank {worker.rank} should be alive"
|
||||
|
||||
# Verify executor is not in failed state
|
||||
assert not executor.is_failed, "Executor should not be in failed state"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
# After shutdown, workers should be terminated
|
||||
import time
|
||||
|
||||
time.sleep(0.5) # Give processes time to terminate
|
||||
for worker in executor.workers:
|
||||
assert not worker.proc.is_alive(), (
|
||||
f"Worker rank {worker.rank} should terminate after shutdown"
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_get_response_message_queues():
|
||||
"""Test message queue retrieval for different ranks."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Get all message queues
|
||||
all_queues = executor.get_response_mqs()
|
||||
assert len(all_queues) == 2, "Should have 2 message queues for 2 workers"
|
||||
|
||||
# Get message queue for specific rank
|
||||
rank0_queue = executor.get_response_mqs(unique_reply_rank=0)
|
||||
assert len(rank0_queue) == 1, "Should have 1 message queue for rank 0"
|
||||
|
||||
rank1_queue = executor.get_response_mqs(unique_reply_rank=1)
|
||||
assert len(rank1_queue) == 1, "Should have 1 message queue for rank 1"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
def test_multiproc_executor_shutdown_cleanup():
|
||||
"""Test that shutdown properly cleans up resources."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify executor is set up
|
||||
assert hasattr(executor, "workers"), "Executor should have workers"
|
||||
assert len(executor.workers) > 0, "Should have at least one worker"
|
||||
|
||||
# Shutdown
|
||||
executor.shutdown()
|
||||
|
||||
# Verify cleanup
|
||||
import time
|
||||
|
||||
time.sleep(0.5) # Give processes time to terminate
|
||||
|
||||
for worker in executor.workers:
|
||||
assert not worker.proc.is_alive(), "Worker processes should be terminated"
|
||||
|
||||
# Verify shutdown event is set
|
||||
assert executor.shutdown_event.is_set(), "Shutdown event should be set"
|
||||
|
||||
# Multiple shutdowns should be safe (idempotent)
|
||||
executor.shutdown()
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_multiproc_executor_pipeline_parallel():
|
||||
"""Test MultiprocExecutor with pipeline parallelism."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Verify executor properties
|
||||
assert executor.world_size == 4, "World size should be 4 for TP=2, PP=2"
|
||||
assert len(executor.workers) == 4, "Should have 4 workers"
|
||||
|
||||
# Verify output rank calculation
|
||||
# For TP=2, PP=2: output should be from the last PP stage (ranks 2-3)
|
||||
# Specifically rank 2 (first rank of last PP stage)
|
||||
output_rank = executor._get_output_rank()
|
||||
assert output_rank == 2, "Output rank should be 2 (first rank of last PP stage)"
|
||||
|
||||
# Verify max_concurrent_batches for pipeline parallel
|
||||
assert executor.max_concurrent_batches == 2, (
|
||||
"Max concurrent batches should equal PP size"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
def test_multiproc_executor_properties():
|
||||
"""Test various executor properties and configurations."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Test supports_pp property
|
||||
assert MultiprocExecutor.supports_pp is True, (
|
||||
"MultiprocExecutor should support pipeline parallelism"
|
||||
)
|
||||
|
||||
# Test world_size calculation
|
||||
assert executor.world_size == (
|
||||
executor.parallel_config.tensor_parallel_size
|
||||
* executor.parallel_config.pipeline_parallel_size
|
||||
), "World size should equal TP * PP"
|
||||
|
||||
# Test local_world_size calculation
|
||||
assert executor.local_world_size == (
|
||||
executor.parallel_config.world_size // executor.parallel_config.nnodes
|
||||
), "Local world size should be world_size / nnodes"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_multiproc_executor_multi_node():
|
||||
"""
|
||||
Test MultiprocExecutor with multi-node configuration.
|
||||
This simulates 2 nodes with TP=4:
|
||||
- Node 0 (rank 0): Uses GPUs 0,1 (CUDA_VISIBLE_DEVICES=0,1) with TP=2
|
||||
- Node 1 (rank 1): Uses GPUs 2,3 (CUDA_VISIBLE_DEVICES=2,3) with TP=2
|
||||
Total world_size = 4, nnodes = 2
|
||||
"""
|
||||
port = get_open_port()
|
||||
# symm_mem does not work for simulating multi instance in single node
|
||||
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||
|
||||
def run_node(node_rank: int, result_queue: multiprocessing.Queue, port: int):
|
||||
"""Run a single node's executor."""
|
||||
executor = None
|
||||
try:
|
||||
# Set CUDA_VISIBLE_DEVICES for this node
|
||||
if node_rank == 0:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
|
||||
|
||||
# Create config for this node
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=4, # Total TP across all nodes
|
||||
pipeline_parallel_size=1,
|
||||
nnodes=2, # 2 nodes
|
||||
node_rank=node_rank,
|
||||
master_port=port, # same port
|
||||
)
|
||||
|
||||
# Create executor for this node
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify node-specific properties
|
||||
assert executor.world_size == 4, (
|
||||
f"World size should be 4 on node {node_rank}"
|
||||
)
|
||||
assert executor.local_world_size == 2, (
|
||||
f"Local world size should be 2 on node {node_rank}"
|
||||
)
|
||||
assert len(executor.workers) == 2, (
|
||||
f"Should have 2 local workers on node {node_rank}"
|
||||
)
|
||||
|
||||
# Verify worker ranks are correct for this node
|
||||
expected_ranks = [node_rank * 2, node_rank * 2 + 1]
|
||||
actual_ranks = sorted([w.rank for w in executor.workers])
|
||||
assert actual_ranks == expected_ranks, (
|
||||
f"Node {node_rank} should have workers "
|
||||
f"with ranks {expected_ranks}, got {actual_ranks}"
|
||||
)
|
||||
# Verify all workers are alive
|
||||
for worker in executor.workers:
|
||||
assert worker.proc.is_alive(), (
|
||||
f"Worker rank {worker.rank} should be alive on node {node_rank}"
|
||||
)
|
||||
# executor.gen
|
||||
# Put success result in queue BEFORE shutdown to avoid hanging
|
||||
result_queue.put({"node": node_rank, "success": True})
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
executor.shutdown()
|
||||
except Exception as e:
|
||||
# Put failure result in queue
|
||||
result_queue.put({"node": node_rank, "success": False, "error": str(e)})
|
||||
raise e
|
||||
finally:
|
||||
if executor is not None:
|
||||
executor.shutdown()
|
||||
|
||||
# Create a queue to collect results from both processes
|
||||
result_queue: multiprocessing.Queue[dict[str, int | bool]] = multiprocessing.Queue()
|
||||
|
||||
# Start both node processes
|
||||
processes = []
|
||||
for node_rank in range(2):
|
||||
p = multiprocessing.Process(
|
||||
target=run_node,
|
||||
args=(node_rank, result_queue, port),
|
||||
name=f"Node{node_rank}",
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
# Wait for both processes to complete
|
||||
all_completed = True
|
||||
for p in processes:
|
||||
p.join(timeout=60)
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=20)
|
||||
if p.is_alive():
|
||||
p.kill()
|
||||
p.join()
|
||||
all_completed = False
|
||||
|
||||
# Check results from both nodes
|
||||
results: list[dict[str, int | bool]] = []
|
||||
while len(results) < 2:
|
||||
try:
|
||||
result = result_queue.get(timeout=1)
|
||||
results.append(result)
|
||||
except Exception:
|
||||
pass
|
||||
assert all_completed, "Not all processes completed successfully"
|
||||
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
|
||||
assert results[0]["success"], f"Node 0 failed: {results[0]}"
|
||||
assert results[1]["success"], f"Node 1 failed: {results[1]}"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user