Fuse more ops & Simplify token mapping (#1758)
This commit is contained in:
@@ -92,6 +92,11 @@ def set_torch_compile_config():
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def clamp_position(seq_lens):
|
||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||
|
||||
|
||||
class CudaGraphRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||
|
||||
@@ -112,7 +117,6 @@ class CudaGraphRunner:
|
||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||
else:
|
||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
|
||||
self.capture_bs = [
|
||||
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
||||
]
|
||||
@@ -253,7 +257,7 @@ class CudaGraphRunner:
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||
positions=clamp_position(seq_lens),
|
||||
)
|
||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user