Add more support for intel Gaudi accelerators (#2357)

This commit is contained in:
Qun Yang
2024-12-06 17:16:33 +08:00
committed by GitHub
parent 34b364e073
commit 37ee906f61
8 changed files with 88 additions and 14 deletions

View File

@@ -993,7 +993,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
@@ -1055,7 +1055,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
@@ -1130,7 +1130,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs)