[lint] improve ruff check (#11922)
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
@@ -27,7 +27,9 @@ repos:
|
|||||||
rev: v0.11.7
|
rev: v0.11.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--select=F401,F821, --fixable=F401]
|
args:
|
||||||
|
- --select=F401,F821
|
||||||
|
- --fix
|
||||||
files: ^(benchmark/|docs/|examples/|python/sglang/)
|
files: ^(benchmark/|docs/|examples/|python/sglang/)
|
||||||
exclude: __init__\.py$|\.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
|
exclude: __init__\.py$|\.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
|
|||||||
@@ -167,6 +167,7 @@ class MiniMaxText01LightningAttention(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
slope_rate: Optional[torch.Tensor] = None,
|
slope_rate: Optional[torch.Tensor] = None,
|
||||||
|
do_eval: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if (not self.training) and (not do_eval):
|
if (not self.training) and (not do_eval):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -10,6 +11,8 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
|
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -302,6 +305,7 @@ class MiniMaxText01LightningAttention(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
slope_rate: Optional[torch.Tensor] = None,
|
slope_rate: Optional[torch.Tensor] = None,
|
||||||
|
do_eval: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if (not self.training) and (not do_eval):
|
if (not self.training) and (not do_eval):
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import argparse
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -39,6 +40,8 @@ from sglang.srt.server_args import ServerArgs
|
|||||||
from sglang.srt.utils import is_blackwell, kill_process_tree
|
from sglang.srt.utils import is_blackwell, kill_process_tree
|
||||||
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProfileLinks(BaseModel):
|
class ProfileLinks(BaseModel):
|
||||||
"""Pydantic model for profile trace links."""
|
"""Pydantic model for profile trace links."""
|
||||||
|
|||||||
@@ -77,8 +77,8 @@ class CommonKVManager(BaseKVManager):
|
|||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self._register_to_bootstrap()
|
self._register_to_bootstrap()
|
||||||
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
self.transfer_infos = {}
|
||||||
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
self.decode_kv_args_table = {}
|
||||||
self.pp_group = get_pp_group()
|
self.pp_group = get_pp_group()
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import struct
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional,
|
|||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import openai.types.responses as openai_responses_types
|
import openai.types.responses as openai_responses_types
|
||||||
|
import orjson
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
@@ -1063,7 +1064,7 @@ class OpenAIServingResponses(OpenAIServingChat):
|
|||||||
):
|
):
|
||||||
function_name = previous_item.recipient[len("browser.") :]
|
function_name = previous_item.recipient[len("browser.") :]
|
||||||
action = None
|
action = None
|
||||||
parsed_args = ororjson.loads(previous_item.content[0].text)
|
parsed_args = orjson.loads(previous_item.content[0].text)
|
||||||
if function_name == "search":
|
if function_name == "search":
|
||||||
action = openai_responses_types.response_function_web_search.ActionSearch(
|
action = openai_responses_types.response_function_web_search.ActionSearch(
|
||||||
type="search",
|
type="search",
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
if init_new_workspace:
|
if init_new_workspace:
|
||||||
self.workspace_buffer = torch.empty(
|
self.workspace_buffer = torch.empty(
|
||||||
global_config.flashinfer_workspace_size,
|
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||||
|
FlashInferMlaAttnBackend,
|
||||||
|
)
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.speculative.spec_info import SpecInput
|
from sglang.srt.speculative.spec_info import SpecInput
|
||||||
@@ -66,7 +69,7 @@ global_workspace_buffer = None
|
|||||||
|
|
||||||
class FlashInferMhaChunkKVRunner:
|
class FlashInferMhaChunkKVRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
|
self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend
|
||||||
):
|
):
|
||||||
# Parse Constants
|
# Parse Constants
|
||||||
self.num_local_heads = (
|
self.num_local_heads = (
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ from triton_kernels.matmul_ogs import (
|
|||||||
PrecisionConfig,
|
PrecisionConfig,
|
||||||
matmul_ogs,
|
matmul_ogs,
|
||||||
)
|
)
|
||||||
from triton_kernels.numerics import InFlexData
|
from triton_kernels.numerics import InFlexData, MicroscalingCtx
|
||||||
|
from triton_kernels.quantization import downcast_to_mxfp
|
||||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
||||||
from triton_kernels.swiglu import swiglu_fn
|
from triton_kernels.swiglu import swiglu_fn
|
||||||
|
|
||||||
@@ -119,14 +120,14 @@ def triton_kernel_fused_experts(
|
|||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
|
||||||
assert per_channel_quant == False, "per_channel_quant is not supported"
|
assert per_channel_quant is False, "per_channel_quant is not supported"
|
||||||
assert expert_map == None, "expert_map is not supported"
|
assert expert_map is None, "expert_map is not supported"
|
||||||
assert w1_scale == None, "w1_scale is not supported"
|
assert w1_scale is None, "w1_scale is not supported"
|
||||||
assert w2_scale == None, "w2_scale is not supported"
|
assert w2_scale is None, "w2_scale is not supported"
|
||||||
assert a1_scale == None, "a1_scale is not supported"
|
assert a1_scale is None, "a1_scale is not supported"
|
||||||
assert a2_scale == None, "a2_scale is not supported"
|
assert a2_scale is None, "a2_scale is not supported"
|
||||||
assert block_shape == None, "block_shape is not supported"
|
assert block_shape is None, "block_shape is not supported"
|
||||||
|
|
||||||
# type check
|
# type check
|
||||||
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
||||||
@@ -143,7 +144,7 @@ def triton_kernel_fused_experts(
|
|||||||
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
||||||
|
|
||||||
# feature check
|
# feature check
|
||||||
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
assert inplace is False, "Inplace is not supported in new triton MoE kernel"
|
||||||
|
|
||||||
M, K = hidden_states.shape
|
M, K = hidden_states.shape
|
||||||
E, _, N = w1.shape
|
E, _, N = w1.shape
|
||||||
@@ -264,14 +265,14 @@ def triton_kernel_fused_experts_with_bias(
|
|||||||
gemm1_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
gemm1_clamp_limit: Optional[float] = None,
|
gemm1_clamp_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
|
||||||
assert per_channel_quant == False, "per_channel_quant is not supported"
|
assert per_channel_quant is False, "per_channel_quant is not supported"
|
||||||
assert expert_map == None, "expert_map is not supported"
|
assert expert_map is None, "expert_map is not supported"
|
||||||
assert w1_scale == None, "w1_scale is not supported"
|
assert w1_scale is None, "w1_scale is not supported"
|
||||||
assert w2_scale == None, "w2_scale is not supported"
|
assert w2_scale is None, "w2_scale is not supported"
|
||||||
assert a1_scale == None, "a1_scale is not supported"
|
assert a1_scale is None, "a1_scale is not supported"
|
||||||
assert a2_scale == None, "a2_scale is not supported"
|
assert a2_scale is None, "a2_scale is not supported"
|
||||||
assert block_shape == None, "block_shape is not supported"
|
assert block_shape is None, "block_shape is not supported"
|
||||||
|
|
||||||
# type check
|
# type check
|
||||||
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
||||||
@@ -290,7 +291,7 @@ def triton_kernel_fused_experts_with_bias(
|
|||||||
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
||||||
|
|
||||||
# feature check
|
# feature check
|
||||||
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
assert inplace is False, "Inplace is not supported in new triton MoE kernel"
|
||||||
|
|
||||||
E, _, _ = w1.shape
|
E, _, _ = w1.shape
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,13 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
|||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_24 import (
|
||||||
|
CompressedTensors24,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w4a16_sparse24 import (
|
||||||
|
W4A16SPARSE24_SUPPORTED_BITS,
|
||||||
|
CompressedTensorsW4A16Sparse24,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
|
||||||
WNA16_SUPPORTED_BITS,
|
WNA16_SUPPORTED_BITS,
|
||||||
CompressedTensorsWNA16,
|
CompressedTensorsWNA16,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
||||||
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -341,7 +343,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|||||||
# For chunk prefill req, we do not need to allocate mamba cache,
|
# For chunk prefill req, we do not need to allocate mamba cache,
|
||||||
# We could use allocated mamba cache instead.
|
# We could use allocated mamba cache instead.
|
||||||
def alloc(
|
def alloc(
|
||||||
self, need_size: int, reqs: Optional[List["Req"]] = None
|
self, need_size: int, reqs: Optional[List[Req]] = None
|
||||||
) -> Optional[List[int]]:
|
) -> Optional[List[int]]:
|
||||||
select_index = super().alloc(need_size)
|
select_index = super().alloc(need_size)
|
||||||
if select_index == None:
|
if select_index == None:
|
||||||
|
|||||||
@@ -110,6 +110,9 @@ def convert_bin_to_safetensor_file(
|
|||||||
|
|
||||||
dirname = os.path.dirname(sf_filename)
|
dirname = os.path.dirname(sf_filename)
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||||
|
|
||||||
# check file size
|
# check file size
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import tqdm
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
@@ -3499,7 +3500,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
# temporarily only support DeepSeek V3/R1
|
# temporarily only support DeepSeek V3/R1
|
||||||
weight_block_size = [128, 128]
|
weight_block_size = [128, 128]
|
||||||
|
|
||||||
for layer_id in trange(
|
for layer_id in tqdm.trange(
|
||||||
self.config.num_hidden_layers + int(is_nextn),
|
self.config.num_hidden_layers + int(is_nextn),
|
||||||
desc="quant attn to fp8 ue8m0",
|
desc="quant attn to fp8 ue8m0",
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi
|
|||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.attention import vision_utils
|
from sglang.srt.layers.attention import vision_utils
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Inference-only OPT model compatible with HuggingFace weights."""
|
"""Inference-only OPT model compatible with HuggingFace weights."""
|
||||||
|
import logging
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@@ -46,6 +47,9 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
kv_cache_scales_loader,
|
kv_cache_scales_loader,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import add_prefix, make_layers
|
from sglang.srt.utils import add_prefix, make_layers
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_activation(name="relu"):
|
def get_activation(name="relu"):
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
@@ -55,6 +56,7 @@ from json import JSONDecodeError
|
|||||||
from multiprocessing.reduction import ForkingPickler
|
from multiprocessing.reduction import ForkingPickler
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
@@ -62,6 +64,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@@ -91,6 +94,9 @@ from typing_extensions import Literal
|
|||||||
from sglang.srt.environ import envs
|
from sglang.srt.environ import envs
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizeMethodBase
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
show_time_cost = False
|
show_time_cost = False
|
||||||
@@ -1076,7 +1082,7 @@ def monkey_patch_vllm_gguf_config():
|
|||||||
|
|
||||||
def get_quant_method_with_embedding_replaced(
|
def get_quant_method_with_embedding_replaced(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return GGUFLinearMethod(self)
|
return GGUFLinearMethod(self)
|
||||||
elif isinstance(layer, VocabParallelEmbedding):
|
elif isinstance(layer, VocabParallelEmbedding):
|
||||||
@@ -1956,7 +1962,9 @@ def direct_register_custom_op(
|
|||||||
if fake_impl is not None:
|
if fake_impl is not None:
|
||||||
my_lib._register_fake(op_name, fake_impl)
|
my_lib._register_fake(op_name, fake_impl)
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
if "Tried to register an operator" in str(e) and "multiple times" in str(e):
|
if "Tried to register an operator" in str(error) and "multiple times" in str(
|
||||||
|
error
|
||||||
|
):
|
||||||
# Silently ignore duplicate registration errors
|
# Silently ignore duplicate registration errors
|
||||||
# This can happen in multi-engine scenarios
|
# This can happen in multi-engine scenarios
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import ast
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user