[Misc] Remove some parts of metrics patch (#603)

### What this PR does / why we need it?
Remove some parts of metrics patch, since the `cuda` hard code has been
fixed by https://github.com/vllm-project/vllm/pull/14411.

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2025-04-22 18:45:21 +08:00
committed by GitHub
parent cf6ab42ee2
commit 4a0ce3660e
4 changed files with 69 additions and 40 deletions

View File

@@ -15,46 +15,13 @@
# limitations under the License.
#
from typing import Callable, Optional, Union
from typing import Callable
import torch
import torch_npu
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
from vllm.spec_decode.metrics import AsyncMetricsCollector
Timer = Callable[[], float]
# TODO: revert this patch when the cuda hard code is removed in vllm
# init_tensors: Modified the hard-coded cuda judgment logic to npu;
# maybe_collect_rejsample_metrics: Removed the check for current_platform.is_cuda_alike()
def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'npu') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'npu':
self._copy_stream = torch_npu.npu.Stream()
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
ready_event = self._in_flight_copy
self._in_flight_copy = None
return self._collect_rejsample_metrics(k, ready_event)
# Otherwise, check if we should start a new copy.
if self._should_collect_rejsample_metrics(self._timer()):
assert self._in_flight_copy is None
self._in_flight_copy = self._copy_rejsample_metrics_async()
return None
def _copy_rejsample_metrics_async(self) -> torch.npu.Event:
"""
@@ -83,6 +50,4 @@ def _copy_rejsample_metrics_async(self) -> torch.npu.Event:
return aggregate_metrics_ready
AsyncMetricsCollector.init_tensors = init_tensors
AsyncMetricsCollector.maybe_collect_rejsample_metrics = maybe_collect_rejsample_metrics
AsyncMetricsCollector._copy_rejsample_metrics_async = _copy_rejsample_metrics_async