Skip to content

Commit 47fca2f

Browse files
committed
Add a linalg.lu function for the LU decomposition
Only the default (partial pivoting) algorithm that is implemented in all libraries and for all devices is added here. data-apisgh-627 has details on the no-pivoting case, but it's not universally supported and the only reason to add it would be that it's more performant in some cases where users know it will be numerically stable. Such an addition can be done in the future, but it seems like a potentially large amount of work for implementers for limited gain. Closes data-apisgh-627
1 parent accf8c2 commit 47fca2f

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

spec/draft/extensions/linear_algebra_functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ A conforming implementation of this ``linalg`` extension must provide and suppor
9898
eigh
9999
eigvalsh
100100
inv
101+
lu
101102
matmul
102103
matrix_norm
103104
matrix_power

src/array_api_stubs/_draft/linalg.py

+46
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,51 @@ def inv(x: array, /) -> array:
263263
"""
264264

265265

266+
def lu(x: array, /) -> Tuple[array, array, array]:
267+
"""
268+
Returns the LU decomposition of a matrix (or a stack of matrices).
269+
270+
The decomposition is:
271+
272+
.. math:: x = PLU
273+
274+
where :math:`P` is a permutation matrix, :math:`L` lower triangular with unit
275+
diagonal elements, and :math:`U` upper triangular.
276+
277+
Parameters
278+
----------
279+
x : array
280+
input array having shape ``(..., M, N)`` and whose innermost two
281+
dimensions form ``MxN`` matrices. Should have a floating-point data
282+
type.
283+
284+
Returns
285+
-------
286+
out: Tuple[array, array, array]
287+
a namedtuple ``(P, L, U)`` whose
288+
289+
- first element must have the field name ``P`` and must be an array
290+
of shape ``(M, M)``.
291+
- second element must have the field name ``L`` and must be an array
292+
of shape ``(M, K)``, where ``K == min(M, N)``.
293+
- third element must have the field name ``U`` and must be an array
294+
of shape ``(K, N)``.
295+
296+
Notes
297+
-----
298+
A correct decomposition of the prescribed shape must always be returned.
299+
This can be achieved by the implementer using an LU decomposition with
300+
partial pivoting algorithm.
301+
302+
Note that the LU decomposition is usually not unique, hence different
303+
implementations may return different numerical values for the same input
304+
values.
305+
306+
.. versionchanged:: 2023.12
307+
308+
"""
309+
310+
266311
def matmul(x1: array, x2: array, /) -> array:
267312
"""Alias for :func:`~array_api.matmul`."""
268313

@@ -832,6 +877,7 @@ def vector_norm(
832877
"eigh",
833878
"eigvalsh",
834879
"inv",
880+
"lu",
835881
"matmul",
836882
"matrix_norm",
837883
"matrix_power",

0 commit comments

Comments
 (0)