[MISC] Clean up torch_npu (#688)
torch_npu 2.5.1 support autoload now. This patch does: 1. remove useless torch_npu import 2. replace `torch_npu.npu` to `torch.npu`. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -18,7 +18,6 @@
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
|
||||
@@ -36,7 +35,7 @@ def init_tensors(self,
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
if device_type == 'npu':
|
||||
self._copy_stream = torch_npu.npu.Stream()
|
||||
self._copy_stream = torch.npu.Stream()
|
||||
|
||||
|
||||
def maybe_collect_rejsample_metrics(
|
||||
|
||||
Reference in New Issue
Block a user