Optimize conflicts between CUDA graph and vocab mask tensors (#1392)

This commit is contained in:
Liangsheng Yin
2024-09-13 20:27:53 -07:00
committed by GitHub
parent f3d32f888a
commit 70b6802982
32 changed files with 103 additions and 224 deletions

View File

@@ -41,7 +41,6 @@ class SamplingBatchInfo:
# Vocab bias and min_ps are not supported in CUDA graph
return (
self.logit_bias is None
and self.vocab_mask is None
and self.linear_penalties is None
and self.scaling_penalties is None
and not self.need_min_p_sampling
@@ -50,9 +49,11 @@ class SamplingBatchInfo:
@classmethod
def dummy_one(cls, max_bs: int, vocab_size: int):
ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
with torch.device("cuda"):
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float)
ret.top_ps = torch.ones((max_bs,), dtype=torch.float)
ret.top_ks = torch.ones((max_bs,), dtype=torch.int)
ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool)
return ret
def __getitem__(self, key):
@@ -64,6 +65,7 @@ class SamplingBatchInfo:
temperatures=self.temperatures[key],
top_ps=self.top_ps[key],
top_ks=self.top_ks[key],
vocab_mask=self.vocab_mask[key],
)
else:
raise NotImplementedError
@@ -77,6 +79,11 @@ class SamplingBatchInfo:
self.top_ps[:bs] = other.top_ps
self.top_ks[:bs] = other.top_ks
if other.vocab_mask is None:
self.vocab_mask[:bs].fill_(False)
else:
self.vocab_mask[:bs] = other.vocab_mask
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda"