File tree 1 file changed +18
-0
lines changed
1 file changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -33,3 +33,21 @@ class Example:
33
33
#: what might be interesting to use in the ID algorithm
34
34
example_queries : Optional [list [Query ]] = None
35
35
generate_data : Optional [Callable [[int , Optional [dict [Variable , float ]]], pd .DataFrame ]] = None
36
+
37
+ def generate_ate (
38
+ self ,
39
+ * ,
40
+ num_samples : int ,
41
+ treatment : Variable ,
42
+ outcome : Variable ,
43
+ treatment_0 : float = 0.0 ,
44
+ treatment_1 : float = 1.0 ,
45
+ ** kwargs ,
46
+ ) -> float :
47
+ """Calculate the ATE for a single treatment/outcome pair."""
48
+ if self .generate_data is None :
49
+ raise TypeError (f"no generation method provided in example: { self .name } " )
50
+
51
+ data_1 = self .generate_data (num_samples , {treatment : treatment_1 }, ** kwargs )
52
+ data_0 = self .generate_data (num_samples , {treatment : treatment_0 }, ** kwargs )
53
+ return data_1 .mean ()[outcome .name ] - data_0 .mean ()[outcome .name ]
You can’t perform that action at this time.
0 commit comments