Fix a draft model accuracy bug in eagle; support step=1; return logprob in eagle (#4134)

Co-authored-by: Sehoon Kim <kssteven418@gmail.com>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
Lianmin Zheng
2025-03-06 06:13:59 -08:00
committed by GitHub
parent 3a3918121f
commit bc1534ff32
11 changed files with 304 additions and 106 deletions

View File

@@ -396,16 +396,10 @@ class CudaGraphRunner:
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global global_graph_memory_pool
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global_graph_memory_pool = graph.pool()
return graph, out