Fix torch compile for deepseek-v2 (#1442)
This commit is contained in:
@@ -41,6 +41,9 @@ if TYPE_CHECKING:
|
|||||||
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||||
for sub in model._modules.values():
|
for sub in model._modules.values():
|
||||||
if isinstance(sub, CustomOp):
|
if isinstance(sub, CustomOp):
|
||||||
|
# NOTE: FusedMoE torch native implementaiton is not efficient
|
||||||
|
if "FusedMoE" in sub.__class__.__name__:
|
||||||
|
continue
|
||||||
if reverse:
|
if reverse:
|
||||||
sub._forward_method = sub.forward_cuda
|
sub._forward_method = sub.forward_cuda
|
||||||
setattr(sub, "is_torch_compile", False)
|
setattr(sub, "is_torch_compile", False)
|
||||||
@@ -105,7 +108,15 @@ class CudaGraphRunner:
|
|||||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||||
else:
|
else:
|
||||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
|
self.compile_bs = (
|
||||||
|
[
|
||||||
|
bs
|
||||||
|
for bs in self.capture_bs
|
||||||
|
if bs <= self.model_runner.server_args.max_torch_compile_bs
|
||||||
|
]
|
||||||
|
if self.use_torch_compile
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
self.max_bs = max(self.capture_bs)
|
self.max_bs = max(self.capture_bs)
|
||||||
|
|||||||
@@ -653,6 +653,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ class ServerArgs:
|
|||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
|
max_torch_compile_bs: int = 32
|
||||||
torchao_config: str = ""
|
torchao_config: str = ""
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
enable_mla: bool = False
|
enable_mla: bool = False
|
||||||
@@ -523,6 +524,12 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Optimize the model with torch.compile. Experimental feature.",
|
help="Optimize the model with torch.compile. Experimental feature.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-torch-compile-bs",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.max_torch_compile_bs,
|
||||||
|
help="Set the maximum batch size when using torch compile.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torchao-config",
|
"--torchao-config",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user