import torch class AsyncMetricsCollector: """Class which copies rejection/typical-acceptance sampler metrics from the device to CPU on a non-default Torch stream. """ def _copy_rejsample_metrics_async(self): """Copy rejection/typical-acceptance sampling metrics (number of accepted tokens, etc) to CPU asynchronously. Returns a CUDA event recording when the copy is complete. """ import torch_vacc assert self._copy_stream is not None self._copy_stream.wait_stream(torch.vacc.current_stream()) with torch.vacc.stream(self._copy_stream): self._aggregate_num_accepted_tokens.copy_( self.spec_decode_sampler.num_accepted_tokens, non_blocking=True) self._aggregate_num_emitted_tokens.copy_( self.spec_decode_sampler.num_emitted_tokens, non_blocking=True) # Number of draft tokens is calculated on CPU, so no copy is # required. self._aggregate_num_draft_tokens = ( self.spec_decode_sampler.num_draft_tokens) aggregate_metrics_ready = torch.vacc.Event() aggregate_metrics_ready.record(self._copy_stream) return aggregate_metrics_ready