From 6e56e9f8cdcd4e18968a2a99a9fe364e50d77692 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Tue, 23 Jul 2024 22:37:28 +0200 Subject: [PATCH] Add `df.nx.set_properties(**kwargs)` --- nx_pandas/_patch.py | 48 +++++++++++++++++++++++++++++ nx_pandas/tests/test_df_nx_attrs.py | 36 ++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/nx_pandas/_patch.py b/nx_pandas/_patch.py index f7e5494..173b9be 100644 --- a/nx_pandas/_patch.py +++ b/nx_pandas/_patch.py @@ -99,6 +99,54 @@ def __dir__(self): attrs.remove("edge_key") return attrs + def set_properties( + self, + *, + source=None, + target=None, + edge_key=None, + is_directed=None, + is_multigraph=None, + cache_enabled=None, + ): + """Set many graph properties (i.e., ``df.nx`` attributes) at once. + + Return the original DataFrame to allow method chaining. For example:: + + >>> df = pd.read_csv("my_data.csv").nx.set_properties(is_directed=False) + + This is a bulk transaction, so either all given attributes will be updated, + or nothing will be set if there was an exception. + """ + prev = {} + cur = {} + if source is not None: + prev["_source"] = self._source + cur["source"] = source + if target is not None: + prev["_target"] = self._target + cur["target"] = target + if is_directed is not None: + prev["is_directed"] = self.is_directed + cur["is_directed"] = is_directed + if is_multigraph is not None: + prev["is_multigraph"] = self.is_multigraph + cur["is_multigraph"] = is_multigraph + if edge_key is not None: + prev["_edge_key"] = self._edge_key + cur["edge_key"] = edge_key + if cache_enabled is not None: + prev["cache_enabled"] = self.cache_enabled + cur["cache_enabled"] = cache_enabled + try: + for attr, val in cur.items(): + setattr(self, attr, val) + except Exception: + for attr, val in prev.items(): + setattr(self, attr, val) + raise + return self._df + def _attr_raise_if_invalid_graph(df, attr): try: diff --git a/nx_pandas/tests/test_df_nx_attrs.py b/nx_pandas/tests/test_df_nx_attrs.py index f38e886..29b7b4e 100644 --- a/nx_pandas/tests/test_df_nx_attrs.py +++ b/nx_pandas/tests/test_df_nx_attrs.py @@ -110,3 +110,39 @@ def test_set_attrs(df): df.nx.edge_key = None with pytest.raises(AttributeError, match="to be used as a networkx graph"): df.__networkx_backend__ + + +def test_set_properties(df): + df2 = df.nx.set_properties( + source="target", + target="source", + edge_key="foo", + is_directed=False, + is_multigraph=True, + cache_enabled=True, + ) + assert df is df2 + assert df.nx.source == "target" + assert df.nx.target == "source" + assert df.nx.edge_key == "foo" + assert df.nx.is_directed is False + assert df.nx.is_multigraph is True + assert df.nx.cache_enabled is True + with pytest.raises( + AttributeError, match="'edge_key' attribute only exists for multigraphs" + ): + df.nx.set_properties( + source="source", + target="target", + is_directed=True, + is_multigraph=False, + cache_enabled=False, + edge_key="BAD", + ) + # Unchanged + assert df.nx.source == "target" + assert df.nx.target == "source" + assert df.nx.edge_key == "foo" + assert df.nx.is_directed is False + assert df.nx.is_multigraph is True + assert df.nx.cache_enabled is True