diff --git a/CHANGELOG.md b/CHANGELOG.md index 18b015b..82367a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ Nothing notable unreleased. tests equality of Mapping objects not requiring them to be dicts. Similar to `assertSequenceEqual` but for mappings. +* (testing) Added a new method `absltest.assertDictContainsSubset` that + checks that a dictionary contains a subset of keys and values. Similar + to a removed method `unittest.assertDictContainsSubset` (existed until Python 3.11). + ### Fixed * (testing) Fixed an issue where the test reporter crashes with exceptions with diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py index 26a033b..7c16295 100644 --- a/absl/testing/absltest.py +++ b/absl/testing/absltest.py @@ -1656,6 +1656,25 @@ def CheckEqual(a, b): for a, b in itertools.product(group, group): CheckEqual(a, b) + def assertDictContainsSubset( + self, subset: Mapping[Any, Any], dictionary: Mapping[Any, Any], msg=None + ): + """Raises AssertionError if dictionary is not a superset of subset. + + Args: + subset: A dict, the expected subset of the `dictionary`. + dictionary: A dict, the actual value. + msg: An optional str, the associated message. + + Raises: + AssertionError: if dictionary is not a superset of subset. + """ + if not isinstance(subset, dict): + subset = dict(subset) + if not isinstance(dictionary, dict): + dictionary = dict(dictionary) + self.assertDictEqual(dictionary, {**dictionary, **subset}, msg) + def assertDictEqual(self, a, b, msg=None): """Raises AssertionError if a and b are not equal dictionaries. diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py index dbaa62a..cb8fed7 100644 --- a/absl/testing/tests/absltest_test.py +++ b/absl/testing/tests/absltest_test.py @@ -27,7 +27,7 @@ import sys import tempfile import textwrap -from typing import Optional +from typing import Any, ItemsView, Iterator, KeysView, Mapping, Optional, Type, ValuesView import unittest from absl.testing import _bazelize_command @@ -1597,6 +1597,90 @@ def test_assert_json_equal_bad_json(self): with self.assertRaises(ValueError) as error_context: self.assertJsonEqual('', '') + @parameterized.named_parameters( + dict(testcase_name='empty', subset={}, dictionary={}), + dict(testcase_name='empty_is_a_subset', subset={}, dictionary={'a': 1}), + dict( + testcase_name='equal_one_element', + subset={'a': 1}, + dictionary={'a': 1}, + ), + dict( + testcase_name='subset', subset={'a': 1}, dictionary={'a': 1, 'b': 2} + ), + dict( + testcase_name='equal_many_elements', + subset={'a': 1, 'b': 2}, + dictionary={'a': 1, 'b': 2}, + ), + ) + def test_assert_dict_contains_subset( + self, subset: Mapping[Any, Any], dictionary: Mapping[Any, Any] + ): + self.assertDictContainsSubset(subset, dictionary) + + def test_assert_dict_contains_subset_converts_to_dict(self): + class ConvertibleToDict(Mapping): + + def __init__(self, **kwargs): + self._data = kwargs + + def __getitem__(self, key: Any) -> Any: + return self._data[key] + + def __iter__(self) -> Iterator: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def keys(self) -> KeysView: + return self._data.keys() + + def values(self) -> ValuesView: + return self._data.values() + + def items(self) -> ItemsView: + return self._data.items() + + self.assertDictContainsSubset( + ConvertibleToDict(name='a', value=1), + ConvertibleToDict(name='a', value=1), + ) + + @parameterized.named_parameters( + dict(testcase_name='subset_vs_empty', subset={1: 'one'}, dictionary={}), + dict( + testcase_name='value_is_different', + subset={'a': 2}, + dictionary={'a': 1}, + ), + dict( + testcase_name='key_is_different', subset={'c': 1}, dictionary={'a': 1} + ), + dict( + testcase_name='subset_is_larger', + subset={'a': 1, 'c': 1}, + dictionary={'a': 1}, + ), + dict( + testcase_name='UnicodeDecodeError_constructing_failure_msg', + subset={'foo': ''.join(chr(i) for i in range(255))}, + dictionary={'foo': '\uFFFD'}, + ), + ) + def test_assert_dict_contains_subset_fails( + self, subset: Mapping[Any, Any], dictionary: Mapping[Any, Any] + ): + with self.assertRaises(self.failureException): + self.assertDictContainsSubset(subset, dictionary) + + def test_assert_dict_contains_subset_fails_with_msg(self): + with self.assertRaisesRegex( + AssertionError, re.compile('custom message', re.DOTALL) + ): + self.assertDictContainsSubset({'a': 1}, {'a': 2}, msg='custom message') + class GetCommandStderrTestCase(absltest.TestCase):