-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support "triangular-scan" for native lower/upper triangular algorithms #2667
Comments
Hello! I'm a fan of the idea! A few thoughts:
cc @wsmoses who has talked about something similar to this as well I think? Let me know if you have any thoughts |
So offhand I was earlier thinking of potentially having a type or operation attribute that specifies whether the data has some known structure. For example, for matmul saying if the first operand is upper triangular one could use https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html instead of GEMM. Of course this is still a spec-level change it would be up to whatever stablehlo lowers into to leverage the additional information (or not) |
Hi all, thanks for the detailed responses!
My initial thoughts were the latter, resembling something along the lines of n = x.shape[0]
carry = init
for i in range(n):
if type == 'lower':
carry = my_func(x[i, 0:i], carry)
else:
carry = my_func(x[i, i:n], carry)
return carry This structure encompasses back/forward solves and other linear-algebraic operations that require only inspecting upper/lower triangular parts of the a square matrix.
While I would love to be able to, I'm afraid I do not have the time to contribute substantially towards this, due to my current research/mentoring/administrative duties. I apologize, as I realize it's a bit unfair for me to request a feature, but contribute so little towards its realization.
Yes, absolutely. The case I've outlined above could be cast as a special case of |
Request description
Hi! First, I want to emphasize how much myself and my lab uses JAX to dramatically improve our scientific software and algorithms and truly appreciate the incredible work that JAX/XLA teams have done.
I'm curious if it would be possible to define HLO/XLA primitives that target the specific double loops common for upper/lower triangular matrix algorithms (ie backsolve, levinson-durbin, etc). Due to the static shape requirements throughout loop iterations, workarounds for these style of problems usually involve masking. However, provided the original shape is static/known, the shape of each row operation is inferrable as well.
For example, if provided an n x n matrix, the outer loop iterates from i=1...n, while the inner loop is j=1..i, which is often dealt with through subsetting/vectorization of the necessary computation.
Is this special case something that could be supported? While it may seem niche, it covers quite many classical algorithms in linear algebra.
The text was updated successfully, but these errors were encountered: