Fix the possible bug of decode out of memory (#36)
This commit is contained in:
@@ -34,8 +34,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
reqs.append(req)
|
||||
|
||||
# Prefill
|
||||
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.init_extend_batch(model.model_config.vocab_size(), None)
|
||||
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.prepare_for_extend(model.model_config.vocab_size, None)
|
||||
logits, _ = model.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
print("extend logits (first)", logits)
|
||||
@@ -47,8 +47,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
req.prefix_indices = model.req_to_token_pool.req_to_token[
|
||||
batch.req_pool_indices[i], :cut_num
|
||||
]
|
||||
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.init_extend_batch(model.model_config.vocab_size(), None)
|
||||
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.prepare_for_extend(model.model_config.vocab_size, None)
|
||||
logits, _ = model.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
|
||||
@@ -59,7 +59,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
|
||||
# Decode
|
||||
for i in range(6):
|
||||
batch.update_for_decode(next_token_ids.cpu().numpy())
|
||||
batch.prepare_for_decode(next_token_ids.cpu().numpy())
|
||||
logits = model.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
|
||||
|
||||
132
test/srt/test_robust.py
Normal file
132
test/srt/test_robust.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import argparse
|
||||
import random
|
||||
import string
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
TOKENIZER = None
|
||||
RANDOM_PREFILL_LEN = None
|
||||
RANDOM_DECODE_LEN = None
|
||||
|
||||
|
||||
def gen_prompt(token_num):
|
||||
if RANDOM_PREFILL_LEN:
|
||||
token_num = random.randint(1, token_num)
|
||||
|
||||
cha_set = string.ascii_letters + string.digits
|
||||
ret = "".join(random.choices(cha_set, k=token_num))
|
||||
while len(TOKENIZER(ret).input_ids) < token_num:
|
||||
ret += random.choice(cha_set)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def robust_test_dfs(s, d, args, leaf_states):
|
||||
if d == 0:
|
||||
s += "END"
|
||||
leaf_states.append(s)
|
||||
return
|
||||
|
||||
s += gen_prompt(args.len_prefill)
|
||||
forks = s.fork(args.num_fork)
|
||||
for fork_s in forks:
|
||||
fork_s += gen_prompt(args.len_prefill)
|
||||
new_tokens = (
|
||||
args.len_decode
|
||||
if not RANDOM_DECODE_LEN
|
||||
else random.randint(1, args.len_decode)
|
||||
)
|
||||
fork_s += sgl.gen(
|
||||
max_tokens=new_tokens,
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
for fork_s in forks:
|
||||
robust_test_dfs(fork_s, d - 1, args, leaf_states)
|
||||
|
||||
|
||||
def robust_test_bfs(s, args, leaf_states):
|
||||
old_forks = [s]
|
||||
new_forks = []
|
||||
for _ in range(args.depth):
|
||||
for old_fork in old_forks:
|
||||
old_fork += gen_prompt(args.len_prefill)
|
||||
forks = old_fork.fork(args.num_fork)
|
||||
for fork_s in forks:
|
||||
fork_s += gen_prompt(args.len_prefill)
|
||||
new_tokens = (
|
||||
args.len_decode
|
||||
if not RANDOM_DECODE_LEN
|
||||
else random.randint(1, args.len_decode)
|
||||
)
|
||||
fork_s += sgl.gen(
|
||||
max_tokens=new_tokens,
|
||||
ignore_eos=True,
|
||||
)
|
||||
new_forks.extend(forks)
|
||||
|
||||
old_forks = new_forks
|
||||
new_forks = []
|
||||
|
||||
for old_fork in old_forks:
|
||||
old_fork += "END"
|
||||
leaf_states.append(old_fork)
|
||||
|
||||
|
||||
@sgl.function
|
||||
def robust_test(s, args):
|
||||
leaf_states = []
|
||||
if args.mode == "bfs":
|
||||
robust_test_bfs(s, args, leaf_states)
|
||||
else:
|
||||
robust_test_dfs(s, args.depth, args, leaf_states)
|
||||
return leaf_states
|
||||
|
||||
|
||||
def main(args):
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
arguments = [{"args": args} for _ in range(args.num_req)]
|
||||
|
||||
states = robust_test.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel
|
||||
)
|
||||
|
||||
with open(f"tmp_robust_{args.mode}.txt", "w") as f:
|
||||
for state in states:
|
||||
leaf_states = state.ret_value
|
||||
for leaf_state in leaf_states:
|
||||
assert leaf_state.text()[-3:] == "END"
|
||||
f.write(leaf_state.text()[:-3] + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-req", type=int, default=2)
|
||||
parser.add_argument("--depth", type=int, default=3)
|
||||
parser.add_argument("--num-fork", type=int, default=2)
|
||||
parser.add_argument("--len-prefill", type=int, default=128)
|
||||
parser.add_argument("--len-decode", type=int, default=128)
|
||||
parser.add_argument("--random-prefill-len", action="store_true")
|
||||
parser.add_argument("--random-decode-len", action="store_true")
|
||||
parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
|
||||
parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
# fmt: on
|
||||
|
||||
RANDOM_PREFILL_LEN = args.random_prefill_len
|
||||
RANDOM_DECODE_LEN = args.random_decode_len
|
||||
TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
random.seed(args.seed)
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user