Skip to content

Commit

Permalink
feat: Added support for async watcher callbacks #340 (#341)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanasecucliciu authored Feb 9, 2024
1 parent e409434 commit c04d832
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 12 deletions.
49 changes: 37 additions & 12 deletions casbin/async_internal_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect

from casbin.core_enforcer import CoreEnforcer
from casbin.model import Model, FunctionMap
Expand Down Expand Up @@ -105,8 +106,12 @@ async def save_policy(self):
await self.adapter.save_policy(self.model)

if self.watcher:
if callable(getattr(self.watcher, "update_for_save_policy", None)):
self.watcher.update_for_save_policy(self.model)
update_for_save_policy = getattr(self.watcher, "update_for_save_policy", None)
if callable(update_for_save_policy):
if inspect.iscoroutinefunction(update_for_save_policy):
await update_for_save_policy(self.model)
else:
update_for_save_policy(self.model)
else:
self.watcher.update()

Expand All @@ -122,8 +127,12 @@ async def _add_policy(self, sec, ptype, rule):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_add_policy", None)):
self.watcher.update_for_add_policy(sec, ptype, rule)
update_for_add_policy = getattr(self.watcher, "update_for_add_policy", None)
if callable(update_for_add_policy):
if inspect.iscoroutinefunction(update_for_add_policy):
await update_for_add_policy(sec, ptype, rule)
else:
update_for_add_policy(sec, ptype, rule)
else:
self.watcher.update()

Expand All @@ -144,8 +153,12 @@ async def _add_policies(self, sec, ptype, rules):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_add_policies", None)):
self.watcher.update_for_add_policies(sec, ptype, rules)
update_for_add_policies = getattr(self.watcher, "update_for_add_policies", None)
if callable(update_for_add_policies):
if inspect.iscoroutinefunction(update_for_add_policies):
await update_for_add_policies(sec, ptype, rules)
else:
update_for_add_policies(sec, ptype, rules)
else:
self.watcher.update()

Expand Down Expand Up @@ -224,8 +237,12 @@ async def _remove_policy(self, sec, ptype, rule):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_remove_policy", None)):
self.watcher.update_for_remove_policy(sec, ptype, rule)
update_for_remove_policy = getattr(self.watcher, "update_for_remove_policy", None)
if callable(update_for_remove_policy):
if inspect.iscoroutinefunction(update_for_remove_policy):
await update_for_remove_policy(sec, ptype, rule)
else:
update_for_remove_policy(sec, ptype, rule)
else:
self.watcher.update()

Expand All @@ -246,8 +263,12 @@ async def _remove_policies(self, sec, ptype, rules):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_remove_policies", None)):
self.watcher.update_for_remove_policies(sec, ptype, rules)
update_for_remove_policies = getattr(self.watcher, "update_for_remove_policies", None)
if callable(update_for_remove_policies):
if inspect.iscoroutinefunction(update_for_remove_policies):
await update_for_remove_policies(sec, ptype, rules)
else:
update_for_remove_policies(sec, ptype, rules)
else:
self.watcher.update()

Expand All @@ -265,8 +286,12 @@ async def _remove_filtered_policy(self, sec, ptype, field_index, *field_values):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_remove_filtered_policy", None)):
self.watcher.update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
update_for_remove_filtered_policy = getattr(self.watcher, "update_for_remove_filtered_policy", None)
if callable(update_for_remove_filtered_policy):
if inspect.iscoroutinefunction(update_for_remove_filtered_policy):
await update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
else:
update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
else:
self.watcher.update()

Expand Down
178 changes: 178 additions & 0 deletions tests/test_watcher_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import casbin
from tests.test_enforcer import get_examples, TestCaseBase
from unittest import IsolatedAsyncioTestCase


class SampleWatcher:
Expand Down Expand Up @@ -113,6 +114,103 @@ def start_watch(self):
pass


class AsyncSampleWatcher:
def __init__(self):
self.callback = None
self.notify_message = None

async def close(self):
pass

async def set_update_callback(self, callback):
"""
sets the callback function to be called when the policy is updated
:param callable callback: callback(event)
- event: event received from the rabbitmq
:return:
"""
self.callback = callback

async def update(self, msg):
"""
update the policy
"""
self.notify_message = msg
return True

async def update_for_add_policy(self, section, ptype, *params):
"""
update for add policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = "called add policy"
return await self.update(message)

async def update_for_remove_policy(self, section, ptype, *params):
"""
update for remove policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = "called remove policy"
return await self.update(message)

async def update_for_remove_filtered_policy(self, section, ptype, field_index, *params):
"""
update for remove filtered policy
:param section: section
:param ptype: policy type
:param field_index: field index
:param params: other params
:return:
"""
message = "called remove filtered policy"
return await self.update(message)

async def update_for_save_policy(self, model: casbin.Model):
"""
update for save policy
:param model: casbin model
:return:
"""
message = "called save policy"
return await self.update(message)

async def update_for_add_policies(self, section, ptype, *params):
"""
update for add policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = "called add policies"
return await self.update(message)

async def update_for_remove_policies(self, section, ptype, *params):
"""
update for remove policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = "called remove policies"
return await self.update(message)

async def start_watch(self):
"""
starts the watch thread
:return:
"""
pass


class TestWatcherEx(TestCaseBase):
def get_enforcer(self, model=None, adapter=None):
return casbin.Enforcer(
Expand Down Expand Up @@ -187,3 +285,83 @@ def test_auto_notify_disabled(self):

e.remove_policies(rules)
self.assertEqual(w.notify_message, None)


class TestAsyncWatcherEx(IsolatedAsyncioTestCase):
def get_enforcer(self, model=None, adapter=None):
return casbin.AsyncEnforcer(
model,
adapter,
)

async def test_auto_notify_enabled(self):
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
)
await e.load_policy()

w = AsyncSampleWatcher()
e.set_watcher(w)
e.enable_auto_notify_watcher(True)

await e.save_policy()
self.assertEqual(w.notify_message, "called save policy")

await e.add_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, "called add policy")

await e.remove_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, "called remove policy")

await e.remove_filtered_policy(1, "data1")
self.assertEqual(w.notify_message, "called remove filtered policy")

rules = [
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["ham", "data4", "write"],
]
await e.add_policies(rules)
self.assertEqual(w.notify_message, "called add policies")

await e.remove_policies(rules)
self.assertEqual(w.notify_message, "called remove policies")

async def test_auto_notify_disabled(self):
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
)
await e.load_policy()

w = SampleWatcher()
e.set_watcher(w)
e.enable_auto_notify_watcher(False)

await e.save_policy()
self.assertEqual(w.notify_message, "called save policy")

w.notify_message = None

await e.add_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, None)

await e.remove_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, None)

await e.remove_filtered_policy(1, "data1")
self.assertEqual(w.notify_message, None)

rules = [
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["ham", "data4", "write"],
]
await e.add_policies(rules)
self.assertEqual(w.notify_message, None)

await e.remove_policies(rules)
self.assertEqual(w.notify_message, None)

0 comments on commit c04d832

Please sign in to comment.