Fuse more ops & Simplify token mapping (#1758)

This commit is contained in:
Lianmin Zheng
2024-10-22 23:20:43 -07:00
committed by GitHub
parent 17536e7e3d
commit ad4125d1a9
9 changed files with 99 additions and 75 deletions

View File

@@ -92,6 +92,11 @@ def set_torch_compile_config():
torch._dynamo.config.accumulated_cache_size_limit = 1024
@torch.compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
@@ -112,7 +117,6 @@ class CudaGraphRunner:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
]
@@ -253,7 +257,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
positions=clamp_position(seq_lens),
)
return forward(input_ids, forward_batch.positions, forward_batch)