Fix the possible bug of decode out of memory (#36)

This commit is contained in:
Liangsheng Yin
2024-01-20 03:01:15 +08:00
committed by GitHub
parent 199e82a15d
commit 40ab1f0129
7 changed files with 274 additions and 46 deletions

View File

@@ -34,8 +34,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
reqs.append(req)
# Prefill
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None)
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits (first)", logits)
@@ -47,8 +47,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
req.prefix_indices = model.req_to_token_pool.req_to_token[
batch.req_pool_indices[i], :cut_num
]
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None)
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
@@ -59,7 +59,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
# Decode
for i in range(6):
batch.update_for_decode(next_token_ids.cpu().numpy())
batch.prepare_for_decode(next_token_ids.cpu().numpy())
logits = model.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)