Skip to content

Commit a99412e

Browse files
Merge pull request #496 from emmaai/master
Support multithreading
2 parents 7484828 + c4f527d commit a99412e

File tree

4 files changed

+308
-5
lines changed

4 files changed

+308
-5
lines changed

bench/large_array_vs_numpy.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#################################################################################
2+
# To mimic the scenario that computation is i/o bound and constrained by memory
3+
#
4+
# It's a much simplified version that the chunk is computed in a loop,
5+
# and expression is evaluated in a sequence, which is not true in reality.
6+
# Neverthless, numexpr outperforms numpy.
7+
#################################################################################
8+
"""
9+
Benchmarking Expression 1:
10+
NumPy time (threaded over 32 chunks with 2 threads): 4.612313 seconds
11+
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 0.951172 seconds
12+
numexpr speedup: 4.85x
13+
----------------------------------------
14+
Benchmarking Expression 2:
15+
NumPy time (threaded over 32 chunks with 2 threads): 23.862752 seconds
16+
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 2.182058 seconds
17+
numexpr speedup: 10.94x
18+
----------------------------------------
19+
Benchmarking Expression 3:
20+
NumPy time (threaded over 32 chunks with 2 threads): 20.594895 seconds
21+
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 2.927881 seconds
22+
numexpr speedup: 7.03x
23+
----------------------------------------
24+
Benchmarking Expression 4:
25+
NumPy time (threaded over 32 chunks with 2 threads): 12.834101 seconds
26+
numexpr time (threaded with re_evaluate over 32 chunks with 2 threads): 5.392480 seconds
27+
numexpr speedup: 2.38x
28+
----------------------------------------
29+
"""
30+
31+
import os
32+
33+
os.environ["NUMEXPR_NUM_THREADS"] = "16"
34+
import numpy as np
35+
import numexpr as ne
36+
import timeit
37+
import threading
38+
39+
array_size = 10**8
40+
num_runs = 10
41+
num_chunks = 32 # Number of chunks
42+
num_threads = 2 # Number of threads constrained by how many chunks memory can hold
43+
44+
a = np.random.rand(array_size).reshape(10**4, -1)
45+
b = np.random.rand(array_size).reshape(10**4, -1)
46+
c = np.random.rand(array_size).reshape(10**4, -1)
47+
48+
chunk_size = array_size // num_chunks
49+
50+
expressions_numpy = [
51+
lambda a, b, c: a + b * c,
52+
lambda a, b, c: a**2 + b**2 - 2 * a * b * np.cos(c),
53+
lambda a, b, c: np.sin(a) + np.log(b) * np.sqrt(c),
54+
lambda a, b, c: np.exp(a) + np.tan(b) - np.sinh(c),
55+
]
56+
57+
expressions_numexpr = [
58+
"a + b * c",
59+
"a**2 + b**2 - 2 * a * b * cos(c)",
60+
"sin(a) + log(b) * sqrt(c)",
61+
"exp(a) + tan(b) - sinh(c)",
62+
]
63+
64+
65+
def benchmark_numpy_chunk(func, a, b, c, results, indices):
66+
for index in indices:
67+
start = index * chunk_size
68+
end = (index + 1) * chunk_size
69+
time_taken = timeit.timeit(
70+
lambda: func(a[start:end], b[start:end], c[start:end]), number=num_runs
71+
)
72+
results.append(time_taken)
73+
74+
75+
def benchmark_numexpr_re_evaluate(expr, a, b, c, results, indices):
76+
for index in indices:
77+
start = index * chunk_size
78+
end = (index + 1) * chunk_size
79+
if index == 0:
80+
# Evaluate the first chunk with evaluate
81+
time_taken = timeit.timeit(
82+
lambda: ne.evaluate(
83+
expr,
84+
local_dict={
85+
"a": a[start:end],
86+
"b": b[start:end],
87+
"c": c[start:end],
88+
},
89+
),
90+
number=num_runs,
91+
)
92+
else:
93+
# Re-evaluate subsequent chunks with re_evaluate
94+
time_taken = timeit.timeit(
95+
lambda: ne.re_evaluate(
96+
local_dict={"a": a[start:end], "b": b[start:end], "c": c[start:end]}
97+
),
98+
number=num_runs,
99+
)
100+
results.append(time_taken)
101+
102+
103+
def run_benchmark_threaded():
104+
chunk_indices = list(range(num_chunks))
105+
106+
for i in range(len(expressions_numpy)):
107+
print(f"Benchmarking Expression {i+1}:")
108+
109+
results_numpy = []
110+
results_numexpr = []
111+
112+
threads_numpy = []
113+
for j in range(num_threads):
114+
indices = chunk_indices[j::num_threads] # Distribute chunks across threads
115+
thread = threading.Thread(
116+
target=benchmark_numpy_chunk,
117+
args=(expressions_numpy[i], a, b, c, results_numpy, indices),
118+
)
119+
threads_numpy.append(thread)
120+
thread.start()
121+
122+
for thread in threads_numpy:
123+
thread.join()
124+
125+
numpy_time = sum(results_numpy)
126+
print(
127+
f"NumPy time (threaded over {num_chunks} chunks with {num_threads} threads): {numpy_time:.6f} seconds"
128+
)
129+
130+
threads_numexpr = []
131+
for j in range(num_threads):
132+
indices = chunk_indices[j::num_threads] # Distribute chunks across threads
133+
thread = threading.Thread(
134+
target=benchmark_numexpr_re_evaluate,
135+
args=(expressions_numexpr[i], a, b, c, results_numexpr, indices),
136+
)
137+
threads_numexpr.append(thread)
138+
thread.start()
139+
140+
for thread in threads_numexpr:
141+
thread.join()
142+
143+
numexpr_time = sum(results_numexpr)
144+
print(
145+
f"numexpr time (threaded with re_evaluate over {num_chunks} chunks with {num_threads} threads): {numexpr_time:.6f} seconds"
146+
)
147+
print(f"numexpr speedup: {numpy_time / numexpr_time:.2f}x")
148+
print("-" * 40)
149+
150+
151+
if __name__ == "__main__":
152+
run_benchmark_threaded()

numexpr/necompiler.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE
2121
from numexpr import interpreter, expressions, use_vml
22-
from numexpr.utils import CacheDict
22+
from numexpr.utils import CacheDict, ContextDict
2323

2424
# Declare a double type that does not exist in Python space
2525
double = numpy.double
@@ -776,11 +776,9 @@ def getArguments(names, local_dict=None, global_dict=None, _frame_depth: int=2):
776776
# Dictionaries for caching variable names and compiled expressions
777777
_names_cache = CacheDict(256)
778778
_numexpr_cache = CacheDict(256)
779-
_numexpr_last = {}
779+
_numexpr_last = ContextDict()
780780
evaluate_lock = threading.Lock()
781781

782-
# MAYBE: decorate this function to add attributes instead of having the
783-
# _numexpr_last dictionary?
784782
def validate(ex: str,
785783
local_dict: Optional[Dict] = None,
786784
global_dict: Optional[Dict] = None,
@@ -887,7 +885,7 @@ def validate(ex: str,
887885
compiled_ex = _numexpr_cache[numexpr_key] = NumExpr(ex, signature, sanitize=sanitize, **context)
888886
kwargs = {'out': out, 'order': order, 'casting': casting,
889887
'ex_uses_vml': ex_uses_vml}
890-
_numexpr_last = dict(ex=compiled_ex, argnames=names, kwargs=kwargs)
888+
_numexpr_last.set(ex=compiled_ex, argnames=names, kwargs=kwargs)
891889
except Exception as e:
892890
return e
893891
return None

numexpr/tests/test_numexpr.py

+72
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,7 @@ def run(self):
12011201
test.join()
12021202

12031203
def test_multithread(self):
1204+
12041205
import threading
12051206

12061207
# Running evaluate() from multiple threads shouldn't crash
@@ -1218,6 +1219,77 @@ def work(n):
12181219
for t in threads:
12191220
t.join()
12201221

1222+
def test_thread_safety(self):
1223+
"""
1224+
Expected output
1225+
1226+
When not safe (before the pr this test is commited)
1227+
AssertionError: Thread-0 failed: result does not match expected
1228+
1229+
When safe (after the pr this test is commited)
1230+
Should pass without failure
1231+
"""
1232+
import threading
1233+
import time
1234+
1235+
barrier = threading.Barrier(4)
1236+
1237+
# Function that each thread will run with different expressions
1238+
def thread_function(a_value, b_value, expression, expected_result, results, index):
1239+
validate(expression, local_dict={"a": a_value, "b": b_value})
1240+
# Wait for all threads to reach this point
1241+
# such that they all set _numexpr_last
1242+
barrier.wait()
1243+
1244+
# Simulate some work or a context switch delay
1245+
time.sleep(0.1)
1246+
1247+
result = re_evaluate(local_dict={"a": a_value, "b": b_value})
1248+
results[index] = np.array_equal(result, expected_result)
1249+
1250+
def test_thread_safety_with_numexpr():
1251+
num_threads = 4
1252+
array_size = 1000000
1253+
1254+
expressions = [
1255+
"a + b",
1256+
"a - b",
1257+
"a * b",
1258+
"a / b"
1259+
]
1260+
1261+
a_value = [np.full(array_size, i + 1) for i in range(num_threads)]
1262+
b_value = [np.full(array_size, (i + 1) * 2) for i in range(num_threads)]
1263+
1264+
expected_results = [
1265+
a_value[i] + b_value[i] if expr == "a + b" else
1266+
a_value[i] - b_value[i] if expr == "a - b" else
1267+
a_value[i] * b_value[i] if expr == "a * b" else
1268+
a_value[i] / b_value[i] if expr == "a / b" else None
1269+
for i, expr in enumerate(expressions)
1270+
]
1271+
1272+
results = [None] * num_threads
1273+
threads = []
1274+
1275+
# Create and start threads with different expressions
1276+
for i in range(num_threads):
1277+
thread = threading.Thread(
1278+
target=thread_function,
1279+
args=(a_value[i], b_value[i], expressions[i], expected_results[i], results, i)
1280+
)
1281+
threads.append(thread)
1282+
thread.start()
1283+
1284+
for thread in threads:
1285+
thread.join()
1286+
1287+
for i in range(num_threads):
1288+
if not results[i]:
1289+
self.fail(f"Thread-{i} failed: result does not match expected")
1290+
1291+
test_thread_safety_with_numexpr()
1292+
12211293

12221294
# The worker function for the subprocess (needs to be here because Windows
12231295
# has problems pickling nested functions with the multiprocess module :-/)

numexpr/utils.py

+81
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import os
1515
import subprocess
16+
import contextvars
1617

1718
from numexpr.interpreter import _set_num_threads, _get_num_threads, MAX_THREADS
1819
from numexpr import use_vml
@@ -226,3 +227,83 @@ def __setitem__(self, key, value):
226227
super(CacheDict, self).__delitem__(k)
227228
super(CacheDict, self).__setitem__(key, value)
228229

230+
231+
class ContextDict:
232+
"""
233+
A context aware version dictionary
234+
"""
235+
def __init__(self):
236+
self._context_data = contextvars.ContextVar('context_data', default={})
237+
238+
def set(self, key=None, value=None, **kwargs):
239+
data = self._context_data.get().copy()
240+
241+
if key is not None:
242+
data[key] = value
243+
244+
for k, v in kwargs.items():
245+
data[k] = v
246+
247+
self._context_data.set(data)
248+
249+
def get(self, key, default=None):
250+
data = self._context_data.get()
251+
return data.get(key, default)
252+
253+
def delete(self, key):
254+
data = self._context_data.get().copy()
255+
if key in data:
256+
del data[key]
257+
self._context_data.set(data)
258+
259+
def clear(self):
260+
self._context_data.set({})
261+
262+
def all(self):
263+
return self._context_data.get()
264+
265+
def update(self, *args, **kwargs):
266+
data = self._context_data.get().copy()
267+
268+
if args:
269+
if len(args) > 1:
270+
raise TypeError(f"update() takes at most 1 positional argument ({len(args)} given)")
271+
other = args[0]
272+
if isinstance(other, dict):
273+
data.update(other)
274+
else:
275+
for k, v in other:
276+
data[k] = v
277+
278+
data.update(kwargs)
279+
self._context_data.set(data)
280+
281+
def keys(self):
282+
return self._context_data.get().keys()
283+
284+
def values(self):
285+
return self._context_data.get().values()
286+
287+
def items(self):
288+
return self._context_data.get().items()
289+
290+
def __getitem__(self, key):
291+
return self.get(key)
292+
293+
def __setitem__(self, key, value):
294+
self.set(key, value)
295+
296+
def __delitem__(self, key):
297+
self.delete(key)
298+
299+
def __contains__(self, key):
300+
return key in self._context_data.get()
301+
302+
def __len__(self):
303+
return len(self._context_data.get())
304+
305+
def __iter__(self):
306+
return iter(self._context_data.get())
307+
308+
def __repr__(self):
309+
return repr(self._context_data.get())

0 commit comments

Comments
 (0)