Update version to v0.1.13 (#280)
This commit is contained in:
@@ -28,8 +28,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
|
||||
reqs = []
|
||||
for i in range(len(prompts)):
|
||||
req = Req(i, None, None)
|
||||
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
|
||||
input_ids = tokenizer.encode(prompts[i])[:cut_num]
|
||||
req = Req(i, prompts[i], input_ids)
|
||||
req.sampling_params = sampling_params
|
||||
reqs.append(req)
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
# Decode
|
||||
for i in range(6):
|
||||
batch.prepare_for_decode(next_token_ids.cpu().numpy())
|
||||
logits = model.forward(batch, ForwardMode.DECODE)
|
||||
logits, _ = model.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
|
||||
print(
|
||||
|
||||
Reference in New Issue
Block a user