Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dataset] Implementation of five datasets in GreycDataset #9977

Open
wants to merge 42 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c86a625
Added new file for class
Jan 20, 2025
2fc73c9
Added download methods
Jan 20, 2025
c1dfc4e
temp
thomasbauer76 Jan 20, 2025
71aa5b7
temp
thomasbauer76 Jan 20, 2025
4e0a79c
download ok
thomasbauer76 Jan 20, 2025
65dca72
load
thomasbauer76 Jan 20, 2025
73a22b7
GreycDataset fonctionnel
thomasbauer76 Jan 20, 2025
6fa5383
GreycDataset fonctionnel
thomasbauer76 Jan 20, 2025
e1a19ca
Rollback io.__init__.py
thomasbauer76 Jan 20, 2025
d6c1ac3
Formatting
thomasbauer76 Jan 20, 2025
f024e81
Premiere doc (copie TUDataset)
thomasbauer76 Jan 20, 2025
3dfca84
Debut de doc après formattage
thomasbauer76 Jan 20, 2025
d752776
Mise à jour de la doc
thomasbauer76 Jan 21, 2025
c311e08
Merge pull request #1 from thomasbauer76/greycdataset
thomasbauer76 Jan 21, 2025
1abeb27
Mise à jour doc
thomasbauer76 Jan 21, 2025
00a27df
Update of docstring
thomasbauer76 Jan 22, 2025
0cd9144
Updated &
Jan 23, 2025
341e3d1
Pre-commit
Jan 23, 2025
f60be55
Addition of PAH and Monoterpens input support
thomasbauer76 Jan 23, 2025
309f898
Merge pull request #2 from pyg-team/master
thomasbauer76 Jan 23, 2025
740071f
Merge pull request #3 from pyg-team/master
thomasbauer76 Jan 23, 2025
def79ca
Merge pull request #4 from pyg-team/master
thomasbauer76 Jan 23, 2025
fe2132a
Merge branch 'greycdataset' into develop
thomasbauer76 Jan 23, 2025
5b0afe5
Merge pull request #5 from thomasbauer76/develop
thomasbauer76 Jan 23, 2025
f384419
Loading datasets from compressed gml files
thomasbauer76 Jan 23, 2025
50d7351
Link update and restructuring
thomasbauer76 Jan 24, 2025
2350c73
Fix data loading warning
thomasbauer76 Jan 24, 2025
0c1e94a
Bugfix : y tensor dtype set according to y's class
Jan 24, 2025
7c1e7a7
Merge branch 'develop' of github.com:RAY41/pytorch_geometric into dev…
Jan 24, 2025
35121ae
Override of num_* methods
thomasbauer76 Jan 24, 2025
9a83dd4
Deletion of debugging print
thomasbauer76 Jan 24, 2025
eaa3395
Update of greyc documentation
thomasbauer76 Jan 24, 2025
f683e14
Update of GreycDataset documentation
thomasbauer76 Jan 24, 2025
53339e7
Merge pull request #6 from thomasbauer76/develop
thomasbauer76 Jan 24, 2025
c6441d7
Update CHANGELOG.md
thomasbauer76 Jan 24, 2025
75f9936
Fix mypy assigment errors
thomasbauer76 Jan 24, 2025
2b7c620
Update CHANGELOG.md with pull-request link
thomasbauer76 Jan 24, 2025
aefbd95
Merge branch 'master' into greycdataset
thomasbauer76 Jan 27, 2025
b06f4ab
Merge branch 'pyg-team:master' into master
thomasbauer76 Jan 27, 2025
ca67d60
Merge branch 'master' into greycdataset
thomasbauer76 Jan 27, 2025
1555d37
Merge branch 'pyg-team:master' into master
thomasbauer76 Jan 27, 2025
771fe68
Merge pull request #7 from thomasbauer76/master
thomasbauer76 Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added five molecular datasets implemented in `GreycDataset` ([#9977](https://github.com/pyg-team/pytorch_geometric/pull/9977))
- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
Expand Down
4 changes: 4 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from .molecule_gpt_dataset import MoleculeGPTDataset
from .instruct_mol_dataset import InstructMolDataset
from .tag_dataset import TAGDataset
from .greyc import GreycDataset

from .dbp15k import DBP15K
from .aminer import AMiner
Expand Down Expand Up @@ -199,6 +200,7 @@
'MoleculeGPTDataset',
'InstructMolDataset',
'TAGDataset',
'GreycDataset',
]

hetero_datasets = [
Expand All @@ -220,9 +222,11 @@
'RCDD',
'OPFDataset',
]

hyper_datasets = [
'CornellTemporalHyperGraphDataset',
]

synthetic_datasets = [
'FakeDataset',
'FakeHeteroDataset',
Expand Down
224 changes: 224 additions & 0 deletions torch_geometric/datasets/greyc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import os
from typing import Callable, List, Optional

import torch

from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
extract_zip,
)


class GreycDataset(InMemoryDataset):
r"""Implementation of five `GREYC's Chemistry datasets
<https://lucbrun.ensicaen.fr/CHEMISTRY/>`_ : :obj:`Acyclic`,
:obj:`Alkane`, :obj:`MAO`, :obj:`Monoterpens` and :obj:`PAH`.

Args:
root (str): Root directory where the dataset should be saved.
name (str): The name (:obj:`Acyclic`,
:obj:`Alkane`, :obj:`MAO`, :obj:`Monoterpens` or :obj:`PAH`)
of the dataset.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)

**STATS:**

.. list-table::
:widths: 20 10 10 10 10 10 10
:header-rows: 1

* - Name
- #graphs
- #nodes
- #edges
- #features
- #classes
- Type
* - Acyclic
- 183
- ~8.2
- ~14.3
- 7
- 148
- Regression
* - Alkane
- 149
- ~8.9
- ~15.7
- 4
- 123
- Regression
* - MAO
- 68
- ~18.4
- ~39.3
- 7
- 2
- Classification
* - Monoterpens
- 302
- ~11.0
- ~22.2
- 7
- 10
- Classification
* - PAH
- 94
- ~20.7
- ~48.9
- 4
- 2
- Classification
"""

URL = ('https://raw.githubusercontent.com/bgauzere/'
'greycdata/refs/heads/main/greycdata/data_gml/')

def __init__(
self,
root: str,
name: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
force_reload: bool = False,
) -> None:
self.name = name.lower()
if self.name not in {"acyclic", "alkane", "mao", "monoterpens", "pah"}:
raise ValueError(f"Dataset {self.name} not found.")
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)
self.load(self.processed_paths[0])

def __repr__(self) -> str:
name = self.name.capitalize()
if self.name in ["mao", "pah"]:
name = self.name.upper()
return f'{name}({len(self)})'

@staticmethod
def _gml_to_data(gml: str, gml_file: bool = True) -> Data:
"""Reads a `gml` file and creates a `Data` object.

Parameters :
------------

* `gml` : gml file path if `gml_file` is
set to `True`, gml content otherwise.
* `gml_file` : indicates whether `gml` is a path
to a gml file or a gml file content.

Returns :
---------

* `Data` : the `Data` object created from the file content.

Raises :
--------

* `FileNotFoundError` : if `gml` is a file path and doesn't exist.
"""
import networkx as nx
if gml_file:
if not os.path.exists(gml):
raise FileNotFoundError(f"File `{gml}` does not exist")
g = nx.read_gml(gml)
else:
g = nx.parse_gml(gml)

x_l, edge_index_l, edge_attr_l = [], [], []

y = g.graph["y"] if "y" in g.graph else None
dtype = torch.float if isinstance(y, float) else torch.long
y = torch.tensor([y], dtype=dtype) if y is not None else None

for _, attr in g.nodes(data=True):
x_l.append(attr["x"])

for u, v, attr in g.edges(data=True):
edge_index_l.append([int(u), int(v)])
edge_index_l.append([int(v), int(u)])
edge_attr_l.append(attr["edge_attr"])
edge_attr_l.append(attr["edge_attr"])

x = torch.tensor(x_l, dtype=torch.float)
edge_index = torch.tensor(edge_index_l,
dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_attr_l, dtype=torch.long)

return Data(x=x, edge_attr=edge_attr, edge_index=edge_index, y=y)

@staticmethod
def _load_gml_data(gml: str) -> List[Data]:
"""Reads a dataset from a gml file
and converts it into a list of `Data`.

Parameters :
------------

* `gml` : Source filename. If `zip` file, it will be extracted.

Returns :
---------

* `List[Data]` : Content of the gml dataset.
"""
GML_SEPARATOR = "---"
with open(gml, encoding="utf8") as f:
gml_contents = f.read()
gml_files = gml_contents.split(GML_SEPARATOR)
return [
GreycDataset._gml_to_data(content, False) for content in gml_files
]

@property
def processed_file_names(self) -> str:
return "data.pt"

@property
def raw_file_names(self) -> str:
return f"{self.name}.gml"

@property
def num_node_features(self) -> int:
return int(self.x.shape[1])

@property
def num_edge_features(self) -> int:
return int(self.edge_attr.shape[1])

@property
def num_classes(self) -> int:
return len(self.y.unique())

def download(self) -> None:
path = download_url(GreycDataset.URL + self.name + ".zip",
self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)

def process(self) -> None:
data_list = self._load_gml_data(self.raw_paths[0])

if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]

if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]

self.save(data_list, self.processed_paths[0])
Loading