Skip to content

Commit

Permalink
Fix training and baseline code (#64)
Browse files Browse the repository at this point in the history
Adapt `train.py` to the updated relbench package.

This is work in progress.

---------

Co-authored-by: kexinhuang12345 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 12, 2023
1 parent 0433616 commit 8702a2d
Show file tree
Hide file tree
Showing 15 changed files with 387 additions and 711 deletions.
244 changes: 84 additions & 160 deletions examples/baseline.py
Original file line number Diff line number Diff line change
@@ -1,174 +1,98 @@
import argparse
from typing import Dict

import numpy as np
import pandas as pd
import torch
from rtb.data import Table
from rtb.data.task import TaskType
from rtb.datasets import get_dataset
from torch import Tensor
from torchmetrics import AUROC, AveragePrecision, MeanAbsoluteError

from relbench.data import RelBenchDataset, Table
from relbench.data.task import TaskType
from relbench.datasets import get_dataset

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="relbench-forum")
parser.add_argument("--task", type=str, default="UserContributionTask")
parser.add_argument("--dataset", type=str, default="rel-stackex")
parser.add_argument("--task", type=str, default="rel-stackex-engage")
# Classification task: rel-stackex-engage
# Regression task: rel-stackex-votes
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = get_dataset(name=args.dataset, root="./data")
if args.task not in dataset.tasks:
raise ValueError(
f"'{args.dataset}' does not support the given task {args.task}. "
f"Please choose the task from {list(dataset.tasks.keys())}."
)

task = dataset.tasks[args.task]
train_table = dataset.make_train_table(args.task)
val_table = dataset.make_val_table(args.task)
test_table = dataset.make_test_table(args.task)

if task.task_type == TaskType.BINARY_CLASSIFICATION:
metrics = {
"AUROC": AUROC(task="binary").to(device),
"AP": AveragePrecision(task="binary").to(device),
}

elif task.task_type == TaskType.REGRESSION:
metrics = {
"MAE": MeanAbsoluteError(squared=False).to(device),
}


def get_metrics(pred: Tensor, target: Tensor) -> Dict[str, float]:
out_dict: Dict[str, float] = {}
for metric_name, metric in metrics.items():
metric.reset()
metric.update(pred, target)
out_dict[metric_name] = float(metric.compute())
return out_dict


def global_zero(train_table: Table, pred_table: Table) -> Dict[str, float]:
target = pred_table.df[task.target_col].astype(float).values
target = torch.from_numpy(target)

pred = torch.zeros_like(target)

return get_metrics(pred, target)


def global_mean(train_table: Table, pred_table: Table) -> float:
target = pred_table.df[task.target_col].astype(float).values
target = torch.from_numpy(target)

pred = train_table.df[task.target_col].astype(float).values
pred = torch.from_numpy(pred)
pred = pred.mean().expand(target.size(0))

return get_metrics(pred, target)


def global_median(train_table: Table, pred_table: Table) -> float:
target = pred_table.df[task.target_col].astype(float).values
target = torch.from_numpy(target)

pred = train_table.df[task.target_col].astype(float).values
pred = torch.from_numpy(pred)
pred = pred.median().expand(target.size(0))

return get_metrics(pred, target)


def entity_mean(train_table: Table, pred_table: Table) -> float:
fkey = list(train_table.fkey_col_to_pkey_table.keys())[0]
df = train_table.df.groupby(fkey).agg({task.target_col: "mean"})
df = pred_table.df.merge(df, how="left", on=fkey)

target = df[f"{task.target_col}_x"].astype(float).values
target = torch.from_numpy(target)

pred = df[f"{task.target_col}_y"].fillna(0).astype(float).values
pred = torch.from_numpy(pred)

return get_metrics(pred, target)


def entity_median(train_table: Table, pred_table: Table) -> float:
fkey = list(train_table.fkey_col_to_pkey_table.keys())[0]
df = train_table.df.groupby(fkey).agg({task.target_col: "median"})
df = pred_table.df.merge(df, how="left", on=fkey)

target = df[f"{task.target_col}_x"].astype(float).values
target = torch.from_numpy(target)

pred = df[f"{task.target_col}_y"].fillna(0).astype(float).values
pred = torch.from_numpy(pred)

return get_metrics(pred, target)


def random(train_table: Table, pred_table: Table) -> float:
target = pred_table.df[task.target_col].astype(int).values
target = torch.from_numpy(target)

pred = torch.rand(target.size())

return get_metrics(pred, target)


def majority(train_table: Table, pred_table: Table) -> float:
target = pred_table.df[task.target_col].astype(int).values
target = torch.from_numpy(target)

past_target = train_table.df[task.target_col].astype(int).values
past_target = torch.from_numpy(past_target)

majority_label = float(past_target.bincount().argmax())
pred = torch.full((target.numel(),), fill_value=majority_label)

return get_metrics(pred, target)

# TODO: remove process=True once correct data is uploaded.
dataset: RelBenchDataset = get_dataset(name=args.dataset, process=True)
task = dataset.get_task(args.task)

train_table = task.train_table
val_table = task.val_table
test_table = task.test_table


def evaluate(train_table: Table, pred_table: Table, name: str) -> Dict[str, float]:
is_test = task.target_col not in pred_table.df
if name == "global_zero":
pred = np.zeros(len(pred_table))
elif name == "global_mean":
mean = train_table.df[task.target_col].astype(float).values.mean()
pred = np.ones(len(pred_table)) * mean
elif name == "global_median":
median = np.median(train_table.df[task.target_col].astype(float).values)
pred = np.ones(len(pred_table)) * median
elif name == "entity_mean":
fkey = list(train_table.fkey_col_to_pkey_table.keys())[0]
df = train_table.df.groupby(fkey).agg({task.target_col: "mean"})
df.rename(columns={task.target_col: "__target__"}, inplace=True)
df = pred_table.df.merge(df, how="left", on=fkey)
pred = df["__target__"].fillna(0).astype(float).values
elif name == "entity_median":
fkey = list(train_table.fkey_col_to_pkey_table.keys())[0]
df = train_table.df.groupby(fkey).agg({task.target_col: "median"})
df.rename(columns={task.target_col: "__target__"}, inplace=True)
df = pred_table.df.merge(df, how="left", on=fkey)
pred = df["__target__"].fillna(0).astype(float).values
elif name == "random":
pred = np.random.rand(len(pred_table))
elif name == "majority":
past_target = train_table.df[task.target_col].astype(int)
majority_label = int(past_target.mode())
pred = torch.full((len(pred_table),), fill_value=majority_label)
else:
raise ValueError("Unknown eval name called {name}.")
return task.evaluate(pred, None if is_test else pred_table)


trainval_table_df = pd.concat([train_table.df, val_table.df], axis=0)
trainval_table = Table(
df=trainval_table_df,
fkey_col_to_pkey_table=train_table.fkey_col_to_pkey_table,
pkey_col=train_table.pkey_col,
time_col=train_table.time_col,
)

if task.task_type == TaskType.REGRESSION:
train_metrics = global_zero(train_table, train_table)
val_metrics = global_zero(train_table, val_table)
print("Global Zero:")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")

train_metrics = global_mean(train_table, train_table)
val_metrics = global_mean(train_table, val_table)
print("Global Mean:")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")

train_metrics = global_median(train_table, train_table)
val_metrics = global_median(train_table, val_table)
print("Global Median:")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")

train_metrics = entity_mean(train_table, train_table)
val_metrics = entity_mean(train_table, val_table)
print("Entity Mean:")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")

train_metrics = entity_median(train_table, train_table)
val_metrics = entity_median(train_table, val_table)
print("Entity Median:")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")
eval_name_list = [
"global_zero",
"global_mean",
"global_median",
"entity_mean",
"entity_median",
]

for name in eval_name_list:
train_metrics = evaluate(train_table, train_table, name=name)
val_metrics = evaluate(train_table, val_table, name=name)
test_metrics = evaluate(trainval_table, test_table, name=name)
print(f"{name}:")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")
print(f"Test: {test_metrics}")

elif task.task_type == TaskType.BINARY_CLASSIFICATION:
train_metrics = random(train_table, train_table)
val_metrics = random(train_table, val_table)
print("Random")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")

train_metrics = majority(train_table, train_table)
val_metrics = majority(train_table, val_table)
print("Majority:")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")
eval_name_list = ["random", "majority"]
for name in eval_name_list:
train_metrics = evaluate(train_table, train_table, name=name)
val_metrics = evaluate(train_table, val_table, name=name)
test_metrics = evaluate(trainval_table, test_table, name=name)
print(f"{name}:")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")
print(f"Test: {test_metrics}")
Loading

0 comments on commit 8702a2d

Please sign in to comment.