4
4
5
5
import os
6
6
import pytest
7
+ import sys
8
+ try :
9
+ from cStringIO import StringIO
10
+ except ImportError :
11
+ from io import StringIO
7
12
8
13
import ray
9
14
from ray import tune
10
15
from ray .rllib import _register_all
11
16
from ray .tune import commands
12
17
13
18
19
+ class Capturing ():
20
+ def __enter__ (self ):
21
+ self ._stdout = sys .stdout
22
+ sys .stdout = self ._stringio = StringIO ()
23
+ self .captured = []
24
+ return self
25
+
26
+ def __exit__ (self , * args ):
27
+ self .captured .extend (self ._stringio .getvalue ().splitlines ())
28
+ del self ._stringio # free up some memory
29
+ sys .stdout = self ._stdout
30
+
31
+
14
32
@pytest .fixture
15
33
def start_ray ():
16
34
ray .init ()
@@ -19,48 +37,46 @@ def start_ray():
19
37
ray .shutdown ()
20
38
21
39
22
- def test_ls (start_ray , capsys , tmpdir ):
40
+ def test_ls (start_ray , tmpdir ):
23
41
"""This test captures output of list_trials."""
24
42
experiment_name = "test_ls"
25
43
experiment_path = os .path .join (str (tmpdir ), experiment_name )
26
44
num_samples = 2
27
- with capsys .disabled ():
28
- tune .run_experiments ({
29
- experiment_name : {
30
- "run" : "__fake" ,
31
- "stop" : {
32
- "training_iteration" : 1
33
- },
34
- "num_samples" : num_samples ,
35
- "local_dir" : str (tmpdir )
36
- }
37
- })
45
+ tune .run_experiments ({
46
+ experiment_name : {
47
+ "run" : "__fake" ,
48
+ "stop" : {
49
+ "training_iteration" : 1
50
+ },
51
+ "num_samples" : num_samples ,
52
+ "local_dir" : str (tmpdir )
53
+ }
54
+ })
38
55
39
- commands . list_trials ( experiment_path , info_keys = ( "status" , ))
40
- captured = capsys . readouterr (). out . strip ( )
41
- lines = captured . split ( " \n " )
56
+ with Capturing () as output :
57
+ commands . list_trials ( experiment_path , info_keys = ( "status" , ) )
58
+ lines = output . captured
42
59
assert sum ("TERMINATED" in line for line in lines ) == num_samples
43
60
44
61
45
- def test_lsx (start_ray , capsys , tmpdir ):
62
+ def test_lsx (start_ray , tmpdir ):
46
63
"""This test captures output of list_experiments."""
47
64
project_path = str (tmpdir )
48
65
num_experiments = 3
49
66
for i in range (num_experiments ):
50
67
experiment_name = "test_lsx{}" .format (i )
51
- with capsys .disabled ():
52
- tune .run_experiments ({
53
- experiment_name : {
54
- "run" : "__fake" ,
55
- "stop" : {
56
- "training_iteration" : 1
57
- },
58
- "num_samples" : 1 ,
59
- "local_dir" : project_path
60
- }
61
- })
68
+ tune .run_experiments ({
69
+ experiment_name : {
70
+ "run" : "__fake" ,
71
+ "stop" : {
72
+ "training_iteration" : 1
73
+ },
74
+ "num_samples" : 1 ,
75
+ "local_dir" : project_path
76
+ }
77
+ })
62
78
63
- commands . list_experiments ( project_path , info_keys = ( "total_trials" , ))
64
- captured = capsys . readouterr (). out . strip ( )
65
- lines = captured . split ( " \n " )
79
+ with Capturing () as output :
80
+ commands . list_experiments ( project_path , info_keys = ( "total_trials" , ) )
81
+ lines = output . captured
66
82
assert sum ("1" in line for line in lines ) >= 3
0 commit comments