Files
sglang/python/sglang/srt/managers/controller/tp_worker.py
2024-07-29 02:38:31 -07:00

822 lines
30 KiB
Python

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""A tensor parallel worker."""
import logging
import multiprocessing
import pickle
import time
import warnings
from typing import List, Optional
import torch
import torch.distributed as dist
from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.controller.infer_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class ModelTpServer:
def __init__(
self,
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
model_overide_args: dict,
):
suppress_other_loggers()
# Copy arguments
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
model_overide_args=model_overide_args,
)
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=nccl_port,
server_args=server_args,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = (
16384
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
self.max_running_requests = min(
(
self.max_total_num_tokens // 2
if server_args.max_running_requests is None
else server_args.max_running_requests
),
self.model_runner.req_to_token_pool.size - 1,
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
self.max_req_input_len = min(
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,
)
set_random_seed(server_args.random_seed)
# Print info
logger.info(
f"[gpu={self.gpu_id}] "
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}"
)
# Init cache
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = ScheduleHeuristic(
self.schedule_heuristic,
self.max_running_requests,
self.max_prefill_tokens,
self.max_total_num_tokens,
self.tree_cache,
)
self.req_to_token_pool = self.model_runner.req_to_token_pool
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
# Init running status
self.forward_queue: List[Req] = []
self.running_batch: Batch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
)
self.jump_forward_cache = JumpForwardCache()
# Init new token estimation
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.min_new_token_ratio = min(
global_config.base_min_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
)
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
# Forward
self.forward_step()
except Exception:
logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
raise
# Return results
ret = self.out_pyobjs
self.out_pyobjs = []
return ret
@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_prefill_batch()
if new_batch is not None:
# Run a new prefill batch
self.forward_prefill_batch(new_batch)
self.cache_filled_batch(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
else:
# Run a decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)
# Print stats
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_stats()
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.has_stream():
break
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def print_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu={self.gpu_id}] Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
def check_memory(self):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_tokens:
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
"KV cache pool leak detected!"
)
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
(
req.origin_input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex
)
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warn(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
self.forward_queue.append(req)
def get_new_prefill_batch(self) -> Optional[Batch]:
running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
if running_bs >= self.max_running_requests:
return
# Compute matched prefix length
for req in self.forward_queue:
req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len]
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
req.prefix_indices = prefix_indices
req.last_node = last_node
# Get priority queue
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
# Add requests if there is available space
can_run_list = []
new_batch_total_tokens = 0
new_batch_input_tokens = 0
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
if self.running_batch:
available_size -= sum(
[
(r.sampling_params.max_new_tokens - len(r.output_ids))
* self.new_token_ratio
for r in self.running_batch.reqs
]
)
for req in self.forward_queue:
if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
delta = 2 - req.extend_input_len
req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
if (
req.extend_input_len
+ req.sampling_params.max_new_tokens
+ new_batch_total_tokens
< available_size
and (
req.extend_input_len + new_batch_input_tokens
<= self.max_prefill_tokens
or len(can_run_list) == 0
)
):
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta
if not (
req.extend_input_len
+ req.sampling_params.max_new_tokens
+ new_batch_total_tokens
< available_size
):
# Undo locking
delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta
break
else:
# Add this request to the running batch
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
else:
break
if running_bs + len(can_run_list) >= self.max_running_requests:
break
if len(can_run_list) == 0:
return None
# Print stats
if self.tp_rank == 0:
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (
hit_tokens + new_batch_input_tokens
) / 10**9
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {new_batch_input_tokens}, "
f"#cached-token: {hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
)
# Return the new batch
new_batch = Batch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
)
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch
def forward_prefill_batch(self, batch: Batch):
# Build batch tensors
batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias
)
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish conditions
pt = 0
for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len
self.handle_finished_requests(batch)
def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.input_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.input_token_logprobs = list(
zip(
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.input_token_logprobs = [
(None, req.input_ids[0])
] + req.input_token_logprobs
if req.last_update_decode_tokens != 0:
req.output_token_logprobs.extend(
list(
zip(
output.input_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
+ req.extend_input_len
- 1
],
req.input_ids[-req.last_update_decode_tokens + 1 :],
)
)
)
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
if req.top_logprobs_num > 0:
if req.input_top_logprobs is None:
req.input_top_logprobs = output.input_top_logprobs[i]
if req.logprob_start_len == 0:
req.input_top_logprobs = [None] + req.input_top_logprobs
if req.last_update_decode_tokens != 0:
req.output_top_logprobs.extend(
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
)
req.output_top_logprobs.append(output.output_top_logprobs[i])
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False,
old_last_node=req.last_node,
)
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
def forward_decode_batch(self, batch: Batch):
# Check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode()
self.new_token_ratio = new_token_ratio
logger.info(
"decode out of memory happened, "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.forward_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
self.min_new_token_ratio,
)
if not self.disable_regex_jump_forward:
# Check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
self.forward_queue.extend(jump_forward_reqs)
if batch.is_empty():
return
# Update batch tensors
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode()
# Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
if req.return_logprob:
req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(output.output_top_logprobs[i])
self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: Batch):
output_rids = []
output_vids = []
decoded_texts = []
output_read_ids = []
output_read_offsets = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
finished_indices = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if req.finished():
finished_indices.append(i)
else:
unfinished_indices.append(i)
if req.finished() or (
(
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
)
):
output_rids.append(req.rid)
output_vids.append(req.vid)
decoded_texts.append(req.decoded_text)
read_ids, read_offset = req.init_incremental_detokenize()
output_read_ids.append(read_ids)
output_read_offsets.append(read_offset)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finished_reason),
}
if req.return_logprob:
(
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
output_finished_reason.append(req.finished_reason)
# Send to detokenizer
if output_rids:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
output_vids,
decoded_texts,
output_read_ids,
output_read_offsets,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
)
)
# Remove finished reqs
if finished_indices:
# Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.tolist()
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
)
self.tree_cache.dec_lock_ref(req.last_node)
# Update batch tensors
if unfinished_indices:
batch.filter_batch(unfinished_indices)
else:
batch.reqs = []
def flush_cache(self):
if len(self.forward_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
else:
warnings.warn(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
def abort_request(self, recv_req):
# Delete requests in the waiting queue
to_del = None
for i, req in enumerate(self.forward_queue):
if req.rid == recv_req.rid:
to_del = i
break
if to_del is not None:
del self.forward_queue[to_del]
# Delete requests in the running batch
if self.running_batch:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid:
req.finished_reason = FINISH_ABORT()
break
def run_tp_server(
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
model_overide_args: dict,
):
"""Run a tensor parallel server."""
try:
model_server = ModelTpServer(
gpu_id,
tp_rank,
server_args,
nccl_port,
model_overide_args,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
while True:
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
model_server.exposed_step(recv_reqs)
except Exception:
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
raise
def launch_tp_servers(
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
):
"""Launch multiple tensor parallel servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(
target=run_tp_server,
args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args),
)
proc.start()
procs.append(proc)
return procs
def broadcast_recv_input(data, rank, dist_group):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
data = pickle.loads(serialized_data)
return data