summaryrefslogtreecommitdiff
path: root/gnu/packages/machine-learning.scm
diff options
context:
space:
mode:
authorNicolas Graves <ngraves@ngraves.fr>2024-09-09 12:08:03 +0200
committerRicardo Wurmus <rekado@elephly.net>2025-04-07 16:44:28 +0200
commit3d4fc384f6da9d0b2ae9155f9e8232ff2d044ad4 (patch)
tree4ef8f3f882e9d511b096c53a506cec5fef2ce168 /gnu/packages/machine-learning.scm
parentdadb51bb82c9b1b3a2a09ee34dbbdc5f5f0ddf08 (diff)
gnu: Add python-safetensors.
* gnu/packages/machine-learning.scm (python-safetensors): New variable. Signed-off-by: Ricardo Wurmus <rekado@elephly.net> Change-Id: I90a1684d06756ce87ca0862d745a75be5919f0b2
Diffstat (limited to 'gnu/packages/machine-learning.scm')
-rw-r--r--gnu/packages/machine-learning.scm100
1 files changed, 100 insertions, 0 deletions
diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index 9d3f4060170..9d2b0b32178 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -1331,6 +1331,106 @@ storing tensors safely, named safetensors. They aim to be safer than their
@code{PyTorch} counterparts.")
(license license:asl2.0)))
+(define-public python-safetensors
+ (package
+ (name "python-safetensors")
+ (version "0.4.3")
+ (source
+ (origin
+ (method url-fetch)
+ (uri (pypi-uri "safetensors" version))
+ (sha256
+ (base32 "1hhiwy67jarm70l0k26fs1cjhzkgzrh79q14bklj2yp0qi8gr19g"))
+ (modules '((guix build utils)
+ (ice-9 ftw)))
+ (snippet
+ #~(begin ;Only keep bindings.
+ (for-each
+ (lambda (file)
+ (unless (member file '("." ".." "bindings" "PKG-INFO"))
+ (delete-file-recursively file)))
+ (scandir "."))
+ (for-each
+ (lambda (file)
+ (unless (member file '("." ".."))
+ (rename-file (string-append "bindings/python/" file)
+ file)))
+ (scandir "bindings/python"))))))
+ (build-system cargo-build-system)
+ (arguments
+ (list
+ #:modules '((guix build cargo-build-system)
+ (guix build utils)
+ (ice-9 regex)
+ (ice-9 textual-ports)
+ (srfi srfi-26))
+ #:phases
+ #~(modify-phases %standard-phases
+ (add-after 'unpack-rust-crates 'inject-safetensors
+ (lambda _
+ (substitute* "Cargo.toml"
+ (("\\[dependencies\\]")
+ (format #f "[dependencies]~%safetensors = ~s"
+ #$(package-version rust-safetensors))))
+ (call-with-input-file "Cargo.toml"
+ (lambda (port)
+ (let* ((content (get-string-all port))
+ (top-match (string-match
+ "\\[dependencies.safetensors"
+ content)))
+ (call-with-output-file "Cargo.toml"
+ (cut display (match:prefix top-match) <>)))))))
+ (add-before 'check 'install-rust-library
+ (lambda _
+ (copy-file "target/release/libsafetensors_rust.so"
+ "py_src/safetensors/_safetensors_rust.so")))
+ (replace 'check
+ (lambda* (#:key tests? #:allow-other-keys)
+ (when tests?
+ (setenv "PYTHONPATH" (string-append (getcwd) "/py_src"))
+ (invoke "python3"
+ "-m" "pytest"
+ "-n" "auto"
+ "--dist=loadfile"
+ "-s" "-v" "./tests/"
+ ;; Missing jax and tensorflow dependency
+ "--ignore=./tests/test_flax_comparison.py"
+ "--ignore=./tests/test_tf_comparison.py"))))
+ (add-after 'install 'install-python
+ (lambda _
+ (let* ((pversion #$(version-major+minor
+ (package-version python)))
+ (lib (string-append #$output "/lib/python" pversion
+ "/site-packages/"))
+ (info (string-append lib "safetensors-"
+ #$(package-version this-package)
+ ".dist-info")))
+ (mkdir-p info)
+ (copy-file "PKG-INFO" (string-append info "/METADATA"))
+ (copy-recursively
+ "py_src/safetensors"
+ (string-append lib "safetensors"))))))
+ #:cargo-inputs
+ `(("rust-pyo3" ,rust-pyo3-0.21)
+ ("rust-memmap2" ,rust-memmap2-0.9)
+ ("rust-safetensors" ,rust-safetensors)
+ ("rust-serde-json" ,rust-serde-json-1))))
+ (inputs
+ (list rust-safetensors))
+ (native-inputs
+ (list python-h5py
+ python-minimal
+ python-numpy
+ python-pytest
+ python-pytest-xdist
+ python-pytorch))
+ (home-page "https://huggingface.co/docs/safetensors")
+ (synopsis "Simple and safe way to store and distribute tensors")
+ (description "This package provides a fast (zero-copy) and safe
+(dedicated) format for storing tensors safely. This package builds upon
+@code{rust-safetensors} and provides Python bindings.")
+ (license license:asl2.0)))
+
(define-public python-sentencepiece
(package
(name "python-sentencepiece")