Fix cuda graph with flashinfer (#675)
This commit is contained in:
@@ -64,7 +64,7 @@ def main(args):
|
|||||||
@sgl.function
|
@sgl.function
|
||||||
def few_shot_gsm8k(s, question):
|
def few_shot_gsm8k(s, question):
|
||||||
s += few_shot_examples + question
|
s += few_shot_examples + question
|
||||||
s += sgl.gen("answer", max_tokens=256, stop="Question")
|
s += sgl.gen("answer", max_tokens=512, stop="Question")
|
||||||
|
|
||||||
#####################################
|
#####################################
|
||||||
########## SGL Program End ##########
|
########## SGL Program End ##########
|
||||||
|
|||||||
@@ -150,8 +150,8 @@ class CudaGraphRunner:
|
|||||||
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
||||||
bs = self.batch_size_list[index]
|
bs = self.batch_size_list[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.zero_()
|
self.seq_lens.fill_(1)
|
||||||
self.position_ids_offsets.fill_(1)
|
self.position_ids_offsets.zero_()
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user