diff --git a/billiard/common.py b/billiard/common.py index eda0ffb..b0440cb 100644 --- a/billiard/common.py +++ b/billiard/common.py @@ -6,6 +6,7 @@ import signal import sys +import dill import pickle from .exceptions import RestartFreqExceeded @@ -14,6 +15,9 @@ pickle_load = pickle.load pickle_loads = pickle.loads +dill_load = dill.load +dill_loads = dill.loads + # cPickle.loads does not support buffer() objects, # but we can just create a StringIO and use load. from io import BytesIO diff --git a/billiard/pool.py b/billiard/pool.py index da4cc4e..c56794d 100644 --- a/billiard/pool.py +++ b/billiard/pool.py @@ -27,7 +27,7 @@ from . import cpu_count, get_context from . import util from .common import ( - TERM_SIGNAL, human_status, pickle_loads, reset_signals, restart_state, + TERM_SIGNAL, human_status, dill_loads, reset_signals, restart_state, ) from .compat import get_errno, mem_rss, send_offset from .einfo import ExceptionInfo @@ -441,7 +441,7 @@ def _make_recv_method(self, conn): if hasattr(conn, 'get_payload') and conn.get_payload: get_payload = conn.get_payload - def _recv(timeout, loads=pickle_loads): + def _recv(timeout, loads=dill_loads): return True, loads(get_payload()) else: def _recv(timeout): # noqa @@ -456,7 +456,7 @@ def _recv(timeout): # noqa return False, None return _recv - def _make_child_methods(self, loads=pickle_loads): + def _make_child_methods(self, loads=dill_loads): self.wait_for_job = self._make_protected_receive(self.inq) self.wait_for_syn = (self._make_protected_receive(self.synq) if self.synq else None) diff --git a/requirements/default.txt b/requirements/default.txt new file mode 100644 index 0000000..e2f3fc4 --- /dev/null +++ b/requirements/default.txt @@ -0,0 +1 @@ +dill>=0.3.8 diff --git a/setup.py b/setup.py index ac5cda5..13522d9 100644 --- a/setup.py +++ b/setup.py @@ -166,6 +166,42 @@ def _is_build_command(argv=sys.argv, cmds=('install', 'build', 'bdist')): return arg +def _strip_comments(l): + return l.split('#', 1)[0].strip() + + +def _pip_requirement(req): + if req.startswith('-r '): + _, path = req.split() + return reqs(*path.split('/')) + return [req] + + +def _reqs(*f): + return [ + _pip_requirement(r) for r in ( + _strip_comments(l) for l in open( + os.path.join(os.getcwd(), 'requirements', *f)).readlines() + ) if r] + + +def reqs(*f): + """Parse requirement file. + + Example: + reqs('default.txt') # requirements/default.txt + reqs('extras', 'redis.txt') # requirements/extras/redis.txt + Returns: + List[str]: list of requirements specified in the file. + """ + return [req for subreq in _reqs(*f) for req in subreq] + + +def install_requires(): + """Get list of requirements required for installation.""" + return reqs('default.txt') + + def run_setup(with_extensions=True): extensions = [] if with_extensions: @@ -204,6 +240,7 @@ def run_setup(with_extensions=True): maintainer=meta['maintainer'], maintainer_email=meta['contact'], url=meta['homepage'], + install_requires=install_requires(), zip_safe=False, license='BSD', python_requires='>=3.7', diff --git a/t/integration/tests/test_multiprocessing.py b/t/integration/tests/test_multiprocessing.py index c5c39a5..209b09d 100644 --- a/t/integration/tests/test_multiprocessing.py +++ b/t/integration/tests/test_multiprocessing.py @@ -13,6 +13,8 @@ import array import random import logging + +import dill from StringIO import StringIO import pytest @@ -1410,6 +1412,36 @@ def test_sendbytes(self): self.assertRaises(ValueError, a.send_bytes, msg, -1) self.assertRaises(ValueError, a.send_bytes, msg, 4, -1) + def test_sendlocals(self): + # We test sending and receiving variables (i.e. lambdas or instances of dynamically generated classes) + if self.TYPE != 'processes': + return + + a, b = self.Pipe() + + initial_lambda_function = lambda x: x + 1 + a.send_bytes(dill.dumps(obj=initial_lambda_function)) + received_lambda_function = dill.loads(b.recv_bytes()) + self.assertEqual(initial_lambda_function(0), received_lambda_function(0)) + + class ClassGenerator: + @staticmethod + def generate(generated_class_id: int) -> type: + class GeneratedClass: + class_id: int = generated_class_id + + def __init__(self, instance_id: int): + self.instance_id: int = instance_id + + def __eq__(self, other) -> bool: + return self.class_id == other.class_id and self.instance_id == other.instance_id + return GeneratedClass + generated_class: type = ClassGenerator.generate(generated_class_id=0) + initial_generated_class_instance: generated_class = generated_class(instance_id=1) + a.send_bytes(dill.dumps(obj=initial_generated_class_instance)) + received_generated_class_instance = dill.loads(b.recv_bytes()) + self.assertEqual(initial_generated_class_instance, received_generated_class_instance) + class _TestListenerClient(BaseTestCase):