Skip to content

Commit 01732df

Browse files
committed
feat: pull recursively to include subject references
Signed-off-by: Tomas Coufal <[email protected]>
1 parent ccc28b3 commit 01732df

File tree

4 files changed

+55
-2
lines changed

4 files changed

+55
-2
lines changed

oras/provider.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -862,11 +862,13 @@ def pull(self, *args, **kwargs) -> List[str]:
862862
refresh_headers = kwargs.get("refresh_headers")
863863
if refresh_headers is None:
864864
refresh_headers = True
865-
container = self.get_container(kwargs["target"])
865+
target: str = kwargs["target"]
866+
container = self.get_container(target)
866867
self.load_configs(container, configs=kwargs.get("config_path"))
867868
manifest = self.get_manifest(container, allowed_media_type, refresh_headers)
868869
outdir = kwargs.get("outdir") or oras.utils.get_tmpdir()
869870
overwrite = kwargs.get("overwrite", True)
871+
include_subject = kwargs.get("include_subject", False)
870872

871873
files = []
872874
for layer in manifest.get("layers", []):
@@ -900,6 +902,16 @@ def pull(self, *args, **kwargs) -> List[str]:
900902
self.download_blob(container, layer["digest"], outfile)
901903
logger.info(f"Successfully pulled {outfile}.")
902904
files.append(outfile)
905+
906+
if include_subject and manifest.get('subject', False):
907+
separator = "@" if "@" in target else ":"
908+
repo, _tag = target.rsplit(separator, 1)
909+
subject_digest = manifest['subject']['digest']
910+
new_kwargs = kwargs
911+
new_kwargs['target'] = f'{repo}@{subject_digest}'
912+
913+
files += self.pull(*args, **kwargs)
914+
903915
return files
904916

905917
@decorator.ensure_container

oras/tests/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def target(registry):
5353
return f"{registry}/dinosaur/artifact:v1"
5454

5555

56+
@pytest.fixture
57+
def derived_target(registry):
58+
return f"{registry}/dinosaur/artifact:v1-derived"
59+
60+
5661
@pytest.fixture
5762
def target_dir(registry):
5863
return f"{registry}/dinosaur/directory:v1"

oras/tests/derived-artifact.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
referred artifact is greeting extinct creatures

oras/tests/test_oras.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
import oras.client
11+
import oras.provider
1112

1213
here = os.path.abspath(os.path.dirname(__file__))
1314

@@ -67,6 +68,40 @@ def test_basic_push_pull(tmp_path, registry, credentials, target):
6768
assert res.status_code == 201
6869

6970

71+
72+
@pytest.mark.with_auth(False)
73+
def test_push_pull_attached_artifacts(tmp_path, registry, credentials, target, derived_target):
74+
"""
75+
Basic tests for oras (without authentication)
76+
"""
77+
client = oras.client.OrasClient(hostname=registry, insecure=True)
78+
79+
artifact = os.path.join(here, "artifact.txt")
80+
assert os.path.exists(artifact)
81+
82+
res = client.push(files=[artifact], target=target)
83+
assert res.status_code in [200, 201]
84+
85+
derived_artifact = os.path.join(here, "derived-artifact.txt")
86+
assert os.path.exists(derived_artifact)
87+
88+
manifest = client.remote.get_manifest(target)
89+
subject = oras.provider.Subject.from_manifest(manifest)
90+
res = client.push(files=[derived_artifact], target=derived_target, subject=subject)
91+
assert res.status_code in [200, 201]
92+
93+
# Test pulling elsewhere
94+
files = sorted(client.pull(target=derived_target, outdir=tmp_path, include_subject=True))
95+
assert len(files) == 2
96+
assert os.path.basename(files[0]) == "artifact.txt"
97+
assert os.path.basename(files[1]) == "derived-artifact.txt"
98+
assert str(tmp_path) in files[0]
99+
assert str(tmp_path) in files[1]
100+
assert os.path.exists(files[0])
101+
assert os.path.exists(files[1])
102+
103+
104+
70105
@pytest.mark.with_auth(False)
71106
def test_get_delete_tags(tmp_path, registry, credentials, target):
72107
"""
@@ -87,7 +122,7 @@ def test_get_delete_tags(tmp_path, registry, credentials, target):
87122
assert not client.delete_tags(target, "v1-boop-boop")
88123
assert "v1" in client.delete_tags(target, "v1")
89124
tags = client.get_tags(target)
90-
assert not tags
125+
assert "v1" not in tags
91126

92127

93128
def test_get_many_tags():

0 commit comments

Comments
 (0)