Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --include_label_tables argument to optionally include (time-censored) labels as features in the db #272

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
56 changes: 52 additions & 4 deletions examples/gnn_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict

import numpy as np
import pandas as pd
import torch
from model import Model
from text_embedder import GloveTextEmbedding
Expand All @@ -17,11 +18,11 @@
from torch_geometric.seed import seed_everything
from tqdm import tqdm

from relbench.base import Dataset, EntityTask, TaskType
from relbench.base import Dataset, EntityTask, Table, TaskType
from relbench.datasets import get_dataset
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from relbench.modeling.utils import get_stype_proposal
from relbench.tasks import get_task
from relbench.tasks import get_task, get_task_names

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="rel-event")
Expand All @@ -37,6 +38,14 @@
parser.add_argument("--max_steps_per_epoch", type=int, default=2000)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--include_label_tables",
type=str,
default="none",
help="Optionally include labels as autoregressive features with \
appropriate time censoring. One of 'all', \
'task_only', or 'none'.",
)
parser.add_argument(
"--cache_dir",
type=str,
Expand Down Expand Up @@ -67,13 +76,52 @@
with open(stypes_cache_path, "w") as f:
json.dump(col_to_stype_dict, f, indent=2, default=str)

if args.include_label_tables == "all":
tasks_to_add = get_task_names(args.dataset)
elif args.include_label_tables == "task_only":
tasks_to_add = [args.task]
else:
tasks_to_add = []

db = dataset.get_db()
# add (time-censored) labels tables to the db
for task_name in tasks_to_add:
t = get_task(args.dataset, task_name)
if not isinstance(t, EntityTask):
continue
labels_table_name = f"{task_name}_labels"
label_df = pd.concat(
[
t.get_table("train").df,
t.get_table("val").df,
# test set not included b/c labels are not revealed
]
)
# time-censoring labels: we add timedelta to the time column to ensure that
# the labels become available at the appropriate time (i.e. no leakage)
label_df[t.time_col] = label_df[t.time_col] + t.timedelta
db.table_dict[labels_table_name] = Table(
df=label_df,
fkey_col_to_pkey_table={t.entity_col: t.entity_table},
pkey_col=None,
time_col=t.time_col,
)
col_to_stype_dict[labels_table_name] = {
t.entity_col: stype.numerical,
t.time_col: stype.timestamp,
t.target_col: stype.numerical,
}

cache_name = (
args.include_label_tables if args.include_label_tables != "task_only" else args.task
)
data, col_stats_dict = make_pkey_fkey_graph(
dataset.get_db(),
db,
col_to_stype_dict=col_to_stype_dict,
text_embedder_cfg=TextEmbedderConfig(
text_embedder=GloveTextEmbedding(device=device), batch_size=256
),
cache_dir=f"{args.cache_dir}/{args.dataset}/materialized",
cache_dir=f"{args.cache_dir}/{args.dataset}_{cache_name}/materialized",
)

clamp_min, clamp_max = None, None
Expand Down
2 changes: 1 addition & 1 deletion examples/lightgbm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
),
)
path = Path(
f"{args.cache_dir}/{args.dataset}/tasks/{args.task}/materialized/node_train.pt"
f"{args.cache_dir}/{args.dataset}/tasks/{args.task}/materialized/node_train_{args.use_ar_label=}.pt"
)
path.parent.mkdir(parents=True, exist_ok=True)
train_dataset = train_dataset.materialize(path=path)
Expand Down
Loading