commit 19b87011540058bbb7cdb4d89cc89c2abde50154
parent ded377ddc835df7b6ab850607a3c913bd7f2ec37
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 13 Oct 2024 15:55:19 -0700
Fix dependencies (#496)
Diffstat:
5 files changed, 28 insertions(+), 8 deletions(-)
diff --git a/environment_cpu.yml b/environment_cpu.yml
@@ -13,7 +13,7 @@ dependencies:
- h5py
- jupyter
- matplotlib
- - numpy
+ - numpy<2 # Until PyTorch 2.3
- onnx
- onnxruntime!=1.16.0
- pip
diff --git a/environment_gpu.yml b/environment_gpu.yml
@@ -14,7 +14,7 @@ dependencies:
- h5py
- jupyter
- matplotlib
- - numpy
+ - numpy<2 # Until PyTorch 2.3
- onnx
- onnxruntime!=1.16.0
- pip
diff --git a/nam/_version.py b/nam/_version.py
@@ -1 +1 @@
-__version__ = "0.10.0"
+__version__ = "0.10.1"
diff --git a/requirements.txt b/requirements.txt
@@ -6,7 +6,7 @@ auraloss==0.3.0 # 0.4.0 changes API for MRSTFT loss
black
flake8
matplotlib
-numpy
+numpy<2 # Until PyTorch 2.3
onnx
onnxruntime!=1.16.0 # 1.16.0 has a bug to avoid!
pip
diff --git a/setup.py b/setup.py
@@ -7,6 +7,7 @@ from setuptools import setup, find_packages
def get_additional_requirements():
+ additional_requirements = []
# Issue 294
try:
import transformers
@@ -14,9 +15,29 @@ def get_additional_requirements():
# This may not be unnecessarily straict a requirement, but I'd rather
# fix this promptly than leave a chance that it wouldn't be fixed
# properly.
- return ["transformers>=4"]
+ additional_requirements.append("transformers>=4")
except ModuleNotFoundError:
- return []
+ pass
+
+ # Issue 494
+ def get_numpy_requirement() -> str:
+ need_numpy_1 = True # Until proven otherwise
+ try:
+ import torch
+
+ version_split = torch.__version__.split(".")
+ major = int(version_split[0])
+ if major >= 2:
+ minor = int(version_split[1])
+ if minor >= 3: # Hooray, PyTorch 2.3+!
+ need_numpy_1 = False
+ except ModuleNotFoundError:
+ # Until I see PyTorch 2.3 come out:
+ pass
+ return "numpy<2"if need_numpy_1 else "numpy"
+ additional_requirements.append(get_numpy_requirement())
+
+ return additional_requirements
main_ns = {}
@@ -27,9 +48,8 @@ with open(ver_path) as ver_file:
requirements = [
"auraloss==0.3.0",
"matplotlib",
- "numpy",
"onnx",
- "onnxruntime",
+ "onnxruntime!=1.16.0", # Has a bug to avoid
"pydantic>=2.0.0",
"pytorch_lightning",
"scipy",