Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GNN link #36

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added examples/baseline/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions examples/baseline/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from .nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder


class Model(torch.nn.Module):
def __init__(
self,
data: HeteroData,
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
num_layers: int,
channels: int,
out_channels: int,
num_src_nodes: int,
num_dst_nodes: int,
src_entity_table: str,
dst_entity_table: str,
aggr: str,
norm: str,
# List of node types to add shallow embeddings to input
shallow_list: List[NodeType] = [],
# ID awareness
id_awareness: bool = False,
):
super().__init__()

self.src_entity_table = src_entity_table
self.dst_entity_table = dst_entity_table
self.encoder = HeteroEncoder(
channels=channels,
src_entity_table=src_entity_table,
dst_entity_table=dst_entity_table,
num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes,
node_to_col_names_dict={
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
},
node_to_col_stats=col_stats_dict,
)
self.temporal_encoder = HeteroTemporalEncoder(
node_types=[
node_type for node_type in data.node_types
if "time" in data[node_type]
],
channels=channels,
)
self.gnn = HeteroGraphSAGE(
node_types=data.node_types,
edge_types=data.edge_types,
channels=channels,
aggr=aggr,
num_layers=num_layers,
)
self.head = MLP(
channels,
out_channels=out_channels,
norm=norm,
num_layers=1,
)
self.embedding_dict = ModuleDict({
node:
Embedding(data.num_nodes_dict[node], channels)
for node in shallow_list
})

self.id_awareness_emb = None
if id_awareness:
self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.reset_parameters()

def reset_parameters(self):
self.encoder.reset_parameters()
self.temporal_encoder.reset_parameters()
self.gnn.reset_parameters()
self.head.reset_parameters()
for embedding in self.embedding_dict.values():
torch.nn.init.normal_(embedding.weight, std=0.1)
if self.id_awareness_emb is not None:
self.id_awareness_emb.reset_parameters()

def forward(
self,
batch: HeteroData,
entity_table: NodeType,
) -> Tensor:
seed_time = batch[entity_table].seed_time
#######################################################################
x_dict = self.encoder(batch.tf_dict, batch)
#######################################################################

rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

for node_type, embedding in self.embedding_dict.items():
###################################################################
if node_type not in {self.src_entity_table, self.dst_entity_table}:
x_dict[node_type] = x_dict[node_type] + embedding(
batch[node_type].n_id)
###################################################################

x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
batch.num_sampled_nodes_dict,
batch.num_sampled_edges_dict,
)

return self.head(x_dict[entity_table][:seed_time.size(0)])

def forward_dst_readout(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
) -> Tensor:
if self.id_awareness_emb is None:
raise RuntimeError(
"id_awareness must be set True to use forward_dst_readout")
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)
# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight

rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

for node_type, embedding in self.embedding_dict.items():
x_dict[node_type] = x_dict[node_type] + embedding(
batch[node_type].n_id)

x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
)

return self.head(x_dict[dst_table])
228 changes: 228 additions & 0 deletions examples/baseline/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from typing import Any, Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_frame
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_frame.nn.models import ResNet
from torch_geometric.data import HeteroData
from torch_geometric.nn import (
HeteroConv,
LayerNorm,
PositionalEncoding,
SAGEConv,
)
from torch_geometric.typing import EdgeType, NodeType


class HeteroEncoder(torch.nn.Module):
r"""HeteroEncoder based on PyTorch Frame.

Args:
channels (int): The output channels for each node type.
src_entity_table (str): Source entity table name.
dst_entity_table (str): Destination entity table name.
num_src_nodes (int): Number of source nodes.
num_dst_nodes (int): Number of destination nodes.
node_to_col_names_dict
(Dict[NodeType, Dict[torch_frame.stype, List[str]]]):
A dictionary mapping from node type to column names dictionary
compatible to PyTorch Frame.
torch_frame_model_cls: Model class for PyTorch Frame. The class object
takes :class:`TensorFrame` object as input and outputs
:obj:`channels`-dimensional embeddings. Default to
:class:`torch_frame.nn.ResNet`.
torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for
:class:`torch_frame_model_cls` class. Default keyword argument is
set specific for :class:`torch_frame.nn.ResNet`. Expect it to
be changed for different :class:`torch_frame_model_cls`.
default_stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]):
A dictionary mapping from :obj:`torch_frame.stype` object into a
tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its
keyword arguments :obj:`kwargs`.
"""
def __init__(
self,
channels: int,
#######################################################################
src_entity_table: str,
dst_entity_table: str,
num_src_nodes: int,
num_dst_nodes: int,
#######################################################################
node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype,
List[str]]],
node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]],
torch_frame_model_cls=ResNet,
torch_frame_model_kwargs: Dict[str, Any] = {
"channels": 128,
"num_layers": 4,
},
default_stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any] = {
torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}),
torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}),
torch_frame.multicategorical: (
torch_frame.nn.MultiCategoricalEmbeddingEncoder,
{},
),
torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}),
torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}),
},
):
super().__init__()

#######################################################################
self.src_entity_table = src_entity_table
self.dst_entity_table = dst_entity_table
#######################################################################

self.encoders = torch.nn.ModuleDict()

for node_type in node_to_col_names_dict.keys():
###################################################################
if node_type == self.src_entity_table:
self.encoders[node_type] = nn.Embedding(
num_src_nodes, channels)
elif node_type == self.dst_entity_table:
self.encoders[node_type] = nn.Embedding(
num_dst_nodes, channels)
###################################################################
else:
stype_encoder_dict = {
stype:
default_stype_encoder_cls_kwargs[stype][0](
**default_stype_encoder_cls_kwargs[stype][1])
for stype in node_to_col_names_dict[node_type].keys()
}
torch_frame_model = torch_frame_model_cls(
**torch_frame_model_kwargs,
out_channels=channels,
col_stats=node_to_col_stats[node_type],
col_names_dict=node_to_col_names_dict[node_type],
stype_encoder_dict=stype_encoder_dict,
)
self.encoders[node_type] = torch_frame_model

def reset_parameters(self):
for node_type, encoder in self.encoders.items():
###################################################################
if node_type in {self.src_entity_table, self.dst_entity_table}:
nn.init.xavier_uniform_(encoder.weight)
###################################################################
else:
encoder.reset_parameters()

def forward(
self,
tf_dict: Dict[NodeType, torch_frame.TensorFrame],
batch: HeteroData,
) -> Dict[NodeType, Tensor]:
x_dict = {}
for node_type, tf in tf_dict.items():
if node_type not in {self.src_entity_table, self.dst_entity_table}:
x_dict[node_type] = self.encoders[node_type](tf)
else:
###############################################################
x_dict[node_type] = self.encoders[node_type](
batch[node_type].n_id)
###############################################################
return x_dict


class HeteroTemporalEncoder(torch.nn.Module):
def __init__(self, node_types: List[NodeType], channels: int):
super().__init__()

self.encoder_dict = torch.nn.ModuleDict({
node_type:
PositionalEncoding(channels)
for node_type in node_types
})
self.lin_dict = torch.nn.ModuleDict({
node_type:
torch.nn.Linear(channels, channels)
for node_type in node_types
})

def reset_parameters(self):
for encoder in self.encoder_dict.values():
encoder.reset_parameters()
for lin in self.lin_dict.values():
lin.reset_parameters()

def forward(
self,
seed_time: Tensor,
time_dict: Dict[NodeType, Tensor],
batch_dict: Dict[NodeType, Tensor],
) -> Dict[NodeType, Tensor]:
out_dict: Dict[NodeType, Tensor] = {}

for node_type, time in time_dict.items():
rel_time = seed_time[batch_dict[node_type]] - time
rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days.

x = self.encoder_dict[node_type](rel_time)
x = self.lin_dict[node_type](x)
out_dict[node_type] = x

return out_dict


class HeteroGraphSAGE(torch.nn.Module):
def __init__(
self,
node_types: List[NodeType],
edge_types: List[EdgeType],
channels: int,
aggr: str = "mean",
num_layers: int = 2,
):
super().__init__()

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv(
{
edge_type: SAGEConv(
(channels, channels), channels, aggr=aggr)
for edge_type in edge_types
},
aggr="sum",
)
self.convs.append(conv)

self.norms = torch.nn.ModuleList()
for _ in range(num_layers):
norm_dict = torch.nn.ModuleDict()
for node_type in node_types:
norm_dict[node_type] = LayerNorm(channels, mode="node")
self.norms.append(norm_dict)

def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for norm_dict in self.norms:
for norm in norm_dict.values():
norm.reset_parameters()

def forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[NodeType, Tensor],
num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
) -> Dict[NodeType, Tensor]:
for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
###################################################################
x_dict = {
key: F.leaky_relu(x, negative_slope=0.2)
for key, x in x_dict.items()
}
###################################################################

return x_dict
Loading
Loading