diff --git a/src/flint/test/test_all.py b/src/flint/test/test_all.py index aca7668e..bc66da75 100644 --- a/src/flint/test/test_all.py +++ b/src/flint/test/test_all.py @@ -1357,6 +1357,24 @@ def test_nmod(): assert str(G(3,5)) == "3" assert G(3,5).repr() == "nmod(3, 5)" + # We can compare to int and fmpz types + assert G(1, 5) == int(1) + assert G(4, 5) == int(-1) + assert G(1, 5) == flint.fmpz(1) + assert G(4, 5) == flint.fmpz(-1) + + # When the modulus matches, we can compare fmpz_mod + R = flint.fmpz_mod_ctx(5) + assert G(1, 5) == R(1) + assert G(1, 5) != R(-1) + assert G(4, 5) == R(4) + assert G(4, 5) == R(-1) + # when the modulus doesnt match, everything fails + assert G(1, 7) != R(1) + assert G(1, 7) != R(-1) + assert G(4, 7) != R(4) + assert G(4, 7) != R(-1) + def test_nmod_poly(): N = flint.nmod P = flint.nmod_poly diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index 7bcab98a..60f78b68 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -3,6 +3,7 @@ from flint.utils.typecheck cimport typecheck from flint.types.fmpq cimport any_as_fmpq from flint.types.fmpz cimport any_as_fmpz from flint.types.fmpz cimport fmpz +from flint.types.fmpz_mod cimport fmpz_mod from flint.types.fmpq cimport fmpq from flint.flintlib.flint cimport ulong @@ -66,25 +67,28 @@ cdef class nmod(flint_scalar): def modulus(self): return self.mod.n - def __richcmp__(s, t, int op): - cdef mp_limb_t v + def __richcmp__(self, other, int op): cdef bint res + if op != 2 and op != 3: raise TypeError("nmods cannot be ordered") - if typecheck(s, nmod) and typecheck(t, nmod): - res = ((s).val == (t).val) and \ - ((s).mod.n == (t).mod.n) - if op == 2: - return res - else: - return not res - elif typecheck(s, nmod) and typecheck(t, int): - res = s.val == (t % s.mod.n) - if op == 2: - return res - else: - return not res - return NotImplemented + + if typecheck(other, nmod): + res = self.val == (other).val and \ + self.mod.n == (other).mod.n + elif typecheck(other, int): + res = self.val == (other % self.mod.n) + elif typecheck(other, fmpz): + res = self.val == (int(other) % self.mod.n) + elif typecheck(other, fmpz_mod): + res = self.mod.n == (other).ctx.modulus() and \ + self.val == int(other) + else: + return NotImplemented + + if op == 2: + return res + return not res def __hash__(self): return hash((int(self.val), self.modulus))