Support data parallelism (static) (#480)

Co-authored-by: Ying Sheng <ying.sheng@databricks.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
Ying Sheng
2024-05-27 21:24:10 -07:00
committed by GitHub
parent 565d727409
commit 0463f7fb52
32 changed files with 580 additions and 181 deletions

View File

@@ -0,0 +1,791 @@
import asyncio
import logging
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import List
import rpyc
import torch
from rpyc.utils.classic import obtain
from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import (
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
start_rpyc_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.model_tp")
class ModelTpServer:
def __init__(
self,
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args,
):
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers()
# Copy arguments
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
model_overide_args=model_overide_args,
)
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=model_port_args.nccl_port,
server_args=server_args,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = max(
self.model_config.context_len,
(
min(self.max_total_num_tokens // 6, 65536)
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
),
)
self.max_running_requests = (self.max_total_num_tokens // 2
if server_args.max_running_requests is None else server_args.max_running_requests)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
set_random_seed(server_args.random_seed)
# Print info
logger.info(
f"[gpu_id={self.gpu_id}] "
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"context_len={self.model_config.context_len}, "
)
if self.tp_rank == 0:
logger.info(f"server_args: {server_args.print_mode_args()}")
# Init cache
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = ScheduleHeuristic(
self.schedule_heuristic,
self.max_running_requests,
self.max_prefill_tokens,
self.max_total_num_tokens,
self.tree_cache,
)
self.req_to_token_pool = self.model_runner.req_to_token_pool
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
# Init running status
self.forward_queue: List[Req] = []
self.running_batch: Batch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
)
self.jump_forward_cache = JumpForwardCache()
# Init new token estimation
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.new_token_ratio = min(
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
1.0,
)
self.min_new_token_ratio = min(
global_config.base_min_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
)
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
if self.tp_size * self.dp_size != 1:
recv_reqs = obtain(recv_reqs)
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
# Forward
self.forward_step()
except Exception:
logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
# Return results
ret = self.out_pyobjs
self.out_pyobjs = []
return ret
@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_fill_batch()
if new_batch is not None:
# Run a new fill batch
self.forward_fill_batch(new_batch)
self.cache_filled_batch(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
else:
# Run decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(10):
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)
# Print stats
if self.tp_rank == 0:
if self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu_id={self.gpu_id}] "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.reqs[0].stream:
break
else:
# Check the available size
available_size = (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_tokens:
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
"KV cache pool leak detected!"
)
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
req.origin_input_ids, req.image_offset = (
self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex
)
# Truncate prompts that are too long
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens,
self.model_config.context_len - 1 - len(req.origin_input_ids),
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
)
self.forward_queue.append(req)
def get_new_fill_batch(self):
if (
self.running_batch is not None
and len(self.running_batch.reqs) > self.max_running_requests
):
return None
# Compute matched prefix length
for req in self.forward_queue:
assert (
len(req.output_ids) == 0
), "The output ids should be empty when prefilling"
req.input_ids = req.origin_input_ids + req.prev_output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len]
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
req.prefix_indices = prefix_indices
req.last_node = last_node
# Get priority queue
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
# Add requests if there is available space
can_run_list = []
new_batch_total_tokens = 0
new_batch_input_tokens = 0
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
for r in self.running_batch.reqs
]
)
for req in self.forward_queue:
if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
delta = 2 - req.extend_input_len
req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and req.extend_input_len + new_batch_input_tokens
< self.max_prefill_tokens
):
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta
if not (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
):
# Undo locking
delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta
break
else:
# Add this request to the running batch
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.max_new_tokens()
)
new_batch_input_tokens += req.extend_input_len
else:
break
if len(can_run_list) == 0:
return None
# Print stats
if self.tp_rank == 0:
running_req = (
0 if self.running_batch is None else len(self.running_batch.reqs)
)
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (
hit_tokens + new_batch_input_tokens
) / 10**9
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
logger.info(
f"new fill batch. #seq: {len(can_run_list)}. "
f"#cached_token: {hit_tokens}. "
f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%. "
)
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# Return the new batch
new_batch = Batch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
)
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch
def forward_fill_batch(self, batch: Batch):
# Build batch tensors
batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias
)
if batch.extend_num_tokens != 0:
# Forward
logits, (
prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
if prefill_token_logprobs is not None:
prefill_token_logprobs = prefill_token_logprobs.tolist()
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
next_token_ids, _ = batch.sample(logits)
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
last_token_logprobs = last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish condition
pt = 0
for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]]
req.check_finished()
if req.return_logprob:
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
if req.prefill_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
zip(
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
list(
zip(
prefill_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
+ req.extend_input_len
- 1
],
req.input_ids[-req.last_update_decode_tokens + 1 :],
)
)
)
req.decode_token_logprobs.append(
(last_token_logprobs[i], next_token_ids[i])
)
if req.top_logprobs_num > 0:
if req.prefill_top_logprobs is None:
req.prefill_top_logprobs = prefill_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend(
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
)
req.decode_top_logprobs.append(decode_top_logprobs[i])
pt += req.extend_input_len
self.handle_finished_requests(batch)
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False,
old_last_node=req.last_node,
)
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
def forward_decode_batch(self, batch: Batch):
# check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
retracted_reqs = batch.retract_decode()
logger.info(
"decode out of memory happened, "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.forward_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
self.min_new_token_ratio,
)
if not self.disable_regex_jump_forward:
# check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
self.forward_queue.extend(jump_forward_reqs)
if batch.is_empty():
return
# Update batch tensors
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode()
# Forward
logits, (
_,
_,
_,
decode_top_logprobs,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.tolist()
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
new_token_logprobs = last_logprobs[
torch.arange(len(batch.reqs)), next_token_ids
].tolist()
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
if req.return_logprob:
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(decode_top_logprobs[i])
self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: Batch):
output_rids = []
prev_output_strs = []
output_tokens = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = []
output_finished = []
finished_indices = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if req.finished:
finished_indices.append(i)
else:
unfinished_indices.append(i)
if req.finished or (
(
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
)
):
output_rids.append(req.rid)
prev_output_strs.append(req.prev_output_str)
output_tokens.append(req.output_ids)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": FinishReason.to_str(req.finish_reason),
"hit_stop_str": req.hit_stop_str,
}
if req.return_logprob:
(
meta_info["prefill_token_logprobs"],
meta_info["decode_token_logprobs"],
meta_info["prefill_top_logprobs"],
meta_info["decode_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.prefill_token_logprobs,
req.decode_token_logprobs,
req.prefill_top_logprobs,
req.decode_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
output_finished.append(req.finished)
# Send to detokenizer
if output_rids:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
prev_output_strs,
output_tokens,
output_hit_stop_str,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished,
)
)
# Remove finished reqs
if finished_indices:
# Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.tolist()
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
)
self.tree_cache.dec_lock_ref(req.last_node)
# Update batch tensors
if unfinished_indices:
batch.filter_batch(unfinished_indices)
else:
batch.reqs = []
def flush_cache(self):
if len(self.forward_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
else:
warnings.warn(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
def abort_request(self, recv_req):
# Delete requests in the waiting queue
to_del = None
for i, req in enumerate(self.forward_queue):
if req.rid == recv_req.rid:
to_del = i
break
if to_del is not None:
del self.forward_queue[to_del]
# Delete requests in the running batch
if self.running_batch:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid:
req.finished = True
req.finish_reason = FinishReason.ABORT
break
class ModelTpService(rpyc.Service):
exposed_ModelTpServer = ModelTpServer
class ModelTpClient:
def __init__(
self,
gpu_ids: List[int],
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args,
):
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
self.tp_size = server_args.tp_size
if self.tp_size * server_args.dp_size == 1:
# Init model
assert len(gpu_ids) == 1
self.model_server = ModelTpService().exposed_ModelTpServer(
0,
gpu_ids[0],
server_args,
model_port_args,
model_overide_args,
)
# Wrap functions
def async_wrap(f):
async def _func(*args, **kwargs):
return f(*args, **kwargs)
return _func
self.step = async_wrap(self.model_server.exposed_step)
else:
with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes
rets = executor.map(
lambda args: start_rpyc_process(*args),
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
)
self.model_services = [x[0] for x in rets]
self.procs = [x[1] for x in rets]
# Init model
def init_model(i):
return self.model_services[i].ModelTpServer(
gpu_ids[i],
i,
server_args,
model_port_args,
model_overide_args,
)
self.model_servers = executor.map(init_model, range(self.tp_size))
# Wrap functions
def async_wrap(func_name):
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
async def _func(*args, **kwargs):
tasks = [f(*args, **kwargs) for f in fs]
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
return obtain(tasks[0].value)
return _func
self.step = async_wrap("step")