Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve torch compatibility #46

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ cython_debug/
#.idea/
.vscode/settings.json
.vscode/launch.json

# Local
tmp/
6 changes: 2 additions & 4 deletions docs/source/03-value-tables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,7 @@
" 0xFF,\n",
" ):\n",
" print(\n",
" str_tablerow(\n",
" fi, decode_float(fi, i), show_b16_info=True, vs_width=8, vs_d=4\n",
" )\n",
" str_tablerow(fi, decode_float(fi, i), show_b16_info=True, vs_width=8, vs_d=4)\n",
" )"
]
},
Expand Down Expand Up @@ -3266,7 +3264,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "gfloat-clean",
"language": "python",
"name": "python3"
},
Expand Down
23 changes: 8 additions & 15 deletions docs/source/04-benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -34,24 +34,17 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"GFloat scalar : 6306.18 nsec (25 runs at size 10000)\n",
"GFloat vectorized, numpy arrays: 52.52 nsec (25 runs at size 1000000)\n",
"GFloat vectorized, JAX JIT : 3.04 nsec (500 runs at size 1000000)\n",
"ML_dtypes : 2.69 nsec (500 runs at size 1000000)\n"
"GFloat scalar : 7510.22 nsec (25 runs at size 10000)\n",
"GFloat vectorized, numpy arrays: 43.82 nsec (25 runs at size 1000000)\n",
"GFloat vectorized, JAX JIT : 2.69 nsec (500 runs at size 1000000)\n",
"ML_dtypes : 2.57 nsec (500 runs at size 1000000)\n"
]
}
],
Expand All @@ -61,7 +54,7 @@
"N = 1_000_000\n",
"a = np.random.rand(N)\n",
"\n",
"jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x, np=jnp))\n",
"jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x))\n",
"ja = jnp.array(a)\n",
"jax_round_jit(ja) # Cache compilation\n",
"\n",
Expand Down Expand Up @@ -108,7 +101,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand Down
203 changes: 131 additions & 72 deletions docs/source/05-stochastic-rounding.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ dependencies = {file = ["requirements.txt"]}
optional-dependencies = {dev = {file = ["requirements-dev.txt"]}}

[tool.black]
line-length = 88
line-length = 90
fast = true

[tool.mypy]
[[tool.mypy.overrides]]
module = "mx.*"
module = ["mx.*", "array_api_compat.*", "array_api_strict.*"]
ignore_missing_imports = true

[tool.pytest.ini_options]
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ nbval
ml_dtypes
jaxlib
jax
torch
array-api-strict
airium
pandas
matplotlib
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy
more_itertools
array-api-compat
4 changes: 1 addition & 3 deletions src/gfloat/decode_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from .types import FormatInfo


def decode_ndarray(
fi: FormatInfo, codes: np.ndarray, np: ModuleType = np
) -> np.ndarray:
def decode_ndarray(fi: FormatInfo, codes: np.ndarray, np: ModuleType = np) -> np.ndarray:
r"""
Vectorized version of :meth:`decode_float`

Expand Down
2 changes: 1 addition & 1 deletion src/gfloat/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def format_info_p3109(k: int, precision: int) -> FormatInfo:
ValueError: If p is not in 1..k-1
ValueError: If k is < 2
"""
if precision < 1 or precision > 7:
if precision < 1 or precision > k - 1:
raise ValueError(f"P3109 format not defined for p={precision}")

name = f"p3109_{k}p{precision}"
Expand Down
12 changes: 6 additions & 6 deletions src/gfloat/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def round_float(
case RoundMode.TowardNegative:
should_round_away = sign and delta > 0
case RoundMode.TiesToAway:
should_round_away = delta >= 0.5
should_round_away = delta + 0.5 >= 1.0
case RoundMode.TiesToEven:
should_round_away = delta > 0.5 or (delta == 0.5 and code_is_odd)
case RoundMode.Stochastic:
Expand All @@ -113,20 +113,20 @@ def round_float(
(d - floord > 0.5) or ((d - floord == 0.5) and _isodd(floord))
)

should_round_away = d > srbits
should_round_away = d + srbits >= 2.0**srnumbits
case RoundMode.StochasticOdd:
## RTNE delta to srbits
d = delta * 2.0**srnumbits
floord = np.floor(d).astype(np.int64)
d = floord + (
(d - floord > 0.5) or ((d - floord == 0.5) and ~_isodd(floord))
(d - floord > 0.5) or ((d - floord == 0.5) and not _isodd(floord))
)

should_round_away = d > srbits
should_round_away = d + srbits >= 2.0**srnumbits
case RoundMode.StochasticFast:
should_round_away = delta > (0.5 + srbits) * 2.0**-srnumbits
should_round_away = delta + (0.5 + srbits) * 2.0**-srnumbits >= 1.0
case RoundMode.StochasticFastest:
should_round_away = delta > srbits * 2.0**-srnumbits
should_round_away = delta + srbits * 2.0**-srnumbits >= 1.0

if should_round_away:
# This may increase isignificand to 2**p,
Expand Down
113 changes: 78 additions & 35 deletions src/gfloat/round_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,36 @@
from types import ModuleType
from .types import FormatInfo, RoundMode
import numpy as np
import array_api_compat


def _isodd(v: np.ndarray) -> np.ndarray:
return v & 0x1 == 1


def _ldexp(v: np.ndarray, s: np.ndarray) -> np.ndarray:
xp = array_api_compat.array_namespace(v, s)
if (
array_api_compat.is_torch_array(v)
or array_api_compat.is_jax_array(v)
or array_api_compat.is_numpy_array(v)
):
return xp.ldexp(v, s)

# Scale away from subnormal/infinite ranges
offset = 24
vlo = (v * 2.0**+offset) * 2.0 ** xp.astype(s - offset, v.dtype)
vhi = (v * 2.0**-offset) * 2.0 ** xp.astype(s + offset, v.dtype)
return xp.where(v < 1.0, vlo, vhi)


def round_ndarray(
fi: FormatInfo,
v: np.ndarray,
rnd: RoundMode = RoundMode.TiesToEven,
sat: bool = False,
srbits: Optional[np.ndarray] = None,
srnumbits: int = 0,
np: ModuleType = np,
) -> np.ndarray:
"""
Vectorized version of :meth:`round_float`.
Expand All @@ -38,8 +54,6 @@ def round_ndarray(
srbits (int array): Bits to use for stochastic rounding if rnd == Stochastic.
srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits.

np (Module): May be `numpy`, `jax.numpy` or another module cloning numpy

Returns:
An array of floats which is a subset of the format's value set.

Expand All @@ -48,27 +62,42 @@ def round_ndarray(
(e.g. converting a `NaN`, or an `Inf` when the target has no
`NaN` or `Inf`, and :paramref:`sat` is false)
"""
xp = array_api_compat.array_namespace(v, srbits)

# Until https://github.com/data-apis/array-api/issues/807
xp_where = lambda a, t, f: xp.where(a, xp.asarray(t), xp.asarray(f))
xp_maximum = lambda a, b: xp.maximum(xp.asarray(a), xp.asarray(b))

p = fi.precision
bias = fi.expBias

is_negative = np.signbit(v) & fi.is_signed
absv = np.where(is_negative, -v, v)
is_negative = xp.signbit(v) & fi.is_signed
absv = xp_where(is_negative, -v, v)

finite_nonzero = ~(np.isnan(v) | np.isinf(v) | (v == 0))
finite_nonzero = ~(xp.isnan(v) | xp.isinf(v) | (v == 0))

# Place 1.0 where finite_nonzero is False, to avoid log of {0,inf,nan}
absv_masked = np.where(finite_nonzero, absv, 1.0)
absv_masked = xp_where(finite_nonzero, absv, 1.0)

int_type = xp.int64 if fi.k > 8 or srnumbits > 8 else xp.int16

def to_int(x: np.ndarray) -> np.ndarray:
return xp.astype(x, int_type)

def to_float(x: np.ndarray) -> np.ndarray:
return xp.astype(x, v.dtype)

expval = np.floor(np.log2(absv_masked)).astype(int)
expval = to_int(xp.floor(xp.log2(absv_masked)))

if fi.has_subnormals:
expval = np.maximum(expval, 1 - bias)
expval = xp_maximum(expval, 1 - bias)

expval = expval - p + 1
fsignificand = np.ldexp(absv_masked, -expval)
fsignificand = _ldexp(absv_masked, -expval)

isignificand = np.floor(fsignificand).astype(np.int64)
delta = fsignificand - isignificand
floorfsignificand = xp.floor(fsignificand)
isignificand = to_int(floorfsignificand)
delta = fsignificand - floorfsignificand

if fi.precision > 1:
code_is_odd = _isodd(isignificand)
Expand All @@ -77,48 +106,62 @@ def round_ndarray(

match rnd:
case RoundMode.TowardZero:
should_round_away = np.zeros_like(delta, dtype=bool)
should_round_away = xp.zeros_like(delta, dtype=xp.bool)

case RoundMode.TowardPositive:
should_round_away = ~is_negative & (delta > 0)

case RoundMode.TowardNegative:
should_round_away = is_negative & (delta > 0)

case RoundMode.TiesToAway:
should_round_away = delta >= 0.5

case RoundMode.TiesToEven:
should_round_away = (delta > 0.5) | ((delta == 0.5) & code_is_odd)

case RoundMode.Stochastic:
assert srbits is not None
## RTNE delta to srbits
d = delta * 2.0**srnumbits
floord = np.floor(d).astype(np.int64)
dd = d - floord
drnd = floord + (dd > 0.5) + ((dd == 0.5) & _isodd(floord))
floord = to_int(xp.floor(d))
dd = d - xp.floor(d)
should_round_away_tne = (dd > 0.5) | ((dd == 0.5) & _isodd(floord))
drnd = floord + xp.astype(should_round_away_tne, floord.dtype)

should_round_away = drnd + srbits >= 2**srnumbits

should_round_away = drnd > srbits
case RoundMode.StochasticOdd:
assert srbits is not None
## RTNO delta to srbits
d = delta * 2.0**srnumbits
floord = np.floor(d).astype(np.int64)
dd = d - floord
drnd = floord + (dd > 0.5) + ((dd == 0.5) & ~_isodd(floord))
floord = to_int(xp.floor(d))
dd = d - xp.floor(d)
should_round_away_tno = (dd > 0.5) | ((dd == 0.5) & ~_isodd(floord))
drnd = floord + xp.astype(should_round_away_tno, floord.dtype)

should_round_away = drnd + srbits >= 2**srnumbits

should_round_away = drnd > srbits
case RoundMode.StochasticFast:
assert srbits is not None
should_round_away = delta > (2 * srbits + 1) * 2.0 ** -(1 + srnumbits)
should_round_away = (
delta + to_float(2 * srbits + 1) * 2.0 ** -(1 + srnumbits) >= 1.0
)

case RoundMode.StochasticFastest:
assert srbits is not None
should_round_away = delta > srbits * 2.0**-srnumbits
should_round_away = delta + to_float(srbits) * 2.0**-srnumbits >= 1.0

isignificand = xp_where(should_round_away, isignificand + 1, isignificand)

isignificand = np.where(should_round_away, isignificand + 1, isignificand)
fresult = _ldexp(to_float(isignificand), expval)

result = np.where(finite_nonzero, np.ldexp(isignificand, expval), absv)
result = xp_where(finite_nonzero, fresult, absv)

amax = np.where(is_negative, -fi.min, fi.max)
amax = xp_where(is_negative, -fi.min, fi.max)

if sat:
result = np.where(result > amax, amax, result)
result = xp_where(result > amax, amax, result)
else:
match rnd:
case RoundMode.TowardNegative:
Expand All @@ -128,25 +171,25 @@ def round_ndarray(
case RoundMode.TowardZero:
put_amax_at = result > amax
case _:
put_amax_at = np.zeros_like(result, dtype=bool)
put_amax_at = xp.zeros_like(result, dtype=xp.bool)

result = np.where(finite_nonzero & put_amax_at, amax, result)
result = xp_where(finite_nonzero & put_amax_at, amax, result)

# Now anything larger than amax goes to infinity or NaN
if fi.has_infs:
result = np.where(result > amax, np.inf, result)
result = xp_where(result > amax, xp.inf, result)
elif fi.num_nans > 0:
result = np.where(result > amax, np.nan, result)
result = xp_where(result > amax, xp.nan, result)
else:
if np.any(result > amax):
if xp.any(result > amax):
raise ValueError(f"No Infs or NaNs in format {fi}, and sat=False")

result = np.where(is_negative, -result, result)
result = xp_where(is_negative, -result, result)

# Make negative zeros negative if has_nz, else make them not negative.
if fi.has_nz:
result = np.where((result == 0) & is_negative, -0.0, result)
result = xp_where((result == 0) & is_negative, -0.0, result)
else:
result = np.where(result == 0, 0.0, result)
result = xp_where(result == 0, 0.0, result)

return result
2 changes: 1 addition & 1 deletion src/gfloat/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def min(self) -> float:
return -self.max
else:
assert not self.has_infs and self.num_high_nans == 0 and not self.has_nz
return -(2 ** (self.emax + 1))
return -(2.0 ** (self.emax + 1))
elif self.has_zero:
return 0.0
else:
Expand Down
Loading