Fix the possible bug of decode out of memory (#36)
This commit is contained in:
@@ -34,8 +34,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
reqs.append(req)
|
||||
|
||||
# Prefill
|
||||
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.init_extend_batch(model.model_config.vocab_size(), None)
|
||||
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.prepare_for_extend(model.model_config.vocab_size, None)
|
||||
logits, _ = model.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
print("extend logits (first)", logits)
|
||||
@@ -47,8 +47,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
req.prefix_indices = model.req_to_token_pool.req_to_token[
|
||||
batch.req_pool_indices[i], :cut_num
|
||||
]
|
||||
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.init_extend_batch(model.model_config.vocab_size(), None)
|
||||
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.prepare_for_extend(model.model_config.vocab_size, None)
|
||||
logits, _ = model.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
|
||||
@@ -59,7 +59,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
|
||||
# Decode
|
||||
for i in range(6):
|
||||
batch.update_for_decode(next_token_ids.cpu().numpy())
|
||||
batch.prepare_for_decode(next_token_ids.cpu().numpy())
|
||||
logits = model.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user