From c86a6255c83f7e126cfe2a8c6b188b3709c34315 Mon Sep 17 00:00:00 2001 From: Lyam Chardey Date: Mon, 20 Jan 2025 12:08:23 +0100 Subject: [PATCH 01/29] Added new file for class --- torch_geometric/datasets/greyc.py | 82 +++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 torch_geometric/datasets/greyc.py diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py new file mode 100644 index 000000000000..3ca868c54e8f --- /dev/null +++ b/torch_geometric/datasets/greyc.py @@ -0,0 +1,82 @@ +from typing import Callable, List, Optional + +import torch +from greycdata.loaders import load_acyclic, load_alkane, load_MAO + +from torch_geometric.data import InMemoryDataset # , download_url +from torch_geometric.utils import from_networkx + +# https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets + + +class DatasetNotFoundError(Exception): + pass + + +class GreycDataset(InMemoryDataset): + r"""Class to load three GREYC Datasets as pytorch geometric dataset.""" + url = 'https://github.com/bgauzere/greycdata/tree/main/greycdata/data' + + def __init__( + self, + root: str, + name: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + ) -> None: + self.name = name.lower() + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) + self.data, self.slices = torch.load(self.processed_paths[0]) + + def __str__(self) -> str: + return self.name + + @property + def processed_file_names(self) -> str: + return 'data.pt' + + @property + def raw_file_names(self) -> List[str]: + return [] + + def download(self) -> None: + """Load the right data according to initializer.""" + if self.name == 'alkane': + return load_alkane() + elif self.name == 'acyclic': + return load_acyclic() + elif self.name == 'mao': + return load_MAO() + else: + raise DatasetNotFoundError("Dataset not found") + + def process(self): + """Read data into huge `Data` list.""" + graph_list, property_list = self._load_data() + + # Convert to PyG. + + def from_nx_to_pyg(graph, y): + """Convert networkx graph to pytorch graph and add y.""" + pyg_graph = from_networkx( + graph, + group_node_attrs=['atom_symbol', 'degree', 'x', 'y', 'z']) + pyg_graph.y = y + return pyg_graph + + data_list = [ + from_nx_to_pyg(graph, y) + for graph, y in zip(graph_list, property_list) + ] + + if self.pre_filter is not None: + data_list = [data for data in data_list if self.pre_filter(data)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) From 2fc73c997cf41f9da23d08b756e7953bb12fad0b Mon Sep 17 00:00:00 2001 From: Lyam Chardey Date: Mon, 20 Jan 2025 12:28:10 +0100 Subject: [PATCH 02/29] Added download methods --- torch_geometric/datasets/greyc.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 3ca868c54e8f..74c53a9343aa 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -1,9 +1,8 @@ from typing import Callable, List, Optional import torch -from greycdata.loaders import load_acyclic, load_alkane, load_MAO -from torch_geometric.data import InMemoryDataset # , download_url +from torch_geometric.data import InMemoryDataset, download_url from torch_geometric.utils import from_networkx # https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets @@ -15,7 +14,8 @@ class DatasetNotFoundError(Exception): class GreycDataset(InMemoryDataset): r"""Class to load three GREYC Datasets as pytorch geometric dataset.""" - url = 'https://github.com/bgauzere/greycdata/tree/main/greycdata/data' + + URL = 'https://github.com/bgauzere/greycdata/tree/main/greycdata/data/' def __init__( self, @@ -42,16 +42,31 @@ def processed_file_names(self) -> str: def raw_file_names(self) -> List[str]: return [] + # def _load_alkane(self) -> None: + # """Load Alkane dataset.""" + # url = GreycDataset.URL + "Alkane" + # download_url(url, self.raw_dir) + + # def _load_acyclic(self) -> None: + # """Load Acyclic dataset.""" + # url = GreycDataset.URL + "Acyclic" + # download_url(url, self.raw_dir) + + # def _load_mao(self) -> None: + # """Load MAO dataset.""" + # url = GreycDataset.URL + "MAO" + # download_url(url, self.raw_dir) + def download(self) -> None: """Load the right data according to initializer.""" if self.name == 'alkane': - return load_alkane() + download_url(GreycDataset.URL + "Aklane", self.raw_dir) elif self.name == 'acyclic': - return load_acyclic() + download_url(GreycDataset.URL + "Acyclic", self.raw_dir) elif self.name == 'mao': - return load_MAO() + download_url(GreycDataset.URL + "MAO", self.raw_dir) else: - raise DatasetNotFoundError("Dataset not found") + raise DatasetNotFoundError(f"Dataset `{self.name}` not found") def process(self): """Read data into huge `Data` list.""" From c1dfc4e06f2f2eb077fed022afc86daa2a0dad5d Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 14:10:13 +0100 Subject: [PATCH 03/29] temp --- ...t-config.yaml => ...pre-commit-config.yaml | 0 torch_geometric/datasets/__init__.py | 4 + torch_geometric/datasets/greyc.py | 23 +- torch_geometric/io/__init__.py | 2 + torch_geometric/io/file_managers.py | 895 ++++++++++++++++++ torch_geometric/io/greyc.py | 155 +++ 6 files changed, 1060 insertions(+), 19 deletions(-) rename .pre-commit-config.yaml => ...pre-commit-config.yaml (100%) create mode 100644 torch_geometric/io/file_managers.py create mode 100644 torch_geometric/io/greyc.py diff --git a/.pre-commit-config.yaml b/...pre-commit-config.yaml similarity index 100% rename from .pre-commit-config.yaml rename to ...pre-commit-config.yaml diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 12895ad1dbac..cd4b9388e20b 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -80,6 +80,7 @@ from .git_mol_dataset import GitMolDataset from .molecule_gpt_dataset import MoleculeGPTDataset from .tag_dataset import TAGDataset +from .greyc import GreycDataset from .dbp15k import DBP15K from .aminer import AMiner @@ -196,6 +197,7 @@ 'GitMolDataset', 'MoleculeGPTDataset', 'TAGDataset', + 'GreycDataset', ] hetero_datasets = [ @@ -217,9 +219,11 @@ 'RCDD', 'OPFDataset', ] + hyper_datasets = [ 'CornellTemporalHyperGraphDataset', ] + synthetic_datasets = [ 'FakeDataset', 'FakeHeteroDataset', diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 74c53a9343aa..f585ac0f78f8 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -3,10 +3,9 @@ import torch from torch_geometric.data import InMemoryDataset, download_url +from torch_geometric.io import read_greyc from torch_geometric.utils import from_networkx -# https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets - class DatasetNotFoundError(Exception): pass @@ -15,7 +14,8 @@ class DatasetNotFoundError(Exception): class GreycDataset(InMemoryDataset): r"""Class to load three GREYC Datasets as pytorch geometric dataset.""" - URL = 'https://github.com/bgauzere/greycdata/tree/main/greycdata/data/' + URL = ('https://raw.githubusercontent.com/bgauzere/greycdata/refs/' + 'heads/main/greycdata/data/') def __init__( self, @@ -42,21 +42,6 @@ def processed_file_names(self) -> str: def raw_file_names(self) -> List[str]: return [] - # def _load_alkane(self) -> None: - # """Load Alkane dataset.""" - # url = GreycDataset.URL + "Alkane" - # download_url(url, self.raw_dir) - - # def _load_acyclic(self) -> None: - # """Load Acyclic dataset.""" - # url = GreycDataset.URL + "Acyclic" - # download_url(url, self.raw_dir) - - # def _load_mao(self) -> None: - # """Load MAO dataset.""" - # url = GreycDataset.URL + "MAO" - # download_url(url, self.raw_dir) - def download(self) -> None: """Load the right data according to initializer.""" if self.name == 'alkane': @@ -70,7 +55,7 @@ def download(self) -> None: def process(self): """Read data into huge `Data` list.""" - graph_list, property_list = self._load_data() + graph_list, property_list = read_greyc(self.raw_dir, self.name) # Convert to PyG. diff --git a/torch_geometric/io/__init__.py b/torch_geometric/io/__init__.py index 2b43b6cfc5c6..eae652ef5814 100644 --- a/torch_geometric/io/__init__.py +++ b/torch_geometric/io/__init__.py @@ -6,6 +6,7 @@ from .sdf import read_sdf, parse_sdf from .off import read_off, write_off from .npz import read_npz, parse_npz +from .greyc import read_greyc __all__ = [ 'read_off', @@ -20,4 +21,5 @@ 'parse_sdf', 'read_npz', 'parse_npz', + 'read_greyc', ] diff --git a/torch_geometric/io/file_managers.py b/torch_geometric/io/file_managers.py new file mode 100644 index 000000000000..53f08851674d --- /dev/null +++ b/torch_geometric/io/file_managers.py @@ -0,0 +1,895 @@ +"""Utilities function to manage graph files +Taken from graphkit-learn +""" +from os.path import dirname, splitext + + +class DataLoader(): + def __init__(self, filename, filename_targets=None, gformat=None, + **kwargs): + """Read graph data from filename and load them as NetworkX graphs. + + Parameters + ---------- + filename : string + The name of the file from where the dataset is read. + filename_targets : string + The name of file of the targets corresponding to graphs. + + Notes: + ----- + This function supports following graph dataset formats: + + 'ds': load data from .ds file. See comments of function loadFromDS for a example. + + 'cxl': load data from Graph eXchange Language file (.cxl file). See + `here `__ for detail. + + 'sdf': load data from structured data file (.sdf file). See + `here `__ + for details. + + 'mat': Load graph data from a MATLAB (up to version 7.1) .mat file. See + README in `downloadable file `__ + for details. + + 'txt': Load graph data from the TUDataset. See + `here `__ + for details. Note here filename is the name of either .txt file in + the dataset directory. + """ + if isinstance(filename, str): + extension = splitext(filename)[1][1:] + else: # filename is a list of files. + extension = splitext(filename[0])[1][1:] + + if extension == "ds": + self._graphs, self._targets, self._label_names = self.load_from_ds( + filename, filename_targets) + elif extension == "cxl": + dir_dataset = kwargs.get('dirname_dataset', None) + self._graphs, self._targets, self._label_names = self.load_from_xml( + filename, dir_dataset) + elif extension == 'xml': + dir_dataset = kwargs.get('dirname_dataset', None) + self._graphs, self._targets, self._label_names = self.load_from_xml( + filename, dir_dataset) + elif extension == "mat": + order = kwargs.get('order') + self._graphs, self._targets, self._label_names = self.load_mat( + filename, order) + elif extension == 'txt': + if gformat is None: + self._graphs, self._targets, self._label_names = self.load_tud( + filename) + elif gformat == 'cml': + self._graphs, self._targets, self._label_names = self.load_from_ds( + filename, filename_targets) + + else: + raise ValueError( + 'The input file with the extension ".', extension, + '" is not supported. The supported extensions includes: ".ds", ".cxl", ".xml", ".mat", ".txt".' + ) + + def load_from_ds(self, filename, filename_targets): + """Load data from .ds file. + + Possible graph formats include: + + '.ct': see function load_ct for detail. + + '.gxl': see dunction load_gxl for detail. + + Note these graph formats are checked automatically by the extensions of + graph files. + """ + if isinstance(filename, str): + dirname_dataset = dirname(filename) + with open(filename) as f: + content = f.read().splitlines() + else: # filename is a list of files. + dirname_dataset = dirname(filename[0]) + content = [] + for fn in filename: + with open(fn) as f: + content += f.read().splitlines() + # to remove duplicate file names. + + data = [] + y = [] + label_names = { + 'node_labels': [], + 'edge_labels': [], + 'node_attrs': [], + 'edge_attrs': [] + } + # Alkane + content = [line for line in content if not line.endswith('.ds')] + # Acyclic + content = [line for line in content if not line.startswith('#')] + extension = splitext(content[0].split(' ')[0])[1][1:] + if extension == 'ct': + load_file_fun = self.load_ct + # @todo: .sdf not tested yet. + elif extension == 'gxl' or extension == 'sdf': + load_file_fun = self.load_gxl + elif extension == 'cml': # dataset "Chiral" + load_file_fun = self.load_cml + + if filename_targets is None or filename_targets == '': + for i in range(0, len(content)): + tmp = content[i].split(' ') + # remove the '#'s in file names + g, l_names = load_file_fun(dirname_dataset + '/' + + tmp[0].replace('#', '', 1)) + data.append(g) + # @todo: this is so redundant. + self._append_label_names(label_names, l_names) + y.append(float(tmp[1])) + else: # targets in a seperate file + for i in range(0, len(content)): + tmp = content[i] + # remove the '#'s in file names + g, l_names = load_file_fun(dirname_dataset + '/' + + tmp.replace('#', '', 1)) + data.append(g) + self._append_label_names(label_names, l_names) + + with open(filename_targets) as fnt: + content_y = fnt.read().splitlines() + # assume entries in filename and filename_targets have the same order. + for item in content_y: + tmp = item.split(' ') + # assume the 3rd entry in a line is y (for Alkane dataset) + y.append(float(tmp[2])) + + return data, y, label_names + + def load_from_xml(self, filename, dir_dataset=None): + import xml.etree.ElementTree as ET + + def load_one_file(filename, data, y, label_names): + tree = ET.parse(filename) + root = tree.getroot() + # "graph" for ... I forgot; "print" for datasets GREC and Web. + for graph in root.iter('graph') if root.find( + 'graph') is not None else root.iter('print'): + mol_filename = graph.attrib['file'] + mol_class = graph.attrib['class'] + g, l_names = self.load_gxl(dir_dataset + '/' + mol_filename) + data.append(g) + self._append_label_names(label_names, l_names) + y.append(mol_class) + + data = [] + y = [] + label_names = { + 'node_labels': [], + 'edge_labels': [], + 'node_attrs': [], + 'edge_attrs': [] + } + + if isinstance(filename, str): + if dir_dataset is not None: + dir_dataset = dir_dataset + else: + dir_dataset = dirname(filename) + load_one_file(filename, data, y, label_names) + + else: # filename is a list of files. + if dir_dataset is not None: + dir_dataset = dir_dataset + else: + dir_dataset = dirname(filename[0]) + + for fn in filename: + load_one_file(fn, data, y, label_names) + + return data, y, label_names + + # @todo: need to be updated (auto order) or deprecated. + def load_mat(self, filename, order): + """Load graph data from a MATLAB (up to version 7.1) .mat file. + + Notes: + ------ + A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph. + Check README in `downloadable file `__ for detailed structure. + """ + import networkx as nx + import numpy as np + from scipy.io import loadmat + data = [] + content = loadmat(filename) + for key, value in content.items(): + if key[0] == 'l': # class label + y = np.transpose(value)[0].tolist() + elif key[0] != '_': + # if adjacency matrix is not compressed / edge label exists + if order[1] == 0: + for i, item in enumerate(value[0]): + g = nx.Graph(name=i) # set name of the graph + nl = np.transpose( + item[order[3]][0][0][0]) # node label + for index, label in enumerate(nl[0]): + g.add_node(index, label_1=str(label)) + el = item[order[4]][0][0][0] # edge label + for edge in el: + g.add_edge(edge[0] - 1, edge[1] - 1, + label_1=str(edge[2])) + data.append(g) + else: + for i, item in enumerate(value[0]): + g = nx.Graph(name=i) # set name of the graph + nl = np.transpose( + item[order[3]][0][0][0]) # node label + for index, label in enumerate(nl[0]): + g.add_node(index, label_1=str(label)) + sam = item[order[0]] # sparse adjacency matrix + index_no0 = sam.nonzero() + for col, row in zip(index_no0[0], index_no0[1]): + g.add_edge(col, row) + data.append(g) + + label_names = { + 'node_labels': ['label_1'], + 'edge_labels': [], + 'node_attrs': [], + 'edge_attrs': [] + } + if order[1] == 0: + label_names['edge_labels'].append('label_1') + + return data, y, label_names + + def load_tud(self, filename): + """Load graph data from TUD dataset files. + + Notes: + ------ + The graph data is loaded from separate files. + Check README in `downloadable file `__, 2018 for detailed structure. + """ + from os import listdir + from os.path import basename, dirname + + import networkx as nx + + # @todo: add README (cuniform), maybe node/edge label maps. + def get_infos_from_readme(frm): + """Get information from DS_label_readme.txt file. + """ + def get_label_names_from_line(line): + """Get names of labels/attributes from a line. + """ + str_names = line.split('[')[1].split(']')[0] + names = str_names.split(',') + names = [attr.strip() for attr in names] + return names + + def get_class_label_map(label_map_strings): + label_map = {} + for string in label_map_strings: + integer, label = string.split('\t') + label_map[int(integer.strip())] = label.strip() + return label_map + + label_names = { + 'node_labels': [], + 'node_attrs': [], + 'edge_labels': [], + 'edge_attrs': [] + } + class_label_map = None + class_label_map_strings = [] + with open(frm) as rm: + content_rm = rm.read().splitlines() + i = 0 + while i < len(content_rm): + line = content_rm[i].strip() + # get node/edge labels and attributes. + if line.startswith('Node labels:'): + label_names['node_labels'] = get_label_names_from_line( + line) + elif line.startswith('Node attributes:'): + label_names['node_attrs'] = get_label_names_from_line(line) + elif line.startswith('Edge labels:'): + label_names['edge_labels'] = get_label_names_from_line( + line) + elif line.startswith('Edge attributes:'): + label_names['edge_attrs'] = get_label_names_from_line(line) + # get class label map. + elif line.startswith( + 'Class labels were converted to integer values using this map:' + ): + i += 2 + line = content_rm[i].strip() + while line != '' and i < len(content_rm): + class_label_map_strings.append(line) + i += 1 + line = content_rm[i].strip() + class_label_map = get_class_label_map( + class_label_map_strings) + i += 1 + + return label_names, class_label_map + + # get dataset name. + dirname_dataset = dirname(filename) + filename = basename(filename) + fn_split = filename.split('_A') + ds_name = fn_split[0].strip() + + # load data file names + for name in listdir(dirname_dataset): + if ds_name + '_A' in name: + fam = dirname_dataset + '/' + name + elif ds_name + '_graph_indicator' in name: + fgi = dirname_dataset + '/' + name + elif ds_name + '_graph_labels' in name: + fgl = dirname_dataset + '/' + name + elif ds_name + '_node_labels' in name: + fnl = dirname_dataset + '/' + name + elif ds_name + '_edge_labels' in name: + fel = dirname_dataset + '/' + name + elif ds_name + '_edge_attributes' in name: + fea = dirname_dataset + '/' + name + elif ds_name + '_node_attributes' in name: + fna = dirname_dataset + '/' + name + elif ds_name + '_graph_attributes' in name: + fga = dirname_dataset + '/' + name + elif ds_name + '_label_readme' in name: + frm = dirname_dataset + '/' + name + # this is supposed to be the node attrs, make sure to put this as the last 'elif' + elif ds_name + '_attributes' in name: + fna = dirname_dataset + '/' + name + + # get labels and attributes names. + if 'frm' in locals(): + label_names, class_label_map = get_infos_from_readme(frm) + else: + label_names = { + 'node_labels': [], + 'node_attrs': [], + 'edge_labels': [], + 'edge_attrs': [] + } + class_label_map = None + + with open(fgi) as gi: + content_gi = gi.read().splitlines() # graph indicator + with open(fam) as am: + content_am = am.read().splitlines() # adjacency matrix + + # load targets. + if 'fgl' in locals(): + with open(fgl) as gl: + content_targets = gl.read().splitlines( + ) # targets (classification) + targets = [float(i) for i in content_targets] + elif 'fga' in locals(): + with open(fga) as ga: + content_targets = ga.read().splitlines( + ) # targets (regression) + targets = [int(i) for i in content_targets] + else: + exp_msg = 'Can not find targets file. Please make sure there is a "', ds_name, '_graph_labels.txt" or "', ds_name, '_graph_attributes.txt"', 'file in your dataset folder.' + raise Exception(exp_msg) + if class_label_map is not None: + targets = [class_label_map[t] for t in targets] + + # create graphs and add nodes + data = [nx.Graph(name=str(i)) for i in range(0, len(content_targets))] + if 'fnl' in locals(): + with open(fnl) as nl: + content_nl = nl.read().splitlines() # node labels + for idx, line in enumerate(content_gi): + # transfer to int first in case of unexpected blanks + data[int(line) - 1].add_node(idx) + labels = [l.strip() for l in content_nl[idx].split(',')] + if label_names['node_labels'] == []: # @todo: need fix bug. + for i, label in enumerate(labels): + l_name = 'label_' + str(i) + data[int(line) - 1].nodes[idx][l_name] = label + label_names['node_labels'].append(l_name) + else: + for i, l_name in enumerate(label_names['node_labels']): + data[int(line) - 1].nodes[idx][l_name] = labels[i] + else: + for i, line in enumerate(content_gi): + data[int(line) - 1].add_node(i) + + # add edges + for line in content_am: + tmp = line.split(',') + n1 = int(tmp[0]) - 1 + n2 = int(tmp[1]) - 1 + # ignore edge weight here. + g = int(content_gi[n1]) - 1 + data[g].add_edge(n1, n2) + + # add edge labels + if 'fel' in locals(): + with open(fel) as el: + content_el = el.read().splitlines() + for idx, line in enumerate(content_el): + labels = [l.strip() for l in line.split(',')] + n = [int(i) - 1 for i in content_am[idx].split(',')] + g = int(content_gi[n[0]]) - 1 + if label_names['edge_labels'] == []: + for i, label in enumerate(labels): + l_name = 'label_' + str(i) + data[g].edges[n[0], n[1]][l_name] = label + label_names['edge_labels'].append(l_name) + else: + for i, l_name in enumerate(label_names['edge_labels']): + data[g].edges[n[0], n[1]][l_name] = labels[i] + + # add node attributes + if 'fna' in locals(): + with open(fna) as na: + content_na = na.read().splitlines() + for idx, line in enumerate(content_na): + attrs = [a.strip() for a in line.split(',')] + g = int(content_gi[idx]) - 1 + if label_names['node_attrs'] == []: + for i, attr in enumerate(attrs): + a_name = 'attr_' + str(i) + data[g].nodes[idx][a_name] = attr + label_names['node_attrs'].append(a_name) + else: + for i, a_name in enumerate(label_names['node_attrs']): + data[g].nodes[idx][a_name] = attrs[i] + + # add edge attributes + if 'fea' in locals(): + with open(fea) as ea: + content_ea = ea.read().splitlines() + for idx, line in enumerate(content_ea): + attrs = [a.strip() for a in line.split(',')] + n = [int(i) - 1 for i in content_am[idx].split(',')] + g = int(content_gi[n[0]]) - 1 + if label_names['edge_attrs'] == []: + for i, attr in enumerate(attrs): + a_name = 'attr_' + str(i) + data[g].edges[n[0], n[1]][a_name] = attr + label_names['edge_attrs'].append(a_name) + else: + for i, a_name in enumerate(label_names['edge_attrs']): + data[g].edges[n[0], n[1]][a_name] = attrs[i] + + return data, targets, label_names + + def load_ct( + self, filename + ): # @todo: this function is only tested on CTFile V2000; header not considered; only simple cases (atoms and bonds are considered.) + """Load data from a Chemical Table (.ct) file. + + Notes: + ------ + a typical example of data in .ct is like this: + + 3 2 <- number of nodes and edges + + 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label) + + 0.0000 0.0000 0.0000 C + + 0.0000 0.0000 0.0000 O + + 1 3 1 1 <- each line describes an edge : to, from, bond type, bond stereo + + 2 3 1 1 + + Check `CTFile Formats file `__ + for detailed format discription. + """ + from os.path import basename + + import networkx as nx + g = nx.Graph() + with open(filename) as f: + content = f.read().splitlines() + g = nx.Graph(name=str(content[0]), + filename=basename(filename)) # set name of the graph + + # read the counts line. + tmp = content[1].split(' ') + tmp = [x for x in tmp if x != ''] + nb_atoms = int(tmp[0].strip()) # number of atoms + nb_bonds = int(tmp[1].strip()) # number of bonds + count_line_tags = [ + 'number_of_atoms', 'number_of_bonds', 'number_of_atom_lists', + '', 'chiral_flag', 'number_of_stext_entries', '', '', '', '', + 'number_of_properties', 'CT_version' + ] + i = 0 + while i < len(tmp): + if count_line_tags[i] != '': # if not obsoleted + g.graph[count_line_tags[i]] = tmp[i].strip() + i += 1 + + # read the atom block. + atom_tags = [ + 'x', 'y', 'z', 'atom_symbol', 'mass_difference', 'charge', + 'atom_stereo_parity', 'hydrogen_count_plus_1', + 'stereo_care_box', 'valence', 'h0_designator', '', '', + 'atom_atom_mapping_number', 'inversion_retention_flag', + 'exact_change_flag' + ] + for i in range(0, nb_atoms): + tmp = content[i + 2].split(' ') + tmp = [x for x in tmp if x != ''] + g.add_node(i) + j = 0 + while j < len(tmp): + if atom_tags[j] != '': + g.nodes[i][atom_tags[j]] = tmp[j].strip() + j += 1 + + # read the bond block. + bond_tags = [ + 'first_atom_number', 'second_atom_number', 'bond_type', + 'bond_stereo', '', 'bond_topology', 'reacting_center_status' + ] + for i in range(0, nb_bonds): + tmp = content[i + g.number_of_nodes() + 2].split(' ') + tmp = [x for x in tmp if x != ''] + n1, n2 = int(tmp[0].strip()) - 1, int(tmp[1].strip()) - 1 + g.add_edge(n1, n2) + j = 2 + while j < len(tmp): + if bond_tags[j] != '': + g.edges[(n1, n2)][bond_tags[j]] = tmp[j].strip() + j += 1 + + # get label names. + label_names = { + 'node_labels': [], + 'edge_labels': [], + 'node_attrs': [], + 'edge_attrs': [] + } + atom_symbolic = [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, None, None, 1, 1, 1] + for nd in g.nodes(): + for key in g.nodes[nd]: + if atom_symbolic[atom_tags.index(key)] == 1: + label_names['node_labels'].append(key) + else: + label_names['node_attrs'].append(key) + break + bond_symbolic = [None, None, 1, 1, None, 1, 1] + for ed in g.edges(): + for key in g.edges[ed]: + if bond_symbolic[bond_tags.index(key)] == 1: + label_names['edge_labels'].append(key) + else: + label_names['edge_attrs'].append(key) + break + + return g, label_names + + def load_gxl(self, filename): # @todo: directed graphs. + import xml.etree.ElementTree as ET + from os.path import basename + + import networkx as nx + + tree = ET.parse(filename) + root = tree.getroot() + index = 0 + g = nx.Graph(filename=basename(filename), name=root[0].attrib['id']) + dic = {} # used to retrieve incident nodes of edges + for node in root.iter('node'): + dic[node.attrib['id']] = index + labels = {} + # for datasets "GREC" and "Monoterpens". + for attr in node.iter('attr'): + labels[attr.attrib['name']] = attr[0].text + for attr in node.iter('attribute'): # for dataset "Web". + labels[attr.attrib['name']] = attr.attrib['value'] + g.add_node(index, **labels) + index += 1 + + for edge in root.iter('edge'): + labels = {} + # for datasets "GREC" and "Monoterpens". + for attr in edge.iter('attr'): + labels[attr.attrib['name']] = attr[0].text + for attr in edge.iter('attribute'): # for dataset "Web". + labels[attr.attrib['name']] = attr.attrib['value'] + g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], + **labels) + + # get label names. + label_names = { + 'node_labels': [], + 'edge_labels': [], + 'node_attrs': [], + 'edge_attrs': [] + } + # @todo: possible loss of label names if some nodes miss some labels. + for node in root.iter('node'): + # for datasets "GREC" and "Monoterpens". + for attr in node.iter('attr'): + # @todo: this maybe wrong, and slow. "type" is for dataset GREC; "int" is for dataset "Monoterpens". + if attr[0].tag == 'int' or attr.attrib['name'] == 'type': + label_names['node_labels'].append(attr.attrib['name']) + else: + label_names['node_attrs'].append(attr.attrib['name']) + + for attr in node.iter('attribute'): # for dataset "Web". + label_names['node_attrs'].append(attr.attrib['name']) + # @todo: is id useful in dataset "Web"? is "FREQUENCY" symbolic or not? + break + + for edge in root.iter('edge'): + # for datasets "GREC" and "Monoterpens". + for attr in edge.iter('attr'): + # @todo: this maybe wrong, and slow. "frequency" and "type" are for dataset GREC; "int" is for dataset "Monoterpens". + if attr[0].tag == 'int' or attr.attrib[ + 'name'] == 'frequency' or 'type' in attr.attrib['name']: + label_names['edge_labels'].append(attr.attrib['name']) + else: + label_names['edge_attrs'].append(attr.attrib['name']) + + for attr in edge.iter('attribute'): # for dataset "Web". + label_names['edge_attrs'].append(attr.attrib['name']) + break + + return g, label_names + + def load_cml(self, filename): # @todo: directed graphs. + # @todo: what is "atomParity" and "bondStereo" in the data file? + import xml.etree.ElementTree as ET + from os.path import basename + + import networkx as nx + + # @todo: why this has to be added? + xmlns = '{http://www.xml-cml.org/schema}' + tree = ET.parse(filename) + root = tree.getroot() + index = 0 + if root.tag == xmlns + 'molecule': + g_id = root.attrib['id'] + else: + g_id = root.find(xmlns + 'molecule').attrib['id'] + g = nx.Graph(filename=basename(filename), name=g_id) + dic = {} # used to retrieve incident nodes of edges + for atom in root.iter(xmlns + 'atom'): + dic[atom.attrib['id']] = index + labels = {} + for key, val in atom.attrib.items(): + if key != 'id': + labels[key] = val + g.add_node(index, **labels) + index += 1 + + for bond in root.iter(xmlns + 'bond'): + labels = {} + for key, val in bond.attrib.items(): + # "id" is in dataset "ACE". + if key != 'atomRefs2' and key != 'id': + labels[key] = val + n1, n2 = bond.attrib['atomRefs2'].strip().split(' ') + g.add_edge(dic[n1], dic[n2], **labels) + + # get label names. + label_names = { + 'node_labels': [], + 'edge_labels': [], + 'node_attrs': [], + 'edge_attrs': [] + } + # @todo: possible loss of label names if some nodes miss some labels. + for key, val in g.nodes[0].items(): + try: + float(val) + except: + label_names['node_labels'].append(key) + else: + if val.isdigit(): + label_names['node_labels'].append(key) + else: + label_names['node_attrs'].append(key) + for _, _, attrs in g.edges(data=True): + for key, val in attrs.items(): + try: + float(val) + except: + label_names['edge_labels'].append(key) + else: + if val.isdigit(): + label_names['edge_labels'].append(key) + else: + label_names['edge_attrs'].append(key) + break + + return g, label_names + + def _append_label_names(self, label_names, new_names): + for key, val in label_names.items(): + label_names[key] += [ + name for name in new_names[key] if name not in val + ] + + @property + def data(self): + return self._graphs, self._targets, self._label_names + + @property + def graphs(self): + return self._graphs + + @property + def targets(self): + return self._targets + + @property + def label_names(self): + return self._label_names + + +class DataSaver(): + def __init__(self, graphs, targets=None, filename='gfile', gformat='gxl', + group=None, **kwargs): + """Save list of graphs. + """ + import os + dirname_ds = os.path.dirname(filename) + if dirname_ds != '': + dirname_ds += '/' + os.makedirs(dirname_ds, exist_ok=True) + + if 'graph_dir' in kwargs: + graph_dir = kwargs['graph_dir'] + '/' + os.makedirs(graph_dir, exist_ok=True) + del kwargs['graph_dir'] + else: + graph_dir = dirname_ds + + if group == 'xml' and gformat == 'gxl': + with open(filename + '.xml', 'w') as fgroup: + fgroup.write("") + fgroup.write( + "\n" + ) + fgroup.write("\n") + for idx, g in enumerate(graphs): + fname_tmp = "graph" + str(idx) + ".gxl" + self.save_gxl(g, graph_dir + fname_tmp, **kwargs) + fgroup.write("\n\t") + fgroup.write("\n") + fgroup.close() + + def save_gxl(self, graph, filename, method='default', node_labels=[], + edge_labels=[], node_attrs=[], edge_attrs=[]): + if method == 'default': + gxl_file = open(filename, 'w') + gxl_file.write("\n") + gxl_file.write( + "\n" + ) + gxl_file.write( + "\n") + if 'name' in graph.graph: + name = str(graph.graph['name']) + else: + name = 'dummy' + gxl_file.write("\n") + for v, attrs in graph.nodes(data=True): + gxl_file.write("") + for l_name in node_labels: + gxl_file.write("" + + str(attrs[l_name]) + "") + for a_name in node_attrs: + gxl_file.write("" + + str(attrs[a_name]) + "") + gxl_file.write("\n") + for v1, v2, attrs in graph.edges(data=True): + gxl_file.write("") + for l_name in edge_labels: + gxl_file.write("" + + str(attrs[l_name]) + "") + for a_name in edge_attrs: + gxl_file.write("" + + str(attrs[a_name]) + "") + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("") + gxl_file.close() + elif method == 'benoit': + import xml.etree.ElementTree as ET + root_node = ET.Element('gxl') + attr = dict() + attr['id'] = str(graph.graph['name']) + attr['edgeids'] = 'true' + attr['edgemode'] = 'undirected' + graph_node = ET.SubElement(root_node, 'graph', attrib=attr) + + for v in graph: + current_node = ET.SubElement(graph_node, 'node', + attrib={'id': str(v)}) + for attr in graph.nodes[v].keys(): + cur_attr = ET.SubElement(current_node, 'attr', + attrib={'name': attr}) + cur_value = ET.SubElement( + cur_attr, graph.nodes[v][attr].__class__.__name__) + cur_value.text = graph.nodes[v][attr] + + for v1 in graph: + for v2 in graph[v1]: + if (v1 < v2): # Non oriented graphs + cur_edge = ET.SubElement( + graph_node, 'edge', attrib={ + 'from': str(v1), + 'to': str(v2) + }) + for attr in graph[v1][v2].keys(): + cur_attr = ET.SubElement(cur_edge, 'attr', + attrib={'name': attr}) + cur_value = ET.SubElement( + cur_attr, + graph[v1][v2][attr].__class__.__name__) + cur_value.text = str(graph[v1][v2][attr]) + + tree = ET.ElementTree(root_node) + tree.write(filename) + elif method == 'gedlib': + # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22 + # pass + gxl_file = open(filename, 'w') + gxl_file.write("\n") + gxl_file.write( + "\n" + ) + gxl_file.write( + "\n") + gxl_file.write("\n") + for v, attrs in graph.nodes(data=True): + gxl_file.write("") + gxl_file.write("" + + str(attrs['chem']) + "") + gxl_file.write("\n") + for v1, v2, attrs in graph.edges(data=True): + gxl_file.write("") + gxl_file.write("" + + str(attrs['valence']) + "") + # gxl_file.write("" + "1" + "") + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("") + gxl_file.close() + elif method == 'gedlib-letter': + # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22 + # and https://github.com/dbblumenthal/gedlib/blob/master/data/datasets/Letter/HIGH/AP1_0000.gxl + gxl_file = open(filename, 'w') + gxl_file.write("\n") + gxl_file.write( + "\n" + ) + gxl_file.write( + "\n") + gxl_file.write("\n") + for v, attrs in graph.nodes(data=True): + gxl_file.write("") + gxl_file.write("" + + str(attrs['attributes'][0]) + "") + gxl_file.write("" + + str(attrs['attributes'][1]) + "") + gxl_file.write("\n") + for v1, v2, attrs in graph.edges(data=True): + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("") + gxl_file.close() diff --git a/torch_geometric/io/greyc.py b/torch_geometric/io/greyc.py new file mode 100644 index 000000000000..246311d5f367 --- /dev/null +++ b/torch_geometric/io/greyc.py @@ -0,0 +1,155 @@ +"""Module to load greyc datasets as list of networkx graphs +""" + +import os +from typing import List, Union + +from torch_geometric.io.file_managers import DataLoader + +PATH = os.path.dirname(__file__) + + +def one_hot_encode(val: Union[int, str], allowable_set: Union[List[str], + List[int]], + include_unknown_set: bool = False) -> List[float]: + """One hot encoder for elements of a provided set. + + Examples: + -------- + >>> one_hot_encode("a", ["a", "b", "c"]) + [1.0, 0.0, 0.0] + >>> one_hot_encode(2, [0, 1, 2]) + [0.0, 0.0, 1.0] + >>> one_hot_encode(3, [0, 1, 2]) + [0.0, 0.0, 0.0] + >>> one_hot_encode(3, [0, 1, 2], True) + [0.0, 0.0, 0.0, 1.0] + + Parameters + ---------- + val: int or str + The value must be present in `allowable_set`. + allowable_set: List[int] or List[str] + List of allowable quantities. + include_unknown_set: bool, default False + If true, the index of all values not in `allowable_set` is `len(allowable_set)`. + + Returns: + ------- + List[float] + An one-hot vector of val. + If `include_unknown_set` is False, the length is `len(allowable_set)`. + If `include_unknown_set` is True, the length is `len(allowable_set) + 1`. + + Raises: + ------ + ValueError + If include_unknown_set is False and `val` is not in `allowable_set`. + """ + if include_unknown_set is False: + if val not in allowable_set: + logger.info("input {} not in allowable set {}:".format( + val, allowable_set)) + + # init an one-hot vector + if include_unknown_set is False: + one_hot_legnth = len(allowable_set) + else: + one_hot_legnth = len(allowable_set) + 1 + one_hot = [0.0 for _ in range(one_hot_legnth)] + + try: + one_hot[allowable_set.index(val)] = 1.0 # type: ignore + except: + if include_unknown_set: + # If include_unknown_set is True, set the last index is 1. + one_hot[-1] = 1.0 + else: + pass + return one_hot + + +def prepare_graph( + graph, atom_list=['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'H']): + """Prepare graph to include all data before pyg conversion + Parameters + ---------- + graph : networkx graph + graph to be prepared + atom_list : list[str], optional + List of node attributes to be converted into one hot encoding + """ + # Convert attributes. + graph.graph = {} + for node in graph.nodes: + graph.nodes[node]['atom_symbol'] = one_hot_encode( + graph.nodes[node]['atom_symbol'], + atom_list, + include_unknown_set=True, + ) + graph.nodes[node]['degree'] = float(graph.degree[node]) + for attr in ['x', 'y', 'z']: + graph.nodes[node][attr] = float(graph.nodes[node][attr]) + + for edge in graph.edges: + for attr in ['bond_type', 'bond_stereo']: + graph.edges[edge][attr] = float(graph.edges[edge][attr]) + + return graph + + +def _load_greyc_networkx_graphs(dir: str, name: str): + """Load the dataset as a llist of networkx graphs and + returns list of graphs and list of properties + + Args: + dataset:str the dataset to load (Alkane,Acyclic,...) + + Returns: + list of nx graphs + list of properties (float or int) + """ + loaders = { + "alkane": _loader_alkane, + "acyclic": _loader_acyclic, + "mao": _loader_mao + } + loader_f = loaders.get(name, None) + loader = loader_f(dir) + if loader is None: + raise Exception("Dataset Not Found") + + graphs = [prepare_graph(graph) for graph in loader.graphs] + return graphs, loader.targets + + +def read_greyc(dir: str, name: str): + return _load_greyc_networkx_graphs(dir, name) + + +def _loader_alkane(dir: str): + """Load the 150 graphs of Alkane datasets + returns two lists + - 150 networkx graphs + - boiling points + """ + dloader = DataLoader( + os.path.join(dir, 'dataset.ds'), + filename_targets=os.path.join(dir, 'dataset_boiling_point_names.txt'), + dformat='ds', gformat='ct', y_separator=' ') + return dloader + + +def _loader_acyclic(dir: str): + dloader = DataLoader(os.path.join(dir, + 'dataset_bps.ds'), filename_targets=None, + dformat='ds', gformat='ct', y_separator=' ') + return dloader + + +def _loader_mao(dir: str): + dloader = DataLoader(os.path.join(dir, + 'dataset.ds'), filename_targets=None, + dformat='ds', gformat='ct', y_separator=' ') + dloader._targets = [int(yi) for yi in dloader.targets] + return dloader From 71aa5b7dd3798c325b119af637d6e7c39a46bf23 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 14:10:20 +0100 Subject: [PATCH 04/29] temp --- ...pre-commit-config.yaml => .pre-commit-config.yaml | 0 torch_geometric/io/file_managers.py | 9 +++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) rename ...pre-commit-config.yaml => .pre-commit-config.yaml (100%) diff --git a/...pre-commit-config.yaml b/.pre-commit-config.yaml similarity index 100% rename from ...pre-commit-config.yaml rename to .pre-commit-config.yaml diff --git a/torch_geometric/io/file_managers.py b/torch_geometric/io/file_managers.py index 53f08851674d..b1449fc8858f 100644 --- a/torch_geometric/io/file_managers.py +++ b/torch_geometric/io/file_managers.py @@ -5,8 +5,13 @@ class DataLoader(): - def __init__(self, filename, filename_targets=None, gformat=None, - **kwargs): + def __init__( + self, + filename, + filename_targets=None, + gformat=None, + **kwargs + ) -> None: """Read graph data from filename and load them as NetworkX graphs. Parameters From 4e0a79c16b11e2917e9e7e2633c301aadc5798cf Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 14:46:36 +0100 Subject: [PATCH 05/29] download ok --- torch_geometric/datasets/greyc.py | 32 +++++++++++-------- torch_geometric/io/greyc.py | 51 ++++++++++++++++++++----------- 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index f585ac0f78f8..9055b2c1ad39 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -1,8 +1,14 @@ +import os from typing import Callable, List, Optional import torch -from torch_geometric.data import InMemoryDataset, download_url +from torch_geometric.data import ( + Data, + InMemoryDataset, + download_url, + extract_zip, +) from torch_geometric.io import read_greyc from torch_geometric.utils import from_networkx @@ -14,8 +20,7 @@ class DatasetNotFoundError(Exception): class GreycDataset(InMemoryDataset): r"""Class to load three GREYC Datasets as pytorch geometric dataset.""" - URL = ('https://raw.githubusercontent.com/bgauzere/greycdata/refs/' - 'heads/main/greycdata/data/') + URL = ('http://localhost:3000/') def __init__( self, @@ -44,21 +49,22 @@ def raw_file_names(self) -> List[str]: def download(self) -> None: """Load the right data according to initializer.""" - if self.name == 'alkane': - download_url(GreycDataset.URL + "Aklane", self.raw_dir) - elif self.name == 'acyclic': - download_url(GreycDataset.URL + "Acyclic", self.raw_dir) - elif self.name == 'mao': - download_url(GreycDataset.URL + "MAO", self.raw_dir) - else: - raise DatasetNotFoundError(f"Dataset `{self.name}` not found") + zips = { + "alkane": "Aklane.zip", + "acyclic": "Acyclic.zip", + "mao": "MAO.zip", + } + file = zips.get(self.name, None) + if file is None: + raise Exception("Wrong dataset name") + path = download_url(GreycDataset.URL + file, self.raw_dir) + extract_zip(path, self.raw_dir) + os.unlink(path) def process(self): """Read data into huge `Data` list.""" graph_list, property_list = read_greyc(self.raw_dir, self.name) - # Convert to PyG. - def from_nx_to_pyg(graph, y): """Convert networkx graph to pytorch graph and add y.""" pyg_graph = from_networkx( diff --git a/torch_geometric/io/greyc.py b/torch_geometric/io/greyc.py index 246311d5f367..bfdaa2d828a9 100644 --- a/torch_geometric/io/greyc.py +++ b/torch_geometric/io/greyc.py @@ -3,15 +3,18 @@ import os from typing import List, Union +import warnings from torch_geometric.io.file_managers import DataLoader PATH = os.path.dirname(__file__) -def one_hot_encode(val: Union[int, str], allowable_set: Union[List[str], - List[int]], - include_unknown_set: bool = False) -> List[float]: +def one_hot_encode( + val: Union[int, str], + allowable_set: Union[List[str], List[int]], + include_unknown_set: bool = False + ) -> List[float]: """One hot encoder for elements of a provided set. Examples: @@ -32,27 +35,32 @@ def one_hot_encode(val: Union[int, str], allowable_set: Union[List[str], allowable_set: List[int] or List[str] List of allowable quantities. include_unknown_set: bool, default False - If true, the index of all values not in `allowable_set` is `len(allowable_set)`. + If true, the index of all values not + in `allowable_set` is `len(allowable_set)`. Returns: ------- List[float] An one-hot vector of val. - If `include_unknown_set` is False, the length is `len(allowable_set)`. - If `include_unknown_set` is True, the length is `len(allowable_set) + 1`. + If `include_unknown_set` is False, + the length is `len(allowable_set)`. + If `include_unknown_set` is True, + the length is `len(allowable_set) + 1`. Raises: ------ ValueError If include_unknown_set is False and `val` is not in `allowable_set`. """ - if include_unknown_set is False: + if not include_unknown_set: if val not in allowable_set: - logger.info("input {} not in allowable set {}:".format( - val, allowable_set)) + warnings.warn( + f"input {val} not in allowable set {allowable_set}.", + UserWarning + ) # init an one-hot vector - if include_unknown_set is False: + if not include_unknown_set: one_hot_legnth = len(allowable_set) else: one_hot_legnth = len(allowable_set) + 1 @@ -70,7 +78,8 @@ def one_hot_encode(val: Union[int, str], allowable_set: Union[List[str], def prepare_graph( - graph, atom_list=['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'H']): + graph, atom_list=['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'H'] + ): """Prepare graph to include all data before pyg conversion Parameters ---------- @@ -141,15 +150,23 @@ def _loader_alkane(dir: str): def _loader_acyclic(dir: str): - dloader = DataLoader(os.path.join(dir, - 'dataset_bps.ds'), filename_targets=None, - dformat='ds', gformat='ct', y_separator=' ') + dloader = DataLoader( + os.path.join(dir, 'dataset_bps.ds'), + filename_targets=None, + dformat='ds', + gformat='ct', + y_separator=' ' + ) return dloader def _loader_mao(dir: str): - dloader = DataLoader(os.path.join(dir, - 'dataset.ds'), filename_targets=None, - dformat='ds', gformat='ct', y_separator=' ') + dloader = DataLoader( + os.path.join(dir, 'dataset.ds'), + filename_targets=None, + dformat='ds', + gformat='ct', + y_separator=' ' + ) dloader._targets = [int(yi) for yi in dloader.targets] return dloader From 65dca72bf6ceb037d9ebe62216bc901ed2fd72e1 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 15:03:47 +0100 Subject: [PATCH 06/29] load --- torch_geometric/datasets/greyc.py | 40 +++++-------------------------- 1 file changed, 6 insertions(+), 34 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 9055b2c1ad39..34f2d62123e8 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -3,14 +3,7 @@ import torch -from torch_geometric.data import ( - Data, - InMemoryDataset, - download_url, - extract_zip, -) -from torch_geometric.io import read_greyc -from torch_geometric.utils import from_networkx +from torch_geometric.data import InMemoryDataset, download_url, extract_zip class DatasetNotFoundError(Exception): @@ -49,40 +42,19 @@ def raw_file_names(self) -> List[str]: def download(self) -> None: """Load the right data according to initializer.""" - zips = { - "alkane": "Aklane.zip", - "acyclic": "Acyclic.zip", - "mao": "MAO.zip", - } - file = zips.get(self.name, None) - if file is None: - raise Exception("Wrong dataset name") - path = download_url(GreycDataset.URL + file, self.raw_dir) + path = download_url(GreycDataset.URL + self.name + ".zip", self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self): """Read data into huge `Data` list.""" - graph_list, property_list = read_greyc(self.raw_dir, self.name) - - def from_nx_to_pyg(graph, y): - """Convert networkx graph to pytorch graph and add y.""" - pyg_graph = from_networkx( - graph, - group_node_attrs=['atom_symbol', 'degree', 'x', 'y', 'z']) - pyg_graph.y = y - return pyg_graph - - data_list = [ - from_nx_to_pyg(graph, y) - for graph, y in zip(graph_list, property_list) - ] + dataset = torch.load(self.raw_dir + self.name + ".pth") if self.pre_filter is not None: - data_list = [data for data in data_list if self.pre_filter(data)] + dataset = [data for data in dataset if self.pre_filter(data)] if self.pre_transform is not None: - data_list = [self.pre_transform(data) for data in data_list] + dataset = [self.pre_transform(data) for data in dataset] - data, slices = self.collate(data_list) + data, slices = self.collate(dataset) torch.save((data, slices), self.processed_paths[0]) From 73a22b77bbad58b1e3c698e330ea7b1385fff100 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 15:39:26 +0100 Subject: [PATCH 07/29] GreycDataset fonctionnel --- torch_geometric/datasets/greyc.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 34f2d62123e8..0e4b1b3a685e 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -11,7 +11,7 @@ class DatasetNotFoundError(Exception): class GreycDataset(InMemoryDataset): - r"""Class to load three GREYC Datasets as pytorch geometric dataset.""" + """Class to load three GREYC Datasets as pytorch geometric dataset.""" URL = ('http://localhost:3000/') @@ -25,20 +25,22 @@ def __init__( force_reload: bool = False, ) -> None: self.name = name.lower() - super().__init__(root, transform, pre_transform, pre_filter, - force_reload=force_reload) + super().__init__( + root, + transform, + pre_transform, + pre_filter, + force_reload=force_reload + ) self.data, self.slices = torch.load(self.processed_paths[0]) - def __str__(self) -> str: - return self.name - @property def processed_file_names(self) -> str: return 'data.pt' @property - def raw_file_names(self) -> List[str]: - return [] + def raw_file_names(self) -> str: + return f"{self.name.lower()}.pth" def download(self) -> None: """Load the right data according to initializer.""" @@ -48,13 +50,13 @@ def download(self) -> None: def process(self): """Read data into huge `Data` list.""" - dataset = torch.load(self.raw_dir + self.name + ".pth") + data_list = torch.load(os.path.join(self.raw_dir, self.name + ".pth")) if self.pre_filter is not None: - dataset = [data for data in dataset if self.pre_filter(data)] + data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: - dataset = [self.pre_transform(data) for data in dataset] + data_list = [self.pre_transform(data) for data in data_list] - data, slices = self.collate(dataset) + data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[0]) From 6fa53833d9aee91e00c972e7b36382018d0b8b9a Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 15:42:26 +0100 Subject: [PATCH 08/29] GreycDataset fonctionnel --- torch_geometric/datasets/greyc.py | 14 +- torch_geometric/io/file_managers.py | 900 ---------------------------- torch_geometric/io/greyc.py | 172 ------ 3 files changed, 5 insertions(+), 1081 deletions(-) delete mode 100644 torch_geometric/io/file_managers.py delete mode 100644 torch_geometric/io/greyc.py diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 0e4b1b3a685e..1fee339b7194 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -1,5 +1,5 @@ import os -from typing import Callable, List, Optional +from typing import Callable, Optional import torch @@ -25,13 +25,8 @@ def __init__( force_reload: bool = False, ) -> None: self.name = name.lower() - super().__init__( - root, - transform, - pre_transform, - pre_filter, - force_reload=force_reload - ) + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) self.data, self.slices = torch.load(self.processed_paths[0]) @property @@ -44,7 +39,8 @@ def raw_file_names(self) -> str: def download(self) -> None: """Load the right data according to initializer.""" - path = download_url(GreycDataset.URL + self.name + ".zip", self.raw_dir) + path = download_url(GreycDataset.URL + self.name + ".zip", + self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) diff --git a/torch_geometric/io/file_managers.py b/torch_geometric/io/file_managers.py deleted file mode 100644 index b1449fc8858f..000000000000 --- a/torch_geometric/io/file_managers.py +++ /dev/null @@ -1,900 +0,0 @@ -"""Utilities function to manage graph files -Taken from graphkit-learn -""" -from os.path import dirname, splitext - - -class DataLoader(): - def __init__( - self, - filename, - filename_targets=None, - gformat=None, - **kwargs - ) -> None: - """Read graph data from filename and load them as NetworkX graphs. - - Parameters - ---------- - filename : string - The name of the file from where the dataset is read. - filename_targets : string - The name of file of the targets corresponding to graphs. - - Notes: - ----- - This function supports following graph dataset formats: - - 'ds': load data from .ds file. See comments of function loadFromDS for a example. - - 'cxl': load data from Graph eXchange Language file (.cxl file). See - `here `__ for detail. - - 'sdf': load data from structured data file (.sdf file). See - `here `__ - for details. - - 'mat': Load graph data from a MATLAB (up to version 7.1) .mat file. See - README in `downloadable file `__ - for details. - - 'txt': Load graph data from the TUDataset. See - `here `__ - for details. Note here filename is the name of either .txt file in - the dataset directory. - """ - if isinstance(filename, str): - extension = splitext(filename)[1][1:] - else: # filename is a list of files. - extension = splitext(filename[0])[1][1:] - - if extension == "ds": - self._graphs, self._targets, self._label_names = self.load_from_ds( - filename, filename_targets) - elif extension == "cxl": - dir_dataset = kwargs.get('dirname_dataset', None) - self._graphs, self._targets, self._label_names = self.load_from_xml( - filename, dir_dataset) - elif extension == 'xml': - dir_dataset = kwargs.get('dirname_dataset', None) - self._graphs, self._targets, self._label_names = self.load_from_xml( - filename, dir_dataset) - elif extension == "mat": - order = kwargs.get('order') - self._graphs, self._targets, self._label_names = self.load_mat( - filename, order) - elif extension == 'txt': - if gformat is None: - self._graphs, self._targets, self._label_names = self.load_tud( - filename) - elif gformat == 'cml': - self._graphs, self._targets, self._label_names = self.load_from_ds( - filename, filename_targets) - - else: - raise ValueError( - 'The input file with the extension ".', extension, - '" is not supported. The supported extensions includes: ".ds", ".cxl", ".xml", ".mat", ".txt".' - ) - - def load_from_ds(self, filename, filename_targets): - """Load data from .ds file. - - Possible graph formats include: - - '.ct': see function load_ct for detail. - - '.gxl': see dunction load_gxl for detail. - - Note these graph formats are checked automatically by the extensions of - graph files. - """ - if isinstance(filename, str): - dirname_dataset = dirname(filename) - with open(filename) as f: - content = f.read().splitlines() - else: # filename is a list of files. - dirname_dataset = dirname(filename[0]) - content = [] - for fn in filename: - with open(fn) as f: - content += f.read().splitlines() - # to remove duplicate file names. - - data = [] - y = [] - label_names = { - 'node_labels': [], - 'edge_labels': [], - 'node_attrs': [], - 'edge_attrs': [] - } - # Alkane - content = [line for line in content if not line.endswith('.ds')] - # Acyclic - content = [line for line in content if not line.startswith('#')] - extension = splitext(content[0].split(' ')[0])[1][1:] - if extension == 'ct': - load_file_fun = self.load_ct - # @todo: .sdf not tested yet. - elif extension == 'gxl' or extension == 'sdf': - load_file_fun = self.load_gxl - elif extension == 'cml': # dataset "Chiral" - load_file_fun = self.load_cml - - if filename_targets is None or filename_targets == '': - for i in range(0, len(content)): - tmp = content[i].split(' ') - # remove the '#'s in file names - g, l_names = load_file_fun(dirname_dataset + '/' + - tmp[0].replace('#', '', 1)) - data.append(g) - # @todo: this is so redundant. - self._append_label_names(label_names, l_names) - y.append(float(tmp[1])) - else: # targets in a seperate file - for i in range(0, len(content)): - tmp = content[i] - # remove the '#'s in file names - g, l_names = load_file_fun(dirname_dataset + '/' + - tmp.replace('#', '', 1)) - data.append(g) - self._append_label_names(label_names, l_names) - - with open(filename_targets) as fnt: - content_y = fnt.read().splitlines() - # assume entries in filename and filename_targets have the same order. - for item in content_y: - tmp = item.split(' ') - # assume the 3rd entry in a line is y (for Alkane dataset) - y.append(float(tmp[2])) - - return data, y, label_names - - def load_from_xml(self, filename, dir_dataset=None): - import xml.etree.ElementTree as ET - - def load_one_file(filename, data, y, label_names): - tree = ET.parse(filename) - root = tree.getroot() - # "graph" for ... I forgot; "print" for datasets GREC and Web. - for graph in root.iter('graph') if root.find( - 'graph') is not None else root.iter('print'): - mol_filename = graph.attrib['file'] - mol_class = graph.attrib['class'] - g, l_names = self.load_gxl(dir_dataset + '/' + mol_filename) - data.append(g) - self._append_label_names(label_names, l_names) - y.append(mol_class) - - data = [] - y = [] - label_names = { - 'node_labels': [], - 'edge_labels': [], - 'node_attrs': [], - 'edge_attrs': [] - } - - if isinstance(filename, str): - if dir_dataset is not None: - dir_dataset = dir_dataset - else: - dir_dataset = dirname(filename) - load_one_file(filename, data, y, label_names) - - else: # filename is a list of files. - if dir_dataset is not None: - dir_dataset = dir_dataset - else: - dir_dataset = dirname(filename[0]) - - for fn in filename: - load_one_file(fn, data, y, label_names) - - return data, y, label_names - - # @todo: need to be updated (auto order) or deprecated. - def load_mat(self, filename, order): - """Load graph data from a MATLAB (up to version 7.1) .mat file. - - Notes: - ------ - A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph. - Check README in `downloadable file `__ for detailed structure. - """ - import networkx as nx - import numpy as np - from scipy.io import loadmat - data = [] - content = loadmat(filename) - for key, value in content.items(): - if key[0] == 'l': # class label - y = np.transpose(value)[0].tolist() - elif key[0] != '_': - # if adjacency matrix is not compressed / edge label exists - if order[1] == 0: - for i, item in enumerate(value[0]): - g = nx.Graph(name=i) # set name of the graph - nl = np.transpose( - item[order[3]][0][0][0]) # node label - for index, label in enumerate(nl[0]): - g.add_node(index, label_1=str(label)) - el = item[order[4]][0][0][0] # edge label - for edge in el: - g.add_edge(edge[0] - 1, edge[1] - 1, - label_1=str(edge[2])) - data.append(g) - else: - for i, item in enumerate(value[0]): - g = nx.Graph(name=i) # set name of the graph - nl = np.transpose( - item[order[3]][0][0][0]) # node label - for index, label in enumerate(nl[0]): - g.add_node(index, label_1=str(label)) - sam = item[order[0]] # sparse adjacency matrix - index_no0 = sam.nonzero() - for col, row in zip(index_no0[0], index_no0[1]): - g.add_edge(col, row) - data.append(g) - - label_names = { - 'node_labels': ['label_1'], - 'edge_labels': [], - 'node_attrs': [], - 'edge_attrs': [] - } - if order[1] == 0: - label_names['edge_labels'].append('label_1') - - return data, y, label_names - - def load_tud(self, filename): - """Load graph data from TUD dataset files. - - Notes: - ------ - The graph data is loaded from separate files. - Check README in `downloadable file `__, 2018 for detailed structure. - """ - from os import listdir - from os.path import basename, dirname - - import networkx as nx - - # @todo: add README (cuniform), maybe node/edge label maps. - def get_infos_from_readme(frm): - """Get information from DS_label_readme.txt file. - """ - def get_label_names_from_line(line): - """Get names of labels/attributes from a line. - """ - str_names = line.split('[')[1].split(']')[0] - names = str_names.split(',') - names = [attr.strip() for attr in names] - return names - - def get_class_label_map(label_map_strings): - label_map = {} - for string in label_map_strings: - integer, label = string.split('\t') - label_map[int(integer.strip())] = label.strip() - return label_map - - label_names = { - 'node_labels': [], - 'node_attrs': [], - 'edge_labels': [], - 'edge_attrs': [] - } - class_label_map = None - class_label_map_strings = [] - with open(frm) as rm: - content_rm = rm.read().splitlines() - i = 0 - while i < len(content_rm): - line = content_rm[i].strip() - # get node/edge labels and attributes. - if line.startswith('Node labels:'): - label_names['node_labels'] = get_label_names_from_line( - line) - elif line.startswith('Node attributes:'): - label_names['node_attrs'] = get_label_names_from_line(line) - elif line.startswith('Edge labels:'): - label_names['edge_labels'] = get_label_names_from_line( - line) - elif line.startswith('Edge attributes:'): - label_names['edge_attrs'] = get_label_names_from_line(line) - # get class label map. - elif line.startswith( - 'Class labels were converted to integer values using this map:' - ): - i += 2 - line = content_rm[i].strip() - while line != '' and i < len(content_rm): - class_label_map_strings.append(line) - i += 1 - line = content_rm[i].strip() - class_label_map = get_class_label_map( - class_label_map_strings) - i += 1 - - return label_names, class_label_map - - # get dataset name. - dirname_dataset = dirname(filename) - filename = basename(filename) - fn_split = filename.split('_A') - ds_name = fn_split[0].strip() - - # load data file names - for name in listdir(dirname_dataset): - if ds_name + '_A' in name: - fam = dirname_dataset + '/' + name - elif ds_name + '_graph_indicator' in name: - fgi = dirname_dataset + '/' + name - elif ds_name + '_graph_labels' in name: - fgl = dirname_dataset + '/' + name - elif ds_name + '_node_labels' in name: - fnl = dirname_dataset + '/' + name - elif ds_name + '_edge_labels' in name: - fel = dirname_dataset + '/' + name - elif ds_name + '_edge_attributes' in name: - fea = dirname_dataset + '/' + name - elif ds_name + '_node_attributes' in name: - fna = dirname_dataset + '/' + name - elif ds_name + '_graph_attributes' in name: - fga = dirname_dataset + '/' + name - elif ds_name + '_label_readme' in name: - frm = dirname_dataset + '/' + name - # this is supposed to be the node attrs, make sure to put this as the last 'elif' - elif ds_name + '_attributes' in name: - fna = dirname_dataset + '/' + name - - # get labels and attributes names. - if 'frm' in locals(): - label_names, class_label_map = get_infos_from_readme(frm) - else: - label_names = { - 'node_labels': [], - 'node_attrs': [], - 'edge_labels': [], - 'edge_attrs': [] - } - class_label_map = None - - with open(fgi) as gi: - content_gi = gi.read().splitlines() # graph indicator - with open(fam) as am: - content_am = am.read().splitlines() # adjacency matrix - - # load targets. - if 'fgl' in locals(): - with open(fgl) as gl: - content_targets = gl.read().splitlines( - ) # targets (classification) - targets = [float(i) for i in content_targets] - elif 'fga' in locals(): - with open(fga) as ga: - content_targets = ga.read().splitlines( - ) # targets (regression) - targets = [int(i) for i in content_targets] - else: - exp_msg = 'Can not find targets file. Please make sure there is a "', ds_name, '_graph_labels.txt" or "', ds_name, '_graph_attributes.txt"', 'file in your dataset folder.' - raise Exception(exp_msg) - if class_label_map is not None: - targets = [class_label_map[t] for t in targets] - - # create graphs and add nodes - data = [nx.Graph(name=str(i)) for i in range(0, len(content_targets))] - if 'fnl' in locals(): - with open(fnl) as nl: - content_nl = nl.read().splitlines() # node labels - for idx, line in enumerate(content_gi): - # transfer to int first in case of unexpected blanks - data[int(line) - 1].add_node(idx) - labels = [l.strip() for l in content_nl[idx].split(',')] - if label_names['node_labels'] == []: # @todo: need fix bug. - for i, label in enumerate(labels): - l_name = 'label_' + str(i) - data[int(line) - 1].nodes[idx][l_name] = label - label_names['node_labels'].append(l_name) - else: - for i, l_name in enumerate(label_names['node_labels']): - data[int(line) - 1].nodes[idx][l_name] = labels[i] - else: - for i, line in enumerate(content_gi): - data[int(line) - 1].add_node(i) - - # add edges - for line in content_am: - tmp = line.split(',') - n1 = int(tmp[0]) - 1 - n2 = int(tmp[1]) - 1 - # ignore edge weight here. - g = int(content_gi[n1]) - 1 - data[g].add_edge(n1, n2) - - # add edge labels - if 'fel' in locals(): - with open(fel) as el: - content_el = el.read().splitlines() - for idx, line in enumerate(content_el): - labels = [l.strip() for l in line.split(',')] - n = [int(i) - 1 for i in content_am[idx].split(',')] - g = int(content_gi[n[0]]) - 1 - if label_names['edge_labels'] == []: - for i, label in enumerate(labels): - l_name = 'label_' + str(i) - data[g].edges[n[0], n[1]][l_name] = label - label_names['edge_labels'].append(l_name) - else: - for i, l_name in enumerate(label_names['edge_labels']): - data[g].edges[n[0], n[1]][l_name] = labels[i] - - # add node attributes - if 'fna' in locals(): - with open(fna) as na: - content_na = na.read().splitlines() - for idx, line in enumerate(content_na): - attrs = [a.strip() for a in line.split(',')] - g = int(content_gi[idx]) - 1 - if label_names['node_attrs'] == []: - for i, attr in enumerate(attrs): - a_name = 'attr_' + str(i) - data[g].nodes[idx][a_name] = attr - label_names['node_attrs'].append(a_name) - else: - for i, a_name in enumerate(label_names['node_attrs']): - data[g].nodes[idx][a_name] = attrs[i] - - # add edge attributes - if 'fea' in locals(): - with open(fea) as ea: - content_ea = ea.read().splitlines() - for idx, line in enumerate(content_ea): - attrs = [a.strip() for a in line.split(',')] - n = [int(i) - 1 for i in content_am[idx].split(',')] - g = int(content_gi[n[0]]) - 1 - if label_names['edge_attrs'] == []: - for i, attr in enumerate(attrs): - a_name = 'attr_' + str(i) - data[g].edges[n[0], n[1]][a_name] = attr - label_names['edge_attrs'].append(a_name) - else: - for i, a_name in enumerate(label_names['edge_attrs']): - data[g].edges[n[0], n[1]][a_name] = attrs[i] - - return data, targets, label_names - - def load_ct( - self, filename - ): # @todo: this function is only tested on CTFile V2000; header not considered; only simple cases (atoms and bonds are considered.) - """Load data from a Chemical Table (.ct) file. - - Notes: - ------ - a typical example of data in .ct is like this: - - 3 2 <- number of nodes and edges - - 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label) - - 0.0000 0.0000 0.0000 C - - 0.0000 0.0000 0.0000 O - - 1 3 1 1 <- each line describes an edge : to, from, bond type, bond stereo - - 2 3 1 1 - - Check `CTFile Formats file `__ - for detailed format discription. - """ - from os.path import basename - - import networkx as nx - g = nx.Graph() - with open(filename) as f: - content = f.read().splitlines() - g = nx.Graph(name=str(content[0]), - filename=basename(filename)) # set name of the graph - - # read the counts line. - tmp = content[1].split(' ') - tmp = [x for x in tmp if x != ''] - nb_atoms = int(tmp[0].strip()) # number of atoms - nb_bonds = int(tmp[1].strip()) # number of bonds - count_line_tags = [ - 'number_of_atoms', 'number_of_bonds', 'number_of_atom_lists', - '', 'chiral_flag', 'number_of_stext_entries', '', '', '', '', - 'number_of_properties', 'CT_version' - ] - i = 0 - while i < len(tmp): - if count_line_tags[i] != '': # if not obsoleted - g.graph[count_line_tags[i]] = tmp[i].strip() - i += 1 - - # read the atom block. - atom_tags = [ - 'x', 'y', 'z', 'atom_symbol', 'mass_difference', 'charge', - 'atom_stereo_parity', 'hydrogen_count_plus_1', - 'stereo_care_box', 'valence', 'h0_designator', '', '', - 'atom_atom_mapping_number', 'inversion_retention_flag', - 'exact_change_flag' - ] - for i in range(0, nb_atoms): - tmp = content[i + 2].split(' ') - tmp = [x for x in tmp if x != ''] - g.add_node(i) - j = 0 - while j < len(tmp): - if atom_tags[j] != '': - g.nodes[i][atom_tags[j]] = tmp[j].strip() - j += 1 - - # read the bond block. - bond_tags = [ - 'first_atom_number', 'second_atom_number', 'bond_type', - 'bond_stereo', '', 'bond_topology', 'reacting_center_status' - ] - for i in range(0, nb_bonds): - tmp = content[i + g.number_of_nodes() + 2].split(' ') - tmp = [x for x in tmp if x != ''] - n1, n2 = int(tmp[0].strip()) - 1, int(tmp[1].strip()) - 1 - g.add_edge(n1, n2) - j = 2 - while j < len(tmp): - if bond_tags[j] != '': - g.edges[(n1, n2)][bond_tags[j]] = tmp[j].strip() - j += 1 - - # get label names. - label_names = { - 'node_labels': [], - 'edge_labels': [], - 'node_attrs': [], - 'edge_attrs': [] - } - atom_symbolic = [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, None, None, 1, 1, 1] - for nd in g.nodes(): - for key in g.nodes[nd]: - if atom_symbolic[atom_tags.index(key)] == 1: - label_names['node_labels'].append(key) - else: - label_names['node_attrs'].append(key) - break - bond_symbolic = [None, None, 1, 1, None, 1, 1] - for ed in g.edges(): - for key in g.edges[ed]: - if bond_symbolic[bond_tags.index(key)] == 1: - label_names['edge_labels'].append(key) - else: - label_names['edge_attrs'].append(key) - break - - return g, label_names - - def load_gxl(self, filename): # @todo: directed graphs. - import xml.etree.ElementTree as ET - from os.path import basename - - import networkx as nx - - tree = ET.parse(filename) - root = tree.getroot() - index = 0 - g = nx.Graph(filename=basename(filename), name=root[0].attrib['id']) - dic = {} # used to retrieve incident nodes of edges - for node in root.iter('node'): - dic[node.attrib['id']] = index - labels = {} - # for datasets "GREC" and "Monoterpens". - for attr in node.iter('attr'): - labels[attr.attrib['name']] = attr[0].text - for attr in node.iter('attribute'): # for dataset "Web". - labels[attr.attrib['name']] = attr.attrib['value'] - g.add_node(index, **labels) - index += 1 - - for edge in root.iter('edge'): - labels = {} - # for datasets "GREC" and "Monoterpens". - for attr in edge.iter('attr'): - labels[attr.attrib['name']] = attr[0].text - for attr in edge.iter('attribute'): # for dataset "Web". - labels[attr.attrib['name']] = attr.attrib['value'] - g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], - **labels) - - # get label names. - label_names = { - 'node_labels': [], - 'edge_labels': [], - 'node_attrs': [], - 'edge_attrs': [] - } - # @todo: possible loss of label names if some nodes miss some labels. - for node in root.iter('node'): - # for datasets "GREC" and "Monoterpens". - for attr in node.iter('attr'): - # @todo: this maybe wrong, and slow. "type" is for dataset GREC; "int" is for dataset "Monoterpens". - if attr[0].tag == 'int' or attr.attrib['name'] == 'type': - label_names['node_labels'].append(attr.attrib['name']) - else: - label_names['node_attrs'].append(attr.attrib['name']) - - for attr in node.iter('attribute'): # for dataset "Web". - label_names['node_attrs'].append(attr.attrib['name']) - # @todo: is id useful in dataset "Web"? is "FREQUENCY" symbolic or not? - break - - for edge in root.iter('edge'): - # for datasets "GREC" and "Monoterpens". - for attr in edge.iter('attr'): - # @todo: this maybe wrong, and slow. "frequency" and "type" are for dataset GREC; "int" is for dataset "Monoterpens". - if attr[0].tag == 'int' or attr.attrib[ - 'name'] == 'frequency' or 'type' in attr.attrib['name']: - label_names['edge_labels'].append(attr.attrib['name']) - else: - label_names['edge_attrs'].append(attr.attrib['name']) - - for attr in edge.iter('attribute'): # for dataset "Web". - label_names['edge_attrs'].append(attr.attrib['name']) - break - - return g, label_names - - def load_cml(self, filename): # @todo: directed graphs. - # @todo: what is "atomParity" and "bondStereo" in the data file? - import xml.etree.ElementTree as ET - from os.path import basename - - import networkx as nx - - # @todo: why this has to be added? - xmlns = '{http://www.xml-cml.org/schema}' - tree = ET.parse(filename) - root = tree.getroot() - index = 0 - if root.tag == xmlns + 'molecule': - g_id = root.attrib['id'] - else: - g_id = root.find(xmlns + 'molecule').attrib['id'] - g = nx.Graph(filename=basename(filename), name=g_id) - dic = {} # used to retrieve incident nodes of edges - for atom in root.iter(xmlns + 'atom'): - dic[atom.attrib['id']] = index - labels = {} - for key, val in atom.attrib.items(): - if key != 'id': - labels[key] = val - g.add_node(index, **labels) - index += 1 - - for bond in root.iter(xmlns + 'bond'): - labels = {} - for key, val in bond.attrib.items(): - # "id" is in dataset "ACE". - if key != 'atomRefs2' and key != 'id': - labels[key] = val - n1, n2 = bond.attrib['atomRefs2'].strip().split(' ') - g.add_edge(dic[n1], dic[n2], **labels) - - # get label names. - label_names = { - 'node_labels': [], - 'edge_labels': [], - 'node_attrs': [], - 'edge_attrs': [] - } - # @todo: possible loss of label names if some nodes miss some labels. - for key, val in g.nodes[0].items(): - try: - float(val) - except: - label_names['node_labels'].append(key) - else: - if val.isdigit(): - label_names['node_labels'].append(key) - else: - label_names['node_attrs'].append(key) - for _, _, attrs in g.edges(data=True): - for key, val in attrs.items(): - try: - float(val) - except: - label_names['edge_labels'].append(key) - else: - if val.isdigit(): - label_names['edge_labels'].append(key) - else: - label_names['edge_attrs'].append(key) - break - - return g, label_names - - def _append_label_names(self, label_names, new_names): - for key, val in label_names.items(): - label_names[key] += [ - name for name in new_names[key] if name not in val - ] - - @property - def data(self): - return self._graphs, self._targets, self._label_names - - @property - def graphs(self): - return self._graphs - - @property - def targets(self): - return self._targets - - @property - def label_names(self): - return self._label_names - - -class DataSaver(): - def __init__(self, graphs, targets=None, filename='gfile', gformat='gxl', - group=None, **kwargs): - """Save list of graphs. - """ - import os - dirname_ds = os.path.dirname(filename) - if dirname_ds != '': - dirname_ds += '/' - os.makedirs(dirname_ds, exist_ok=True) - - if 'graph_dir' in kwargs: - graph_dir = kwargs['graph_dir'] + '/' - os.makedirs(graph_dir, exist_ok=True) - del kwargs['graph_dir'] - else: - graph_dir = dirname_ds - - if group == 'xml' and gformat == 'gxl': - with open(filename + '.xml', 'w') as fgroup: - fgroup.write("") - fgroup.write( - "\n" - ) - fgroup.write("\n") - for idx, g in enumerate(graphs): - fname_tmp = "graph" + str(idx) + ".gxl" - self.save_gxl(g, graph_dir + fname_tmp, **kwargs) - fgroup.write("\n\t") - fgroup.write("\n") - fgroup.close() - - def save_gxl(self, graph, filename, method='default', node_labels=[], - edge_labels=[], node_attrs=[], edge_attrs=[]): - if method == 'default': - gxl_file = open(filename, 'w') - gxl_file.write("\n") - gxl_file.write( - "\n" - ) - gxl_file.write( - "\n") - if 'name' in graph.graph: - name = str(graph.graph['name']) - else: - name = 'dummy' - gxl_file.write("\n") - for v, attrs in graph.nodes(data=True): - gxl_file.write("") - for l_name in node_labels: - gxl_file.write("" + - str(attrs[l_name]) + "") - for a_name in node_attrs: - gxl_file.write("" + - str(attrs[a_name]) + "") - gxl_file.write("\n") - for v1, v2, attrs in graph.edges(data=True): - gxl_file.write("") - for l_name in edge_labels: - gxl_file.write("" + - str(attrs[l_name]) + "") - for a_name in edge_attrs: - gxl_file.write("" + - str(attrs[a_name]) + "") - gxl_file.write("\n") - gxl_file.write("\n") - gxl_file.write("") - gxl_file.close() - elif method == 'benoit': - import xml.etree.ElementTree as ET - root_node = ET.Element('gxl') - attr = dict() - attr['id'] = str(graph.graph['name']) - attr['edgeids'] = 'true' - attr['edgemode'] = 'undirected' - graph_node = ET.SubElement(root_node, 'graph', attrib=attr) - - for v in graph: - current_node = ET.SubElement(graph_node, 'node', - attrib={'id': str(v)}) - for attr in graph.nodes[v].keys(): - cur_attr = ET.SubElement(current_node, 'attr', - attrib={'name': attr}) - cur_value = ET.SubElement( - cur_attr, graph.nodes[v][attr].__class__.__name__) - cur_value.text = graph.nodes[v][attr] - - for v1 in graph: - for v2 in graph[v1]: - if (v1 < v2): # Non oriented graphs - cur_edge = ET.SubElement( - graph_node, 'edge', attrib={ - 'from': str(v1), - 'to': str(v2) - }) - for attr in graph[v1][v2].keys(): - cur_attr = ET.SubElement(cur_edge, 'attr', - attrib={'name': attr}) - cur_value = ET.SubElement( - cur_attr, - graph[v1][v2][attr].__class__.__name__) - cur_value.text = str(graph[v1][v2][attr]) - - tree = ET.ElementTree(root_node) - tree.write(filename) - elif method == 'gedlib': - # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22 - # pass - gxl_file = open(filename, 'w') - gxl_file.write("\n") - gxl_file.write( - "\n" - ) - gxl_file.write( - "\n") - gxl_file.write("\n") - for v, attrs in graph.nodes(data=True): - gxl_file.write("") - gxl_file.write("" + - str(attrs['chem']) + "") - gxl_file.write("\n") - for v1, v2, attrs in graph.edges(data=True): - gxl_file.write("") - gxl_file.write("" + - str(attrs['valence']) + "") - # gxl_file.write("" + "1" + "") - gxl_file.write("\n") - gxl_file.write("\n") - gxl_file.write("") - gxl_file.close() - elif method == 'gedlib-letter': - # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22 - # and https://github.com/dbblumenthal/gedlib/blob/master/data/datasets/Letter/HIGH/AP1_0000.gxl - gxl_file = open(filename, 'w') - gxl_file.write("\n") - gxl_file.write( - "\n" - ) - gxl_file.write( - "\n") - gxl_file.write("\n") - for v, attrs in graph.nodes(data=True): - gxl_file.write("") - gxl_file.write("" + - str(attrs['attributes'][0]) + "") - gxl_file.write("" + - str(attrs['attributes'][1]) + "") - gxl_file.write("\n") - for v1, v2, attrs in graph.edges(data=True): - gxl_file.write("\n") - gxl_file.write("\n") - gxl_file.write("") - gxl_file.close() diff --git a/torch_geometric/io/greyc.py b/torch_geometric/io/greyc.py deleted file mode 100644 index bfdaa2d828a9..000000000000 --- a/torch_geometric/io/greyc.py +++ /dev/null @@ -1,172 +0,0 @@ -"""Module to load greyc datasets as list of networkx graphs -""" - -import os -from typing import List, Union -import warnings - -from torch_geometric.io.file_managers import DataLoader - -PATH = os.path.dirname(__file__) - - -def one_hot_encode( - val: Union[int, str], - allowable_set: Union[List[str], List[int]], - include_unknown_set: bool = False - ) -> List[float]: - """One hot encoder for elements of a provided set. - - Examples: - -------- - >>> one_hot_encode("a", ["a", "b", "c"]) - [1.0, 0.0, 0.0] - >>> one_hot_encode(2, [0, 1, 2]) - [0.0, 0.0, 1.0] - >>> one_hot_encode(3, [0, 1, 2]) - [0.0, 0.0, 0.0] - >>> one_hot_encode(3, [0, 1, 2], True) - [0.0, 0.0, 0.0, 1.0] - - Parameters - ---------- - val: int or str - The value must be present in `allowable_set`. - allowable_set: List[int] or List[str] - List of allowable quantities. - include_unknown_set: bool, default False - If true, the index of all values not - in `allowable_set` is `len(allowable_set)`. - - Returns: - ------- - List[float] - An one-hot vector of val. - If `include_unknown_set` is False, - the length is `len(allowable_set)`. - If `include_unknown_set` is True, - the length is `len(allowable_set) + 1`. - - Raises: - ------ - ValueError - If include_unknown_set is False and `val` is not in `allowable_set`. - """ - if not include_unknown_set: - if val not in allowable_set: - warnings.warn( - f"input {val} not in allowable set {allowable_set}.", - UserWarning - ) - - # init an one-hot vector - if not include_unknown_set: - one_hot_legnth = len(allowable_set) - else: - one_hot_legnth = len(allowable_set) + 1 - one_hot = [0.0 for _ in range(one_hot_legnth)] - - try: - one_hot[allowable_set.index(val)] = 1.0 # type: ignore - except: - if include_unknown_set: - # If include_unknown_set is True, set the last index is 1. - one_hot[-1] = 1.0 - else: - pass - return one_hot - - -def prepare_graph( - graph, atom_list=['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'H'] - ): - """Prepare graph to include all data before pyg conversion - Parameters - ---------- - graph : networkx graph - graph to be prepared - atom_list : list[str], optional - List of node attributes to be converted into one hot encoding - """ - # Convert attributes. - graph.graph = {} - for node in graph.nodes: - graph.nodes[node]['atom_symbol'] = one_hot_encode( - graph.nodes[node]['atom_symbol'], - atom_list, - include_unknown_set=True, - ) - graph.nodes[node]['degree'] = float(graph.degree[node]) - for attr in ['x', 'y', 'z']: - graph.nodes[node][attr] = float(graph.nodes[node][attr]) - - for edge in graph.edges: - for attr in ['bond_type', 'bond_stereo']: - graph.edges[edge][attr] = float(graph.edges[edge][attr]) - - return graph - - -def _load_greyc_networkx_graphs(dir: str, name: str): - """Load the dataset as a llist of networkx graphs and - returns list of graphs and list of properties - - Args: - dataset:str the dataset to load (Alkane,Acyclic,...) - - Returns: - list of nx graphs - list of properties (float or int) - """ - loaders = { - "alkane": _loader_alkane, - "acyclic": _loader_acyclic, - "mao": _loader_mao - } - loader_f = loaders.get(name, None) - loader = loader_f(dir) - if loader is None: - raise Exception("Dataset Not Found") - - graphs = [prepare_graph(graph) for graph in loader.graphs] - return graphs, loader.targets - - -def read_greyc(dir: str, name: str): - return _load_greyc_networkx_graphs(dir, name) - - -def _loader_alkane(dir: str): - """Load the 150 graphs of Alkane datasets - returns two lists - - 150 networkx graphs - - boiling points - """ - dloader = DataLoader( - os.path.join(dir, 'dataset.ds'), - filename_targets=os.path.join(dir, 'dataset_boiling_point_names.txt'), - dformat='ds', gformat='ct', y_separator=' ') - return dloader - - -def _loader_acyclic(dir: str): - dloader = DataLoader( - os.path.join(dir, 'dataset_bps.ds'), - filename_targets=None, - dformat='ds', - gformat='ct', - y_separator=' ' - ) - return dloader - - -def _loader_mao(dir: str): - dloader = DataLoader( - os.path.join(dir, 'dataset.ds'), - filename_targets=None, - dformat='ds', - gformat='ct', - y_separator=' ' - ) - dloader._targets = [int(yi) for yi in dloader.targets] - return dloader From e1a19cac16f5b3c0850befa12338ab99345a8790 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 15:45:27 +0100 Subject: [PATCH 09/29] Rollback io.__init__.py --- torch_geometric/io/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_geometric/io/__init__.py b/torch_geometric/io/__init__.py index eae652ef5814..2b43b6cfc5c6 100644 --- a/torch_geometric/io/__init__.py +++ b/torch_geometric/io/__init__.py @@ -6,7 +6,6 @@ from .sdf import read_sdf, parse_sdf from .off import read_off, write_off from .npz import read_npz, parse_npz -from .greyc import read_greyc __all__ = [ 'read_off', @@ -21,5 +20,4 @@ 'parse_sdf', 'read_npz', 'parse_npz', - 'read_greyc', ] From d6c1ac32cfa72e6c67f9a96e7f9a7ba8fff6ac9f Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 15:55:38 +0100 Subject: [PATCH 10/29] Formatting --- torch_geometric/datasets/greyc.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 1fee339b7194..1d364c1acedf 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -56,3 +56,9 @@ def process(self): data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[0]) + + def __repr__(self) -> str: + name = self.name.capitalize() + if self.name == "mao": + name = self.name.upper() + return f'{name}({len(self)})' From f024e81cc544fa16e67b9874bd0afca628d3dbf6 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 16:09:24 +0100 Subject: [PATCH 11/29] Premiere doc (copie TUDataset) --- torch_geometric/datasets/greyc.py | 72 ++++++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 1d364c1acedf..0229673e86f4 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -6,12 +6,70 @@ from torch_geometric.data import InMemoryDataset, download_url, extract_zip -class DatasetNotFoundError(Exception): - pass - - class GreycDataset(InMemoryDataset): - """Class to load three GREYC Datasets as pytorch geometric dataset.""" + r"""A variety of graph kernel benchmark datasets, *.e.g.*, + :obj:`"IMDB-BINARY"`, :obj:`"REDDIT-BINARY"` or :obj:`"PROTEINS"`, + collected from the `TU Dortmund University + `_. + In addition, this dataset wrapper provides `cleaned dataset versions + `_ as motivated by the + `"Understanding Isomorphism Bias in Graph Data Sets" + `_ paper, containing only non-isomorphic + graphs. + + .. note:: + Some datasets may not come with any node labels. + You can then either make use of the argument :obj:`use_node_attr` + to load additional continuous node attributes (if present) or provide + synthetic node features using transforms such as + :class:`torch_geometric.transforms.Constant` or + :class:`torch_geometric.transforms.OneHotDegree`. + + Args: + root (str): Root directory where the dataset should be saved. + name (str): The `name + `_ of the + dataset. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + + **STATS:** + + .. list-table:: + :widths: 20 10 10 10 10 10 + :header-rows: 1 + + * - Name + - #graphs + - #nodes + - #edges + - #features + - #classes + * - MUTAG + - 188 + - ~17.9 + - ~39.6 + - 7 + - 2 + * - ENZYMES + - 600 + - ~32.6 + - ~124.3 + - 3 + - 6 + """ URL = ('http://localhost:3000/') @@ -25,6 +83,8 @@ def __init__( force_reload: bool = False, ) -> None: self.name = name.lower() + if self.name not in {"acyclic", "alkane", "mao"}: + raise ValueError(f"Dataset {self.name} not found.") super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.data, self.slices = torch.load(self.processed_paths[0]) @@ -38,14 +98,12 @@ def raw_file_names(self) -> str: return f"{self.name.lower()}.pth" def download(self) -> None: - """Load the right data according to initializer.""" path = download_url(GreycDataset.URL + self.name + ".zip", self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self): - """Read data into huge `Data` list.""" data_list = torch.load(os.path.join(self.raw_dir, self.name + ".pth")) if self.pre_filter is not None: From 3dfca8402a8ca1ca152936fa17955fbd1ec72e3c Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Mon, 20 Jan 2025 16:38:19 +0100 Subject: [PATCH 12/29] =?UTF-8?q?Debut=20de=20doc=20apr=C3=A8s=20formattag?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_geometric/datasets/greyc.py | 45 ++++++++++++++++--------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 0229673e86f4..71bba9e17664 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -7,15 +7,10 @@ class GreycDataset(InMemoryDataset): - r"""A variety of graph kernel benchmark datasets, *.e.g.*, - :obj:`"IMDB-BINARY"`, :obj:`"REDDIT-BINARY"` or :obj:`"PROTEINS"`, - collected from the `TU Dortmund University - `_. - In addition, this dataset wrapper provides `cleaned dataset versions - `_ as motivated by the - `"Understanding Isomorphism Bias in Graph Data Sets" - `_ paper, containing only non-isomorphic - graphs. + r"""Implementation of three GREYC chemistry small datasets as pytorch + geometric datasets : Alkane, Acyclic and MAO. + See `"CHEMISTRY" `_ + for details. .. note:: Some datasets may not come with any node labels. @@ -28,7 +23,7 @@ class GreycDataset(InMemoryDataset): Args: root (str): Root directory where the dataset should be saved. name (str): The `name - `_ of the + `_ of the dataset. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed @@ -57,18 +52,24 @@ class GreycDataset(InMemoryDataset): - #edges - #features - #classes - * - MUTAG - - 188 - - ~17.9 - - ~39.6 - - 7 - - 2 - * - ENZYMES - - 600 - - ~32.6 - - ~124.3 - - 3 - - 6 + * - Acyclic + - 183 + - / + - / + - / + - / + * - Alkane + - 150 + - / + - / + - / + - / + * - MAO + - 68 + - / + - / + - / + - / """ URL = ('http://localhost:3000/') From d7527760d552c972ccade77f7d63bc5dfdf03d75 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Tue, 21 Jan 2025 10:17:45 +0100 Subject: [PATCH 13/29] =?UTF-8?q?Mise=20=C3=A0=20jour=20de=20la=20doc?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_geometric/datasets/greyc.py | 34 +++++++++++++++---------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 71bba9e17664..62dbbe4a2fc0 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -1,9 +1,8 @@ import os from typing import Callable, Optional -import torch - from torch_geometric.data import InMemoryDataset, download_url, extract_zip +from torch_geometric.io import fs class GreycDataset(InMemoryDataset): @@ -54,22 +53,22 @@ class GreycDataset(InMemoryDataset): - #classes * - Acyclic - 183 - - / - - / - - / - - / + - ~8.2 + - ~14.3 + - 15 + - 148 * - Alkane - 150 - - / - - / - - / - - / + - ~8.9 + - ~15.8 + - 15 + - 123 * - MAO - 68 - - / - - / - - / - - / + - ~18.4 + - ~39.3 + - 15 + - 2 """ URL = ('http://localhost:3000/') @@ -88,7 +87,7 @@ def __init__( raise ValueError(f"Dataset {self.name} not found.") super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) - self.data, self.slices = torch.load(self.processed_paths[0]) + self.data, self.slices = fs.torch_load(self.processed_paths[0]) @property def processed_file_names(self) -> str: @@ -105,7 +104,8 @@ def download(self) -> None: os.unlink(path) def process(self): - data_list = torch.load(os.path.join(self.raw_dir, self.name + ".pth")) + data_list = fs.torch_load( + os.path.join(self.raw_dir, self.name + ".pth")) if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] @@ -114,7 +114,7 @@ def process(self): data_list = [self.pre_transform(data) for data in data_list] data, slices = self.collate(data_list) - torch.save((data, slices), self.processed_paths[0]) + fs.torch_save((data, slices), self.processed_paths[0]) def __repr__(self) -> str: name = self.name.capitalize() From 1abeb27dd23ff51fb7e39cf7d0560631ba60d3ab Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Tue, 21 Jan 2025 11:09:30 +0100 Subject: [PATCH 14/29] =?UTF-8?q?Mise=20=C3=A0=20jour=20doc?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_geometric/datasets/greyc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 62dbbe4a2fc0..b280416bab57 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -7,8 +7,8 @@ class GreycDataset(InMemoryDataset): r"""Implementation of three GREYC chemistry small datasets as pytorch - geometric datasets : Alkane, Acyclic and MAO. - See `"CHEMISTRY" `_ + geometric datasets : Alkane, Acyclic and MAO. See + `"GREYC's Chemistry dataset" `_ for details. .. note:: @@ -22,7 +22,7 @@ class GreycDataset(InMemoryDataset): Args: root (str): Root directory where the dataset should be saved. name (str): The `name - `_ of the + `_ of the dataset. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed From 00a27dfb91c5a89ce4ff28ca11669255d7d44da6 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Wed, 22 Jan 2025 10:41:35 +0100 Subject: [PATCH 15/29] Update of docstring --- torch_geometric/datasets/greyc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index b280416bab57..4375de8ce60e 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -6,8 +6,8 @@ class GreycDataset(InMemoryDataset): - r"""Implementation of three GREYC chemistry small datasets as pytorch - geometric datasets : Alkane, Acyclic and MAO. See + r"""Implementation of five GREYC chemistry small datasets as pytorch + geometric datasets : Alkane, Acyclic, MAO, Monoterpens and PAH. See `"GREYC's Chemistry dataset" `_ for details. From 0cd9144ed58b32c4b11ebae08e6e9599cb6eeaab Mon Sep 17 00:00:00 2001 From: Lyam Chardey Date: Thu, 23 Jan 2025 11:08:53 +0100 Subject: [PATCH 16/29] Updated & --- torch_geometric/datasets/greyc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 62dbbe4a2fc0..e8af267b5e94 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -91,11 +91,11 @@ def __init__( @property def processed_file_names(self) -> str: - return 'data.pt' + return os.path.join(self.root, "data.pt") @property def raw_file_names(self) -> str: - return f"{self.name.lower()}.pth" + return os.path.join(self.root, self.name, f"{self.name}.gml") def download(self) -> None: path = download_url(GreycDataset.URL + self.name + ".zip", From 341e3d1e384fcfbf77d7e848ff35899e30fb9b0d Mon Sep 17 00:00:00 2001 From: Lyam Chardey Date: Thu, 23 Jan 2025 11:24:31 +0100 Subject: [PATCH 17/29] Pre-commit --- torch_geometric/datasets/greyc.py | 85 +++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 3 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index e8af267b5e94..1aa1104f7fe5 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -1,7 +1,14 @@ import os -from typing import Callable, Optional +from typing import Callable, List, Optional -from torch_geometric.data import InMemoryDataset, download_url, extract_zip +import torch + +from torch_geometric.data import ( + Data, + InMemoryDataset, + download_url, + extract_zip, +) from torch_geometric.io import fs @@ -95,7 +102,9 @@ def processed_file_names(self) -> str: @property def raw_file_names(self) -> str: - return os.path.join(self.root, self.name, f"{self.name}.gml") + self.gml_datafile = os.path.join(self.root, self.name, + f"{self.name}.gml") + return self.gml_datafile def download(self) -> None: path = download_url(GreycDataset.URL + self.name + ".zip", @@ -116,6 +125,76 @@ def process(self): data, slices = self.collate(data_list) fs.torch_save((data, slices), self.processed_paths[0]) + def _gml_to_data(gml: str, gml_file: bool = True) -> Data: + """Reads a `gml` file and creates a `Data` object. + + Parameters : + ------------ + + * `gml` : gml file path if `gml_file` is + set to `True`, gml content otherwise. + * `gml_file` : indicates whether `gml` is a path + to a gml file or a gml file content. + + Returns : + --------- + + * `Data` : the `Data` object created from the file content. + + Raises : + -------- + + * `FileNotFoundError` : if `gml` is a file path and doesn't exist. + """ + import networkx as nx + if gml_file: + if not os.path.exists(gml): + raise FileNotFoundError(f"File `{gml}` does not exist") + g: nx.Graph = nx.read_gml(gml) + else: + g: nx.Graph = nx.parse_gml(gml) + + x, edge_index, edge_attr = [], [], [] + + y = torch.tensor([g.graph["y"]], + dtype=torch.long) if "y" in g.graph else None + + for _, attr in g.nodes(data=True): + x.append(attr["x"]) + + for u, v, attr in g.edges(data=True): + edge_index.append([int(u), int(v)]) + edge_index.append([int(v), int(u)]) + edge_attr.append(attr["edge_attr"]) + edge_attr.append(attr["edge_attr"]) + + x = torch.tensor(x, dtype=torch.float) + edge_index = torch.tensor(edge_index, + dtype=torch.long).t().contiguous() + edge_attr = torch.tensor(edge_attr, dtype=torch.long) + + return Data(x=x, edge_attr=edge_attr, edge_index=edge_index, y=y) + + def _load_gml_data(self, gml: str) -> List[Data]: + """Reads a dataset from a gml file + and converts it into a list of `Data`. + + Parameters : + ------------ + + * `gml` : Source filename. If `zip` file, it will be extracted. + + Returns : + --------- + + * `List[Data]` : Content of the gml dataset. + """ + GML_SEPARATOR = "---" + with open(gml, encoding="utf8") as f: + gml_contents = f.read() + gml_files = gml_contents.split(GML_SEPARATOR) + return [self._gml_to_data(content, False) for content in gml_files] + def __repr__(self) -> str: name = self.name.capitalize() if self.name == "mao": From f60be55d74fa803b2a421a1cb6897aff3569d2d5 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Thu, 23 Jan 2025 15:16:36 +0100 Subject: [PATCH 18/29] Addition of PAH and Monoterpens input support --- torch_geometric/datasets/greyc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 4375de8ce60e..84f631ff90fb 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -83,7 +83,7 @@ def __init__( force_reload: bool = False, ) -> None: self.name = name.lower() - if self.name not in {"acyclic", "alkane", "mao"}: + if self.name not in {"acyclic", "alkane", "mao", "monoterpens", "pah"}: raise ValueError(f"Dataset {self.name} not found.") super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) @@ -118,6 +118,6 @@ def process(self): def __repr__(self) -> str: name = self.name.capitalize() - if self.name == "mao": + if self.name in ["mao", "pah"]: name = self.name.upper() return f'{name}({len(self)})' From f3844192528a978ad22974116e921ac72b1a8ee1 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Thu, 23 Jan 2025 16:49:46 +0100 Subject: [PATCH 19/29] Loading datasets from compressed gml files --- torch_geometric/datasets/greyc.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 0703cc90bd46..c455511a1b67 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -78,7 +78,8 @@ class GreycDataset(InMemoryDataset): - 2 """ - URL = ('http://localhost:3000/') + URL = ('https://raw.githubusercontent.com/thomasbauer76/' + 'greycdata/refs/heads/main/greycdata/data_gml/') def __init__( self, @@ -98,13 +99,11 @@ def __init__( @property def processed_file_names(self) -> str: - return os.path.join(self.root, "data.pt") + return "data.pt" @property def raw_file_names(self) -> str: - self.gml_datafile = os.path.join(self.root, self.name, - f"{self.name}.gml") - return self.gml_datafile + return f"{self.name}.gml" def download(self) -> None: path = download_url(GreycDataset.URL + self.name + ".zip", @@ -113,8 +112,7 @@ def download(self) -> None: os.unlink(path) def process(self): - data_list = fs.torch_load( - os.path.join(self.raw_dir, self.name + ".pth")) + data_list = self._load_gml_data(self.raw_paths[0]) if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] @@ -125,6 +123,7 @@ def process(self): data, slices = self.collate(data_list) fs.torch_save((data, slices), self.processed_paths[0]) + @staticmethod def _gml_to_data(gml: str, gml_file: bool = True) -> Data: """Reads a `gml` file and creates a `Data` object. @@ -175,7 +174,8 @@ def _gml_to_data(gml: str, gml_file: bool = True) -> Data: return Data(x=x, edge_attr=edge_attr, edge_index=edge_index, y=y) - def _load_gml_data(self, gml: str) -> List[Data]: + @staticmethod + def _load_gml_data(gml: str) -> List[Data]: """Reads a dataset from a gml file and converts it into a list of `Data`. @@ -193,7 +193,9 @@ def _load_gml_data(self, gml: str) -> List[Data]: with open(gml, encoding="utf8") as f: gml_contents = f.read() gml_files = gml_contents.split(GML_SEPARATOR) - return [self._gml_to_data(content, False) for content in gml_files] + return [ + GreycDataset._gml_to_data(content, False) for content in gml_files + ] def __repr__(self) -> str: name = self.name.capitalize() From 50d73511404f6de5b0f6f0b57a0254212f026603 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Fri, 24 Jan 2025 09:34:20 +0100 Subject: [PATCH 20/29] Link update and restructuring --- torch_geometric/datasets/greyc.py | 64 +++++++++++++++---------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index c455511a1b67..7f14f974c465 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -78,7 +78,7 @@ class GreycDataset(InMemoryDataset): - 2 """ - URL = ('https://raw.githubusercontent.com/thomasbauer76/' + URL = ('https://raw.githubusercontent.com/bgauzere/' 'greycdata/refs/heads/main/greycdata/data_gml/') def __init__( @@ -95,33 +95,13 @@ def __init__( raise ValueError(f"Dataset {self.name} not found.") super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) - self.data, self.slices = fs.torch_load(self.processed_paths[0]) + self.data, self.slices = torch.load(self.processed_paths[0]) - @property - def processed_file_names(self) -> str: - return "data.pt" - - @property - def raw_file_names(self) -> str: - return f"{self.name}.gml" - - def download(self) -> None: - path = download_url(GreycDataset.URL + self.name + ".zip", - self.raw_dir) - extract_zip(path, self.raw_dir) - os.unlink(path) - - def process(self): - data_list = self._load_gml_data(self.raw_paths[0]) - - if self.pre_filter is not None: - data_list = [data for data in data_list if self.pre_filter(data)] - - if self.pre_transform is not None: - data_list = [self.pre_transform(data) for data in data_list] - - data, slices = self.collate(data_list) - fs.torch_save((data, slices), self.processed_paths[0]) + def __repr__(self) -> str: + name = self.name.capitalize() + if self.name in ["mao", "pah"]: + name = self.name.upper() + return f'{name}({len(self)})' @staticmethod def _gml_to_data(gml: str, gml_file: bool = True) -> Data: @@ -197,8 +177,28 @@ def _load_gml_data(gml: str) -> List[Data]: GreycDataset._gml_to_data(content, False) for content in gml_files ] - def __repr__(self) -> str: - name = self.name.capitalize() - if self.name in ["mao", "pah"]: - name = self.name.upper() - return f'{name}({len(self)})' + @property + def processed_file_names(self) -> str: + return "data.pt" + + @property + def raw_file_names(self) -> str: + return f"{self.name}.gml" + + def download(self) -> None: + path = download_url(GreycDataset.URL + self.name + ".zip", + self.raw_dir) + extract_zip(path, self.raw_dir) + os.unlink(path) + + def process(self): + data_list = self._load_gml_data(self.raw_paths[0]) + + if self.pre_filter is not None: + data_list = [data for data in data_list if self.pre_filter(data)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + data, slices = self.collate(data_list) + fs.torch_save((data, slices), self.processed_paths[0]) From 2350c73a4c459100c9132a2577354ac45200f4e6 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Fri, 24 Jan 2025 09:45:17 +0100 Subject: [PATCH 21/29] Fix data loading warning --- torch_geometric/datasets/greyc.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 7f14f974c465..2dec562220dd 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -9,7 +9,6 @@ download_url, extract_zip, ) -from torch_geometric.io import fs class GreycDataset(InMemoryDataset): @@ -95,7 +94,7 @@ def __init__( raise ValueError(f"Dataset {self.name} not found.") super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) - self.data, self.slices = torch.load(self.processed_paths[0]) + self.load(self.processed_paths[0]) def __repr__(self) -> str: name = self.name.capitalize() @@ -200,5 +199,4 @@ def process(self): if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] - data, slices = self.collate(data_list) - fs.torch_save((data, slices), self.processed_paths[0]) + self.save(data_list, self.processed_paths[0]) From 0c1e94abc5a605b8fe562160b71267ee3a1926fc Mon Sep 17 00:00:00 2001 From: Lyam Chardey Date: Fri, 24 Jan 2025 10:24:04 +0100 Subject: [PATCH 22/29] Bugfix : y tensor dtype set according to y's class --- torch_geometric/datasets/greyc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 7f14f974c465..cc1a1b7eb9c9 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -135,8 +135,9 @@ def _gml_to_data(gml: str, gml_file: bool = True) -> Data: x, edge_index, edge_attr = [], [], [] - y = torch.tensor([g.graph["y"]], - dtype=torch.long) if "y" in g.graph else None + y = g.graph["y"] if "y" in g.graph else None + dtype = torch.float if isinstance(y, float) else torch.long + y = torch.tensor([y], dtype=dtype) if y is not None else None for _, attr in g.nodes(data=True): x.append(attr["x"]) From 35121ae59a7ab41277378a42ba148ad32b34affd Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Fri, 24 Jan 2025 10:51:33 +0100 Subject: [PATCH 23/29] Override of num_* methods --- torch_geometric/datasets/greyc.py | 53 ++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 5b7b26cec167..01ca2dc55d13 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -13,18 +13,10 @@ class GreycDataset(InMemoryDataset): r"""Implementation of five GREYC chemistry small datasets as pytorch - geometric datasets : Alkane, Acyclic, MAO, Monoterpens and PAH. See + geometric datasets : Acyclic, Alkane, MAO, Monoterpens and PAH. See `"GREYC's Chemistry dataset" `_ for details. - .. note:: - Some datasets may not come with any node labels. - You can then either make use of the argument :obj:`use_node_attr` - to load additional continuous node attributes (if present) or provide - synthetic node features using transforms such as - :class:`torch_geometric.transforms.Constant` or - :class:`torch_geometric.transforms.OneHotDegree`. - Args: root (str): Root directory where the dataset should be saved. name (str): The `name @@ -48,7 +40,7 @@ class GreycDataset(InMemoryDataset): **STATS:** .. list-table:: - :widths: 20 10 10 10 10 10 + :widths: 20 10 10 10 10 10 10 :header-rows: 1 * - Name @@ -57,24 +49,42 @@ class GreycDataset(InMemoryDataset): - #edges - #features - #classes + - #type * - Acyclic - 183 - ~8.2 - ~14.3 - - 15 + - 7 - 148 + - Regression * - Alkane - - 150 + - 149 - ~8.9 - - ~15.8 - - 15 + - ~15.7 + - 4 - 123 + - Regression * - MAO - 68 - ~18.4 - ~39.3 - - 15 + - 7 + - 2 + - Classification + * - Monoterpens + - 302 + - ~11.0 + - ~22.2 + - 7 + - 10 + - Classification + * - PAH + - 94 + - ~20.7 + - ~48.9 + - 4 - 2 + - Classification """ URL = ('https://raw.githubusercontent.com/bgauzere/' @@ -185,6 +195,19 @@ def processed_file_names(self) -> str: def raw_file_names(self) -> str: return f"{self.name}.gml" + @property + def num_node_features(self) -> int: + print("test") + return int(self.x.shape[1]) + + @property + def num_edge_features(self) -> int: + return int(self.edge_attr.shape[1]) + + @property + def num_classes(self) -> int: + return len(self.y.unique()) + def download(self) -> None: path = download_url(GreycDataset.URL + self.name + ".zip", self.raw_dir) From 9a83dd47d3bda185719718222e7e13c69bb0e7a3 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Fri, 24 Jan 2025 10:52:08 +0100 Subject: [PATCH 24/29] Deletion of debugging print --- torch_geometric/datasets/greyc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index 01ca2dc55d13..b0222712c109 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -197,7 +197,6 @@ def raw_file_names(self) -> str: @property def num_node_features(self) -> int: - print("test") return int(self.x.shape[1]) @property From eaa33951c81ed9defdb40f539a73f8a2be4b635c Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Fri, 24 Jan 2025 11:13:24 +0100 Subject: [PATCH 25/29] Update of greyc documentation --- torch_geometric/datasets/greyc.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index b0222712c109..afd9e4048dc0 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -12,16 +12,16 @@ class GreycDataset(InMemoryDataset): - r"""Implementation of five GREYC chemistry small datasets as pytorch - geometric datasets : Acyclic, Alkane, MAO, Monoterpens and PAH. See - `"GREYC's Chemistry dataset" `_ - for details. + r"""Implementation of five GREYC chemistry datasets : :obj:`Acyclic`, + :obj:`Alkane`, :obj:`MAO`, :obj:`Monoterpens` and :obj:`PAH`. See + `"GREYC's Chemistry dataset" `_ + for details. Args: root (str): Root directory where the dataset should be saved. - name (str): The `name - `_ of the - dataset. + name (str): The name (:obj:`Acyclic`, + :obj:`Alkane`, :obj:`MAO`, :obj:`Monoterpens` or :obj:`PAH`) + of the dataset. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. @@ -49,7 +49,7 @@ class GreycDataset(InMemoryDataset): - #edges - #features - #classes - - #type + - Type * - Acyclic - 183 - ~8.2 @@ -213,7 +213,7 @@ def download(self) -> None: extract_zip(path, self.raw_dir) os.unlink(path) - def process(self): + def process(self) -> None: data_list = self._load_gml_data(self.raw_paths[0]) if self.pre_filter is not None: From f683e145816cfe8375e2ca75755a479b56f54c02 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Fri, 24 Jan 2025 11:22:44 +0100 Subject: [PATCH 26/29] Update of GreycDataset documentation --- torch_geometric/datasets/greyc.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index afd9e4048dc0..b2136cdca758 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -12,10 +12,9 @@ class GreycDataset(InMemoryDataset): - r"""Implementation of five GREYC chemistry datasets : :obj:`Acyclic`, - :obj:`Alkane`, :obj:`MAO`, :obj:`Monoterpens` and :obj:`PAH`. See - `"GREYC's Chemistry dataset" `_ - for details. + r"""Implementation of five `GREYC's Chemistry datasets + `_ : :obj:`Acyclic`, + :obj:`Alkane`, :obj:`MAO`, :obj:`Monoterpens` and :obj:`PAH`. Args: root (str): Root directory where the dataset should be saved. From c6441d724afa0e33bd0911de0089e9b2286222cc Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Fri, 24 Jan 2025 11:35:19 +0100 Subject: [PATCH 27/29] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73a5a3d2a6fd..11fafb8e2907 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added five molecular datasets implemented in `GreycDataset` - Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947)) - Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945)) - Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941)) From 75f99362c54e9b144701f3073429106cf29439a4 Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Fri, 24 Jan 2025 11:57:48 +0100 Subject: [PATCH 28/29] Fix mypy assigment errors --- torch_geometric/datasets/greyc.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torch_geometric/datasets/greyc.py b/torch_geometric/datasets/greyc.py index b2136cdca758..e5529b39589c 100644 --- a/torch_geometric/datasets/greyc.py +++ b/torch_geometric/datasets/greyc.py @@ -137,29 +137,29 @@ def _gml_to_data(gml: str, gml_file: bool = True) -> Data: if gml_file: if not os.path.exists(gml): raise FileNotFoundError(f"File `{gml}` does not exist") - g: nx.Graph = nx.read_gml(gml) + g = nx.read_gml(gml) else: - g: nx.Graph = nx.parse_gml(gml) + g = nx.parse_gml(gml) - x, edge_index, edge_attr = [], [], [] + x_l, edge_index_l, edge_attr_l = [], [], [] y = g.graph["y"] if "y" in g.graph else None dtype = torch.float if isinstance(y, float) else torch.long y = torch.tensor([y], dtype=dtype) if y is not None else None for _, attr in g.nodes(data=True): - x.append(attr["x"]) + x_l.append(attr["x"]) for u, v, attr in g.edges(data=True): - edge_index.append([int(u), int(v)]) - edge_index.append([int(v), int(u)]) - edge_attr.append(attr["edge_attr"]) - edge_attr.append(attr["edge_attr"]) + edge_index_l.append([int(u), int(v)]) + edge_index_l.append([int(v), int(u)]) + edge_attr_l.append(attr["edge_attr"]) + edge_attr_l.append(attr["edge_attr"]) - x = torch.tensor(x, dtype=torch.float) - edge_index = torch.tensor(edge_index, + x = torch.tensor(x_l, dtype=torch.float) + edge_index = torch.tensor(edge_index_l, dtype=torch.long).t().contiguous() - edge_attr = torch.tensor(edge_attr, dtype=torch.long) + edge_attr = torch.tensor(edge_attr_l, dtype=torch.long) return Data(x=x, edge_attr=edge_attr, edge_index=edge_index, y=y) From 2b7c620971a459c4e70e9d1d49318169f3075d2b Mon Sep 17 00:00:00 2001 From: Thomas Bauer Date: Fri, 24 Jan 2025 12:04:55 +0100 Subject: [PATCH 29/29] Update CHANGELOG.md with pull-request link --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 11fafb8e2907..a627e42bce9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added five molecular datasets implemented in `GreycDataset` +- Added five molecular datasets implemented in `GreycDataset` ([#9977](https://github.com/pyg-team/pytorch_geometric/pull/9977)) - Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947)) - Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945)) - Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))