Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 762313b

Browse files
committedDec 3, 2024·
Switch from entrypoints to importlib.metadata
Since the minimum Python version supported now directly supports querying entry points using the standard library, write a wrapper around to support both upstream APIs, and make use of it, rather than the external entrypoints package.
1 parent 5384731 commit 762313b

File tree

11 files changed

+49
-12
lines changed

11 files changed

+49
-12
lines changed
 

‎MANIFEST.in

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ include *.toml
2020

2121
include .bumpversion.cfg
2222

23+
include papermill/tests/fixtures/foo-0.0.1.dist-info/METADATA
24+
include papermill/tests/fixtures/foo-0.0.1.dist-info/entry_points.txt
25+
2326
# Documentation
2427
prune docs
2528

‎papermill/engines.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
from functools import wraps
55

66
import dateutil
7-
import entrypoints
87

98
from .clientwrap import PapermillNotebookClient
109
from .exceptions import PapermillException
1110
from .iorw import write_ipynb
1211
from .log import logger
13-
from .utils import merge_kwargs, nb_kernel_name, nb_language, remove_args
12+
from .utils import get_entrypoints_group, merge_kwargs, nb_kernel_name, nb_language, remove_args
1413

1514

1615
class PapermillEngines:
@@ -33,7 +32,7 @@ def register_entry_points(self):
3332
3433
Load handlers provided by other packages
3534
"""
36-
for entrypoint in entrypoints.get_group_all("papermill.engine"):
35+
for entrypoint in get_entrypoints_group("papermill.engine"):
3736
self.register(entrypoint.name, entrypoint.load())
3837

3938
def get_engine(self, name=None):

‎papermill/iorw.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import warnings
66
from contextlib import contextmanager
77

8-
import entrypoints
98
import nbformat
109
import requests
1110
import yaml
@@ -18,7 +17,7 @@
1817
missing_environment_variable_generator,
1918
)
2019
from .log import logger
21-
from .utils import chdir
20+
from .utils import chdir, get_entrypoints_group
2221
from .version import version as __version__
2322

2423
try:
@@ -116,7 +115,7 @@ def register(self, scheme, handler):
116115

117116
def register_entry_points(self):
118117
# Load handlers provided by other packages
119-
for entrypoint in entrypoints.get_group_all("papermill.io"):
118+
for entrypoint in get_entrypoints_group("papermill.io"):
120119
self.register(entrypoint.name, entrypoint.load())
121120

122121
def get_handler(self, path, extensions=None):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Metadata-Version: 2.3
2+
Name: foo
3+
Version: 0.0.1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[papermill.tests.fake]
2+
foo = bar

‎papermill/tests/test_engines.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,8 @@ def test_registering_entry_points(self):
492492
fake_entrypoint = Mock(load=Mock())
493493
fake_entrypoint.name = "fake-engine"
494494

495-
with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all:
495+
entry_points = {"papermill.engine": [fake_entrypoint]}
496+
with patch("papermill.utils.entry_points", return_value=entry_points) as mock_entry_points:
496497
self.papermill_engines.register_entry_points()
497-
mock_get_group_all.assert_called_once_with("papermill.engine")
498+
mock_entry_points.assert_called_once()
498499
self.assertEqual(self.papermill_engines.get_engine("fake-engine"), fake_entrypoint.load.return_value)

‎papermill/tests/test_iorw.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ def test_entrypoint_register(self):
104104
fake_entrypoint = Mock(load=Mock())
105105
fake_entrypoint.name = "fake-from-entry-point://"
106106

107-
with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all:
107+
entry_points = {"papermill.io": [fake_entrypoint]}
108+
with patch("papermill.utils.entry_points", return_value=entry_points) as mock_entry_points:
108109
self.papermill_io.register_entry_points()
109-
mock_get_group_all.assert_called_once_with("papermill.io")
110+
mock_entry_points.assert_called_once()
110111
fake_ = self.papermill_io.get_handler("fake-from-entry-point://")
111112
assert fake_ == fake_entrypoint.load.return_value
112113

‎papermill/tests/test_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import warnings
23
from pathlib import Path
34
from tempfile import TemporaryDirectory
@@ -10,6 +11,7 @@
1011
from ..utils import (
1112
any_tagged_cell,
1213
chdir,
14+
get_entrypoints_group,
1315
merge_kwargs,
1416
remove_args,
1517
retry,
@@ -58,3 +60,14 @@ def test_chdir():
5860
assert Path.cwd() == Path(temp_dir)
5961

6062
assert Path.cwd() == old_cwd
63+
64+
65+
def test_get_entrypoints_group():
66+
# We don't need to mock anything here, there is just enough metadata
67+
# present to give us one entry point.
68+
sys.path.insert(0, Path(__file__).parent / "fixtures")
69+
# We need to cast to a list here, 3.8/3.9 and 3.10+ return different
70+
# types.
71+
eps = list(get_entrypoints_group("papermill.tests.fake"))
72+
sys.path.pop()
73+
assert eps[0].name == "foo"

‎papermill/utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from contextlib import contextmanager
55
from functools import wraps
6+
from importlib.metadata import entry_points
67

78
from .exceptions import PapermillParameterOverwriteWarning
89

@@ -190,3 +191,20 @@ def chdir(path):
190191
yield
191192
finally:
192193
os.chdir(old_dir)
194+
195+
196+
def get_entrypoints_group(group):
197+
"""Return a given group of entrypoints.
198+
199+
Since the importlib.metadata entry points API is very simple in 3.8 and
200+
more complete in 3.10+, we need to support both. This function can be
201+
removed when 3.10 is the minimum supported version, and replaced
202+
with ``entry_points(group=group)``.
203+
"""
204+
eps = entry_points()
205+
if hasattr(eps, "select"):
206+
# New and shiny Python 3.10+ API
207+
return eps.select(group=group)
208+
else:
209+
# Python 3.8 and 3.9
210+
return eps.get(group, [])

‎requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ nbformat >= 5.2.0
44
nbclient >= 0.2.0
55
tqdm >= 4.32.2
66
requests
7-
entrypoints
87
tenacity >= 5.0.2
98
aiohttp >=3.9.0; python_version=="3.12"
109
ansicolors

‎requirements/docs.txt

-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@ myst-parser>=2.0.0
77
moto>=4.2.8
88
sphinx-copybutton>=0.5.2
99
nbformat
10-
entrypoints

0 commit comments

Comments
 (0)
Please sign in to comment.