[Feature] improve TBO: two chunk overlap (#8144)
This commit is contained in:
@@ -262,6 +262,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False |
|
| `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False |
|
||||||
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
|
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
|
||||||
| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False |
|
| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False |
|
||||||
|
| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 |
|
||||||
| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False |
|
| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False |
|
||||||
| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 |
|
| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 |
|
||||||
| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | |
|
| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | |
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"disable_radix_cache",
|
"disable_radix_cache",
|
||||||
"enable_dp_attention",
|
"enable_dp_attention",
|
||||||
"enable_two_batch_overlap",
|
"enable_two_batch_overlap",
|
||||||
|
"tbo_token_distribution_threshold",
|
||||||
"enable_dp_lm_head",
|
"enable_dp_lm_head",
|
||||||
"moe_a2a_backend",
|
"moe_a2a_backend",
|
||||||
"deepep_mode",
|
"deepep_mode",
|
||||||
|
|||||||
@@ -420,16 +420,12 @@ class ForwardBatch:
|
|||||||
batch.extend_prefix_lens, dtype=torch.int32
|
batch.extend_prefix_lens, dtype=torch.int32
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
ret.extend_num_tokens = batch.extend_num_tokens
|
ret.extend_num_tokens = batch.extend_num_tokens
|
||||||
if support_triton(model_runner.server_args.attention_backend):
|
positions, ret.extend_start_loc = compute_position(
|
||||||
positions, ret.extend_start_loc = compute_position_triton(
|
model_runner.server_args.attention_backend,
|
||||||
ret.extend_prefix_lens,
|
ret.extend_prefix_lens,
|
||||||
ret.extend_seq_lens,
|
ret.extend_seq_lens,
|
||||||
ret.extend_num_tokens,
|
ret.extend_num_tokens,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
positions, ret.extend_start_loc = compute_position_torch(
|
|
||||||
ret.extend_prefix_lens, ret.extend_seq_lens
|
|
||||||
)
|
|
||||||
if ret.positions is None:
|
if ret.positions is None:
|
||||||
ret.positions = positions
|
ret.positions = positions
|
||||||
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
||||||
@@ -882,6 +878,25 @@ class PPProxyTensors:
|
|||||||
return f"PPProxyTensors(tensors={self.tensors})"
|
return f"PPProxyTensors(tensors={self.tensors})"
|
||||||
|
|
||||||
|
|
||||||
|
def compute_position(
|
||||||
|
attn_backend: str,
|
||||||
|
extend_prefix_lens: torch.Tensor,
|
||||||
|
extend_seq_lens: torch.Tensor,
|
||||||
|
extend_seq_lens_sum: int,
|
||||||
|
):
|
||||||
|
if support_triton(attn_backend):
|
||||||
|
positions, extend_start_loc = compute_position_triton(
|
||||||
|
extend_prefix_lens,
|
||||||
|
extend_seq_lens,
|
||||||
|
extend_seq_lens_sum,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
positions, extend_start_loc = compute_position_torch(
|
||||||
|
extend_prefix_lens, extend_seq_lens
|
||||||
|
)
|
||||||
|
return positions, extend_start_loc
|
||||||
|
|
||||||
|
|
||||||
def compute_position_triton(
|
def compute_position_triton(
|
||||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -229,6 +229,7 @@ class ServerArgs:
|
|||||||
enable_dp_attention: bool = False
|
enable_dp_attention: bool = False
|
||||||
enable_dp_lm_head: bool = False
|
enable_dp_lm_head: bool = False
|
||||||
enable_two_batch_overlap: bool = False
|
enable_two_batch_overlap: bool = False
|
||||||
|
tbo_token_distribution_threshold: float = 0.48
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
torchao_config: str = ""
|
torchao_config: str = ""
|
||||||
@@ -1689,6 +1690,12 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enabling two micro batches to overlap.",
|
help="Enabling two micro batches to overlap.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tbo-token-distribution-threshold",
|
||||||
|
type=float,
|
||||||
|
default=ServerArgs.tbo_token_distribution_threshold,
|
||||||
|
help="The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-torch-compile",
|
"--enable-torch-compile",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
@@ -17,7 +18,11 @@ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
|||||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
compute_position,
|
||||||
|
)
|
||||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||||
from sglang.srt.operations_strategy import OperationsStrategy
|
from sglang.srt.operations_strategy import OperationsStrategy
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
@@ -58,7 +63,7 @@ def compute_split_seq_index(
|
|||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
if forward_mode == ForwardMode.EXTEND:
|
if forward_mode == ForwardMode.EXTEND:
|
||||||
assert extend_lens is not None
|
assert extend_lens is not None
|
||||||
return _split_array_by_half_sum(extend_lens)
|
return _split_extend_seqs(extend_lens)
|
||||||
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
||||||
assert token_num_per_seq is not None
|
assert token_num_per_seq is not None
|
||||||
return (num_tokens // token_num_per_seq) // 2
|
return (num_tokens // token_num_per_seq) // 2
|
||||||
@@ -69,7 +74,43 @@ def compute_split_seq_index(
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
|
||||||
|
if extend_lens is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
|
||||||
|
left_sum = sum(extend_lens[:vanilla_split_seq_index])
|
||||||
|
overall_sum = sum(extend_lens)
|
||||||
|
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
|
||||||
|
assert threshold <= 0.5, f"{threshold=}"
|
||||||
|
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
|
||||||
|
1 - threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_extend_seqs(arr: Sequence[int]) -> int:
|
||||||
|
if _is_two_chunk_split_enabled(arr):
|
||||||
|
return _split_array_by_cum_less_than_half(arr)
|
||||||
|
|
||||||
|
return _split_array_by_balanced_sum(arr)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
|
||||||
|
left_sum = 0
|
||||||
|
overall_sum = sum(arr)
|
||||||
|
half_sum = overall_sum // 2
|
||||||
|
chosen_index = 0
|
||||||
|
|
||||||
|
for i in range(len(arr)):
|
||||||
|
left_sum += arr[i]
|
||||||
|
if left_sum > half_sum:
|
||||||
|
chosen_index = i
|
||||||
|
break
|
||||||
|
|
||||||
|
return chosen_index
|
||||||
|
|
||||||
|
|
||||||
|
def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
|
||||||
overall_sum = sum(arr)
|
overall_sum = sum(arr)
|
||||||
left_sum = 0
|
left_sum = 0
|
||||||
min_diff = float("inf")
|
min_diff = float("inf")
|
||||||
@@ -88,6 +129,34 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
|||||||
return best_index
|
return best_index
|
||||||
|
|
||||||
|
|
||||||
|
def _update_device_and_sum_field_from_cpu_field(
|
||||||
|
batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
|
||||||
|
):
|
||||||
|
cpu_value = getattr(batch, cpu_field, None)
|
||||||
|
old_device_value = getattr(batch, device_field, None)
|
||||||
|
if (
|
||||||
|
cpu_value is None
|
||||||
|
or old_device_value is None
|
||||||
|
or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
new_device_value = (
|
||||||
|
cpu_value
|
||||||
|
if isinstance(cpu_value, torch.Tensor)
|
||||||
|
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
|
||||||
|
).to(device=global_server_args_dict["device"], non_blocking=True)
|
||||||
|
setattr(batch, device_field, new_device_value)
|
||||||
|
|
||||||
|
if sum_field is not None:
|
||||||
|
sum_value = (
|
||||||
|
cpu_value.sum().item()
|
||||||
|
if isinstance(cpu_value, torch.Tensor)
|
||||||
|
else sum(cpu_value)
|
||||||
|
)
|
||||||
|
setattr(batch, sum_field, sum_value)
|
||||||
|
|
||||||
|
|
||||||
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
||||||
if seq_index == 0:
|
if seq_index == 0:
|
||||||
return 0
|
return 0
|
||||||
@@ -181,6 +250,8 @@ def compute_split_token_index(
|
|||||||
) -> int:
|
) -> int:
|
||||||
if forward_mode == ForwardMode.EXTEND:
|
if forward_mode == ForwardMode.EXTEND:
|
||||||
assert extend_seq_lens is not None
|
assert extend_seq_lens is not None
|
||||||
|
if _is_two_chunk_split_enabled(extend_seq_lens):
|
||||||
|
return sum(extend_seq_lens) // 2
|
||||||
return sum(extend_seq_lens[:split_seq_index])
|
return sum(extend_seq_lens[:split_seq_index])
|
||||||
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
||||||
assert token_num_per_seq is not None
|
assert token_num_per_seq is not None
|
||||||
@@ -388,9 +459,15 @@ class TboForwardBatchPreparer:
|
|||||||
|
|
||||||
tbo_split_token_index = cls._compute_split_token_index(batch)
|
tbo_split_token_index = cls._compute_split_token_index(batch)
|
||||||
|
|
||||||
|
is_enable_two_chunk = (
|
||||||
|
batch.forward_mode == ForwardMode.EXTEND
|
||||||
|
and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
|
||||||
|
)
|
||||||
|
|
||||||
if _tbo_debug:
|
if _tbo_debug:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"TboForwardBatchPreparer.prepare "
|
f"TboForwardBatchPreparer.prepare "
|
||||||
|
f"is_enable_two_chunk={is_enable_two_chunk} "
|
||||||
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
||||||
f"tbo_split_token_index={tbo_split_token_index} "
|
f"tbo_split_token_index={tbo_split_token_index} "
|
||||||
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
||||||
@@ -410,7 +487,11 @@ class TboForwardBatchPreparer:
|
|||||||
start_token_index=0,
|
start_token_index=0,
|
||||||
end_token_index=tbo_split_token_index,
|
end_token_index=tbo_split_token_index,
|
||||||
start_seq_index=0,
|
start_seq_index=0,
|
||||||
end_seq_index=batch.tbo_split_seq_index,
|
end_seq_index=(
|
||||||
|
batch.tbo_split_seq_index + 1
|
||||||
|
if is_enable_two_chunk
|
||||||
|
else batch.tbo_split_seq_index
|
||||||
|
),
|
||||||
output_attn_backend=attn_backend_child_a,
|
output_attn_backend=attn_backend_child_a,
|
||||||
out_num_token_non_padded=out_num_token_non_padded_a,
|
out_num_token_non_padded=out_num_token_non_padded_a,
|
||||||
)
|
)
|
||||||
@@ -424,9 +505,79 @@ class TboForwardBatchPreparer:
|
|||||||
out_num_token_non_padded=out_num_token_non_padded_b,
|
out_num_token_non_padded=out_num_token_non_padded_b,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_enable_two_chunk:
|
||||||
|
cls.derive_fields_related_to_seq_len_for_two_chunk(
|
||||||
|
batch,
|
||||||
|
child_a=child_a,
|
||||||
|
child_b=child_b,
|
||||||
|
tbo_split_seq_index=batch.tbo_split_seq_index,
|
||||||
|
)
|
||||||
|
|
||||||
assert batch.tbo_children is None
|
assert batch.tbo_children is None
|
||||||
batch.tbo_children = [child_a, child_b]
|
batch.tbo_children = [child_a, child_b]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def derive_fields_related_to_seq_len_for_two_chunk(
|
||||||
|
cls,
|
||||||
|
batch: ForwardBatch,
|
||||||
|
*,
|
||||||
|
child_a: ForwardBatch,
|
||||||
|
child_b: ForwardBatch,
|
||||||
|
tbo_split_seq_index: int,
|
||||||
|
):
|
||||||
|
extend_seq_lens_cpu = batch.extend_seq_lens_cpu
|
||||||
|
overall_seq_lens_sum = sum(extend_seq_lens_cpu)
|
||||||
|
half_seq_lens_sum = overall_seq_lens_sum // 2
|
||||||
|
left_last_seq_token_num = half_seq_lens_sum - sum(
|
||||||
|
extend_seq_lens_cpu[:tbo_split_seq_index]
|
||||||
|
)
|
||||||
|
right_first_seq_token_num = (
|
||||||
|
extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
|
||||||
|
)
|
||||||
|
|
||||||
|
# making deepcopy to be extra safe
|
||||||
|
child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
|
||||||
|
child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
|
||||||
|
child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
|
||||||
|
child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
|
||||||
|
for child in [child_a, child_b]:
|
||||||
|
_update_device_and_sum_field_from_cpu_field(
|
||||||
|
batch=child,
|
||||||
|
cpu_field="extend_seq_lens_cpu",
|
||||||
|
device_field="extend_seq_lens",
|
||||||
|
sum_field="extend_num_tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
child_a.extend_num_tokens == half_seq_lens_sum
|
||||||
|
), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
|
||||||
|
|
||||||
|
child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
|
||||||
|
child_a.seq_lens_cpu[-1] = (
|
||||||
|
child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
|
||||||
|
)
|
||||||
|
_update_device_and_sum_field_from_cpu_field(
|
||||||
|
batch=child_a,
|
||||||
|
cpu_field="seq_lens_cpu",
|
||||||
|
device_field="seq_lens",
|
||||||
|
sum_field="seq_lens_sum",
|
||||||
|
)
|
||||||
|
|
||||||
|
child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
|
||||||
|
child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
|
||||||
|
_update_device_and_sum_field_from_cpu_field(
|
||||||
|
batch=child_b,
|
||||||
|
cpu_field="extend_prefix_lens_cpu",
|
||||||
|
device_field="extend_prefix_lens",
|
||||||
|
sum_field=None,
|
||||||
|
)
|
||||||
|
_, child_b.extend_start_loc = compute_position(
|
||||||
|
global_server_args_dict["attention_backend"],
|
||||||
|
child_b.extend_prefix_lens,
|
||||||
|
child_b.extend_seq_lens,
|
||||||
|
child_b.extend_num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def filter_batch(
|
def filter_batch(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ from types import SimpleNamespace
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.two_batch_overlap import compute_split_seq_index
|
from sglang.srt.two_batch_overlap import (
|
||||||
|
compute_split_seq_index,
|
||||||
|
compute_split_token_index,
|
||||||
|
)
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.run_eval import run_eval
|
from sglang.test.run_eval import run_eval
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -73,35 +76,46 @@ class TestTwoBatchOverlap(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class TestTwoBatchOverlapUnitTest(unittest.TestCase):
|
class TestTwoBatchOverlapUnitTest(unittest.TestCase):
|
||||||
# TODO change tests when having 6328
|
def test_compute_split_seq_and_token_index(self):
|
||||||
def test_compute_split_seq_index(self):
|
|
||||||
for num_tokens, expect in [
|
for num_tokens, expect in [
|
||||||
(0, 0),
|
(0, 0),
|
||||||
(100, 50),
|
(100, 50),
|
||||||
(99, 49),
|
(99, 49),
|
||||||
]:
|
]:
|
||||||
actual = compute_split_seq_index(
|
actual = compute_split_seq_index(
|
||||||
forward_mode=ForwardMode.DECODE, num_tokens=num_tokens, extend_lens=None
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
extend_lens=None,
|
||||||
|
token_num_per_seq=1,
|
||||||
)
|
)
|
||||||
self.assertEqual(actual, expect)
|
self.assertEqual(actual, expect)
|
||||||
|
|
||||||
for extend_lens, expect in [
|
for extend_lens, expect in [
|
||||||
([], 0),
|
([], (0, 0)),
|
||||||
([42], 0),
|
([42], (0, 21)),
|
||||||
([42, 999], 1),
|
([42, 999], (1, 520)),
|
||||||
([999, 42], 1),
|
([999, 42], (0, 520)),
|
||||||
([4096, 4096, 4096, 4096], 2),
|
([498, 502], (1, 498)),
|
||||||
([4095, 4096, 4096, 4096, 1], 2),
|
([4096, 4096, 4096, 4096], (2, 8192)),
|
||||||
([1, 4095, 4096, 4096, 4096], 3),
|
([4095, 4096, 4096, 4096, 1], (2, 8191)),
|
||||||
([4097, 4096, 4096, 4095, 1], 2),
|
([1, 4095, 4096, 4096, 4096], (3, 8192)),
|
||||||
([1, 1, 1, 1, 99999], 4),
|
([4097, 4096, 4096, 4095, 1], (2, 8193)),
|
||||||
([99999, 1, 1, 1, 1], 1),
|
([1, 1, 1, 1, 99999], (4, 50001)),
|
||||||
|
([99999, 1, 1, 1, 1], (0, 50001)),
|
||||||
]:
|
]:
|
||||||
actual = compute_split_seq_index(
|
actual_seq_idx = compute_split_seq_index(
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=ForwardMode.EXTEND,
|
||||||
num_tokens=None,
|
num_tokens=None,
|
||||||
extend_lens=extend_lens,
|
extend_lens=extend_lens,
|
||||||
|
token_num_per_seq=None,
|
||||||
)
|
)
|
||||||
|
actual_token_idx = compute_split_token_index(
|
||||||
|
split_seq_index=actual_seq_idx,
|
||||||
|
forward_mode=ForwardMode.EXTEND,
|
||||||
|
extend_seq_lens=extend_lens,
|
||||||
|
token_num_per_seq=None,
|
||||||
|
)
|
||||||
|
actual = (actual_seq_idx, actual_token_idx)
|
||||||
print(f"{extend_lens=} {expect=} {actual=}")
|
print(f"{extend_lens=} {expect=} {actual=}")
|
||||||
self.assertEqual(actual, expect)
|
self.assertEqual(actual, expect)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user