Skip to content

Commit bdf735d

Browse files
authored
Merge pull request #22 from dlr-eoc/provider
exposed onnxruntime execution provider
2 parents 91938f2 + 13cafef commit bdf735d

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

CHANGELOG.rst

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Changelog
22
=========
33

4+
[0.2.2] (2024-10-21)
5+
--------------------
6+
Added
7+
*******
8+
- expose onnxruntime execution provider
9+
410
[0.2.1] (2023-12-05)
511
--------------------
612
Added

ukis_csmask/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.2.1"
1+
__version__ = "0.2.2"

ukis_csmask/mask.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
invalid_buffer=4,
3030
intra_op_num_threads=0,
3131
inter_op_num_threads=0,
32+
providers=None,
3233
):
3334
"""
3435
:param img: Input satellite image of shape (rows, cols, bands). (ndarray).
@@ -42,6 +43,8 @@ def __init__(
4243
:param invalid_buffer: Number of pixels that should be buffered around invalid areas. (int).
4344
:param intra_op_num_threads: Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose. (int).
4445
:param inter_op_num_threads: Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose. (int).
46+
:param providers: onnxruntime session providers. Default is None to let onnxruntime choose. (list).
47+
>>> providers = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
4548
"""
4649
# consistency checks on input image
4750
if isinstance(img, np.ndarray) is False:
@@ -98,9 +101,8 @@ def __init__(
98101
so = onnxruntime.SessionOptions()
99102
so.intra_op_num_threads = intra_op_num_threads
100103
so.inter_op_num_threads = inter_op_num_threads
101-
self.sess = onnxruntime.InferenceSession(
102-
model_file, sess_options=so, providers=onnxruntime.get_available_providers()
103-
)
104+
providers = onnxruntime.get_available_providers() if providers is None else providers
105+
self.sess = onnxruntime.InferenceSession(model_file, sess_options=so, providers=providers)
104106

105107
self.img = img
106108
self.band_order = band_order

0 commit comments

Comments
 (0)