From c72d116f7aad232275c0c694cd3e189bc3f5e39c Mon Sep 17 00:00:00 2001 From: Andrei Ivanov Date: Mon, 10 Feb 2025 13:03:13 -0800 Subject: [PATCH] Restoring the TensorAttr.fully_specify method. --- torch_geometric/data/feature_store.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 6674bf42ed4b..8ec6fcde1d21 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -74,6 +74,13 @@ def is_fully_specified(self) -> bool: r"""Whether the :obj:`TensorAttr` has no unset fields.""" return all([self.is_set(key) for key in self.__dataclass_fields__]) + def fully_specify(self) -> 'TensorAttr': + r"""Sets all :obj:`UNSET` fields to :obj:`None`.""" + for key in self.__dataclass_fields__: + if not self.is_set(key): + setattr(self, key, None) + return self + def update(self, attr: 'TensorAttr') -> 'TensorAttr': r"""Updates an :class:`TensorAttr` with set attributes from another :class:`TensorAttr`. @@ -473,7 +480,9 @@ def __setitem__(self, key: TensorAttr, value: FeatureTensorType): # CastMixin will handle the case of key being a tuple or TensorAttr # object: key = self._tensor_attr_cls.cast(key) - assert key.is_fully_specified() + # We need to fully-specify the key for __setitem__ as it does not make + # sense to work with a view here: + key.fully_specify() self.put_tensor(value, key) def __getitem__(self, key: TensorAttr) -> Any: