Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: User script import now respects script_dir #1542

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions olive/common/import_lib.py
Copy link
Contributor

@jambayk jambayk Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tezheng commenting on this file so we can have a thread.

Your proposal for allowing user_script paths relative to the directory containing the config json makes a lot of sense. It would help a lot with the portability of olive projects.
However, apart from user_scripts, we also have many other resources like data dirs, script dirs, etc that have the same issue. Workflows can also be run on remote systems such as docker and azureml, both of which require resolvable paths for mounting. Due to this, we originally decided on paths that are absolute or relative to current dir so that every resource is resolvable and independent of one another.

After a discussion with the team, we agree that allowing paths to be relative to the working directory (directory containing config json) should also be supported. I have created a work item for this, and we can discuss offline on teams with you about the priority of this feature. This would require a more involved update of the entire codebase so that all resources have the same behavior (+ not have user script be a special case) and don't break compatibility with remote systems.

FYI this is how we infer resource types. This would need to be updated to check for paths relative to the working dir. There are also other Path.resolve() instances throughout the codebase that we would need to updated accordingly.

Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@

@functools.lru_cache
def import_module_from_file(module_path: Union[Path, str], module_name: Optional[str] = None):
module_path = Path(module_path).resolve()
if not module_path.exists():
raise ValueError(f"{module_path} doesn't exist")

module_path = Path(module_path)
if module_name is None:
if module_path.is_dir():
module_name = module_path.name
Expand All @@ -24,7 +21,18 @@ def import_module_from_file(module_path: Union[Path, str], module_name: Optional
else:
module_name = module_path.stem

spec = importlib.util.spec_from_file_location(module_name, module_path)
# Try to find the module in sys.path
spec = importlib.util.find_spec(module_name)
if not spec:
# If not found, try to load the module from the file
module_path = module_path.resolve()
if not module_path.exists():
raise ValueError(f"{module_path} doesn't exist")

spec = importlib.util.spec_from_file_location(module_name, module_path)
if not spec:
raise ValueError(f"Could not load module at {module_path}")

new_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(new_module)
return new_module
Expand Down
143 changes: 124 additions & 19 deletions test/unit_test/common/test_import_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -15,30 +16,37 @@
@patch("olive.common.import_lib.sys.path")
@patch("olive.common.import_lib.importlib.util")
def test_import_user_module_user_script_is_file(mock_importlib_util, mock_sys_path):
"""Test import_user_module when user_script is a file in script_dir."""
# setup
user_script = "user_script_a.py"
script_dir = "script_dir_a"

Path(script_dir).mkdir(parents=True, exist_ok=True)
script_dir_path = Path(script_dir).resolve()

with open(user_script, "w") as _:
# put user_script in script_dir
user_script_path = script_dir_path / user_script
with open(user_script_path, "w") as _:
pass

# mock
mock_spec = MagicMock()
mock_importlib_util.spec_from_file_location.return_value = mock_spec
mock_importlib_util.find_spec.return_value = mock_spec
expected_res = MagicMock()
mock_importlib_util.module_from_spec.return_value = expected_res

# execute
actual_res = import_user_module(user_script, script_dir)

# assert
script_dir_path = Path(script_dir).resolve()
mock_sys_path.append.assert_called_once_with(str(script_dir_path))
assert actual_res == expected_res

user_script_path = Path(user_script).resolve()
mock_importlib_util.spec_from_file_location.assert_called_once_with("user_script_a", user_script_path)
# script_dir will be added to sys.path
mock_sys_path.append.assert_called_once_with(str(script_dir_path))
# mock_importlib_util can find the user_script
mock_importlib_util.find_spec.assert_called_once_with("user_script_a")
mock_importlib_util.spec_from_file_location.assert_not_called()
mock_importlib_util.module_from_spec.assert_called_once_with(mock_spec)
mock_spec.loader.exec_module.assert_called_once_with(expected_res)

# cleanup
if os.path.exists(script_dir_path):
Expand All @@ -50,34 +58,39 @@ def test_import_user_module_user_script_is_file(mock_importlib_util, mock_sys_pa
@patch("olive.common.import_lib.sys.path")
@patch("olive.common.import_lib.importlib.util")
def test_import_user_module_user_script_is_dir(mock_importlib_util, mock_sys_path):
"""Test import_user_module when.

- script_dir is None
- user_script is a dir
"""
# setup
user_script = "user_script_b"
script_dir = "script_dir_b"

Path(script_dir).mkdir(parents=True, exist_ok=True)
Path(user_script).mkdir(parents=True, exist_ok=True)
user_script_path = Path(user_script).resolve()
with open(user_script_path / "__init__.py", "w") as _:
pass

mock_spec = MagicMock()
mock_importlib_util.find_spec.return_value = None
mock_importlib_util.spec_from_file_location.return_value = mock_spec
expected_res = MagicMock()
mock_importlib_util.module_from_spec.return_value = expected_res

# execute
actual_res = import_user_module(user_script, script_dir)
actual_res = import_user_module(user_script, script_dir=None)

# assert
script_dir_path = Path(script_dir).resolve()
mock_sys_path.append.assert_called_once_with(str(script_dir_path))
assert actual_res == expected_res

user_script_path = Path(user_script).resolve()
user_script_path_init = user_script_path / "__init__.py"
mock_importlib_util.spec_from_file_location.assert_called_once_with("user_script_b", user_script_path_init)
mock_sys_path.append.assert_not_called()
mock_importlib_util.find_spec.assert_called_once_with("user_script_b")
mock_importlib_util.spec_from_file_location.assert_called_once_with(
"user_script_b", (user_script_path / "__init__.py").resolve()
)
mock_importlib_util.module_from_spec.assert_called_once_with(mock_spec)
mock_spec.loader.exec_module.assert_called_once_with(expected_res)

# cleanup
if os.path.exists(script_dir_path):
shutil.rmtree(script_dir_path)
if os.path.exists(user_script_path):
shutil.rmtree(user_script_path)

Expand All @@ -102,8 +115,100 @@ def test_import_user_module_user_script_exception():

# execute
with pytest.raises(ValueError) as errinfo: # noqa: PT011
import_user_module(user_script)
import_user_module(user_script, script_dir=None)

# assert
user_script_path = Path(user_script).resolve()
assert str(errinfo.value) == f"{user_script_path} doesn't exist"


@patch("olive.common.import_lib.sys.path")
@patch("olive.common.import_lib.importlib.util")
def test_import_user_module_script_dir_none_and_user_script_exists(mock_importlib_util, mock_sys_path):
"""Test import_user_module when.

1. script_dir is None
2. user_script is not in any dir in sys.path
3. user_script exists
"""
with TemporaryDirectory(prefix="not_in_sys_path_dir") as temp_dir:
# setup
user_script = "user_script_e.py"
user_script_path = Path(temp_dir) / user_script
with open(user_script_path, "w") as _:
pass

# mock
mock_spec = MagicMock()
mock_importlib_util.find_spec.return_value = None
mock_importlib_util.spec_from_file_location.return_value = mock_spec
expected_module = MagicMock()
mock_importlib_util.module_from_spec.return_value = expected_module

# execute
actual_res = import_user_module(user_script_path, script_dir=None)

# assert
assert actual_res == expected_module
mock_sys_path.append.assert_not_called()
mock_importlib_util.find_spec.assert_called_once_with("user_script_e")
mock_importlib_util.spec_from_file_location.assert_called_once_with("user_script_e", user_script_path.resolve())
mock_importlib_util.module_from_spec.assert_called_once_with(mock_spec)
mock_spec.loader.exec_module.assert_called_once_with(expected_module)


@patch("olive.common.import_lib.sys.path")
@patch("olive.common.import_lib.importlib.util")
def test_import_user_module_script_dir_none_and_user_script_not_exists(mock_importlib_util, mock_sys_path):
"""Test import_user_module when.

1. script_dir is None
2. user_script is not in any dir in sys.path
3. user_script does not exist
"""
# setup
user_script = "nonexistent_script.py"

# mock
mock_importlib_util.find_spec.return_value = None

# execute
with pytest.raises(ValueError) as errinfo: # noqa: PT011
import_user_module(user_script, script_dir=None)

# assert
assert str(errinfo.value) == f"{Path(user_script).resolve()} doesn't exist"
mock_sys_path.append.assert_not_called()
mock_importlib_util.find_spec.assert_called_once_with("nonexistent_script")
mock_importlib_util.spec_from_file_location.assert_not_called()
mock_importlib_util.module_from_spec.assert_not_called()


def test_import_user_module_user_script_in_sys_path():
"""Test import_user_module with the following conditions.

1. user_script is in a directory already in sys.path.
2. script_dir is None.
3. find_spec is used, and spec_from_file_location is not called.
"""
with TemporaryDirectory(prefix="temp_sys_path_dir") as temp_dir:
# setup
temp_dir_path = Path(temp_dir).resolve()
user_script = "user_script_f.py"
user_script_path = temp_dir_path / user_script
with open(user_script_path, "w") as _:
pass

# add temp_dir to sys.path
import sys

sys.path.insert(0, str(temp_dir_path))

try:
# execute
actual_res = import_user_module(user_script, script_dir=None)

# assert
assert actual_res.__file__ == str(user_script_path)
finally:
sys.path.remove(str(temp_dir_path))
Loading