Skip to content

Commit aa1b261

Browse files
fyrestone刘宝
and
刘宝
authored
Fix s3 client kwargs (#3316)
* Fix s3 client kwargs * Add type annotation * Fix UT * Fix UT Co-authored-by: 刘宝 <[email protected]>
1 parent 996ce47 commit aa1b261

File tree

4 files changed

+120
-3
lines changed

4 files changed

+120
-3
lines changed

mars/lib/filesystem/base.py

+6
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,9 @@ def parse_from_path(uri: str):
255255
if parsed_uri.password:
256256
options["password"] = parsed_uri.password
257257
return options
258+
259+
@classmethod
260+
def get_storage_options(cls, storage_options: Dict, uri: str) -> Dict:
261+
options = cls.parse_from_path(uri)
262+
storage_options.update(options)
263+
return storage_options

mars/lib/filesystem/core.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ def get_fs(path: path_type, storage_options: Dict = None) -> FileSystem:
5454
# local file systems are singletons.
5555
return file_system_type.get_instance()
5656
else:
57-
options = file_system_type.parse_from_path(path)
58-
storage_options.update(options)
57+
storage_options = file_system_type.get_storage_options(
58+
storage_options, path
59+
)
5960
return file_system_type(**storage_options)
6061
elif scheme in _scheme_to_dependencies: # pragma: no cover
6162
dependencies = ", ".join(_scheme_to_dependencies[scheme])

mars/lib/filesystem/s3.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import Dict
1516

1617
"""
1718
An example to read csv from s3
@@ -26,8 +27,9 @@
2627
>>> "endpoint_url": "http://192.168.1.12:9000",
2728
>>> "aws_access_key_id": "<s3 access id>",
2829
>>> "aws_secret_access_key": "<s3 access key>",
30+
>>> "aws_session_token": "<s3 session token>",
2931
>>> }})
30-
>>> # Export environment vars AWS_ENDPOINT_URL / AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY.
32+
>>> # Export environment vars AWS_ENDPOINT_URL / AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY / AWS_SESSION_TOKEN.
3133
>>> mdf = md.read_csv("s3://bucket/example.csv", index_col=0)
3234
>>> r = mdf.head(1000).execute()
3335
>>> print(r)
@@ -62,6 +64,16 @@ def parse_from_path(uri: str):
6264
client_kwargs = {k: v for k, v in client_kwargs.items() if v is not None}
6365
return {"client_kwargs": client_kwargs}
6466

67+
@classmethod
68+
def get_storage_options(cls, storage_options: Dict, uri: str) -> Dict:
69+
options = cls.parse_from_path(uri)
70+
for k, v in storage_options.items():
71+
if k == "client_kwargs":
72+
options["client_kwargs"].update(v)
73+
else:
74+
options[k] = v
75+
return options
76+
6577
register_filesystem("s3", S3FileSystem)
6678
else:
6779
S3FileSystem = None

mars/lib/filesystem/tests/test_s3.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import pytest
17+
18+
from ....dataframe import read_csv
19+
from ..core import register_filesystem
20+
from ..s3 import S3FileSystem
21+
22+
23+
class KwArgsException(Exception):
24+
def __init__(self, kwargs):
25+
self.kwargs = kwargs
26+
27+
28+
if S3FileSystem is not None:
29+
30+
class TestS3FileSystem(S3FileSystem):
31+
def __init__(self, **kwargs):
32+
super().__init__(**kwargs)
33+
raise KwArgsException(kwargs)
34+
35+
else:
36+
TestS3FileSystem = None
37+
38+
39+
@pytest.mark.skipif(S3FileSystem is None, reason="S3 is not supported")
40+
def test_client_kwargs():
41+
register_filesystem("s3", TestS3FileSystem)
42+
43+
test_kwargs = {
44+
"endpoint_url": "http://192.168.1.12:9000",
45+
"aws_access_key_id": "test_id",
46+
"aws_secret_access_key": "test_key",
47+
"aws_session_token": "test_session_token",
48+
}
49+
50+
def _assert_true():
51+
# Pass endpoint_url / aws_access_key_id / aws_secret_access_key / aws_session_token to read_csv.
52+
with pytest.raises(KwArgsException) as e:
53+
read_csv(
54+
"s3://bucket/example.csv",
55+
index_col=0,
56+
storage_options={"client_kwargs": test_kwargs},
57+
)
58+
assert e.value.kwargs == {
59+
"client_kwargs": {
60+
"endpoint_url": "http://192.168.1.12:9000",
61+
"aws_access_key_id": "test_id",
62+
"aws_secret_access_key": "test_key",
63+
"aws_session_token": "test_session_token",
64+
}
65+
}
66+
67+
_assert_true()
68+
69+
test_env = {
70+
"AWS_ENDPOINT_URL": "a",
71+
"AWS_ACCESS_KEY_ID": "b",
72+
"AWS_SECRET_ACCESS_KEY": "c",
73+
"AWS_SESSION_TOKEN": "d",
74+
}
75+
for k, v in test_env.items():
76+
os.environ[k] = v
77+
78+
try:
79+
_assert_true()
80+
81+
for k, v in test_kwargs.items():
82+
with pytest.raises(KwArgsException) as e:
83+
read_csv(
84+
"s3://bucket/example.csv",
85+
index_col=0,
86+
storage_options={"client_kwargs": {k: v}},
87+
)
88+
expect = {
89+
"endpoint_url": "a",
90+
"aws_access_key_id": "b",
91+
"aws_secret_access_key": "c",
92+
"aws_session_token": "d",
93+
}
94+
expect[k] = v
95+
assert e.value.kwargs == {"client_kwargs": expect}
96+
finally:
97+
for k, v in test_env.items():
98+
os.environ.pop(k, None)

0 commit comments

Comments
 (0)