Simplify logits penalizer (#2086)
This commit is contained in:
@@ -1019,7 +1019,7 @@ class ScheduleBatch:
|
|||||||
extend_prefix_lens = self.prefix_lens
|
extend_prefix_lens = self.prefix_lens
|
||||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||||
|
|
||||||
if self.sampling_info is not None:
|
if self.sampling_info:
|
||||||
if self.has_grammar:
|
if self.has_grammar:
|
||||||
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||||
else:
|
else:
|
||||||
@@ -1063,6 +1063,7 @@ class ScheduleBatch:
|
|||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
decoding_reqs=self.decoding_reqs,
|
decoding_reqs=self.decoding_reqs,
|
||||||
|
sampling_info=dataclasses.replace(self.sampling_info),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -1122,20 +1123,6 @@ class ModelWorkerBatch:
|
|||||||
# Sampling info
|
# Sampling info
|
||||||
sampling_info: SamplingBatchInfo
|
sampling_info: SamplingBatchInfo
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
|
|
||||||
|
|
||||||
def to(self, device: str):
|
|
||||||
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
|
||||||
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
|
|
||||||
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
|
|
||||||
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
|
|
||||||
self.req_to_token_pool_records = [
|
|
||||||
(x, y.to(device, non_blocking=True))
|
|
||||||
for x, y in self.req_to_token_pool_records
|
|
||||||
]
|
|
||||||
self.sampling_info.to(device)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def write_req_to_token_pool_triton(
|
def write_req_to_token_pool_triton(
|
||||||
|
|||||||
@@ -931,14 +931,14 @@ class Scheduler:
|
|||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
if req.is_retracted:
|
if req.is_retracted:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if req.is_being_chunked <= 0:
|
if req.is_being_chunked <= 0:
|
||||||
# Inflight reqs' prefill is not finished
|
# Inflight reqs' prefill is not finished
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_token_ids[i])
|
req.output_ids.append(next_token_id)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
@@ -947,7 +947,7 @@ class Scheduler:
|
|||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
req.grammar.accept_token(next_token_ids[i])
|
req.grammar.accept_token(next_token_id)
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
logprob_pt += self.add_logprob_return_values(
|
logprob_pt += self.add_logprob_return_values(
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
"""A tensor parallel worker."""
|
"""A tensor parallel worker."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
@@ -138,9 +139,15 @@ class TpModelWorker:
|
|||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
self.model_runner.forward(forward_batch)
|
self.model_runner.forward(forward_batch)
|
||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_generation(
|
||||||
|
self,
|
||||||
|
model_worker_batch: ModelWorkerBatch,
|
||||||
|
launch_event: Optional[threading.Event] = None,
|
||||||
|
):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
|
if launch_event:
|
||||||
|
launch_event.set()
|
||||||
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
"""A tensor parallel worker."""
|
"""A tensor parallel worker."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -107,7 +108,7 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch, self.launch_event
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the future token ids map
|
# Update the future token ids map
|
||||||
@@ -134,7 +135,6 @@ class TpModelWorkerClient:
|
|||||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||||
copy_event.record()
|
copy_event.record()
|
||||||
|
|
||||||
self.launch_event.set()
|
|
||||||
self.output_queue.put((copy_event, logits_output, next_token_ids))
|
self.output_queue.put((copy_event, logits_output, next_token_ids))
|
||||||
|
|
||||||
def resolve_batch_result(self, bid: int):
|
def resolve_batch_result(self, bid: int):
|
||||||
@@ -159,7 +159,10 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||||
# Push a new batch to the queue
|
# Push a new batch to the queue
|
||||||
self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
|
model_worker_batch.sampling_info = dataclasses.replace(
|
||||||
|
model_worker_batch.sampling_info
|
||||||
|
)
|
||||||
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||||
|
|
||||||
# Allocate output future objects
|
# Allocate output future objects
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
|
|||||||
@@ -1,40 +1,34 @@
|
|||||||
import abc
|
import abc
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import typing
|
from typing import List, Set, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class _ReqLike:
|
class _ReqLike:
|
||||||
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
|
origin_input_ids: List[int]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class _BatchLike:
|
class _BatchLike:
|
||||||
reqs: typing.List[_ReqLike]
|
reqs: List[_ReqLike]
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return len(self.reqs)
|
return len(self.reqs)
|
||||||
|
|
||||||
|
|
||||||
class BatchedPenalizerOrchestrator:
|
class BatchedPenalizerOrchestrator:
|
||||||
batch: _BatchLike
|
|
||||||
device: str
|
|
||||||
vocab_size: int
|
|
||||||
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
batch: _BatchLike,
|
batch: _BatchLike,
|
||||||
device: str,
|
device: str,
|
||||||
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
|
Penalizers: Set[Type["_BatchedPenalizer"]],
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
||||||
|
|
||||||
is_required = False
|
is_required = False
|
||||||
@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
|
|||||||
is_required |= pen_is_required
|
is_required |= pen_is_required
|
||||||
self.is_required = is_required
|
self.is_required = is_required
|
||||||
|
|
||||||
|
input_ids = [
|
||||||
|
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
|
||||||
|
for req in self.reqs()
|
||||||
|
]
|
||||||
if self.is_required:
|
if self.is_required:
|
||||||
self.cumulate_input_tokens(
|
self.cumulate_input_tokens(input_ids=input_ids)
|
||||||
input_ids=[req.origin_input_ids for req in self.reqs()]
|
|
||||||
)
|
|
||||||
|
|
||||||
def reqs(self):
|
def reqs(self):
|
||||||
return self.batch.reqs
|
return self.batch.reqs
|
||||||
@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
|
|||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return self.batch.batch_size()
|
return self.batch.batch_size()
|
||||||
|
|
||||||
def cumulate_input_tokens(
|
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
|
||||||
self,
|
|
||||||
input_ids: typing.Union[
|
|
||||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
|
||||||
],
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Feed the input tokens to the penalizers.
|
Feed the input tokens to the penalizers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
|
input_ids (List[torch.Tensor]): The input tokens.
|
||||||
"""
|
"""
|
||||||
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
||||||
|
|
||||||
for penalizer in self.penalizers.values():
|
for penalizer in self.penalizers.values():
|
||||||
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
||||||
|
|
||||||
def cumulate_output_tokens(
|
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||||
self,
|
|
||||||
output_ids: typing.Union[
|
|
||||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
|
||||||
],
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Feed the output tokens to the penalizers.
|
Feed the output tokens to the penalizers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
output_ids (torch.Tensor): The output tokens.
|
||||||
"""
|
"""
|
||||||
if not self.is_required:
|
if not self.is_required:
|
||||||
return
|
return
|
||||||
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
|
|||||||
|
|
||||||
def filter(
|
def filter(
|
||||||
self,
|
self,
|
||||||
indices_to_keep: typing.List[int],
|
indices_to_keep: List[int],
|
||||||
indices_tensor_to_keep: torch.Tensor = None,
|
indices_tensor_to_keep: torch.Tensor = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Filter the penalizers based on the indices to keep in the batch.
|
Filter the penalizers based on the indices to keep in the batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
indices_to_keep (List[int]): List of indices to keep in the batch.
|
||||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
||||||
"""
|
"""
|
||||||
if not self.is_required:
|
if not self.is_required:
|
||||||
@@ -174,32 +160,18 @@ class _TokenIDs:
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
||||||
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
|
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
|
||||||
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
orchestrator: BatchedPenalizerOrchestrator
|
|
||||||
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
|
|
||||||
cached_counts: torch.Tensor = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
orchestrator: BatchedPenalizerOrchestrator,
|
orchestrator: BatchedPenalizerOrchestrator,
|
||||||
token_ids: typing.Union[
|
token_ids: Union[torch.Tensor, List[torch.Tensor]],
|
||||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
self.orchestrator = orchestrator
|
self.orchestrator = orchestrator
|
||||||
|
|
||||||
if not isinstance(token_ids[0], torch.Tensor):
|
|
||||||
token_ids = [
|
|
||||||
torch.tensor(
|
|
||||||
data=ids, dtype=torch.int64, device=self.orchestrator.device
|
|
||||||
)
|
|
||||||
for ids in token_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
self.token_ids = token_ids
|
self.token_ids = token_ids
|
||||||
|
self.cached_counts = None
|
||||||
|
|
||||||
def occurrence_count(self) -> torch.Tensor:
|
def occurrence_count(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -213,19 +185,13 @@ class _TokenIDs:
|
|||||||
|
|
||||||
token_ids = self.token_ids
|
token_ids = self.token_ids
|
||||||
|
|
||||||
if isinstance(token_ids, torch.Tensor):
|
if isinstance(token_ids, list):
|
||||||
token_ids = token_ids.unsqueeze(1)
|
# TODO: optimize this part
|
||||||
|
|
||||||
# needs to be long to be used as index in scatter_add
|
|
||||||
if token_ids.dtype != torch.int64:
|
|
||||||
token_ids = token_ids.to(torch.int64)
|
|
||||||
|
|
||||||
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
sequences=token_ids,
|
sequences=token_ids,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
padding_value=self.orchestrator.vocab_size,
|
padding_value=self.orchestrator.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cached_counts = torch.zeros(
|
self.cached_counts = torch.zeros(
|
||||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
@@ -237,6 +203,16 @@ class _TokenIDs:
|
|||||||
)[
|
)[
|
||||||
:, : self.orchestrator.vocab_size
|
:, : self.orchestrator.vocab_size
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
# TODO: optimize this part. We do not need to create this big tensor every time.
|
||||||
|
# We can directly apply the results on the logits.
|
||||||
|
self.cached_counts = torch.zeros(
|
||||||
|
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
|
||||||
|
device=self.orchestrator.device,
|
||||||
|
)
|
||||||
|
self.cached_counts[
|
||||||
|
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
|
||||||
|
] = 1
|
||||||
|
|
||||||
return self.cached_counts
|
return self.cached_counts
|
||||||
|
|
||||||
@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
|
|||||||
An abstract class for a batched penalizer.
|
An abstract class for a batched penalizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
orchestrator: BatchedPenalizerOrchestrator
|
|
||||||
_is_prepared: bool = False
|
|
||||||
|
|
||||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||||
self.orchestrator = orchestrator
|
self.orchestrator = orchestrator
|
||||||
|
self._is_prepared = False
|
||||||
|
|
||||||
def is_prepared(self) -> bool:
|
def is_prepared(self) -> bool:
|
||||||
return self._is_prepared
|
return self._is_prepared
|
||||||
@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
|
|||||||
|
|
||||||
return self._apply(logits=logits)
|
return self._apply(logits=logits)
|
||||||
|
|
||||||
def filter(
|
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
|
||||||
):
|
|
||||||
if not self.is_prepared():
|
if not self.is_prepared():
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _filter(
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import typing
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||||
|
|
||||||
|
|
||||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||||
@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _teardown(self):
|
def _teardown(self):
|
||||||
del self.frequency_penalties
|
|
||||||
del self.cumulated_frequency_penalties
|
|
||||||
|
|
||||||
self.frequency_penalties = None
|
self.frequency_penalties = None
|
||||||
self.cumulated_frequency_penalties = None
|
self.cumulated_frequency_penalties = None
|
||||||
|
|
||||||
@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|||||||
logits -= self.cumulated_frequency_penalties
|
logits -= self.cumulated_frequency_penalties
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def _filter(
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
|
||||||
):
|
|
||||||
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
||||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||||
indices_tensor_to_keep
|
indices_tensor_to_keep
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import typing
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||||
|
|
||||||
|
|
||||||
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||||
@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _teardown(self):
|
def _teardown(self):
|
||||||
del self.min_new_tokens
|
|
||||||
del self.stop_token_penalties
|
|
||||||
del self.len_output_tokens
|
|
||||||
|
|
||||||
self.min_new_tokens = None
|
self.min_new_tokens = None
|
||||||
self.stop_token_penalties = None
|
self.stop_token_penalties = None
|
||||||
self.len_output_tokens = None
|
self.len_output_tokens = None
|
||||||
@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|||||||
logits[mask] += self.stop_token_penalties[mask]
|
logits[mask] += self.stop_token_penalties[mask]
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def _filter(
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
|
||||||
):
|
|
||||||
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
||||||
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
||||||
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import typing
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||||
|
|
||||||
|
|
||||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||||
@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _teardown(self):
|
def _teardown(self):
|
||||||
del self.presence_penalties
|
|
||||||
del self.cumulated_presence_penalties
|
|
||||||
|
|
||||||
self.presence_penalties = None
|
self.presence_penalties = None
|
||||||
self.cumulated_presence_penalties = None
|
self.cumulated_presence_penalties = None
|
||||||
|
|
||||||
@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|||||||
logits -= self.cumulated_presence_penalties
|
logits -= self.cumulated_presence_penalties
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def _filter(
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
|
||||||
):
|
|
||||||
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
||||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||||
indices_tensor_to_keep
|
indices_tensor_to_keep
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import typing
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||||
|
|
||||||
|
|
||||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||||
@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _teardown(self):
|
def _teardown(self):
|
||||||
del self.repetition_penalties
|
|
||||||
del self.cumulated_repetition_penalties
|
|
||||||
|
|
||||||
self.repetition_penalties = None
|
self.repetition_penalties = None
|
||||||
self.cumulated_repetition_penalties = None
|
self.cumulated_repetition_penalties = None
|
||||||
|
|
||||||
@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|||||||
logits * self.cumulated_repetition_penalties,
|
logits * self.cumulated_repetition_penalties,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _filter(
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
|
||||||
):
|
|
||||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||||
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
||||||
indices_tensor_to_keep
|
indices_tensor_to_keep
|
||||||
|
|||||||
@@ -27,10 +27,10 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
# Bias Tensors
|
# Bias Tensors
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
|
grammars: Optional[List] = None
|
||||||
logit_bias: torch.Tensor = None
|
logit_bias: torch.Tensor = None
|
||||||
vocab_mask: Optional[torch.Tensor] = None
|
vocab_mask: Optional[torch.Tensor] = None
|
||||||
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||||
grammars: Optional[List] = None
|
|
||||||
|
|
||||||
# Penalizer
|
# Penalizer
|
||||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||||
@@ -211,25 +211,3 @@ class SamplingBatchInfo:
|
|||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
return SamplingBatchInfo(
|
|
||||||
temperatures=self.temperatures,
|
|
||||||
top_ps=self.top_ps,
|
|
||||||
top_ks=self.top_ks,
|
|
||||||
min_ps=self.min_ps,
|
|
||||||
is_all_greedy=self.is_all_greedy,
|
|
||||||
need_min_p_sampling=self.need_min_p_sampling,
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def to(self, device: str):
|
|
||||||
for item in [
|
|
||||||
"temperatures",
|
|
||||||
"top_ps",
|
|
||||||
"top_ks",
|
|
||||||
"min_ps",
|
|
||||||
]:
|
|
||||||
value = getattr(self, item)
|
|
||||||
setattr(self, item, value.to(device, non_blocking=True))
|
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ class SamplingParams:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
min_new_tokens: int = 0,
|
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
@@ -34,6 +33,7 @@ class SamplingParams:
|
|||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
|
min_new_tokens: int = 0,
|
||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
n: int = 1,
|
n: int = 1,
|
||||||
|
|||||||
@@ -782,7 +782,7 @@ class PortArgs:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_new(server_args) -> "PortArgs":
|
def init_new(server_args) -> "PortArgs":
|
||||||
port = server_args.port + 42
|
port = server_args.port + random.randint(100, 1000)
|
||||||
while True:
|
while True:
|
||||||
if is_port_available(port):
|
if is_port_available(port):
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import typing
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -16,7 +16,7 @@ from sglang.srt.sampling.penaltylib.orchestrator import (
|
|||||||
class MockSamplingParams:
|
class MockSamplingParams:
|
||||||
frequency_penalty: float = 0.0
|
frequency_penalty: float = 0.0
|
||||||
min_new_tokens: int = 0
|
min_new_tokens: int = 0
|
||||||
stop_token_ids: typing.List[int] = None
|
stop_token_ids: List[int] = None
|
||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
repetition_penalty: float = 1.0
|
repetition_penalty: float = 1.0
|
||||||
|
|
||||||
@@ -24,12 +24,12 @@ class MockSamplingParams:
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MockTokenizer:
|
class MockTokenizer:
|
||||||
eos_token_id: int
|
eos_token_id: int
|
||||||
additional_stop_token_ids: typing.Optional[typing.List[int]] = None
|
additional_stop_token_ids: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MockReq:
|
class MockReq:
|
||||||
origin_input_ids: typing.List[int]
|
origin_input_ids: List[int]
|
||||||
sampling_params: MockSamplingParams
|
sampling_params: MockSamplingParams
|
||||||
tokenizer: MockTokenizer
|
tokenizer: MockTokenizer
|
||||||
|
|
||||||
@@ -42,8 +42,8 @@ class StepType(enum.Enum):
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Step:
|
class Step:
|
||||||
type: StepType
|
type: StepType
|
||||||
token_ids: typing.List[int]
|
token_ids: List[int]
|
||||||
expected_tensors: typing.Dict[str, torch.Tensor]
|
expected_tensors: Dict[str, torch.Tensor]
|
||||||
# assume initial logits are all 1
|
# assume initial logits are all 1
|
||||||
expected_logits: torch.Tensor
|
expected_logits: torch.Tensor
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ class Step:
|
|||||||
class Subject:
|
class Subject:
|
||||||
sampling_params: MockSamplingParams
|
sampling_params: MockSamplingParams
|
||||||
# first step must be input, which will be converted to Req
|
# first step must be input, which will be converted to Req
|
||||||
steps: typing.List[Step]
|
steps: List[Step]
|
||||||
eos_token_id: int = -1
|
eos_token_id: int = -1
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -66,7 +66,7 @@ class Subject:
|
|||||||
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
|
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def tensor_keys(self, i: int = 0) -> typing.Set[str]:
|
def tensor_keys(self, i: int = 0) -> Set[str]:
|
||||||
return set(self.steps[i].expected_tensors.keys())
|
return set(self.steps[i].expected_tensors.keys())
|
||||||
|
|
||||||
def to_req(self) -> MockReq:
|
def to_req(self) -> MockReq:
|
||||||
@@ -80,7 +80,7 @@ class Subject:
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Case:
|
class Case:
|
||||||
enabled: bool
|
enabled: bool
|
||||||
test_subjects: typing.List[Subject]
|
test_subjects: List[Subject]
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# each test_subjects.steps should have the same expected_tensors.keys()
|
# each test_subjects.steps should have the same expected_tensors.keys()
|
||||||
@@ -90,12 +90,12 @@ class Case:
|
|||||||
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
|
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def tensor_keys(self, i: int = 0) -> typing.List[str]:
|
def tensor_keys(self, i: int = 0) -> List[str]:
|
||||||
return set(self.test_subjects[i].tensor_keys())
|
return set(self.test_subjects[i].tensor_keys())
|
||||||
|
|
||||||
|
|
||||||
class BaseBatchedPenalizerTest(unittest.TestCase):
|
class BaseBatchedPenalizerTest(unittest.TestCase):
|
||||||
Penalizer: typing.Type[_BatchedPenalizer]
|
Penalizer: Type[_BatchedPenalizer]
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
vocab_size = 5
|
vocab_size = 5
|
||||||
|
|
||||||
@@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
return torch.tensor(data, **kwargs, device=self.device)
|
return torch.tensor(data, **kwargs, device=self.device)
|
||||||
|
|
||||||
def create_test_subjects(self) -> typing.List[Subject]:
|
def create_test_subjects(self) -> List[Subject]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def create_test_cases(self):
|
def create_test_cases(self):
|
||||||
@@ -127,7 +127,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|||||||
|
|
||||||
def _create_penalizer(
|
def _create_penalizer(
|
||||||
self, case: Case
|
self, case: Case
|
||||||
) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
|
) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
|
||||||
orchestrator = BatchedPenalizerOrchestrator(
|
orchestrator = BatchedPenalizerOrchestrator(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
|
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
|
||||||
@@ -287,22 +287,24 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|||||||
if i < len(subject.steps)
|
if i < len(subject.steps)
|
||||||
]
|
]
|
||||||
|
|
||||||
inputs: typing.List[typing.List[int]] = []
|
inputs: List[List[int]] = []
|
||||||
outputs: typing.List[typing.List[int]] = []
|
outputs: List[List[int]] = []
|
||||||
for subject in filtered_subjects:
|
for subject in filtered_subjects:
|
||||||
step = subject.steps[i]
|
step = subject.steps[i]
|
||||||
if step.type == StepType.INPUT:
|
if step.type == StepType.INPUT:
|
||||||
inputs.append(step.token_ids)
|
raise NotImplementedError()
|
||||||
outputs.append([])
|
|
||||||
else:
|
else:
|
||||||
inputs.append([])
|
inputs.append([])
|
||||||
outputs.append(step.token_ids)
|
outputs.append(step.token_ids)
|
||||||
|
|
||||||
if any(inputs):
|
|
||||||
orchestrator.cumulate_input_tokens(inputs)
|
|
||||||
|
|
||||||
if any(outputs):
|
if any(outputs):
|
||||||
orchestrator.cumulate_output_tokens(outputs)
|
for j in range(max(len(x) for x in outputs)):
|
||||||
|
tmp_outputs = torch.tensor(
|
||||||
|
[x[j] for x in outputs],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=orchestrator.device,
|
||||||
|
)
|
||||||
|
orchestrator.cumulate_output_tokens(tmp_outputs)
|
||||||
|
|
||||||
if penalizer.is_required():
|
if penalizer.is_required():
|
||||||
self.assertTrue(penalizer.is_prepared())
|
self.assertTrue(penalizer.is_prepared())
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
|
||||||
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
@@ -68,7 +73,7 @@ class TestSRTBackend(unittest.TestCase):
|
|||||||
# Run twice to capture more bugs
|
# Run twice to capture more bugs
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
accuracy, latency = test_hellaswag_select()
|
accuracy, latency = test_hellaswag_select()
|
||||||
assert accuracy > 0.71, f"{accuracy=}"
|
self.assertGreater(accuracy, 0.71)
|
||||||
|
|
||||||
def test_gen_min_new_tokens(self):
|
def test_gen_min_new_tokens(self):
|
||||||
test_gen_min_new_tokens()
|
test_gen_min_new_tokens()
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import typing
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -48,7 +48,11 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
|
|||||||
),
|
),
|
||||||
Step(
|
Step(
|
||||||
type=StepType.OUTPUT,
|
type=StepType.OUTPUT,
|
||||||
token_ids=[1, 2, 2],
|
token_ids=[
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
], # This is the output ids of one request in three steps.
|
||||||
expected_tensors={
|
expected_tensors={
|
||||||
"frequency_penalties": self.tensor(
|
"frequency_penalties": self.tensor(
|
||||||
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
|
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
|
||||||
@@ -76,7 +80,7 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_test_subjects(self) -> typing.List[Subject]:
|
def create_test_subjects(self) -> List[Subject]:
|
||||||
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
|
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
|
||||||
self.disabled = self._create_subject(frequency_penalty=0.0)
|
self.disabled = self._create_subject(frequency_penalty=0.0)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import typing
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -143,7 +143,7 @@ class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_test_subjects(self) -> typing.List[Subject]:
|
def create_test_subjects(self) -> List[Subject]:
|
||||||
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
|
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
|
||||||
self.disabled = self._create_subject(min_new_tokens=0.0)
|
self.disabled = self._create_subject(min_new_tokens=0.0)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import typing
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -76,7 +76,7 @@ class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_test_subjects(self) -> typing.List[Subject]:
|
def create_test_subjects(self) -> List[Subject]:
|
||||||
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
|
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
|
||||||
self.disabled = self._create_subject(presence_penalty=0.0)
|
self.disabled = self._create_subject(presence_penalty=0.0)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import typing
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_test_subjects(self) -> typing.List[Subject]:
|
def create_test_subjects(self) -> List[Subject]:
|
||||||
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
|
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
|
||||||
self.disabled = self._create_subject(repetition_penalty=1.0)
|
self.disabled = self._create_subject(repetition_penalty=1.0)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user