implement model runner v2 basic framework (#5051)
### What this PR does / why we need it?
This PR aim to implement model runner v2 basic framework in vllm-ascend,
the e2e function is not guaranteed by this pr.
### Does this PR introduce _any_ user-facing change?
use envs.VLLM_USE_V2_MODEL_RUNNER to decide if choose model_runenr_v2.
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
33
vllm_ascend/worker/v2/utils.py
Normal file
33
vllm_ascend/worker/v2/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_cuda_wrapper():
|
||||
ori_event = torch.cuda.Event
|
||||
ori_stream = torch.cuda.Stream
|
||||
ori_default_stream = torch.cuda.default_stream
|
||||
ori_current_stream = torch.cuda.current_stream
|
||||
ori_graph_pool_handle = torch.cuda.graph_pool_handle
|
||||
ori_cuda_graph_cls = torch.cuda.CUDAGraph
|
||||
ori_cuda_graph_func = torch.cuda.graph
|
||||
try:
|
||||
torch.cuda.Event = torch.npu.Event
|
||||
torch.cuda.Stream = torch.npu.Stream
|
||||
torch.cuda.default_stream = torch.npu.default_stream
|
||||
torch.cuda.current_stream = torch.npu.current_stream
|
||||
torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle
|
||||
torch.cuda.CUDAGraph = torch.npu.NpuGraph
|
||||
torch.cuda.graph = torch.npu.graph
|
||||
yield
|
||||
finally:
|
||||
# revert back torch cuda properties, so it will still raise error
|
||||
# to call cuda ops in npu environment.
|
||||
torch.cuda.Event = ori_event
|
||||
torch.cuda.Stream = ori_stream
|
||||
torch.cuda.default_stream = ori_default_stream
|
||||
torch.cuda.current_stream = ori_current_stream
|
||||
torch.cuda.graph_pool_handle = ori_graph_pool_handle
|
||||
torch.cuda.CUDAGraph = ori_cuda_graph_cls
|
||||
torch.cuda.graph = ori_cuda_graph_func
|
||||
Reference in New Issue
Block a user