diff options
| author | Ricardo Wurmus <rekado@elephly.net> | 2025-11-07 17:45:02 +0100 |
|---|---|---|
| committer | Ricardo Wurmus <rekado@elephly.net> | 2025-11-07 21:52:07 +0100 |
| commit | 02f59daf078af5b54c020a04a4db9b02253e2f64 (patch) | |
| tree | c028955990fd519bdb3709bc5a9a82f10860945c /gnu/packages/patches/python-pytorch-for-r-torch-fix-codegen.patch | |
| parent | 0f2df2dad59e2f5e6da6144b009184a7f26e33b0 (diff) | |
gnu: python-pytorch-for-r-torch: Update to 2.7.1.
* gnu/packages/patches/python-pytorch-for-r-torch-fix-codegen.patch,
gnu/packages/patches/python-pytorch-for-r-torch-system-libraries.patch: Update.
* gnu/packages/patches/python-pytorch-for-r-torch-without-kineto.patch: New file.
* gnu/local.mk (dist_patch_DATA): Record it.
* gnu/packages/machine-learning.scm
(python-pytorch-for-r-torch): Update to 2.7.1.
[source]: Use new patch.
[arguments]: Remove phase 'fix-aten-vec; copy and adjust 'use-system-libraries
phase from python-pytorch.
[inputs]: Inherit all from python-pytorch; replace gloo with gloo-for-r-torch.
[native-inputs]: Inherit all from python-pytorch.
[propagated-inputs]: Inherit all from python-pytorch.
Change-Id: Ib2cf511fc34f609bbc7e92971720b00c4523419f
Diffstat (limited to 'gnu/packages/patches/python-pytorch-for-r-torch-fix-codegen.patch')
| -rw-r--r-- | gnu/packages/patches/python-pytorch-for-r-torch-fix-codegen.patch | 59 |
1 files changed, 35 insertions, 24 deletions
diff --git a/gnu/packages/patches/python-pytorch-for-r-torch-fix-codegen.patch b/gnu/packages/patches/python-pytorch-for-r-torch-fix-codegen.patch index 8515e5ab13a..3862339b141 100644 --- a/gnu/packages/patches/python-pytorch-for-r-torch-fix-codegen.patch +++ b/gnu/packages/patches/python-pytorch-for-r-torch-fix-codegen.patch @@ -6,7 +6,7 @@ is later corrected. codegen_external.py is patched to avoid duplicate functions and add the static keyword as in the existing generated file. diff --git a/tools/gen_flatbuffers.sh b/tools/gen_flatbuffers.sh -index cc0263dbbf..ac34e84b82 100644 +index cc0263dbb..ac34e84b8 100644 --- a/tools/gen_flatbuffers.sh +++ b/tools/gen_flatbuffers.sh @@ -1,13 +1,13 @@ @@ -32,10 +32,10 @@ index cc0263dbbf..ac34e84b82 100644 -c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs" echo '// @generated' >> "$ROOT/torch/csrc/jit/serialization/mobile_bytecode_generated.h" diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py -index 120520b139..0c8587f02d 100644 +index 5dcf1b284..0e20b0c10 100644 --- a/torch/csrc/jit/tensorexpr/codegen_external.py +++ b/torch/csrc/jit/tensorexpr/codegen_external.py -@@ -16,9 +16,14 @@ def gen_external(native_functions_path, tags_path, external_path): +@@ -21,9 +21,14 @@ def gen_external(native_functions_path, tags_path, external_path): native_functions = parse_native_yaml(native_functions_path, tags_path) func_decls = [] func_registrations = [] @@ -51,7 +51,7 @@ index 120520b139..0c8587f02d 100644 args = schema.arguments # Only supports extern calls for functions with out variants if not schema.is_out_fn(): -@@ -48,7 +53,7 @@ def gen_external(native_functions_path, tags_path, external_path): +@@ -63,7 +68,7 @@ def gen_external(native_functions_path, tags_path, external_path): # print(tensor_decls, name, arg_names) func_decl = f"""\ @@ -61,7 +61,7 @@ index 120520b139..0c8587f02d 100644 void** buf_data, int64_t* buf_ranks, diff --git a/torchgen/decompositions/gen_jit_decompositions.py b/torchgen/decompositions/gen_jit_decompositions.py -index 7cfbb803f9..2e69bb1868 100644 +index b42948045..e1cfc73a5 100644 --- a/torchgen/decompositions/gen_jit_decompositions.py +++ b/torchgen/decompositions/gen_jit_decompositions.py @@ -1,8 +1,12 @@ @@ -76,9 +76,9 @@ index 7cfbb803f9..2e69bb1868 100644 +else: + decomposition_table = {} - # from torchgen.code_template import CodeTemplate -@@ -85,7 +89,7 @@ def write_decomposition_util_file(path: str) -> None: + # from torchgen.code_template import CodeTemplate +@@ -86,7 +90,7 @@ def write_decomposition_util_file(path: str) -> None: def main() -> None: @@ -88,40 +88,41 @@ index 7cfbb803f9..2e69bb1868 100644 write_decomposition_util_file(str(upgrader_path)) diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py -index e5287cffc5..57f3c38096 100644 +index 845034cb7..a1c5767c2 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py -@@ -2,10 +2,12 @@ - import os +@@ -6,10 +6,13 @@ import os from enum import Enum + from operator import itemgetter from pathlib import Path +import sys - from typing import Any, Dict, List + from typing import Any -import torch -from torch.jit.generate_bytecode import generate_upgraders_bytecode +if len(sys.argv) < 2 or sys.argv[1] != "dummy": + import torch + from torch.jit.generate_bytecode import generate_upgraders_bytecode - ++ from torchgen.code_template import CodeTemplate from torchgen.operator_versions.gen_mobile_upgraders_constant import ( -@@ -262,7 +264,10 @@ def construct_register_size(register_size_from_yaml: int) -> str: + MOBILE_UPGRADERS_HEADER_DESCRIPTION, +@@ -263,7 +266,10 @@ def construct_register_size(register_size_from_yaml: int) -> str: def construct_version_maps( - upgrader_bytecode_function_to_index_map: Dict[str, Any] + upgrader_bytecode_function_to_index_map: dict[str, Any], ) -> str: - version_map = torch._C._get_operator_version_map() + if len(sys.argv) < 2 or sys.argv[1] != "dummy": + version_map = torch._C._get_operator_version_map() + else: + version_map = {} - sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0]) # type: ignore[no-any-return] - sorted_version_map = {name: lst for name, lst in sorted_version_map_} + sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] + sorted_version_map = dict(sorted_version_map_) -@@ -379,7 +384,10 @@ def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +@@ -375,7 +381,10 @@ def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]: - def main() -> None: + def main() -> None: - upgrader_list = generate_upgraders_bytecode() + if len(sys.argv) < 2 or sys.argv[1] != "dummy": + upgrader_list = generate_upgraders_bytecode() @@ -131,16 +132,24 @@ index e5287cffc5..57f3c38096 100644 for up in sorted_upgrader_list: print("after sort upgrader : ", next(iter(up))) diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py -index c6336a6951..34e394d818 100644 +index 56a3d8bf0..ffd0785fd 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py -@@ -18,16 +18,20 @@ you are in the root directory of the Pytorch git repo""" +@@ -1,6 +1,7 @@ + #!/usr/bin/env python3 + import os + import sys ++import importlib + from importlib.util import module_from_spec, spec_from_file_location + from itertools import chain + from pathlib import Path +@@ -18,17 +19,21 @@ you are in the root directory of the Pytorch git repo""" if not file_path.exists(): - raise Exception(err_msg) + raise Exception(err_msg) # noqa: TRY002 --spec = importlib.util.spec_from_file_location(module_name, file_path) +-spec = spec_from_file_location(module_name, file_path) -assert spec is not None --module = importlib.util.module_from_spec(spec) +-module = module_from_spec(spec) -sys.modules[module_name] = module -assert spec.loader is not None -assert module is not None @@ -148,6 +157,7 @@ index c6336a6951..34e394d818 100644 - -bounded_compute_graph_mapping = module.bounded_compute_graph_mapping -shape_compute_graph_mapping = module.shape_compute_graph_mapping +- +if len(sys.argv) < 2 or sys.argv[1] != "dummy": + spec = importlib.util.spec_from_file_location(module_name, file_path) + assert spec is not None @@ -159,9 +169,10 @@ index c6336a6951..34e394d818 100644 + + bounded_compute_graph_mapping = module.bounded_compute_graph_mapping + shape_compute_graph_mapping = module.shape_compute_graph_mapping ++ +else: + bounded_compute_graph_mapping = {} + shape_compute_graph_mapping = {} - SHAPE_HEADER = r""" + /** |
