1. remove some useless test func and file 2. fix format.sh problem 3. enable full test for singlecard and multicard 4. move long term test to long_term folder. For this kind of test, it only runs by labeled and daily test. Include: spec decode、accuracy test ## After refactor: There are 4 test modules - `singlecard`: contains the test running on one NPU. It'll be run for each PR and daily test. - `multicard`: contains the test running on multi NPUs. It'll be run for each PR and daily test. - `long_term`: contains the test that cost much time(Now include `spec decode` and `accuracy` test). It'll be run for the PR with `long-term-test` labeled and daily test. - `e2e`: contains the test for doc and pd feature. It'll be run for the PR with `pd-test` labeled and daily test. ## Todo: 1. some test are skipped, they should be fixed and reenabled in the future. 2. pyhccl test for multicard doesn't work at all. It should be enabled as well. 3. ensure long-term-test pass by daily test. ### Know issue Now, `ready` labels is required to start pd test or long term test. And when `long-term-test` or `pd-test` is labeled after another one, the old labeled test will be re-run again. So the labeled test should be ran in the following step: 1. decide which test need run, then label it. `long-term-test` or `pd-test` or both. 2. add `ready-for-test` label, then the test will be ran. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
119 lines
3.3 KiB
Python
119 lines
3.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""
|
|
Test the piecewise compilation with a simple model so that we
|
|
can exactly calculate the expected output and side effects.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
from torch.library import Library
|
|
from vllm.compilation.counter import compilation_counter
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
|
set_current_vllm_config)
|
|
from vllm.utils import direct_register_custom_op
|
|
|
|
global_counter = 0
|
|
|
|
# create a library to hold the custom op
|
|
silly_lib = Library("silly", "FRAGMENT") # noqa
|
|
|
|
|
|
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
out: torch.Tensor) -> None:
|
|
global global_counter
|
|
global_counter += 1
|
|
print(f"{global_counter=}")
|
|
out.copy_(q)
|
|
out[0] += 1
|
|
|
|
|
|
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
out: torch.Tensor) -> None:
|
|
return
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="attention",
|
|
op_func=silly_attention,
|
|
mutates_args=["out"],
|
|
fake_impl=silly_attention_fake,
|
|
dispatch_key="PrivateUse1",
|
|
target_lib=silly_lib,
|
|
)
|
|
|
|
|
|
@support_torch_compile
|
|
class SillyModel(nn.Module):
|
|
|
|
def __init__(self,
|
|
*,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
**kwargs) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Overall effect:
|
|
x += 1
|
|
x[0] += 2
|
|
global_counter += 2
|
|
"""
|
|
x = x + 1
|
|
x = x + 2
|
|
out = torch.empty_like(x)
|
|
torch.ops.silly.attention(x, x, x, out)
|
|
x = out
|
|
x = x - 2
|
|
x = x - 1
|
|
out = torch.empty_like(x)
|
|
torch.ops.silly.attention(x, x, x, out)
|
|
x = out
|
|
x = x + 1
|
|
return x
|
|
|
|
|
|
@pytest.mark.skipif(True, reason="requires unreleased components")
|
|
def test_simple_piecewise_compile():
|
|
|
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
|
level=CompilationLevel.PIECEWISE,
|
|
use_inductor=False,
|
|
use_cudagraph=True,
|
|
splitting_ops=["silly.attention"],
|
|
cudagraph_copy_inputs=True,
|
|
cudagraph_capture_sizes=[1, 2],
|
|
))
|
|
vllm_config.compilation_config.pass_config.enable_fusion = False
|
|
with set_current_vllm_config(vllm_config):
|
|
model = SillyModel(vllm_config=vllm_config, prefix="")
|
|
|
|
inputs = torch.randn(100).npu()
|
|
|
|
with compilation_counter.expect(
|
|
num_graphs_seen=1, # one graph for the model
|
|
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
|
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
|
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
|
num_cudagraph_caputured=
|
|
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
):
|
|
|
|
model(inputs)
|
|
|
|
model(torch.randn(2).npu())
|
|
model(torch.randn(1).npu())
|
|
|
|
input = torch.zeros(2).npu()
|
|
global global_counter
|
|
global_counter = 0
|
|
output = model(input)
|
|
assert global_counter == 2
|
|
assert torch.allclose(output.cpu(), torch.tensor([3.0, 1.0]))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_simple_piecewise_compile()
|