Support cuda graph for DP attention (#2061)

This commit is contained in:
Ke Bao
2024-11-18 08:29:20 +08:00
committed by GitHub
parent 11f881d173
commit 62832bb272
9 changed files with 88 additions and 26 deletions

View File

@@ -83,9 +83,6 @@ class TpModelWorkerClient:
def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()
def get_tp_device_group(self):
return self.worker.get_tp_device_group()
def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
@@ -96,7 +93,7 @@ class TpModelWorkerClient:
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
@torch.inference_mode()
@torch.no_grad()
def forward_thread_func_(self):
while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get()