Add more support for intel Gaudi accelerators (#2357)
This commit is contained in:
@@ -993,7 +993,7 @@ class Scheduler:
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
@@ -1055,7 +1055,7 @@ class Scheduler:
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
else: # embedding or reward model
|
||||
@@ -1130,7 +1130,7 @@ class Scheduler:
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
|
||||
Reference in New Issue
Block a user