123 lines
4.5 KiB
Python
123 lines
4.5 KiB
Python
"""A collection of functions which are triggered automatically by finder when
|
|
PyTorch package is included.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from cx_Freeze._compat import IS_MINGW, IS_WINDOWS
|
|
|
|
if TYPE_CHECKING:
|
|
from cx_Freeze.finder import ModuleFinder
|
|
from cx_Freeze.module import Module
|
|
|
|
# 1) Test in Windows from pypi
|
|
# pip install torch
|
|
#
|
|
# 2) Test in Windows from pytorch
|
|
# # https://pytorch.org/get-started/locally/
|
|
#
|
|
# 3) Test in Linux from pypi includes nvidia packages
|
|
# pip install torch
|
|
|
|
|
|
def load_torch(finder: ModuleFinder, module: Module) -> None:
|
|
"""Hook for PyTorch. Tested in Windows and Linux."""
|
|
module_path = module.file.parent
|
|
site_packages_path = module_path.parent
|
|
|
|
# Activate the optimized mode by default
|
|
if module.name in finder.zip_exclude_packages:
|
|
print(f"WARNING: {module.name} hook optimizations disabled.")
|
|
module.in_file_system = 1
|
|
elif module.name in finder.zip_include_packages:
|
|
print(f"WARNING: {module.name} hook optimizations enabled.")
|
|
module.in_file_system = 2
|
|
else:
|
|
module.in_file_system = 2
|
|
|
|
# patch the code to ignore CUDA_PATH_Vxx_x installation directory
|
|
code_string = module.file.read_text(encoding="utf_8")
|
|
code_string = code_string.replace("CUDA_PATH", "NO_CUDA_PATH")
|
|
module.code = compile(
|
|
code_string,
|
|
module.file.as_posix(),
|
|
"exec",
|
|
dont_inherit=True,
|
|
optimize=finder.optimize,
|
|
)
|
|
|
|
# include the cuda libraries as fixed libraries
|
|
source_lib = module.file.parent.parent / "nvidia"
|
|
if source_lib.exists():
|
|
target_lib = f"lib/{source_lib.name}"
|
|
for source in source_lib.glob("*/lib/*"):
|
|
target = target_lib / source.relative_to(source_lib)
|
|
finder.lib_files[source] = target.as_posix()
|
|
|
|
# include the shared libraries in 'lib' as fixed libraries
|
|
source_lib = module.file.parent / "lib"
|
|
if source_lib.exists():
|
|
target_lib = f"lib/{module.name}/lib"
|
|
extension = "*.dll" if (IS_WINDOWS or IS_MINGW) else "*.so*"
|
|
for source in source_lib.glob(extension):
|
|
finder.lib_files[source] = f"{target_lib}/{source.name}"
|
|
finder.include_files(source, f"{target_lib}/{source.name}")
|
|
|
|
# include the binaries (torch 2.1+)
|
|
source_bin = module.file.parent / "bin"
|
|
if source_bin.exists():
|
|
finder.include_files(source_bin, f"lib/{module.name}/bin")
|
|
# hidden modules
|
|
finder.include_module("torch._C")
|
|
finder.include_module("torch._VF")
|
|
finder.include_package("torch.distributions")
|
|
finder.include_package("torch.testing")
|
|
# exclude C files
|
|
finder.exclude_module("torch.include")
|
|
finder.exclude_module("torch.share")
|
|
finder.exclude_module("torchgen.packaged.ATen.templates")
|
|
# torch 2.2
|
|
finder.include_module("torch.return_types")
|
|
for config_file in (
|
|
"_dynamo/config.py",
|
|
"_functorch/config.py",
|
|
"_inductor/config.py",
|
|
"_lazy/config.py",
|
|
"distributed/_spmd/config.py",
|
|
"fx/config.py",
|
|
"fx/experimental/_config.py",
|
|
):
|
|
config = module.file.parent / config_file
|
|
if config.exists():
|
|
finder.include_files(config, f"lib/{module.name}/{config_file}")
|
|
# include source files for torch.jit._overload
|
|
source_path = site_packages_path / "torch/ao"
|
|
for source in source_path.rglob("*.py"): # type: Path
|
|
target = "lib" / source.relative_to(site_packages_path)
|
|
finder.include_files(source, target)
|
|
source_path = site_packages_path / "torch/nn"
|
|
for source in source_path.rglob("*.py"): # type: Path
|
|
target = "lib" / source.relative_to(site_packages_path)
|
|
finder.include_files(source, target)
|
|
source = site_packages_path / "torch/functional.py"
|
|
target = "lib" / source.relative_to(site_packages_path)
|
|
finder.include_files(source, target)
|
|
|
|
|
|
def load_torch__dynamo_skipfiles(finder: ModuleFinder, module: Module) -> None:
|
|
"""Patch to work with Python 3.11+."""
|
|
code_string = module.file.read_text(encoding="utf_8")
|
|
code_string = code_string.replace(
|
|
"return _strip_init_py(m.__file__)",
|
|
'return _strip_init_py(getattr(m, "__file__", ""))',
|
|
)
|
|
module.code = compile(
|
|
code_string,
|
|
module.file.as_posix(),
|
|
"exec",
|
|
dont_inherit=True,
|
|
optimize=finder.optimize,
|
|
)
|