Skip to content

Commit d5e2b56

Browse files
committed
Make joint context creation and coercion explicit
1 parent 9e14bcb commit d5e2b56

File tree

3 files changed

+40
-43
lines changed

3 files changed

+40
-43
lines changed

src/flint/flint_base/flint_base.pyx

+14-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ from flint.flintlib.flint cimport (
55
slong
66
)
77
from flint.flint_base.flint_context cimport thectx
8+
from flint.utils.typecheck cimport typecheck
89
cimport libc.stdlib
10+
911
from typing import Optional
1012

1113

@@ -146,6 +148,9 @@ cdef class flint_mpoly_context(flint_elem):
146148
libc.stdlib.free(self.c_names)
147149
self._init = False
148150

151+
def __str__(self):
152+
return self.__repr__()
153+
149154
def __repr__(self):
150155
return f"{self.__class__.__name__}({self.nvars()}, '{self.ordering()}', {self.names()})"
151156

@@ -196,9 +201,15 @@ cdef class flint_mpoly_context(flint_elem):
196201
)
197202

198203
@classmethod
199-
def joint_context(cls, ctxs):
200-
vars = {x: i for i, x in enumerate({var for ctx in ctxs for var in ctx.py_names})}
201-
return cls.get_context(nvars=len(vars), nametup=tuple(vars.keys())), vars
204+
def joint_context(cls, *ctxs):
205+
vars = set()
206+
for ctx in ctxs:
207+
if not typecheck(ctx, cls):
208+
raise ValueError(f"{ctx} is not a {cls}")
209+
else:
210+
for var in ctx.py_names:
211+
vars.add(var)
212+
return cls.get_context(nvars=len(vars), nametup=tuple(vars))
202213

203214

204215
cdef class flint_mpoly(flint_elem):

src/flint/types/fmpq_mpoly.pyx

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ cdef class fmpq_mpoly_ctx(flint_mpoly_context):
8181
return "deglex"
8282
if self.val.zctx.minfo.ord == ordering_t.ORD_DEGREVLEX:
8383
return "degrevlex"
84+
8485
def gen(self, slong i):
8586
"""
8687
Return the `i`th generator of the polynomial ring

src/flint/types/fmpz_mpoly.pyx

+25-40
Original file line numberDiff line numberDiff line change
@@ -227,46 +227,6 @@ cdef class fmpz_mpoly_ctx(flint_mpoly_context):
227227
return res
228228

229229

230-
def coerce_fmpz_mpolys(args):
231-
cdef:
232-
fmpz_mpoly_ctx ctx
233-
fmpz_mpoly inpoly, outpoly
234-
slong *C
235-
slong i
236-
237-
if not args:
238-
return ctx, []
239-
240-
# If all arguments are fmpz_mpolys and share the same context then nothing needs to be done
241-
if typecheck(args[0], fmpz_mpoly):
242-
ctx = (<fmpz_mpoly> args[0]).ctx
243-
if all(typecheck(args[i], fmpz_mpoly) and (<fmpz_mpoly> args[i]).ctx is ctx for i in range(1, len(args))):
244-
return ctx, list(args)
245-
246-
for i in range(len(args)):
247-
if not typecheck(args[i], fmpz_mpoly):
248-
args[i] = fmpz_mpoly(args[i])
249-
250-
ctx, vars = fmpz_mpoly_ctx.joint_context((<fmpz_mpoly>inpoly).ctx for inpoly in args)
251-
252-
out = [fmpz_mpoly.__new__(fmpz_mpoly) for _ in range(len(args))]
253-
254-
nvars = max((<fmpz_mpoly>inpoly).ctx.nvars() for inpoly in args)
255-
C = <slong *> libc.stdlib.malloc(nvars * sizeof(slong *))
256-
for inpoly, outpoly in zip(args, out):
257-
inpoly = <fmpz_mpoly>inpoly
258-
outpoly = <fmpz_mpoly>outpoly
259-
260-
init_fmpz_mpoly(outpoly, ctx)
261-
for i, var in enumerate(inpoly.ctx.py_names):
262-
C[i] = <slong>vars[var]
263-
264-
fmpz_mpoly_compose_fmpz_mpoly_gen(outpoly.val, inpoly.val, C, inpoly.ctx.val, ctx.val)
265-
266-
libc.stdlib.free(C)
267-
return ctx, out
268-
269-
270230
cdef class fmpz_mpoly(flint_mpoly):
271231
"""
272232
The *fmpz_poly* type represents sparse multivariate polynomials over
@@ -878,3 +838,28 @@ cdef class fmpz_mpoly(flint_mpoly):
878838
fmpz_set((<fmpz>c).val, fac.constant)
879839
fmpz_mpoly_factor_clear(fac, self.ctx.val)
880840
return c, res
841+
842+
def coerce_to_context(self, ctx):
843+
cdef:
844+
fmpz_mpoly outpoly
845+
slong *C
846+
slong i
847+
848+
if not typecheck(ctx, fmpz_mpoly_ctx):
849+
raise ValueError("provided context is not a fmpz_mpoly_ctx")
850+
851+
if self.ctx is ctx:
852+
return self
853+
854+
C = <slong *> libc.stdlib.malloc(self.ctx.val.minfo.nvars * sizeof(slong *))
855+
outpoly = fmpz_mpoly.__new__(fmpz_mpoly)
856+
init_fmpz_mpoly(outpoly, ctx)
857+
858+
vars = {x: i for i, x in enumerate(ctx.py_names)}
859+
for i, var in enumerate(self.ctx.py_names):
860+
C[i] = <slong>vars[var]
861+
862+
fmpz_mpoly_compose_fmpz_mpoly_gen(outpoly.val, self.val, C, self.ctx.val, (<fmpz_mpoly_ctx>ctx).val)
863+
864+
libc.stdlib.free(C)
865+
return outpoly

0 commit comments

Comments
 (0)