Support cuda graph for DP attention (#2061)
This commit is contained in:
@@ -83,9 +83,6 @@ class TpModelWorkerClient:
|
||||
def get_tp_cpu_group(self):
|
||||
return self.worker.get_tp_cpu_group()
|
||||
|
||||
def get_tp_device_group(self):
|
||||
return self.worker.get_tp_device_group()
|
||||
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.worker.model_runner.req_to_token_pool,
|
||||
@@ -96,7 +93,7 @@ class TpModelWorkerClient:
|
||||
with torch.cuda.stream(self.forward_stream):
|
||||
self.forward_thread_func_()
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def forward_thread_func_(self):
|
||||
while True:
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
|
||||
Reference in New Issue
Block a user