Skip to content

Commit 39e60b4

Browse files
committed
Add runtime type checking
This patch adds beartype for runtime type checking. This gives us the best of both worlds: we do static type checking of our own library with mypy, and we export our static types, but for clients who do not run static type checking of their own code, runtime type checking in our library can help them catch bugs earlier. `tests/test_codex_tool.py::test_bad_argument_type` serves as an example: this fails at initialization time of `CodexTool`, whereas without runtime type checking, this would fail later (e.g., when the user calls the `query` method on the object). Because we're performing runtime type checking, some of the imports that were behind `if TYPE_CHECKING` flags have to be moved to runtime. This patch updates the linter config to allow imports that are only used for type checking. This patch also switches to consistent `from __future__ import annotations` everywhere, stops using type hints deprecated by PEP 585, and uses PEP 585 / PEP 604 syntax everywhere. This patch updates the linter config to match this style. beartype relies on `isinstance` for runtime type checks, which needs to be taken into account when using mocks by overriding the `__class__` attribute. This patch updates the tests accordingly.
1 parent 453263b commit 39e60b4

File tree

12 files changed

+70
-46
lines changed

12 files changed

+70
-46
lines changed

pyproject.toml

+6-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ classifiers = [
2727
dependencies = [
2828
"codex-sdk==0.1.0a9",
2929
"pydantic>=1.9.0, <3",
30+
"beartype>=0.17.0",
3031
]
3132

3233
[project.urls]
@@ -98,4 +99,8 @@ html = "coverage html"
9899
xml = "coverage xml"
99100

100101
[tool.ruff.lint]
101-
ignore = ["FA100", "UP007", "UP006"]
102+
ignore = [
103+
"TCH001", # this package does run-time type checking
104+
"TCH002",
105+
"TCH003"
106+
]

src/cleanlab_codex/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
# SPDX-License-Identifier: MIT
2+
3+
from beartype.claw import beartype_this_package
4+
5+
# this must run before any other imports from the cleanlab_codex package
6+
beartype_this_package()
7+
8+
# ruff: noqa: E402
29
from cleanlab_codex.codex import Codex
310
from cleanlab_codex.codex_tool import CodexTool
411

src/cleanlab_codex/codex.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Optional
4-
53
from cleanlab_codex.internal.project import create_project, query_project
64
from cleanlab_codex.internal.utils import init_codex_client
7-
8-
if TYPE_CHECKING:
9-
from cleanlab_codex.types.entry import Entry, EntryCreate
10-
from cleanlab_codex.types.organization import Organization
5+
from cleanlab_codex.types.entry import Entry, EntryCreate
6+
from cleanlab_codex.types.organization import Organization
117

128

139
class Codex:
@@ -41,7 +37,7 @@ def list_organizations(self) -> list[Organization]:
4137
"""
4238
return self._client.users.myself.organizations.list().organizations
4339

44-
def create_project(self, name: str, organization_id: str, description: Optional[str] = None) -> str:
40+
def create_project(self, name: str, organization_id: str, description: str | None = None) -> str:
4541
"""Create a new Codex project.
4642
4743
Args:
@@ -77,7 +73,7 @@ def create_project_access_key(
7773
self,
7874
project_id: str,
7975
access_key_name: str,
80-
access_key_description: Optional[str] = None,
76+
access_key_description: str | None = None,
8177
) -> str:
8278
"""Create a new access key for a project.
8379
@@ -99,15 +95,15 @@ def query(
9995
self,
10096
question: str,
10197
*,
102-
project_id: Optional[str] = None, # TODO: update to uuid once project IDs are changed to UUIDs
103-
fallback_answer: Optional[str] = None,
98+
project_id: str | None = None, # TODO: update to uuid once project IDs are changed to UUIDs
99+
fallback_answer: str | None = None,
104100
read_only: bool = False,
105-
) -> tuple[Optional[str], Optional[Entry]]:
101+
) -> tuple[str | None, Entry | None]:
106102
"""Query Codex to check if the Codex project contains an answer to this question and add the question to the Codex project for SME review if it does not.
107103
108104
Args:
109105
question (str): The question to ask the Codex API.
110-
project_id (:obj:`int`, optional): The ID of the project to query.
106+
project_id (:obj:`str`, optional): The ID of the project to query.
111107
If the client is authenticated with a user-level API Key, this is required.
112108
If the client is authenticated with a project-level Access Key, this is optional. The client will use the Access Key's project ID by default.
113109
fallback_answer (:obj:`str`, optional): Optional fallback answer to return if Codex is unable to answer the question.

src/cleanlab_codex/codex_tool.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, ClassVar, Optional
3+
from typing import Any, ClassVar
44

55
from cleanlab_codex.codex import Codex
66

@@ -23,8 +23,8 @@ def __init__(
2323
self,
2424
codex_client: Codex,
2525
*,
26-
project_id: Optional[str] = None,
27-
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER,
26+
project_id: str | None = None,
27+
fallback_answer: str | None = DEFAULT_FALLBACK_ANSWER,
2828
):
2929
self._codex_client = codex_client
3030
self._project_id = project_id
@@ -35,8 +35,8 @@ def from_access_key(
3535
cls,
3636
access_key: str,
3737
*,
38-
project_id: Optional[str] = None,
39-
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER,
38+
project_id: str | None = None,
39+
fallback_answer: str | None = DEFAULT_FALLBACK_ANSWER,
4040
) -> CodexTool:
4141
"""Creates a CodexTool from an access key. The project ID that the CodexTool will use is the one that is associated with the access key."""
4242
return cls(
@@ -50,8 +50,8 @@ def from_client(
5050
cls,
5151
codex_client: Codex,
5252
*,
53-
project_id: Optional[str] = None,
54-
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER,
53+
project_id: str | None = None,
54+
fallback_answer: str | None = DEFAULT_FALLBACK_ANSWER,
5555
) -> CodexTool:
5656
"""Creates a CodexTool from a Codex client.
5757
If the Codex client is initialized with a project access key, the CodexTool will use the project ID that is associated with the access key.
@@ -74,16 +74,16 @@ def tool_description(self) -> str:
7474
return self._tool_description
7575

7676
@property
77-
def fallback_answer(self) -> Optional[str]:
77+
def fallback_answer(self) -> str | None:
7878
"""The fallback answer to use if the Codex project cannot answer the question."""
7979
return self._fallback_answer
8080

8181
@fallback_answer.setter
82-
def fallback_answer(self, value: Optional[str]) -> None:
82+
def fallback_answer(self, value: str | None) -> None:
8383
"""Sets the fallback answer to use if the Codex project cannot answer the question."""
8484
self._fallback_answer = value
8585

86-
def query(self, question: str) -> Optional[str]:
86+
def query(self, question: str) -> str | None:
8787
"""Asks an all-knowing advisor this question in cases where it cannot be answered from the provided Context. If the answer is not available, this returns a fallback answer or None.
8888
8989
Args:

src/cleanlab_codex/internal/project.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Optional
4-
5-
if TYPE_CHECKING:
6-
from codex import Codex as _Codex
7-
8-
from cleanlab_codex.types.entry import Entry
3+
from codex import Codex as _Codex
94

5+
from cleanlab_codex.types.entry import Entry
106
from cleanlab_codex.types.project import ProjectConfig
117

128

@@ -17,7 +13,7 @@ def __str__(self) -> str:
1713
return "project_id is required when authenticating with a user-level API Key"
1814

1915

20-
def create_project(client: _Codex, name: str, organization_id: str, description: Optional[str] = None) -> str:
16+
def create_project(client: _Codex, name: str, organization_id: str, description: str | None = None) -> str:
2117
project = client.projects.create(
2218
config=ProjectConfig(),
2319
organization_id=organization_id,
@@ -31,10 +27,10 @@ def query_project(
3127
client: _Codex,
3228
question: str,
3329
*,
34-
project_id: Optional[str] = None,
35-
fallback_answer: Optional[str] = None,
30+
project_id: str | None = None,
31+
fallback_answer: str | None = None,
3632
read_only: bool = False,
37-
) -> tuple[Optional[str], Optional[Entry]]:
33+
) -> tuple[str | None, Entry | None]:
3834
if client.access_key is not None:
3935
project_id = client.projects.access_keys.retrieve_project_id().project_id
4036
elif project_id is None:

src/cleanlab_codex/utils/llamaindex.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
from collections.abc import Callable
34
from inspect import signature
4-
from typing import Any, Callable
5+
from typing import Any
56

67
from llama_index.core.bridge.pydantic import BaseModel, FieldInfo, create_model
78

src/cleanlab_codex/utils/openai.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, List, Literal
3+
from typing import Any, Literal
44

55
from pydantic import BaseModel
66

@@ -12,8 +12,8 @@ class Property(BaseModel):
1212

1313
class FunctionParameters(BaseModel):
1414
type: Literal["object"] = "object"
15-
properties: Dict[str, Property]
16-
required: List[str]
15+
properties: dict[str, Property]
16+
required: list[str]
1717

1818

1919
class Function(BaseModel):
@@ -30,9 +30,9 @@ class Tool(BaseModel):
3030
def format_as_openai_tool(
3131
tool_name: str,
3232
tool_description: str,
33-
tool_properties: Dict[str, Any],
34-
required_properties: List[str],
35-
) -> Dict[str, Any]:
33+
tool_properties: dict[str, Any],
34+
required_properties: list[str],
35+
) -> dict[str, Any]:
3636
return Tool(
3737
function=Function(
3838
name=tool_name,
+6-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
from typing import Callable, Dict, Optional
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
24

35
from smolagents import Tool # type: ignore
46

57

68
class CodexTool(Tool): # type: ignore[misc]
79
def __init__(
810
self,
9-
query: Callable[[str], Optional[str]],
11+
query: Callable[[str], str | None],
1012
tool_name: str,
1113
tool_description: str,
12-
inputs: Dict[str, Dict[str, str]],
14+
inputs: dict[str, dict[str, str]],
1315
):
1416
super().__init__()
1517
self._query = query
@@ -18,5 +20,5 @@ def __init__(
1820
self.inputs = inputs
1921
self.output_type = "string"
2022

21-
def forward(self, question: str) -> Optional[str]:
23+
def forward(self, question: str) -> str | None:
2224
return self._query(question)

tests/fixtures/client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
from typing import Generator
1+
from collections.abc import Generator
22
from unittest.mock import MagicMock, patch
33

44
import pytest
55

6+
from cleanlab_codex.internal.utils import _Codex
7+
68

79
@pytest.fixture
810
def mock_client() -> Generator[MagicMock, None, None]:
911
with patch("cleanlab_codex.codex.init_codex_client") as mock_init:
1012
mock_client = MagicMock()
13+
mock_client.__class__ = _Codex
1114
mock_init.return_value = mock_client
1215
yield mock_client

tests/internal/test_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from cleanlab_codex.internal.utils import MissingAuthKeyError, init_codex_client, is_access_key
6+
from cleanlab_codex.internal.utils import MissingAuthKeyError, _Codex, init_codex_client, is_access_key
77

88
DUMMY_ACCESS_KEY = "sk-1-EMOh6UrRo7exTEbEi8_azzACAEdtNiib2LLa1IGo6kA"
99
DUMMY_API_KEY = "GP0FzPfA7wYy5L64luII2YaRT2JoSXkae7WEo7dH6Bw"
@@ -16,6 +16,7 @@ def test_is_access_key() -> None:
1616

1717
def test_init_codex_client_access_key() -> None:
1818
mock_client = MagicMock()
19+
mock_client.__class__ = _Codex
1920
with patch("cleanlab_codex.internal.utils._Codex", autospec=True, return_value=mock_client) as mock_init:
2021
mock_client.projects.access_keys.retrieve_project_id.return_value = "test_project_id"
2122
client = init_codex_client(DUMMY_ACCESS_KEY)
@@ -25,6 +26,7 @@ def test_init_codex_client_access_key() -> None:
2526

2627
def test_init_codex_client_api_key() -> None:
2728
mock_client = MagicMock()
29+
mock_client.__class__ = _Codex
2830
with patch("cleanlab_codex.internal.utils._Codex", autospec=True, return_value=mock_client) as mock_init:
2931
mock_client.users.myself.api_key.retrieve.return_value = "test_project_id"
3032
client = init_codex_client(DUMMY_API_KEY)
@@ -40,6 +42,7 @@ def test_init_codex_client_no_key() -> None:
4042
def test_init_codex_client_access_key_env_var() -> None:
4143
with patch.dict(os.environ, {"CODEX_ACCESS_KEY": DUMMY_ACCESS_KEY}):
4244
mock_client = MagicMock()
45+
mock_client.__class__ = _Codex
4346
with patch(
4447
"cleanlab_codex.internal.utils._Codex",
4548
autospec=True,
@@ -54,6 +57,7 @@ def test_init_codex_client_access_key_env_var() -> None:
5457
def test_init_codex_client_api_key_env_var() -> None:
5558
with patch.dict(os.environ, {"CODEX_API_KEY": DUMMY_API_KEY}):
5659
mock_client = MagicMock()
60+
mock_client.__class__ = _Codex
5761
with patch(
5862
"cleanlab_codex.internal.utils._Codex",
5963
autospec=True,

tests/test_codex.py

+3
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def test_create_project_access_key(mock_client: MagicMock) -> None:
8686
codex = Codex("")
8787
access_key_name = "Test Access Key"
8888
access_key_description = "Test Access Key Description"
89+
access_key = MagicMock()
90+
access_key.token.__class__ = str
91+
mock_client.projects.access_keys.create.return_value = access_key
8992
codex.create_project_access_key(FAKE_PROJECT_ID, access_key_name, access_key_description)
9093
mock_client.projects.access_keys.create.assert_called_once_with(
9194
project_id=FAKE_PROJECT_ID,

tests/test_codex_tool.py

+7
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,10 @@ def test_to_smolagents_tool(mock_client: MagicMock) -> None: # noqa: ARG001
3434
assert isinstance(smolagents_tool, Tool)
3535
assert smolagents_tool.name == tool.tool_name
3636
assert smolagents_tool.description == tool.tool_description
37+
38+
39+
def test_bad_argument_type() -> None:
40+
from beartype.roar import BeartypeException
41+
42+
with pytest.raises(BeartypeException):
43+
CodexTool("asdf")

0 commit comments

Comments
 (0)