Skip to content

Commit 32d0db1

Browse files
authored
Merge branch 'pyg-team:master' into master
2 parents 720b65d + 1e208e2 commit 32d0db1

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

docs/source/tutorial/graph_transformer.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Graph Transformer
33

44
`Transformer <https://arxiv.org/abs/1706.03762>`_ is an effictive architecture in `natural language processing <https://arxiv.org/abs/1810.04805>`_ and `computer vision <https://arxiv.org/abs/2010.11929>`_.
55
Recently, there have been some applications(`Grover <https://arxiv.org/abs/2007.02835>`_, `GraphGPS <https://arxiv.org/abs/2205.12454>`_, etc) that combine transformers on graphs.
6-
In this tutorial, we will present how to build a graph transformer model via :pyg:`PyG`.
6+
In this tutorial, we will present how to build a graph transformer model via :pyg:`PyG`. See `<our webinar https://youtu.be/wAYryx3GjLw?si=2vB7imfenP5tUvqd>` for in-depth learning on this topic.
77

88
.. note::
99
Click `here <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_gps.py>`_ to download the full example code

torch_geometric/metrics/link_pred.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ def __init__(self, k: int) -> None:
221221
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
222222
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
223223
else:
224-
self.register_buffer('accum', torch.tensor(0.))
225-
self.register_buffer('total', torch.tensor(0))
224+
self.register_buffer('accum', torch.tensor(0.), persistent=False)
225+
self.register_buffer('total', torch.tensor(0), persistent=False)
226226

227227
def update(
228228
self,
@@ -523,10 +523,11 @@ def __init__(self, k: int, weighted: bool = False):
523523
discount = torch.arange(2, k + 2, dtype=dtype).log2()
524524

525525
self.discount: Tensor
526-
self.register_buffer('discount', discount)
526+
self.register_buffer('discount', discount, persistent=False)
527527

528528
if not weighted:
529-
self.register_buffer('idcg', cumsum(1.0 / discount))
529+
self.register_buffer('idcg', cumsum(1.0 / discount),
530+
persistent=False)
530531
else:
531532
self.idcg = None
532533

@@ -617,7 +618,7 @@ def __init__(self, k: int, num_dst_nodes: int) -> None:
617618
if WITH_TORCHMETRICS:
618619
self.add_state('mask', mask, dist_reduce_fx='max')
619620
else:
620-
self.register_buffer('mask', mask)
621+
self.register_buffer('mask', mask, persistent=False)
621622

622623
def update(
623624
self,
@@ -673,11 +674,11 @@ def __init__(self, k: int, category: Tensor) -> None:
673674
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
674675
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
675676
else:
676-
self.register_buffer('accum', torch.tensor(0.))
677-
self.register_buffer('total', torch.tensor(0))
677+
self.register_buffer('accum', torch.tensor(0.), persistent=False)
678+
self.register_buffer('total', torch.tensor(0), persistent=False)
678679

679680
self.category: Tensor
680-
self.register_buffer('category', category)
681+
self.register_buffer('category', category, persistent=False)
681682

682683
def update(
683684
self,
@@ -740,7 +741,7 @@ def __init__(
740741
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
741742
else:
742743
self.preds: List[Tensor] = []
743-
self.register_buffer('total', torch.tensor(0))
744+
self.register_buffer('total', torch.tensor(0), persistent=False)
744745

745746
def update(
746747
self,
@@ -829,11 +830,11 @@ def __init__(self, k: int, popularity: Tensor) -> None:
829830
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
830831
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
831832
else:
832-
self.register_buffer('accum', torch.tensor(0.))
833-
self.register_buffer('total', torch.tensor(0))
833+
self.register_buffer('accum', torch.tensor(0.), persistent=False)
834+
self.register_buffer('total', torch.tensor(0), persistent=False)
834835

835836
self.popularity: Tensor
836-
self.register_buffer('popularity', popularity)
837+
self.register_buffer('popularity', popularity, persistent=False)
837838

838839
def update(
839840
self,

0 commit comments

Comments
 (0)