Skip to content

Commit

Permalink
feat: add start of python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Jan 2, 2025
1 parent c6a4376 commit 8827fe7
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 28 deletions.
115 changes: 87 additions & 28 deletions python/bindings/lib.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,128 @@
// SPDX-FileCopyrightText: 2024 Alexandru Fikl <[email protected]>
// SPDX-License-Identifier: CC0-1.0

use num::complex::Complex64;
use num::complex::{c64, Complex32, Complex64};

use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1};
use numpy::{IntoPyArray, PyReadonlyArrayDyn};
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::PyComplex;

use mittagleffler::{GarrappaMittagLeffler, MittagLeffler};

#[pyclass]
pub struct GarrappaMittagLeffler {
inner: mittagleffler::GarrappaMittagLeffler,
#[pyo3(name = "GarrappaMittagLeffler")]
pub struct PyGarrappaMittagLeffler {
inner: GarrappaMittagLeffler,
}

#[pymethods]
impl GarrappaMittagLeffler {
impl PyGarrappaMittagLeffler {
#[new]
#[pyo3(signature = (eps=None))]
#[pyo3(signature = (*, eps=None))]
pub fn new(eps: Option<f64>) -> Self {
GarrappaMittagLeffler {
inner: mittagleffler::GarrappaMittagLeffler::new(eps),
PyGarrappaMittagLeffler {
inner: GarrappaMittagLeffler::new(eps),
}
}

pub fn evaluate(&self, z: Complex64, alpha: f64, beta: f64) -> Option<Complex64> {
self.inner.evaluate(z, alpha, beta)
}

#[setter]
#[setter(eps)]
pub fn set_eps(&mut self, eps: f64) {
self.inner.eps = eps;
}

#[getter]
#[getter(eps)]
pub fn get_eps(&self) -> f64 {
self.inner.eps
}
}

fn mittag_leffler_always_c64(z: &Complex64, alpha: f64, beta: f64) -> Complex64 {
match z.mittag_leffler(alpha, beta) {
Some(value) => value,
None => Complex64 {
re: f64::NAN,
im: f64::NAN,
},
}
}

#[pyfunction]
pub fn mittag_leffler<'py>(
py: Python<'py>,
z: PyReadonlyArray1<Complex64>,
z: Bound<'py, PyAny>,
alpha: f64,
beta: f64,
) -> Bound<'py, PyArray1<Complex64>> {
let z = z.as_array();
let z: Vec<Complex64> = z
.iter()
.map(
|z_i| match mittagleffler::MittagLeffler::mittag_leffler(z_i, alpha, beta) {
Some(value) => value,
None => Complex64 {
re: f64::NAN,
im: f64::NAN,
},
},
)
.collect();

z.into_pyarray(py)
) -> PyResult<PyObject> {
println!("z: {:?} ({:?})", z, z.get_type());

if let Ok(ary) = z.extract::<Complex64>() {
let result = mittag_leffler_always_c64(&ary, alpha, beta);
return Ok(PyComplex::from_doubles(py, result.re, result.im).into());
}

if let Ok(ary) = z.extract::<PyReadonlyArrayDyn<f32>>() {
let ary = ary
.as_array()
.map(|x| mittag_leffler_always_c64(&c64(*x as f64, 0.0), alpha, beta));
#[allow(deprecated)]
return Ok(ary.into_pyarray(py).into_py(py));
}

if let Ok(ary) = z.extract::<PyReadonlyArrayDyn<f64>>() {
let ary = ary
.as_array()
.map(|x| mittag_leffler_always_c64(&c64(*x, 0.0), alpha, beta));
#[allow(deprecated)]
return Ok(ary.into_pyarray(py).into_py(py));
}

if let Ok(ary) = z.extract::<PyReadonlyArrayDyn<Complex32>>() {
let ary = ary
.as_array()
.map(|x| mittag_leffler_always_c64(&c64(x.re, x.im), alpha, beta));
#[allow(deprecated)]
return Ok(ary.into_pyarray(py).into_py(py));
}

if let Ok(ary) = z.extract::<PyReadonlyArrayDyn<Complex64>>() {
let ary = ary
.as_array()
.map(|x| mittag_leffler_always_c64(x, alpha, beta));
#[allow(deprecated)]
return Ok(ary.into_pyarray(py).into_py(py));
}

if let Ok(ary) = z.extract::<PyReadonlyArrayDyn<i32>>() {
let ary = ary
.as_array()
.map(|x| mittag_leffler_always_c64(&c64(*x as f64, 0.0), alpha, beta));
#[allow(deprecated)]
return Ok(ary.into_pyarray(py).into_py(py));
}

if let Ok(ary) = z.extract::<PyReadonlyArrayDyn<i64>>() {
let ary = ary
.as_array()
.map(|x| mittag_leffler_always_c64(&c64(*x as f64, 0.0), alpha, beta));
#[allow(deprecated)]
return Ok(ary.into_pyarray(py).into_py(py));
}

Err(PyTypeError::new_err(format!(
"Input 'z' has unsupported type {}",
z.get_type()
)))
}

#[pymodule]
#[pyo3(name = "_bindings")]
fn _bindings(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<GarrappaMittagLeffler>()?;
m.add_class::<PyGarrappaMittagLeffler>()?;
m.add_function(wrap_pyfunction!(mittag_leffler, m)?)?;

Ok(())
Expand Down
23 changes: 23 additions & 0 deletions python/src/pymittagleffler/_bindings.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: 2024 Alexandru Fikl <[email protected]>
# SPDX-License-Identifier: CC0-1.0

from __future__ import annotations

from typing import Any, overload

import numpy as np

class GarrappaMittagLeffler:
@property
def eps(self) -> float: ...
def __init__(self, *, eps: float | None = None) -> None: ...
def evaluate(self, z: complex, alpha: float, beta: float) -> complex | None: ...

@overload
def mittag_leffler(
z: int | float | complex | np.generic, alpha: float, beta: float
) -> complex: ...
@overload
def mittag_leffler(
z: np.ndarray[tuple[int, ...], np.dtype[Any]], alpha: float, beta: float
) -> np.ndarray[tuple[int, ...], np.dtype[np.complex128]]: ...
62 changes: 62 additions & 0 deletions python/tests/test_mittag_leffler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: 2024 Alexandru Fikl <[email protected]>
# SPDX-License-Identifier: CC0-1.0

from __future__ import annotations

from typing import Any

import numpy as np
import numpy.linalg as la
import pytest


def test_mittag_leffler_scalar() -> None:
from pymittagleffler import mittag_leffler

alpha = beta = 1.0

for cls in (bool, int, float, np.float32, np.float64):
z = cls(1)
_ = mittag_leffler(z, alpha, beta)
print(f"({type(z)}, {type(_)}): {_} {np.exp(z + 0j)}")

for cls in (complex, np.complex64, np.complex128):
z = cls(1.0, 1.0)
_ = mittag_leffler(z, alpha, beta)
print(f"({type(z)}, {type(_)}): {_} {np.exp(z + 0j)}")

with pytest.raises(TypeError):
z = "1.0"
_ = mittag_leffler(z, alpha, beta)


@pytest.mark.parametrize("etype", [np.int32, np.float32, np.float64, np.complex64, np.complex128])
def test_mittag_leffler_vector(etype: Any) -> None:
from pymittagleffler import mittag_leffler

dtype = np.dtype(etype)
rng = np.random.default_rng()
if issubclass(etype, np.integer):
z = rng.integers(0, 10, size=128, dtype=dtype)
elif issubclass(etype, np.floating):
z = rng.random(size=128, dtype=dtype)
elif issubclass(etype, np.complexfloating):
rtype = dtype.type(1.0).real.dtype
z = rng.random(size=128, dtype=rtype) + 1j * rng.random(size=128, dtype=rtype)
else:
raise TypeError(f"Unsupported dtype: {dtype}")

alpha = beta = 1.0
result = mittag_leffler(z, alpha, beta)
ref = np.exp(z)

assert la.norm(result - ref) < 1.0e-4 * la.norm(ref)


if __name__ == "__main__":
import sys

if len(sys.argv) > 1:
exec(sys.argv[1])
else:
pytest.main([__file__])

0 comments on commit 8827fe7

Please sign in to comment.