From 4c3325fcd8e7ebf9b298b1f56b7ee1e3901cbf83 Mon Sep 17 00:00:00 2001 From: QiyuanChen Date: Thu, 8 Aug 2024 09:56:37 +0800 Subject: [PATCH 1/7] feat: :sparkles: Add OpenMLDataset class for loading datasets from OpenML --- torch_frame/datasets/openml_dataset.py | 85 ++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 torch_frame/datasets/openml_dataset.py diff --git a/torch_frame/datasets/openml_dataset.py b/torch_frame/datasets/openml_dataset.py new file mode 100644 index 00000000..5f40cca5 --- /dev/null +++ b/torch_frame/datasets/openml_dataset.py @@ -0,0 +1,85 @@ +import os + +import openml +import pandas as pd + +import torch_frame +from torch_frame import stype +from torch_frame.utils.infer_stype import infer_series_stype + + +class OpenMLDataset(torch_frame.data.Dataset): + """ + A dataset class for loading datasets from OpenML, designed to integrate with the torch_frame library. + More information about OpenML can be found at https://www.openml.org/. + + Parameters: + - dataset_id (int): The ID of the dataset to be loaded from OpenML. + - cache_dir (str, optional): The directory where the dataset is cached. If None, the default cache directory is used. + """ + + def __init__(self, dataset_id: int, cache_dir: str = None): + if cache_dir is not None: + openml.config.set_root_cache_directory(os.path.expanduser(cache_dir)) + self.dataset_id = dataset_id + self._openml_dataset = openml.datasets.get_dataset( + self.dataset_id, + download_data=True, + download_qualities=True, + download_features_meta_data=True, + ) + # Get dataset info from OpenML + self.dataset_info = self._openml_dataset.qualities + target_col = self._openml_dataset.default_target_attribute + X, y, self.categorical_indicator, _ = self._openml_dataset.get_data( + target=target_col + ) + df = pd.concat([X, y], axis=1) + self._task_type: torch_frame.TaskType = None + self._num_classes: int = None + + # The column type can be inferred from the categorical_indicator + col_to_stype = { + col: stype.categorical if self.categorical_indicator[i] else stype.numerical + for i, col in enumerate(X.columns) + } + + # Infer the stype of the target column + target_col_type = infer_series_stype(df[target_col]) + if target_col_type == torch_frame.categorical: + assert self.dataset_info["NumberOfClasses"] > 0 + if self.dataset_info["NumberOfClasses"] == 2: + assert df[target_col].nunique() == 2 + self._task_type = torch_frame.TaskType.BINARY_CLASSIFICATION + self._num_classes = 2 + else: + assert df[target_col].nunique() == self.dataset_info["NumberOfClasses"] + self._task_type = torch_frame.TaskType.MULTICLASS_CLASSIFICATION + self._num_classes = int(self.dataset_info["NumberOfClasses"]) + col_to_stype[target_col] = torch_frame.categorical + else: + assert self.dataset_info["NumberOfClasses"] == 0 + self._task_type = torch_frame.TaskType.REGRESSION + self._num_classes = 0 + col_to_stype[target_col] = torch_frame.numerical + + super().__init__(df=df, col_to_stype=col_to_stype, target_col=target_col) + + # NOTE: Overriding the `task_type()` and `num_classes` property method + @property + def task_type(self) -> torch_frame.TaskType: + """Returns the task type of the dataset. + + Returns: + torch_frame.TaskType: The task type of the dataset. + """ + return self._task_type + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset. + + Returns: + int: The number of classes in the dataset. + """ + return self._num_classes From 6e4be0623460670fbe9552d870da4710620e4733 Mon Sep 17 00:00:00 2001 From: QiyuanChen Date: Thu, 8 Aug 2024 10:03:00 +0800 Subject: [PATCH 2/7] style: :art: refine class docstring and initialize cache_dir --- torch_frame/datasets/openml_dataset.py | 34 +++++++++++++++----------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/torch_frame/datasets/openml_dataset.py b/torch_frame/datasets/openml_dataset.py index 5f40cca5..a7498e47 100644 --- a/torch_frame/datasets/openml_dataset.py +++ b/torch_frame/datasets/openml_dataset.py @@ -9,18 +9,19 @@ class OpenMLDataset(torch_frame.data.Dataset): - """ - A dataset class for loading datasets from OpenML, designed to integrate with the torch_frame library. + """A dataset class for loading datasets from OpenML, \ + designed to integrate with the torch_frame library. More information about OpenML can be found at https://www.openml.org/. Parameters: - dataset_id (int): The ID of the dataset to be loaded from OpenML. - - cache_dir (str, optional): The directory where the dataset is cached. If None, the default cache directory is used. + - cache_dir (str, optional): The directory where the dataset is cached. \ + If None, the default cache directory is used. """ - - def __init__(self, dataset_id: int, cache_dir: str = None): + def __init__(self, dataset_id: int, cache_dir: str | None = None): if cache_dir is not None: - openml.config.set_root_cache_directory(os.path.expanduser(cache_dir)) + openml.config.set_root_cache_directory( + os.path.expanduser(cache_dir)) self.dataset_id = dataset_id self._openml_dataset = openml.datasets.get_dataset( self.dataset_id, @@ -32,15 +33,17 @@ def __init__(self, dataset_id: int, cache_dir: str = None): self.dataset_info = self._openml_dataset.qualities target_col = self._openml_dataset.default_target_attribute X, y, self.categorical_indicator, _ = self._openml_dataset.get_data( - target=target_col - ) + target=target_col) df = pd.concat([X, y], axis=1) - self._task_type: torch_frame.TaskType = None - self._num_classes: int = None + self._task_type: torch_frame.TaskType = ( + torch_frame.TaskType.BINARY_CLASSIFICATION) + self._num_classes: int = 0 # The column type can be inferred from the categorical_indicator col_to_stype = { - col: stype.categorical if self.categorical_indicator[i] else stype.numerical + col: + stype.categorical + if self.categorical_indicator[i] else stype.numerical for i, col in enumerate(X.columns) } @@ -53,8 +56,10 @@ def __init__(self, dataset_id: int, cache_dir: str = None): self._task_type = torch_frame.TaskType.BINARY_CLASSIFICATION self._num_classes = 2 else: - assert df[target_col].nunique() == self.dataset_info["NumberOfClasses"] - self._task_type = torch_frame.TaskType.MULTICLASS_CLASSIFICATION + assert df[target_col].nunique( + ) == self.dataset_info["NumberOfClasses"] + self._task_type = ( + torch_frame.TaskType.MULTICLASS_CLASSIFICATION) self._num_classes = int(self.dataset_info["NumberOfClasses"]) col_to_stype[target_col] = torch_frame.categorical else: @@ -63,7 +68,8 @@ def __init__(self, dataset_id: int, cache_dir: str = None): self._num_classes = 0 col_to_stype[target_col] = torch_frame.numerical - super().__init__(df=df, col_to_stype=col_to_stype, target_col=target_col) + super().__init__(df=df, col_to_stype=col_to_stype, + target_col=target_col) # NOTE: Overriding the `task_type()` and `num_classes` property method @property From ad50d9ee5edae704fdf91fc023eb068892f5b20c Mon Sep 17 00:00:00 2001 From: QiyuanChen Date: Thu, 8 Aug 2024 10:04:00 +0800 Subject: [PATCH 3/7] feat: :package: add OpenMLDataset to real-world datasets list --- torch_frame/datasets/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_frame/datasets/__init__.py b/torch_frame/datasets/__init__.py index 832a477a..3c3a1d96 100644 --- a/torch_frame/datasets/__init__.py +++ b/torch_frame/datasets/__init__.py @@ -19,6 +19,7 @@ from .amazon_fine_food_reviews import AmazonFineFoodReviews from .diamond_images import DiamondImages from .huggingface_dataset import HuggingFaceDatasetDict +from .openml_dataset import OpenMLDataset real_world_datasets = [ 'Titanic', @@ -38,6 +39,7 @@ 'Movielens1M', 'AmazonFineFoodReviews', 'DiamondImages', + 'OpenMLDataset', ] synthetic_datasets = [ From 74a86c5879eaac66e9f51492496f280fe653df77 Mon Sep 17 00:00:00 2001 From: QiyuanChen Date: Thu, 8 Aug 2024 10:05:04 +0800 Subject: [PATCH 4/7] test: :white_check_mark: add unit tests for OpenMLDataset class --- test/datasets/test_data_frame_openml.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 test/datasets/test_data_frame_openml.py diff --git a/test/datasets/test_data_frame_openml.py b/test/datasets/test_data_frame_openml.py new file mode 100644 index 00000000..c8fbdb14 --- /dev/null +++ b/test/datasets/test_data_frame_openml.py @@ -0,0 +1,20 @@ +import pytest + +from torch_frame.datasets import OpenMLDataset +from torch_frame.typing import TaskType + + +@pytest.mark.parametrize("dataset_id", [8, 31, 455]) +def test_data_frame_openml(dataset_id): + dataset = OpenMLDataset(dataset_id) + if dataset_id == 8: + assert dataset.task_type == TaskType.REGRESSION + assert dataset.target_col == "drinks" + if dataset_id == 31: + assert dataset.task_type == TaskType.BINARY_CLASSIFICATION + assert dataset.num_classes == 2 + assert dataset.target_col == "class" + if dataset_id == 455: + assert dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION + assert dataset.num_classes == 3 + assert dataset.target_col == "origin" From 2c6feffa7b50c84f92cf8fc6efb13920d9e13478 Mon Sep 17 00:00:00 2001 From: QiyuanChen Date: Thu, 8 Aug 2024 10:06:08 +0800 Subject: [PATCH 5/7] build: :heavy_plus_sign: add openml to pyproject.toml dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a611688e..74e6ea6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ full=[ "lightgbm", "datasets", "torchmetrics", + "openml", ] [project.urls] From 8e9172d0ab3d9b234eb156d4c6c327f6d9cbae6c Mon Sep 17 00:00:00 2001 From: QiyuanChen Date: Thu, 8 Aug 2024 10:28:29 +0800 Subject: [PATCH 6/7] ci: :bug: fix typing errors in git actions --- torch_frame/datasets/openml_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_frame/datasets/openml_dataset.py b/torch_frame/datasets/openml_dataset.py index a7498e47..17ffecb0 100644 --- a/torch_frame/datasets/openml_dataset.py +++ b/torch_frame/datasets/openml_dataset.py @@ -1,4 +1,5 @@ import os +from typing import Optional import openml import pandas as pd @@ -9,16 +10,16 @@ class OpenMLDataset(torch_frame.data.Dataset): - """A dataset class for loading datasets from OpenML, \ + r"""A dataset class for loading datasets from OpenML, designed to integrate with the torch_frame library. More information about OpenML can be found at https://www.openml.org/. Parameters: - dataset_id (int): The ID of the dataset to be loaded from OpenML. - - cache_dir (str, optional): The directory where the dataset is cached. \ + - cache_dir (str, optional): The directory where the dataset is cached. If None, the default cache directory is used. """ - def __init__(self, dataset_id: int, cache_dir: str | None = None): + def __init__(self, dataset_id: int, cache_dir: Optional[str] = None): if cache_dir is not None: openml.config.set_root_cache_directory( os.path.expanduser(cache_dir)) From 54d512ced1dd0bebf55483e57144bc36de6c3cba Mon Sep 17 00:00:00 2001 From: QiyuanChen Date: Fri, 9 Aug 2024 17:39:13 +0800 Subject: [PATCH 7/7] fix: :bug: ensure openml is installed for OpenMLDataset --- torch_frame/datasets/openml_dataset.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/torch_frame/datasets/openml_dataset.py b/torch_frame/datasets/openml_dataset.py index 17ffecb0..7508b8a2 100644 --- a/torch_frame/datasets/openml_dataset.py +++ b/torch_frame/datasets/openml_dataset.py @@ -1,7 +1,6 @@ import os from typing import Optional -import openml import pandas as pd import torch_frame @@ -10,16 +9,26 @@ class OpenMLDataset(torch_frame.data.Dataset): - r"""A dataset class for loading datasets from OpenML, + r"""The `OpenML`_. + + Args: + dataset_id (int): The ID of the dataset to be loaded from OpenML. + cache_dir (str, optional): The directory where the dataset is cached. + If None, the default cache directory is used. """ def __init__(self, dataset_id: int, cache_dir: Optional[str] = None): + try: + import openml + except ImportError: + raise ImportError( + "The OpenML library is required to use the OpenMLDataset class. " + "You can install it using `pip install openml`." + ) if cache_dir is not None: openml.config.set_root_cache_directory( os.path.expanduser(cache_dir))