[refactor] refactor model runner capture model (#5230)
### What this PR does / why we need it?
Refactor the `capture_model` method in model_runner to directly reuse
the method from vLLM.
Currently, most of the logic in the capture_model method is similar to
that in the vllm code. Directly using the vllm method can reduce the
maintenance cost of the vllm-ascend code. Modify as follows:
1、refactor capture_model function, directly inheriting community methods
2、refactor initialize_aclgraph_capture function, move to
initialize_attn_backend
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -54,10 +54,10 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
|
||||
|
||||
return wrapped
|
||||
|
||||
original_capture = NPUModelRunner._capture_model
|
||||
original_capture = NPUModelRunner.capture_model
|
||||
|
||||
with patch.object(NPUModelRunner,
|
||||
'_capture_model',
|
||||
'capture_model',
|
||||
new=capture_model_wrapper(original_capture)):
|
||||
prompts = [
|
||||
"Hello, my name is", "The president of the United States is",
|
||||
@@ -73,7 +73,7 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
|
||||
vllm_model = VllmRunner(snapshot_download(model))
|
||||
_ = vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
assert capture_called.value == 1, "_capture_model was not called during test"
|
||||
assert capture_called.value == 1, "capture_model was not called during test"
|
||||
assert capture_mem_before.value != -1, "capture_mem_before not set"
|
||||
assert capture_mem_after.value != -1, "capture_mem_after not set"
|
||||
|
||||
@@ -93,7 +93,7 @@ def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
|
||||
max_capture_mem_gib = baseline_capture_mem * capture_mem_tolerance
|
||||
max_mem_expected = max_capture_mem_gib * (1024**3)
|
||||
assert mem_used_by_capture < max_mem_expected, (
|
||||
f"_capture_model used more memory than expected. "
|
||||
f"capture_model used more memory than expected. "
|
||||
f"Used: {mem_used_by_capture / (1024**3):.2f} GiB, "
|
||||
f"Expected: < {max_capture_mem_gib:.2f} GiB")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn'
|
||||
|
||||
Reference in New Issue
Block a user