Skip to content

Commit

Permalink
fix: field_index is incorrect in RBAC with domains mode (#345) (#346)
Browse files Browse the repository at this point in the history
* fix: field_index is incorrect in RBAC with domains mode (#345)

* chore: replace field name with constant in ManagementEnforcer
  • Loading branch information
truc0 authored Jun 11, 2024
1 parent 846cf24 commit 9f6a379
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
1 change: 1 addition & 0 deletions casbin/constant/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# Index constants
ACTION_INDEX = "act"
DOMAIN_INDEX = "dom"
SUBJECT_INDEX = "sub"
OBJECT_INDEX = "obj"
Expand Down
10 changes: 7 additions & 3 deletions casbin/management_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from casbin.internal_enforcer import InternalEnforcer
from casbin.model.policy_op import PolicyOp
from casbin.constant.constants import ACTION_INDEX, SUBJECT_INDEX, OBJECT_INDEX


class ManagementEnforcer(InternalEnforcer):
Expand All @@ -27,23 +28,26 @@ def get_all_subjects(self):

def get_all_named_subjects(self, ptype):
"""gets the list of subjects that show up in the current named policy."""
return self.model.get_values_for_field_in_policy("p", ptype, 0)
field_index = self.model.get_field_index(ptype, SUBJECT_INDEX)
return self.model.get_values_for_field_in_policy("p", ptype, field_index)

def get_all_objects(self):
"""gets the list of objects that show up in the current policy."""
return self.get_all_named_objects("p")

def get_all_named_objects(self, ptype):
"""gets the list of objects that show up in the current named policy."""
return self.model.get_values_for_field_in_policy("p", ptype, 1)
field_index = self.model.get_field_index(ptype, OBJECT_INDEX)
return self.model.get_values_for_field_in_policy("p", ptype, field_index)

def get_all_actions(self):
"""gets the list of actions that show up in the current policy."""
return self.get_all_named_actions("p")

def get_all_named_actions(self, ptype):
"""gets the list of actions that show up in the current named policy."""
return self.model.get_values_for_field_in_policy("p", ptype, 2)
field_index = self.model.get_field_index(ptype, ACTION_INDEX)
return self.model.get_values_for_field_in_policy("p", ptype, field_index)

def get_all_roles(self):
"""gets the list of roles that show up in the current named policy."""
Expand Down
12 changes: 12 additions & 0 deletions tests/test_management_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ def test_get_list(self):
self.assertEqual(e.get_all_actions(), ["read", "write"])
self.assertEqual(e.get_all_roles(), ["data2_admin"])

def test_get_list_with_domains(self):
e = self.get_enforcer(
get_examples("rbac_with_domains_model.conf"),
get_examples("rbac_with_domains_policy.csv"),
# True,
)

self.assertEqual(e.get_all_subjects(), ["admin"])
self.assertEqual(e.get_all_objects(), ["data1", "data2"])
self.assertEqual(e.get_all_actions(), ["read", "write"])
self.assertEqual(e.get_all_roles(), ["admin"])

def test_get_policy_api(self):
e = self.get_enforcer(
get_examples("rbac_model.conf"),
Expand Down

0 comments on commit 9f6a379

Please sign in to comment.