diff --git a/cfr/__init__.py b/cfr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cfr/cfr_net.py b/cfr/cfr_net.py new file mode 100644 index 0000000..9805614 --- /dev/null +++ b/cfr/cfr_net.py @@ -0,0 +1,277 @@ +import tensorflow as tf +import numpy as np + +from util import * + +class cfr_net(object): + """ + cfr_net implements the counterfactual regression neural network + by F. Johansson, U. Shalit and D. Sontag: https://arxiv.org/abs/1606.03976 + + This file contains the class cfr_net as well as helper functions. + The network is implemented as a tensorflow graph. The class constructor + creates an object containing relevant TF nodes as member variables. + """ + + def __init__(self, x, t, y_ , p_t, FLAGS, r_alpha, r_lambda, do_in, do_out, dims): + self.variables = {} + self.wd_loss = 0 + + if FLAGS.nonlin.lower() == 'elu': + self.nonlin = tf.nn.elu + else: + self.nonlin = tf.nn.relu + + self._build_graph(x, t, y_ , p_t, FLAGS, r_alpha, r_lambda, do_in, do_out, dims) + + def _add_variable(self, var, name): + ''' Adds variables to the internal track-keeper ''' + basename = name + i = 0 + while name in self.variables: + name = '%s_%d' % (basename, i) #@TODO: not consistent with TF internally if changed + i += 1 + + self.variables[name] = var + + def _create_variable(self, var, name): + ''' Create and adds variables to the internal track-keeper ''' + + var = tf.Variable(var, name=name) + self._add_variable(var, name) + return var + + def _create_variable_with_weight_decay(self, initializer, name, wd): + ''' Create and adds variables to the internal track-keeper + and adds it to the list of weight decayed variables ''' + var = self._create_variable(initializer, name) + self.wd_loss += wd*tf.nn.l2_loss(var) + return var + + def _build_graph(self, x, t, y_ , p_t, FLAGS, r_alpha, r_lambda, do_in, do_out, dims): + """ + Constructs a TensorFlow subgraph for counterfactual regression. + Sets the following member variables (to TF nodes): + + self.output The output prediction "y" + self.tot_loss The total objective to minimize + self.imb_loss The imbalance term of the objective + self.pred_loss The prediction term of the objective + self.weights_in The input/representation layer weights + self.weights_out The output/post-representation layer weights + self.weights_pred The (linear) prediction layer weights + self.h_rep The layer of the penalized representation + """ + + self.x = x + self.t = t + self.y_ = y_ + self.p_t = p_t + self.r_alpha = r_alpha + self.r_lambda = r_lambda + self.do_in = do_in + self.do_out = do_out + + dim_input = dims[0] + dim_in = dims[1] + dim_out = dims[2] + + weights_in = []; biases_in = [] + + if FLAGS.n_in == 0 or (FLAGS.n_in == 1 and FLAGS.varsel): + dim_in = dim_input + if FLAGS.n_out == 0: + if FLAGS.split_output == False: + dim_out = dim_in+1 + else: + dim_out = dim_in + + if FLAGS.batch_norm: + bn_biases = [] + bn_scales = [] + + ''' Construct input/representation layers ''' + h_in = [x] + for i in range(0, FLAGS.n_in): + if i==0: + ''' If using variable selection, first layer is just rescaling''' + if FLAGS.varsel: + weights_in.append(tf.Variable(1.0/dim_input*tf.ones([dim_input]))) + else: + weights_in.append(tf.Variable(tf.random_normal([dim_input, dim_in], stddev=FLAGS.weight_init/np.sqrt(dim_input)))) + else: + weights_in.append(tf.Variable(tf.random_normal([dim_in,dim_in], stddev=FLAGS.weight_init/np.sqrt(dim_in)))) + + ''' If using variable selection, first layer is just rescaling''' + if FLAGS.varsel and i==0: + biases_in.append([]) + h_in.append(tf.mul(h_in[i],weights_in[i])) + else: + biases_in.append(tf.Variable(tf.zeros([1,dim_in]))) + z = tf.matmul(h_in[i], weights_in[i]) + biases_in[i] + + if FLAGS.batch_norm: + batch_mean, batch_var = tf.nn.moments(z, [0]) + + if FLAGS.normalization == 'bn_fixed': + z = tf.nn.batch_normalization(z, batch_mean, batch_var, 0, 1, 1e-3) + else: + bn_biases.append(tf.Variable(tf.zeros([dim_in]))) + bn_scales.append(tf.Variable(tf.ones([dim_in]))) + z = tf.nn.batch_normalization(z, batch_mean, batch_var, bn_biases[-1], bn_scales[-1], 1e-3) + + h_in.append(self.nonlin(z)) + h_in[i+1] = tf.nn.dropout(h_in[i+1], do_in) + + h_rep = h_in[len(h_in)-1] + + if FLAGS.normalization == 'divide': + h_rep_norm = h_rep / safe_sqrt(tf.reduce_sum(tf.square(h_rep), axis=1, keep_dims=True)) + else: + h_rep_norm = 1.0*h_rep + + ''' Construct ouput layers ''' + y, weights_out, weights_pred = self._build_output_graph(h_rep_norm, t, dim_in, dim_out, do_out, FLAGS) + + ''' Compute sample reweighting ''' + if FLAGS.reweight_sample: + w_t = t/(2*p_t) + w_c = (1-t)/(2*1-p_t) + sample_weight = w_t + w_c + else: + sample_weight = 1.0 + + self.sample_weight = sample_weight + + ''' Construct factual loss function ''' + if FLAGS.loss == 'l1': + risk = tf.reduce_mean(sample_weight*tf.abs(y_-y)) + pred_error = -tf.reduce_mean(res) + elif FLAGS.loss == 'log': + y = 0.995/(1.0+tf.exp(-y)) + 0.0025 + res = y_*tf.log(y) + (1.0-y_)*tf.log(1.0-y) + + risk = -tf.reduce_mean(sample_weight*res) + pred_error = -tf.reduce_mean(res) + else: + risk = tf.reduce_mean(sample_weight*tf.square(y_ - y)) + pred_error = tf.sqrt(tf.reduce_mean(tf.square(y_ - y))) + + ''' Regularization ''' + if FLAGS.p_lambda>0 and FLAGS.rep_weight_decay: + for i in range(0, FLAGS.n_in): + if not (FLAGS.varsel and i==0): # No penalty on W in variable selection + self.wd_loss += tf.nn.l2_loss(weights_in[i]) + + ''' Imbalance error ''' + if FLAGS.use_p_correction: + p_ipm = self.p_t + else: + p_ipm = 0.5 + + if FLAGS.imb_fun == 'mmd2_rbf': + imb_dist = mmd2_rbf(h_rep_norm,t,p_ipm,FLAGS.rbf_sigma) + imb_error = r_alpha*imb_dist + elif FLAGS.imb_fun == 'mmd2_lin': + imb_dist = mmd2_lin(h_rep_norm,t,p_ipm) + imb_error = r_alpha*mmd2_lin(h_rep_norm,t,p_ipm) + elif FLAGS.imb_fun == 'mmd_rbf': + imb_dist = tf.abs(mmd2_rbf(h_rep_norm,t,p_ipm,FLAGS.rbf_sigma)) + imb_error = safe_sqrt(tf.square(r_alpha)*imb_dist) + elif FLAGS.imb_fun == 'mmd_lin': + imb_dist = mmd2_lin(h_rep_norm,t,p_ipm) + imb_error = safe_sqrt(tf.square(r_alpha)*imb_dist) + elif FLAGS.imb_fun == 'wass': + imb_dist, imb_mat = wasserstein(h_rep_norm,t,p_ipm,lam=FLAGS.wass_lambda,its=FLAGS.wass_iterations,sq=False,backpropT=FLAGS.wass_bpt) + imb_error = r_alpha * imb_dist + self.imb_mat = imb_mat # FOR DEBUG + elif FLAGS.imb_fun == 'wass2': + imb_dist, imb_mat = wasserstein(h_rep_norm,t,p_ipm,lam=FLAGS.wass_lambda,its=FLAGS.wass_iterations,sq=True,backpropT=FLAGS.wass_bpt) + imb_error = r_alpha * imb_dist + self.imb_mat = imb_mat # FOR DEBUG + else: + imb_dist = lindisc(h_rep_norm,p_ipm,t) + imb_error = r_alpha * imb_dist + + ''' Total error ''' + tot_error = risk + + if FLAGS.p_alpha>0: + tot_error = tot_error + imb_error + + if FLAGS.p_lambda>0: + tot_error = tot_error + r_lambda*self.wd_loss + + + if FLAGS.varsel: + self.w_proj = tf.placeholder("float", shape=[dim_input], name='w_proj') + self.projection = weights_in[0].assign(self.w_proj) + + self.output = y + self.tot_loss = tot_error + self.imb_loss = imb_error + self.imb_dist = imb_dist + self.pred_loss = pred_error + self.weights_in = weights_in + self.weights_out = weights_out + self.weights_pred = weights_pred + self.h_rep = h_rep + self.h_rep_norm = h_rep_norm + + def _build_output(self, h_input, dim_in, dim_out, do_out, FLAGS): + h_out = [h_input] + dims = [dim_in] + ([dim_out]*FLAGS.n_out) + + weights_out = []; biases_out = [] + + for i in range(0, FLAGS.n_out): + wo = self._create_variable_with_weight_decay( + tf.random_normal([dims[i], dims[i+1]], + stddev=FLAGS.weight_init/np.sqrt(dims[i])), + 'w_out_%d' % i, 1.0) + weights_out.append(wo) + + biases_out.append(tf.Variable(tf.zeros([1,dim_out]))) + z = tf.matmul(h_out[i], weights_out[i]) + biases_out[i] + # No batch norm on output because p_cf != p_f + + h_out.append(self.nonlin(z)) + h_out[i+1] = tf.nn.dropout(h_out[i+1], do_out) + + weights_pred = self._create_variable(tf.random_normal([dim_out,1], + stddev=FLAGS.weight_init/np.sqrt(dim_out)), 'w_pred') + bias_pred = self._create_variable(tf.zeros([1]), 'b_pred') + + if FLAGS.varsel or FLAGS.n_out == 0: + self.wd_loss += tf.nn.l2_loss(tf.slice(weights_pred,[0,0],[dim_out-1,1])) #don't penalize treatment coefficient + else: + self.wd_loss += tf.nn.l2_loss(weights_pred) + + ''' Construct linear classifier ''' + h_pred = h_out[-1] + y = tf.matmul(h_pred, weights_pred)+bias_pred + + return y, weights_out, weights_pred + + def _build_output_graph(self, rep, t, dim_in, dim_out, do_out, FLAGS): + ''' Construct output/regression layers ''' + + if FLAGS.split_output: + + i0 = tf.to_int32(tf.where(t < 1)[:,0]) + i1 = tf.to_int32(tf.where(t > 0)[:,0]) + + rep0 = tf.gather(rep, i0) + rep1 = tf.gather(rep, i1) + + y0, weights_out0, weights_pred0 = self._build_output(rep0, dim_in, dim_out, do_out, FLAGS) + y1, weights_out1, weights_pred1 = self._build_output(rep1, dim_in, dim_out, do_out, FLAGS) + + y = tf.dynamic_stitch([i0, i1], [y0, y1]) + weights_out = weights_out0 + weights_out1 + weights_pred = weights_pred0 + weights_pred1 + else: + h_input = tf.concat(1,[rep, t]) + y, weights_out, weights_pred = self._build_output(h_input, dim_in+1, dim_out, do_out, FLAGS) + + return y, weights_out, weights_pred diff --git a/cfr/evaluation.py b/cfr/evaluation.py new file mode 100644 index 0000000..e0e1ecb --- /dev/null +++ b/cfr/evaluation.py @@ -0,0 +1,402 @@ +import numpy as np +import os + +from logger import Logger as Log +from loader import * + +POL_CURVE_RES = 40 + +class NaNException(Exception): + pass + + +def policy_range(n, res=10): + step = int(float(n)/float(res)) + n_range = range(0,int(n+1),step) + if not n_range[-1] == n: + n_range.append(n) + + # To make sure every curve is same length. Incurs a small error if res high. + # Only occurs if number of units considered differs. + # For example if resampling validation sets (with different number of + # units in the randomized sub-population) + + while len(n_range) > res: + k = np.random.randint(len(n_range)-2)+1 + del n_range[k] + + return n_range + +def policy_val(t, yf, eff_pred, compute_policy_curve=False): + """ Computes the value of the policy defined by predicted effect """ + + if np.any(np.isnan(eff_pred)): + return np.nan, np.nan + + policy = eff_pred>0 + treat_overlap = (policy==t)*(t>0) + control_overlap = (policy==t)*(t<1) + + if np.sum(treat_overlap)==0: + treat_value = 0 + else: + treat_value = np.mean(yf[treat_overlap]) + + if np.sum(control_overlap)==0: + control_value = 0 + else: + control_value = np.mean(yf[control_overlap]) + + pit = np.mean(policy) + policy_value = pit*treat_value + (1-pit)*control_value + + policy_curve = [] + + if compute_policy_curve: + n = t.shape[0] + I_sort = np.argsort(-eff_pred) + + n_range = policy_range(n, POL_CURVE_RES) + + for i in n_range: + I = I_sort[0:i] + + policy_i = 0*policy + policy_i[I] = 1 + pit_i = np.mean(policy_i) + + treat_overlap = (policy_i>0)*(t>0) + control_overlap = (policy_i<1)*(t<1) + + if np.sum(treat_overlap)==0: + treat_value = 0 + else: + treat_value = np.mean(yf[treat_overlap]) + + if np.sum(control_overlap)==0: + control_value = 0 + else: + control_value = np.mean(yf[control_overlap]) + + policy_curve.append(pit_i*treat_value + (1-pit_i)*control_value) + + return policy_value, policy_curve + +def pdist2(X,Y): + """ Computes the squared Euclidean distance between all pairs x in X, y in Y """ + C = -2*X.dot(Y.T) + nx = np.sum(np.square(X),1,keepdims=True) + ny = np.sum(np.square(Y),1,keepdims=True) + D = (C + ny.T) + nx + + return np.sqrt(D + 1e-8) + +def cf_nn(x, t): + It = np.array(np.where(t==1))[0,:] + Ic = np.array(np.where(t==0))[0,:] + + x_c = x[Ic,:] + x_t = x[It,:] + + D = pdist2(x_c, x_t) + + nn_t = Ic[np.argmin(D,0)] + nn_c = It[np.argmin(D,1)] + + return nn_t, nn_c + +def pehe_nn(yf_p, ycf_p, y, x, t, nn_t=None, nn_c=None): + if nn_t is None or nn_c is None: + nn_t, nn_c = cf_nn(x,t) + + It = np.array(np.where(t==1))[0,:] + Ic = np.array(np.where(t==0))[0,:] + + ycf_t = 1.0*y[nn_t] + eff_nn_t = ycf_t - 1.0*y[It] + eff_pred_t = ycf_p[It] - yf_p[It] + + eff_pred = eff_pred_t + eff_nn = eff_nn_t + + ''' + ycf_c = 1.0*y[nn_c] + eff_nn_c = ycf_c - 1.0*y[Ic] + eff_pred_c = ycf_p[Ic] - yf_p[Ic] + + eff_pred = np.vstack((eff_pred_t, eff_pred_c)) + eff_nn = np.vstack((eff_nn_t, eff_nn_c)) + ''' + + pehe_nn = np.sqrt(np.mean(np.square(eff_pred - eff_nn))) + + return pehe_nn + +def evaluate_bin_att(predictions, data, i_exp, I_subset=None, + compute_policy_curve=False, nn_t=None, nn_c=None): + + x = data['x'][:,:,i_exp] + t = data['t'][:,i_exp] + e = data['e'][:,i_exp] + yf = data['yf'][:,i_exp] + yf_p = predictions[:,0] + ycf_p = predictions[:,1] + + att = np.mean(yf[t>0]) - np.mean(yf[(1-t+e)>1]) + + if not I_subset is None: + x = x[I_subset,:] + t = t[I_subset] + e = e[I_subset] + yf_p = yf_p[I_subset] + ycf_p = ycf_p[I_subset] + yf = yf[I_subset] + + yf_p_b = 1.0*(yf_p>0.5) + ycf_p_b = 1.0*(ycf_p>0.5) + + if np.any(np.isnan(yf_p)) or np.any(np.isnan(ycf_p)): + raise NaNException('NaN encountered') + + #IMPORTANT: NOT USING BINARIZATION FOR EFFECT, ONLY FOR CLASSIFICATION! + + eff_pred = ycf_p - yf_p; + eff_pred[t>0] = -eff_pred[t>0]; + + ate_pred = np.mean(eff_pred[e>0]) + atc_pred = np.mean(eff_pred[(1-t+e)>1]) + + att_pred = np.mean(eff_pred[(t+e)>1]) + bias_att = att_pred - att + + err_fact = np.mean(np.abs(yf_p_b-yf)) + + p1t = np.mean(yf[t>0]) + p1t_p = np.mean(yf_p[t>0]) + + lpr = np.log(p1t / p1t_p + 0.001) + + policy_value, policy_curve = \ + policy_val(t[e>0], yf[e>0], eff_pred[e>0], compute_policy_curve) + + pehe_appr = pehe_nn(yf_p, ycf_p, yf, x, t, nn_t, nn_c) + + return {'ate_pred': ate_pred, 'att_pred': att_pred, + 'bias_att': bias_att, 'atc_pred': atc_pred, + 'err_fact': err_fact, 'lpr': lpr, + 'policy_value': policy_value, 'policy_risk': 1-policy_value, + 'policy_curve': policy_curve, 'pehe_nn': pehe_appr} + +def evaluate_cont_ate(predictions, data, i_exp, I_subset=None, + compute_policy_curve=False, nn_t=None, nn_c=None): + + x = data['x'][:,:,i_exp] + t = data['t'][:,i_exp] + yf = data['yf'][:,i_exp] + ycf = data['ycf'][:,i_exp] + mu0 = data['mu0'][:,i_exp] + mu1 = data['mu1'][:,i_exp] + yf_p = predictions[:,0] + ycf_p = predictions[:,1] + + if not I_subset is None: + x = x[I_subset,] + t = t[I_subset] + yf_p = yf_p[I_subset] + ycf_p = ycf_p[I_subset] + yf = yf[I_subset] + ycf = ycf[I_subset] + mu0 = mu0[I_subset] + mu1 = mu1[I_subset] + + eff = mu1-mu0 + + rmse_fact = np.sqrt(np.mean(np.square(yf_p-yf))) + rmse_cfact = np.sqrt(np.mean(np.square(ycf_p-ycf))) + + eff_pred = ycf_p - yf_p; + eff_pred[t>0] = -eff_pred[t>0]; + + ite_pred = ycf_p - yf + ite_pred[t>0] = -ite_pred[t>0] + rmse_ite = np.sqrt(np.mean(np.square(ite_pred-eff))) + + ate_pred = np.mean(eff_pred) + bias_ate = ate_pred-np.mean(eff) + + att_pred = np.mean(eff_pred[t>0]) + bias_att = att_pred - np.mean(eff[t>0]) + + atc_pred = np.mean(eff_pred[t<1]) + bias_atc = atc_pred - np.mean(eff[t<1]) + + pehe = np.sqrt(np.mean(np.square(eff_pred-eff))) + + pehe_appr = pehe_nn(yf_p, ycf_p, yf, x, t, nn_t, nn_c) + + # @TODO: Not clear what this is for continuous data + #policy_value, policy_curve = policy_val(t, yf, eff_pred, compute_policy_curve) + + return {'ate_pred': ate_pred, 'att_pred': att_pred, + 'atc_pred': atc_pred, 'bias_ate': bias_ate, + 'bias_att': bias_att, 'bias_atc': bias_atc, + 'rmse_fact': rmse_fact, 'rmse_cfact': rmse_cfact, + 'pehe': pehe, 'rmse_ite': rmse_ite, 'pehe_nn': pehe_appr} + #'policy_value': policy_value, 'policy_curve': policy_curve} + +def evaluate_result(result, data, validation=False, + multiple_exps=False, binary=False): + + predictions = result['pred'] + + if validation: + I_valid = result['val'] + + n_units, _, n_rep, n_outputs = predictions.shape + + #@TODO: Should depend on parameter + compute_policy_curve = True + + eval_results = [] + #Loop over output_times + for i_out in range(n_outputs): + eval_results_out = [] + + if not multiple_exps and not validation: + nn_t, nn_c = cf_nn(data['x'][:,:,0], data['t'][:,0]) + + + #Loop over repeated experiments + for i_rep in range(n_rep): + + if validation: + I_valid_rep = I_valid[i_rep,:] + else: + I_valid_rep = None + + if multiple_exps: + i_exp = i_rep + if validation: + nn_t, nn_c = cf_nn(data['x'][I_valid_rep,:,i_exp], data['t'][I_valid_rep,i_exp]) + else: + nn_t, nn_c = cf_nn(data['x'][:,:,i_exp], data['t'][:,i_exp]) + else: + i_exp = 0 + + if validation and not multiple_exps: + nn_t, nn_c = cf_nn(data['x'][I_valid_rep,:,i_exp], data['t'][I_valid_rep,i_exp]) + + if binary: + eval_result = evaluate_bin_att(predictions[:,:,i_rep,i_out], + data, i_exp, I_valid_rep, compute_policy_curve, nn_t=nn_t, nn_c=nn_c) + else: + eval_result = evaluate_cont_ate(predictions[:,:,i_rep,i_out], + data, i_exp, I_valid_rep, compute_policy_curve, nn_t=nn_t, nn_c=nn_c) + + eval_results_out.append(eval_result) + + eval_results.append(eval_results_out) + + # Reformat into dict + eval_dict = {} + keys = eval_results[0][0].keys() + for k in keys: + arr = [[eval_results[i][j][k] for i in range(n_outputs)] for j in range(n_rep)] + v = np.array([[eval_results[i][j][k] for i in range(n_outputs)] for j in range(n_rep)]) + eval_dict[k] = v + + # Gather loss + # Shape [times, types, reps] + # Types: obj_loss, f_error, cf_error, imb_err, valid_f_error, valid_imb, valid_obj + if 'loss' in result.keys() and result['loss'].shape[1]>=6: + losses = result['loss'] + n_loss_outputs = losses.shape[0] + + if validation: + objective = np.array([losses[(n_loss_outputs*i)/n_outputs,6,:] for i in range(n_outputs)]).T + else: + objective = np.array([losses[(n_loss_outputs*i)/n_outputs,0,:] for i in range(n_outputs)]).T + + eval_dict['objective'] = objective + + return eval_dict + +def evaluate(output_dir, data_path_train, data_path_test=None, binary=False): + + print '\nEvaluating experiment %s...' % output_dir + + # Load results for all configurations + results = load_results(output_dir) + + if len(results) == 0: + raise Exception('No finished results found.') + + # Separate configuration files + configs = [r['config'] for r in results] + + # Test whether multiple experiments (different data) + multiple_exps = (configs[0]['experiments'] > 1) + if Log.VERBOSE and multiple_exps: + print 'Multiple data (experiments) detected' + + # Load training data + if Log.VERBOSE: + print 'Loading TRAINING data %s...' % data_path_train + data_train = load_data(data_path_train) + + # Load test data + if data_path_test is not None: + if Log.VERBOSE: + print 'Loading TEST data %s...' % data_path_test + data_test = load_data(data_path_test) + else: + data_test = None + + + + # Evaluate all results + eval_results = [] + configs_out = [] + i = 0 + if Log.VERBOSE: + print 'Evaluating result (out of %d): ' % len(results) + for result in results: + if Log.VERBOSE: + print 'Evaluating %d...' % (i+1) + + try: + eval_train = evaluate_result(result['train'], data_train, + validation=False, multiple_exps=multiple_exps, binary=binary) + + eval_valid = evaluate_result(result['train'], data_train, + validation=True, multiple_exps=multiple_exps, binary=binary) + + if data_test is not None: + eval_test = evaluate_result(result['test'], data_test, + validation=False, multiple_exps=multiple_exps, binary=binary) + else: + eval_test = None + + eval_results.append({'train': eval_train, 'valid': eval_valid, 'test': eval_test}) + configs_out.append(configs[i]) + except NaNException as e: + print 'WARNING: Encountered NaN exception. Skipping.' + print e + + i += 1 + + # Reformat into dict + eval_dict = {'train': {}, 'test': {}, 'valid': {}} + keys = eval_results[0]['train'].keys() + for k in keys: + v = np.array([eval_results[i]['train'][k] for i in range(len(eval_results))]) + eval_dict['train'][k] = v + + v = np.array([eval_results[i]['valid'][k] for i in range(len(eval_results))]) + eval_dict['valid'][k] = v + + if eval_test is not None and k in eval_results[0]['test']: + v = np.array([eval_results[i]['test'][k] for i in range(len(eval_results))]) + eval_dict['test'][k] = v + + return eval_dict, configs_out diff --git a/cfr/loader.py b/cfr/loader.py new file mode 100644 index 0000000..f2951aa --- /dev/null +++ b/cfr/loader.py @@ -0,0 +1,137 @@ +import os +import numpy as np + +from logger import Logger as Log + +def load_result_file(file): + arr = np.load(file) + + D = dict([(k, arr[k]) for k in arr.keys()]) + + return D + +def load_config(cfgfile): + """ Parses a configuration file """ + + cfgf = open(cfgfile,'r') + cfg = {} + for l in cfgf: + ps = [p.strip() for p in l.split(':')] + if len(ps)==2: + try: + cfg[ps[0]] = float(ps[1]) + except ValueError: + cfg[ps[0]] = ps[1] + if cfg[ps[0]] == 'False': + cfg[ps[0]] = False + elif cfg[ps[0]] == 'True': + cfg[ps[0]] = True + cfgf.close() + return cfg + +def load_single_result(result_dir): + if Log.VERBOSE: + print 'Loading %s...' % result_dir + + config_path = '%s/config.txt' % result_dir + has_config = os.path.isfile(config_path) + if not has_config: + print 'WARNING: Could not find config.txt for %s. Skipping.' % os.path.basename(result_dir) + config = None + else: + config = load_config(config_path) + + train_path = '%s/result.npz' % result_dir + test_path = '%s/result.test.npz' % result_dir + + has_test = os.path.isfile(test_path) + + try: + train_results = load_result_file(train_path) + except: + 'WARNING: Couldnt load result file. Skipping' + return None + + n_rep = np.max([config['repetitions'], config['experiments']]) + + if len(train_results['pred'].shape) < 4 or train_results['pred'].shape[2] < n_rep: + print 'WARNING: Experiment %s appears not to have finished. Skipping.' % result_dir + return None + + if has_test: + test_results = load_result_file(test_path) + else: + test_results = None + + return {'train': train_results, 'test': test_results, 'config': config} + +def load_results(output_dir): + + if Log.VERBOSE: + print 'Loading results from %s...' % output_dir + + ''' Detect results structure ''' + # Single result + if os.path.isfile('%s/results.npz' % output_dir): + #@TODO: Implement + pass + + # Multiple results + files = ['%s/%s' % (output_dir, f) for f in os.listdir(output_dir)] + exp_dirs = [f for f in files if os.path.isdir(f) + if os.path.isfile('%s/result.npz' % f)] + + if Log.VERBOSE: + print 'Found %d experiment configurations.' % len(exp_dirs) + + # Load each result folder + results = [] + for dir in exp_dirs: + dir_result = load_single_result(dir) + if dir_result is not None: + results.append(dir_result) + + return results + +def load_data(datapath): + """ Load dataset """ + arr = np.load(datapath) + xs = arr['x'] + + HAVE_TRUTH = False + SPARSE = False + + if len(xs.shape)==1: + SPARSE = True + + ts = arr['t'] + yfs = arr['yf'] + try: + es = arr['e'] + except: + es = None + try: + ate = np.mean(arr['ate']) + except: + ate = None + try: + ymul = arr['ymul'][0,0] + yadd = arr['yadd'][0,0] + except: + ymul = 1 + yadd = 0 + try: + ycfs = arr['ycf'] + mu0s = arr['mu0'] + mu1s = arr['mu1'] + HAVE_TRUTH = True + except: + print 'Couldn\'t find ground truth. Proceeding...' + ycfs = None; mu0s = None; mu1s = None + + data = {'x':xs, 't':ts, 'e':es, 'yf':yfs, 'ycf':ycfs, \ + 'mu0':mu0s, 'mu1':mu1s, 'ate':ate, 'YMUL': ymul, \ + 'YADD': yadd, 'ATE': ate.tolist(), 'HAVE_TRUTH': HAVE_TRUTH, \ + 'SPARSE': SPARSE} + + return data diff --git a/cfr/logger.py b/cfr/logger.py new file mode 100644 index 0000000..67b865e --- /dev/null +++ b/cfr/logger.py @@ -0,0 +1,2 @@ +class Logger(): + VERBOSE = False diff --git a/cfr/plotting.py b/cfr/plotting.py new file mode 100644 index 0000000..0ed5c51 --- /dev/null +++ b/cfr/plotting.py @@ -0,0 +1,622 @@ +import sys +import os +import numpy as np +import matplotlib as mpl +mpl.use('Agg') +import matplotlib.pyplot as plt + +from loader import * + +LINE_WIDTH = 2 +FONTSIZE_LGND = 8 +FONTSIZE = 16 + +EARLY_STOP_SET_CONT = 'valid' +EARLY_STOP_CRITERION_CONT = 'objective' +CONFIG_CHOICE_SET_CONT = 'valid' +CONFIG_CRITERION_CONT = 'pehe_nn' +CORR_CRITERION_CONT = 'pehe' +CORR_CHOICE_SET_CONT = 'test' + +EARLY_STOP_SET_BIN = 'valid' +EARLY_STOP_CRITERION_BIN = 'policy_risk' +CONFIG_CHOICE_SET_BIN = 'valid' +CONFIG_CRITERION_BIN = 'policy_risk' +CORR_CRITERION_BIN = 'policy_risk' +CORR_CHOICE_SET_BIN = 'test' + +CURVE_TOP_K = 7 + +def fix_log_axes(x): + ax = plt.axes() + plt.draw() + labels = [item.get_text() for item in ax.get_xticklabels()] + labels[1] = r'0' + ax.set_xticklabels(labels) + d=0.025 + kwargs = dict(transform=ax.transAxes, color='k', clip_on=False) + ax.plot((0.04-0.25*d, 0.04+0.25*d), (-d, +d), **kwargs) + ax.plot((0.06-0.25*d, 0.06+0.25*d), (-d, +d), **kwargs) + plt.xlim(np.min(x), np.max(x)) + +def plot_format(): + plt.grid(linestyle='-', color=[0.8,0.8,0.8]) + ax = plt.gca() + ax.set_axisbelow(True) + +def fill_bounds(data, axis=0, std_error=False): + if std_error: + dev = np.std(data, axis)/np.sqrt(data.shape[axis]) + else: + dev = np.std(data, axis) + + ub = np.mean(data, axis) + dev + lb = np.mean(data, axis) - dev + + return lb, ub + +def plot_with_fill(x, y, axis=0, std_error=False, color='r'): + plt.plot(x, np.mean(y, axis), '.-', linewidth=2, color=color) + lb, ub = fill_bounds(y, axis=axis, std_error=std_error) + plt.fill_between(x, lb, ub, linewidth=0, facecolor=color, alpha=0.1) + +def cap(s): + t = s[0].upper() + s[1:] + return t + +def table_str_bin(result_set, row_labels, labels_long=None, binary=False): + if binary: + cols = ['policy_risk', 'bias_att', 'err_fact', 'objective', 'pehe_nn'] + else: + cols = ['pehe', 'bias_ate', 'rmse_fact', 'rmse_ite', 'objective', 'pehe_nn'] + + cols = [c for c in cols if c in result_set[0]] + + head = [cap(c) for c in cols] + colw = np.max([16, np.max([len(h)+1 for h in head])]) + col1w = np.max([len(h)+1 for h in row_labels]) + + def rpad(s): + return s+' '*(colw-len(s)) + + def r1pad(s): + return s+' '*(col1w-len(s)) + + head_pad = [r1pad('')]+[rpad(h) for h in head] + + head_str = '| '.join(head_pad) + s = head_str + '\n' + '-'*len(head_str) + '\n' + + for i in range(len(result_set)): + vals = [np.mean(np.abs(result_set[i][c])) for c in cols] # @TODO: np.abs just to make err not bias. change! + stds = [np.std(result_set[i][c])/np.sqrt(result_set[i][c].shape[0]) for c in cols] + val_pad = [r1pad(row_labels[i])] + [rpad('%.3f +/- %.3f ' % (vals[j], stds[j])) for j in range(len(vals))] + val_str = '| '.join(val_pad) + + if labels_long is not None: + s += labels_long[i] + '\n' + + s += val_str + '\n' + + return s + +def evaluation_summary(result_set, row_labels, output_dir, labels_long=None, binary=False): + s = '' + for i in ['train', 'valid', 'test']: + s += 'Mode: %s\n' % cap(i) + s += table_str_bin([results[i] for results in result_set], row_labels, labels_long, binary) + s += '\n' + + return s + +def select_parameters(results, configs, stop_set, stop_criterion, choice_set, choice_criterion): + + if stop_criterion == 'objective' and 'objective' not in results[stop_set]: + if 'err_fact' in results[stop_set]: + stop_criterion = 'err_fact' + else: + stop_criterion = 'rmse_fact' + + ''' Select early stopping for each repetition ''' + n_exp = results[stop_set][stop_criterion].shape[1] + i_sel = np.argmin(results[stop_set][stop_criterion],2) + results_sel = {'train': {}, 'valid': {}, 'test': {}} + + for k in results['valid'].keys(): + # To reduce dimension + results_sel['train'][k] = np.sum(results['train'][k],2) + results_sel['valid'][k] = np.sum(results['valid'][k],2) + + if k in results['test']: + results_sel['test'][k] = np.sum(results['test'][k],2) + + for ic in range(len(configs)): + for ie in range(n_exp): + results_sel['train'][k][ic,ie,] = results['train'][k][ic,ie,i_sel[ic,ie],] + results_sel['valid'][k][ic,ie,] = results['valid'][k][ic,ie,i_sel[ic,ie],] + + if k in results['test']: + results_sel['test'][k][ic,ie,] = results['test'][k][ic,ie,i_sel[ic,ie],] + + print 'Early stopping:' + print np.mean(i_sel,1) + + ''' Select configuration ''' + results_all = [dict([(k1, dict([(k2, v[i,]) for k2,v in results_sel[k1].iteritems()])) + for k1 in results_sel.keys()]) for i in range(len(configs))] + + labels = ['%d' % i for i in range(len(configs))] + + sort_key = np.argsort([np.mean(r[choice_set][choice_criterion]) for r in results_all]) + results_all = [results_all[i] for i in sort_key] + configs_all = [configs[i] for i in sort_key] + labels = [labels[i] for i in sort_key] + + return results_all, configs_all, labels, sort_key + +def plot_option_correlation(output_dir, diff_opts, results, configs, + choice_set, choice_criterion, filter_str=''): + + topk = int(np.min([CURVE_TOP_K, len(configs)])) + + opts_dir = '%s/opts%s' % (output_dir, filter_str) + + try: + os.mkdir(opts_dir) + except: + pass + + for k in diff_opts: + + x_range = sorted(list(set([configs[i][k] for i in range(len(configs))]))) + + x_range_bins = [None]*len(x_range) + x_range_bins_top = [None]*len(x_range) + + plt.figure() + for i in range(0, len(configs)): + x = x_range.index(configs[i][k]) + y = np.mean(results[i][choice_set][choice_criterion]) + + if x_range_bins[x] is None: + x_range_bins[x] = [] + x_range_bins[x].append(y) + + plt.plot(x + 0.2*np.random.rand()-0.1, y , 'ob') + + for i in range(topk): + x = x_range.index(configs[i][k]) + y = np.mean(results[i][choice_set][choice_criterion]) + + if x_range_bins_top[x] is None: + x_range_bins_top[x] = [] + x_range_bins_top[x].append(y) + + plt.plot(x + 0.2*np.random.rand()-0.1, y , 'og') + + for i in range(len(x_range)): + m1 = np.mean(x_range_bins[i]) + plt.plot([i-0.2, i+0.2], [m1, m1], 'r', linewidth=LINE_WIDTH) + + if x_range_bins_top[i] is not None: + m2 = np.mean(x_range_bins_top[i]) + plt.plot([i-0.1, i+0.1], [m2, m2], 'g', linewidth=LINE_WIDTH) + + plt.xticks(range(len(x_range)), x_range) + plt.title(r'$\mathrm{Influence\/of\/%s\/on\/%s\/on\/%s}$' % (k, choice_criterion, choice_set)) + plt.ylabel('%s' % (choice_criterion)) + plt.xlabel('options') + plt.xlim(-0.5, len(x_range)-0.5) + plt.savefig('%s/opt.%s.%s.%s.pdf' % (opts_dir, choice_set, choice_criterion, k)) + plt.close() + +def plot_evaluation_cont(results, configs, output_dir, data_train_path, data_test_path, filters=None): + + data_train = load_data(data_train_path) + data_test = load_data(data_test_path) + + propensity = {} + propensity['train'] = np.mean(data_train['t']) + propensity['valid'] = np.mean(data_train['t']) + propensity['test'] = np.mean(data_test['t']) + + ''' Select by filter ''' + filter_str = '' + if filters is not None: + filter_str = '.'+'.'.join(['%s.%s' % (k,filters[k]) for k in sorted(filters.keys())]) + + N = len(configs) + I = [i for i in range(N) if np.all( \ + [configs[i][k]==filters[k] for k in filters.keys()] \ + )] + + results = dict([(s,dict([(k,results[s][k][I,]) for k in results[s].keys()])) for s in ['train', 'valid', 'test']]) + configs = [configs[i] for i in I] + + ''' Do parameter selection and early stopping ''' + results_all, configs_all, labels, sort_key = select_parameters(results, + configs, EARLY_STOP_SET_CONT, EARLY_STOP_CRITERION_CONT, + CONFIG_CHOICE_SET_CONT, CONFIG_CRITERION_CONT) + + ''' Save sorted configurations by parameters that differ ''' + diff_opts = sorted([k for k in configs[0] if len(set([cfg[k] for cfg in configs]))>1]) + labels_long = [', '.join(['%s=%s' % (k,str(configs[i][k])) for k in diff_opts]) for i in sort_key] + + with open('%s/configs_sorted%s.txt' % (output_dir, filter_str), 'w') as f: + f.write('\n'.join(labels_long)) + + ''' Compute evaluation summary and store''' + eval_str = evaluation_summary(results_all, labels, output_dir, binary=False) + + with open('%s/results_summary%s.txt' % (output_dir, filter_str), 'w') as f: + f.write('Selected early stopping based on individual \'%s\' on \'%s\'\n' % (EARLY_STOP_CRITERION_CONT, EARLY_STOP_SET_CONT)) + f.write('Selected configuration based on mean \'%s\' on \'%s\'\n' % (CONFIG_CRITERION_CONT, CONFIG_CHOICE_SET_CONT)) + f.write(eval_str) + + ''' Plot option correlation ''' + plot_option_correlation(output_dir, diff_opts, results_all, configs_all, + CORR_CHOICE_SET_CONT, CORR_CRITERION_CONT, filter_str) + + +def plot_evaluation_bin(results, configs, output_dir, data_train_path, data_test_path, filters=None): + + data_train = load_data(data_train_path) + data_test = load_data(data_test_path) + + propensity = {} + propensity['train'] = np.mean(data_train['t'][data_train['e']==1,]) + propensity['valid'] = np.mean(data_train['t'][data_train['e']==1,]) + propensity['test'] = np.mean(data_test['t'][data_test['e']==1,]) + + ''' Select by filter ''' + filter_str = '' + if filters is not None: + filter_str = '.'+'.'.join(['%s.%s' % (k,filters[k]) for k in sorted(filters.keys())]) + + def cmp(u,v): + if isinstance(u, basestring): + return u.lower()==v.lower() + else: + return u==v + + N = len(configs) + I = [i for i in range(N) if np.all( \ + [cmp(configs[i][k],filters[k]) for k in filters.keys()] \ + )] + + results = dict([(s,dict([(k,results[s][k][I,]) for k in results[s].keys()])) for s in ['train', 'valid', 'test']]) + configs = [configs[i] for i in I] + + ''' Do parameter selection and early stopping ''' + results_all, configs_all, labels, sort_key = select_parameters(results, + configs, EARLY_STOP_SET_BIN, EARLY_STOP_CRITERION_BIN, + CONFIG_CHOICE_SET_BIN, CONFIG_CRITERION_BIN) + + ''' Save sorted configurations by parameters that differ ''' + diff_opts = sorted([k for k in configs[0] if len(set([cfg[k] for cfg in configs]))>1]) + labels_long = [', '.join(['%s=%s' % (k,str(configs[i][k])) for k in diff_opts]) for i in sort_key] + + with open('%s/configs_sorted%s.txt' % (output_dir,filter_str), 'w') as f: + f.write('\n'.join(labels_long)) + + ''' Compute evaluation summary and store''' + eval_str = evaluation_summary(results_all, labels, output_dir, binary=True) + + with open('%s/results_summary%s.txt' % (output_dir,filter_str), 'w') as f: + f.write('Selected early stopping based on individual \'%s\' on \'%s\'\n' % (EARLY_STOP_CRITERION_BIN, EARLY_STOP_SET_BIN)) + f.write('Selected configuration based on mean \'%s\' on \'%s\'\n' % (CONFIG_CRITERION_BIN, CONFIG_CHOICE_SET_BIN)) + f.write(eval_str) + + ''' Policy curve for top-k configurations ''' + colors = 'rgbcmyk' + topk = int(np.min([CURVE_TOP_K, len(configs)])) + + for eval_set in ['train', 'valid', 'test']: + pc = np.mean(results_all[0][eval_set]['policy_curve'],0) + x = np.array(range(len(pc))).astype(np.float32)/(len(pc)-1) + for i in range(topk): + plot_with_fill(x, results_all[i][eval_set]['policy_curve'], axis=0, std_error=True, color=colors[i]) + plt.plot([0,1], [pc[0], pc[-1]], '--k', linewidth=2) + + + p = propensity[eval_set] + x_lim = plt.xlim() + y_lim = plt.ylim() + plt.plot([p,p], y_lim, ':k') + plt.text(p+0.01*(x_lim[1]-x_lim[0]),y_lim[0]+0.05*(y_lim[1]-y_lim[0]), r'$p(t)$', fontsize=14) + plt.ylim(y_lim) + + plt.xlabel(r'$\mathrm{Inclusion\/rate}$', fontsize=FONTSIZE) + plt.ylabel(r'$\mathrm{Policy\/value}$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{Policy\/curve\/%s\/(w.\/early\/stopping)}$' % eval_set) + plt.legend(['Configuration %d' % i for i in range(topk)], fontsize=FONTSIZE_LGND) + plot_format() + plt.savefig('%s/policy_curve%s.%s.pdf' % (output_dir, filter_str, eval_set)) + plt.close() + + ''' Plot option correlation ''' + plot_option_correlation(output_dir, diff_opts, results_all, configs_all, + CORR_CHOICE_SET_BIN, CORR_CRITERION_BIN, filter_str) + + +def plot_cfr_evaluation_bin(results, configs, output_dir): + alphas = [cfg['p_alpha'] for cfg in configs] + palphas = alphas; palphas[0] = 1e-7; + + EARLY_STOP_CRITERION_BIN = 'err_fact' + ALPHA_CRITERION = 'policy_risk' + EARLY_STOP_SET_BIN = 'valid' + ALPHA_CHOICE_SET = 'valid' + + ''' Select early stopping for each repetition ''' + n_exp = results[EARLY_STOP_SET_BIN][EARLY_STOP_CRITERION_BIN].shape[1] + i_sel = np.argmin(results[EARLY_STOP_SET_BIN][EARLY_STOP_CRITERION_BIN],2) + results_sel = {'train': {}, 'valid': {}, 'test': {}} + + for k in results['valid'].keys(): + # To reduce dimension + results_sel['train'][k] = np.sum(results['train'][k],2) + results_sel['valid'][k] = np.sum(results['valid'][k],2) + results_sel['test'][k] = np.sum(results['test'][k],2) + + for ia in range(len(alphas)): + for ie in range(n_exp): + results_sel['train'][k][ia,ie,] = results['train'][k][ia,ie,i_sel[ia,ie],] + results_sel['valid'][k][ia,ie,] = results['valid'][k][ia,ie,i_sel[ia,ie],] + results_sel['test'][k][ia,ie,] = results['test'][k][ia,ie,i_sel[ia,ie],] + + print 'Early stopping:' + print np.mean(i_sel,1) + + ''' Select alpha based on mean criterion''' + i_skip=1 + A = np.mean(results_sel[ALPHA_CHOICE_SET][ALPHA_CRITERION],1) + ia = i_skip + A[i_skip:].argmin() + print 'Alpha selection criterion:' + print A + + ''' Print evaluation results ''' + results_alphas = [dict([(k1, dict([(k2, v[i,]) for k2,v in results_sel[k1].iteritems()])) + for k1 in results_sel.keys()]) for i in range(len(alphas))] + di = configs[0]['n_in'] + do = configs[0]['n_out'] + labels=['CFR-%d-%d' % (di,do)] + for i in range(len(alphas)): + if i==0: + continue + m = '' + if i==ia: + m = ' *' + labels.append('CFR-%d-%d %s a=%.2g%s' % (di,do,configs[0]['imb_fun'],alphas[i],m)) + eval_str = evaluation_summary_bin(results_alphas, labels, output_dir) + print eval_str + + with open('%s/results_summary.txt' % output_dir, 'w') as f: + f.write('Selected early stopping based on individual \'%s\' on \'%s\'\n' % (EARLY_STOP_CRITERION_BIN, EARLY_STOP_SET_BIN)) + f.write('Selected alpha based on mean \'%s\' on \'%s\'\n\n' % (ALPHA_CRITERION, ALPHA_CHOICE_SET)) + f.write(eval_str) + + ''' Plotting policy curves ''' + pc = np.mean(results['train']['policy_curve'][0,:,-1,:],0) + x = np.array(range(len(pc))).astype(np.float32)/(len(pc)-1) + plot_with_fill(x, results_sel['train']['policy_curve'][0,:,:], axis=0, std_error=True, color='b') + plot_with_fill(x, results_sel['train']['policy_curve'][ia,:,:], axis=0, std_error=True, color='g') + plt.plot([0,1], [pc[0], pc[-1]], '--k', linewidth=2) + plt.xlabel(r'$\mathrm{Inclusion\/rate}$', fontsize=FONTSIZE) + plt.ylabel(r'$\mathrm{Policy\/value}$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{Policy\/curve\/(w.\/early\/stopping)}$') + plt.legend(['alpha=0', 'alpha=%.2g' % alphas[ia]]) + plot_format() + plt.savefig('%s/policy_curve_train.pdf' % (output_dir)) + plt.close() + + pc = np.mean(results['test']['policy_curve'][0,:,-1,:],0) + x = np.array(range(len(pc))).astype(np.float32)/(len(pc)-1) + plot_with_fill(x, results_sel['test']['policy_curve'][0,:,:], axis=0, std_error=True, color='b') + plot_with_fill(x, results_sel['test']['policy_curve'][ia,:,:], axis=0, std_error=True, color='g') + plt.plot([0,1], [pc[0], pc[-1]], '--k', linewidth=2) + plt.xlabel(r'$\mathrm{Inclusion\/rate}$', fontsize=FONTSIZE) + plt.ylabel(r'$\mathrm{Policy\/value}$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{Policy\/curve\/(w.\/early\/stopping)}$') + plt.legend(['alpha=0', 'alpha=%.2g' % alphas[ia]]) + plot_format() + plt.savefig('%s/policy_curve_test.pdf' % (output_dir)) + plt.close() + + ''' Policy value at early stopping point ''' + plot_with_fill(palphas, results_sel['train']['policy_value'][:,:], axis=1, std_error=True, color='r') + plot_with_fill(palphas, results_sel['valid']['policy_value'][:,:], axis=1, std_error=True, color='g') + plot_with_fill(palphas, results_sel['test']['policy_value'][:,:], axis=1, std_error=True, color='b') + + plt.xscale('log') + fix_log_axes(palphas) + plt.ylabel(r'$\mathrm{Policy\/value}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Imbalance\/penalty},\/\alpha$', fontsize=FONTSIZE) + plt.legend(['Train', 'Valid', 'Test']) + plot_format() + plt.savefig('%s/policy_value_sel.pdf' % (output_dir)) + plt.close() + + ''' Policy value at each stage of training ''' + for t in range(results['train']['policy_value'].shape[2]): + plot_with_fill(palphas, results['train']['policy_value'][:,:,t], axis=1, std_error=True, color='r') + plot_with_fill(palphas, results['valid']['policy_value'][:,:,t], axis=1, std_error=True, color='g') + plot_with_fill(palphas, results['test']['policy_value'][:,:,t], axis=1, std_error=True, color='b') + + plt.xscale('log') + fix_log_axes(palphas) + plt.ylabel(r'$\mathrm{Policy\/value}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Imbalance\/penalty},\/\alpha$', fontsize=FONTSIZE) + plt.legend(['Train', 'Valid', 'Test']) + plot_format() + plt.savefig('%s/policy_value_end_t%d.pdf' % (output_dir,t)) + plt.close() + + ''' Accuracy at end of training ''' + err_train = results_sel['train']['err_fact'] + err_valid = results_sel['valid']['err_fact'] + err_test = results_sel['test']['err_fact'] + + plot_with_fill(palphas, err_train, axis=1, std_error=True, color='r') + plot_with_fill(palphas, err_valid, axis=1, std_error=True, color='g') + plot_with_fill(palphas, err_test, axis=1, std_error=True, color='b') + plt.xscale('log') + fix_log_axes(palphas) + + plt.ylabel(r'$\mathrm{Factual\/error\/(test)}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Imbalance\/penalty},\/\alpha$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{Factual\/error\/(w.\/early\/stopping)}$') + plt.legend(['Train', 'Valid', 'Test']) + plot_format() + plt.savefig('%s/err_fact_alpha.pdf' % output_dir) + plt.close() + + ''' Accuracy for different iterations ''' + colors = 'rgbcmyk' + markers = '.d*ox' + err_test = results['test']['err_fact'][:,:,:] + ts = range(err_test.shape[2]) + for i in range(len(alphas)): + plt.plot(ts, np.mean(err_test[i,],0), '-%s' % markers[i%len(markers)], + color=colors[i%len(colors)], linewidth=LINE_WIDTH) + plt.ylabel(r'$\mathrm{Factual\/error\/(test)}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Iteration}$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{Test\/factual\/error}$') + plt.legend(['Alpha=%.2g' % a for a in alphas], fontsize=(FONTSIZE_LGND-2)) + plot_format() + plt.savefig('%s/err_fact_iterations_test.pdf' % output_dir) + plt.close() + + ''' Policy value for different iterations ''' + colors = 'rgbcmyk' + markers = '.d*ox' + y_test = results['test']['policy_value'][:,:,:] + ts = range(y_test.shape[2]) + for i in range(len(alphas)): + plt.plot(ts, np.mean(y_test[i,],0), '-%s' % markers[i%len(markers)], + color=colors[i%len(colors)], linewidth=LINE_WIDTH) + plt.ylabel(r'$\mathrm{Polcy\/value\/(test)}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Iteration}$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{Policy\/value}$') + plt.legend(['Alpha=%.2g' % a for a in alphas], fontsize=(FONTSIZE_LGND-2)) + plot_format() + plt.savefig('%s/policy_val_iterations_test.pdf' % output_dir) + plt.close() + +def plot_cfr_evaluation_cont(results, configs, output_dir): + alphas = [cfg['p_alpha'] for cfg in configs] + + ''' Select early stopping for each experiment ''' + n_exp = results['valid']['pehe'].shape[1] + i_sel = np.argmin(results['valid']['pehe'],2) + results_sel = {'train': {}, 'valid': {}, 'test': {}} + + for k in results['valid'].keys(): + # To reduce dimension + results_sel['train'][k] = np.sum(results['train'][k],2) + results_sel['valid'][k] = np.sum(results['valid'][k],2) + results_sel['test'][k] = np.sum(results['test'][k],2) + + for ia in range(len(alphas)): + for ie in range(n_exp): + results_sel['train'][k][ia,ie] = results['train'][k][ia,ie,i_sel[ia,ie]].copy() + results_sel['valid'][k][ia,ie] = results['valid'][k][ia,ie,i_sel[ia,ie]].copy() + results_sel['test'][k][ia,ie] = results['test'][k][ia,ie,i_sel[ia,ie]].copy() + + + ''' Select alpha and early stopping based on MEAN validation pehe (@TODO: not used) ''' + i_skip=1 + j_skip=1 + A = np.mean(results['valid']['pehe'],1) + i,j = np.unravel_index(A[i_skip:,j_skip:].argmin(), A[i_skip:,j_skip:].shape) + ia = i+i_skip + it = j+j_skip + + ''' Factual vs alphas ''' + err_train = results_sel['train']['rmse_fact'] + err_valid = results_sel['valid']['rmse_fact'] + err_test = results_sel['test']['rmse_fact'] + + plot_with_fill(alphas, err_train, axis=1, std_error=True, color='r') + plot_with_fill(alphas, err_valid, axis=1, std_error=True, color='g') + plot_with_fill(alphas, err_test, axis=1, std_error=True, color='b') + plt.xscale('log') + fix_log_axes(palphas) + + plt.ylabel(r'$\mathrm{Factual\/error\/(test)}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Imbalance\/penalty},\/\alpha$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{RMSE\/fact\/vs\/alpha}$') + plt.legend(['Train', 'Valid', 'Test']) + plot_format() + plt.savefig('%s/err_fact_alpha.pdf' % output_dir) + plt.close() + + ''' Counterfactual vs alphas ''' + err_train = results_sel['train']['rmse_cfact'] + err_valid = results_sel['valid']['rmse_cfact'] + err_test = results_sel['test']['rmse_cfact'] + + plot_with_fill(alphas, err_train, axis=1, std_error=True, color='r') + plot_with_fill(alphas, err_valid, axis=1, std_error=True, color='g') + plot_with_fill(alphas, err_test, axis=1, std_error=True, color='b') + plt.xscale('log') + fix_log_axes(palphas) + + plt.ylabel(r'$\mathrm{Factual\/error\/(test)}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Imbalance\/penalty},\/\alpha$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{RMSE\/cfact\/vs\/\alpha}$') + plt.legend(['Train', 'Valid', 'Test']) + plot_format() + plt.savefig('%s/err_cfact_alpha.pdf' % output_dir) + plt.close() + + ''' PEHE vs alphas ''' + err_train = results_sel['train']['pehe'] + err_valid = results_sel['valid']['pehe'] + err_test = results_sel['test']['pehe'] + + plot_with_fill(alphas, err_train, axis=1, std_error=True, color='r') + plot_with_fill(alphas, err_valid, axis=1, std_error=True, color='g') + plot_with_fill(alphas, err_test, axis=1, std_error=True, color='b') + plt.xscale('log') + fix_log_axes(palphas) + + plt.ylabel(r'$\mathrm{Factual\/error\/(test)}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Imbalance\/penalty},\/\alpha$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{PEHE vs alpha}$') + plt.legend(['Train', 'Valid', 'Test']) + plot_format() + plt.savefig('%s/pehe_alpha.pdf' % output_dir) + plt.close() + + ''' Accuracy for different iterations ''' + colors = 'rgbcmyk' + markers = '.d*ox' + err_test = results['test']['rmse_fact'][:,:,:] + ts = range(err_test.shape[2]) + for i in range(len(alphas)): + plt.plot(ts, np.mean(err_test[i,],0), '-%s' % markers[i%len(markers)], + color=colors[i%len(colors)], linewidth=LINE_WIDTH) + plt.ylabel(r'$\mathrm{Factual\/error\/(test)}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Iteration}$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{Test\/factual\/error}$') + plt.legend(['Alpha=%.2g' % a for a in alphas], fontsize=(FONTSIZE_LGND-2)) + plot_format() + plt.savefig('%s/err_fact_iterations_test.pdf' % output_dir) + plt.close() + + ''' PEHE for different iterations ''' + colors = 'rgbcmyk' + markers = '.d*ox' + y_test = results['test']['pehe'][:,:,:] + ts = range(y_test.shape[2]) + for i in range(len(alphas)): + plt.plot(ts, np.mean(y_test[i,],0), '-%s' % markers[i%len(markers)], + color=colors[i%len(colors)], linewidth=LINE_WIDTH) + plt.ylabel(r'$\mathrm{Polcy\/value\/(test)}$', fontsize=FONTSIZE) + plt.xlabel(r'$\mathrm{Iteration}$', fontsize=FONTSIZE) + plt.title(r'$\mathrm{PEHE\/(Test)}$') + plt.legend(['Alpha=%.2g' % a for a in alphas], fontsize=(FONTSIZE_LGND-2)) + plot_format() + plt.savefig('%s/pehe_iterations_test.pdf' % output_dir) + plt.close() diff --git a/cfr/util.py b/cfr/util.py new file mode 100644 index 0000000..1e21c13 --- /dev/null +++ b/cfr/util.py @@ -0,0 +1,228 @@ +import tensorflow as tf +import numpy as np + +SQRT_CONST = 1e-10 + +FLAGS = tf.app.flags.FLAGS + +def validation_split(D_exp, val_fraction): + """ Construct a train/validation split """ + n = D_exp['x'].shape[0] + + if val_fraction > 0: + n_valid = int(val_fraction*n) + n_train = n-n_valid + I = np.random.permutation(range(0,n)) + I_train = I[:n_train] + I_valid = I[n_train:] + else: + I_train = range(n) + I_valid = [] + + return I_train, I_valid + +def log(logfile,str): + """ Log a string in a file """ + with open(logfile,'a') as f: + f.write(str+'\n') + print str + +def save_config(fname): + """ Save configuration """ + flagdict = FLAGS.__dict__['__flags'] + s = '\n'.join(['%s: %s' % (k,str(flagdict[k])) for k in sorted(flagdict.keys())]) + f = open(fname,'w') + f.write(s) + f.close() + +def load_data(fname): + """ Load data set """ + if fname[-3:] == 'npz': + data_in = np.load(fname) + data = {'x': data_in['x'], 't': data_in['t'], 'yf': data_in['yf']} + try: + data['ycf'] = data_in['ycf'] + except: + data['ycf'] = None + else: + if FLAGS.sparse>0: + data_in = np.loadtxt(open(fname+'.y',"rb"),delimiter=",") + x = load_sparse(fname+'.x') + else: + data_in = np.loadtxt(open(fname,"rb"),delimiter=",") + x = data_in[:,5:] + + data['x'] = x + data['t'] = data_in[:,0:1] + data['yf'] = data_in[:,1:2] + data['ycf'] = data_in[:,2:3] + + data['HAVE_TRUTH'] = not data['ycf'] is None + + data['dim'] = data['x'].shape[1] + data['n'] = data['x'].shape[0] + + return data + +def load_sparse(fname): + """ Load sparse data set """ + E = np.loadtxt(open(fname,"rb"),delimiter=",") + H = E[0,:] + n = int(H[0]) + d = int(H[1]) + E = E[1:,:] + S = sparse.coo_matrix((E[:,2],(E[:,0]-1,E[:,1]-1)),shape=(n,d)) + S = S.todense() + + return S + +def safe_sqrt(x, lbound=SQRT_CONST): + ''' Numerically safe version of TensorFlow sqrt ''' + return tf.sqrt(tf.clip_by_value(x, lbound, np.inf)) + +def lindisc(X,p,t): + ''' Linear MMD ''' + + it = tf.where(t>0)[:,0] + ic = tf.where(t<1)[:,0] + + Xc = tf.gather(X,ic) + Xt = tf.gather(X,it) + + mean_control = tf.reduce_mean(Xc,reduction_indices=0) + mean_treated = tf.reduce_mean(Xt,reduction_indices=0) + + c = tf.square(2*p-1)*0.25 + f = tf.sign(p-0.5) + + mmd = tf.reduce_sum(tf.square(p*mean_treated - (1-p)*mean_control)) + mmd = f*(p-0.5) + safe_sqrt(c + mmd) + + return mmd + +def mmd2_lin(X,t,p): + ''' Linear MMD ''' + + it = tf.where(t>0)[:,0] + ic = tf.where(t<1)[:,0] + + Xc = tf.gather(X,ic) + Xt = tf.gather(X,it) + + mean_control = tf.reduce_mean(Xc,reduction_indices=0) + mean_treated = tf.reduce_mean(Xt,reduction_indices=0) + + mmd = tf.reduce_sum(tf.square(2.0*p*mean_treated - 2.0*(1.0-p)*mean_control)) + + return mmd + +def mmd2_rbf(X,t,p,sig): + """ Computes the l2-RBF MMD for X given t """ + + it = tf.where(t>0)[:,0] + ic = tf.where(t<1)[:,0] + + Xc = tf.gather(X,ic) + Xt = tf.gather(X,it) + + Kcc = tf.exp(-pdist2sq(Xc,Xc)/tf.square(sig)) + Kct = tf.exp(-pdist2sq(Xc,Xt)/tf.square(sig)) + Ktt = tf.exp(-pdist2sq(Xt,Xt)/tf.square(sig)) + + m = tf.to_float(tf.shape(Xc)[0]) + n = tf.to_float(tf.shape(Xt)[0]) + + mmd = tf.square(1.0-p)/(m*(m-1.0))*(tf.reduce_sum(Kcc)-m) + mmd = mmd + tf.square(p)/(n*(n-1.0))*(tf.reduce_sum(Ktt)-n) + mmd = mmd - 2.0*p*(1.0-p)/(m*n)*tf.reduce_sum(Kct) + mmd = 4.0*mmd + + return mmd + +def pdist2sq(X,Y): + """ Computes the squared Euclidean distance between all pairs x in X, y in Y """ + C = -2*tf.matmul(X,tf.transpose(Y)) + nx = tf.reduce_sum(tf.square(X),1,keep_dims=True) + ny = tf.reduce_sum(tf.square(Y),1,keep_dims=True) + D = (C + tf.transpose(ny)) + nx + return D + +def pdist2(X,Y): + """ Returns the tensorflow pairwise distance matrix """ + return safe_sqrt(pdist2sq(X,Y)) + +def pop_dist(X,t): + it = tf.where(t>0)[:,0] + ic = tf.where(t<1)[:,0] + Xc = tf.gather(X,ic) + Xt = tf.gather(X,it) + nc = tf.to_float(tf.shape(Xc)[0]) + nt = tf.to_float(tf.shape(Xt)[0]) + + ''' Compute distance matrix''' + M = pdist2(Xt,Xc) + return M + +def wasserstein(X,t,p,lam=10,its=10,sq=False,backpropT=False): + """ Returns the Wasserstein distance between treatment groups """ + + it = tf.where(t>0)[:,0] + ic = tf.where(t<1)[:,0] + Xc = tf.gather(X,ic) + Xt = tf.gather(X,it) + nc = tf.to_float(tf.shape(Xc)[0]) + nt = tf.to_float(tf.shape(Xt)[0]) + + ''' Compute distance matrix''' + if sq: + M = pdist2sq(Xt,Xc) + else: + M = safe_sqrt(pdist2sq(Xt,Xc)) + + ''' Estimate lambda and delta ''' + M_mean = tf.reduce_mean(M) + M_drop = tf.nn.dropout(M,10/(nc*nt)) + delta = tf.stop_gradient(tf.reduce_max(M)) + eff_lam = tf.stop_gradient(lam/M_mean) + + ''' Compute new distance matrix ''' + Mt = M + row = delta*tf.ones(tf.shape(M[0:1,:])) + col = tf.concat(0,[delta*tf.ones(tf.shape(M[:,0:1])),tf.zeros((1,1))]) + Mt = tf.concat(0,[M,row]) + Mt = tf.concat(1,[Mt,col]) + + ''' Compute marginal vectors ''' + a = tf.concat(0,[p*tf.ones(tf.shape(tf.where(t>0)[:,0:1]))/nt, (1-p)*tf.ones((1,1))]) + b = tf.concat(0,[(1-p)*tf.ones(tf.shape(tf.where(t<1)[:,0:1]))/nc, p*tf.ones((1,1))]) + + ''' Compute kernel matrix''' + Mlam = eff_lam*Mt + K = tf.exp(-Mlam) + 1e-6 # added constant to avoid nan + U = K*Mt + ainvK = K/a + + u = a + for i in range(0,its): + u = 1.0/(tf.matmul(ainvK,(b/tf.transpose(tf.matmul(tf.transpose(u),K))))) + v = b/(tf.transpose(tf.matmul(tf.transpose(u),K))) + + T = u*(tf.transpose(v)*K) + + if not backpropT: + T = tf.stop_gradient(T) + + E = T*Mt + D = 2*tf.reduce_sum(E) + + return D, Mlam + +def simplex_project(x,k): + """ Projects a vector x onto the k-simplex """ + d = x.shape[0] + mu = np.sort(x,axis=0)[::-1] + nu = (np.cumsum(mu)-k)/range(1,d+1) + I = [i for i in range(0,d) if mu[i]>nu[i]] + theta = nu[I[-1]] + w = np.maximum(x-theta,0) + return w diff --git a/cfr_net_train.py b/cfr_net_train.py new file mode 100644 index 0000000..9a450bd --- /dev/null +++ b/cfr_net_train.py @@ -0,0 +1,428 @@ +import tensorflow as tf +import numpy as np +import sys, os +import getopt +import random +import datetime +import traceback + +import cfr.cfr_net as cfr +from cfr.util import * + +''' Define parameter flags ''' +FLAGS = tf.app.flags.FLAGS +tf.app.flags.DEFINE_string('loss', 'l2', """Which loss function to use (l1/l2/log)""") +tf.app.flags.DEFINE_integer('n_in', 2, """Number of representation layers. """) +tf.app.flags.DEFINE_integer('n_out', 2, """Number of regression layers. """) +tf.app.flags.DEFINE_float('p_alpha', 1e-4, """Imbalance regularization param. """) +tf.app.flags.DEFINE_float('p_lambda', 0.0, """Weight decay regularization parameter. """) +tf.app.flags.DEFINE_integer('rep_weight_decay', 1, """Whether to penalize representation layers with weight decay""") +tf.app.flags.DEFINE_float('dropout_in', 0.9, """Input layers dropout keep rate. """) +tf.app.flags.DEFINE_float('dropout_out', 0.9, """Output layers dropout keep rate. """) +tf.app.flags.DEFINE_string('nonlin', 'relu', """Kind of non-linearity. Default relu. """) +tf.app.flags.DEFINE_float('lrate', 0.05, """Learning rate. """) +tf.app.flags.DEFINE_float('decay', 0.5, """RMSProp decay. """) +tf.app.flags.DEFINE_integer('batch_size', 100, """Batch size. """) +tf.app.flags.DEFINE_integer('dim_in', 100, """Pre-representation layer dimensions. """) +tf.app.flags.DEFINE_integer('dim_out', 100, """Post-representation layer dimensions. """) +tf.app.flags.DEFINE_integer('batch_norm', 0, """Whether to use batch normalization. """) +tf.app.flags.DEFINE_string('normalization', 'none', """How to normalize representation (after batch norm). none/bn_fixed/divide/project """) +tf.app.flags.DEFINE_float('rbf_sigma', 0.1, """RBF MMD sigma """) +tf.app.flags.DEFINE_integer('experiments', 1, """Number of experiments. """) +tf.app.flags.DEFINE_integer('iterations', 2000, """Number of iterations. """) +tf.app.flags.DEFINE_float('weight_init', 0.01, """Weight initialization scale. """) +tf.app.flags.DEFINE_float('lrate_decay', 0.95, """Decay of learning rate every 100 iterations """) +tf.app.flags.DEFINE_integer('wass_iterations', 20, """Number of iterations in Wasserstein computation. """) +tf.app.flags.DEFINE_float('wass_lambda', 1, """Wasserstein lambda. """) +tf.app.flags.DEFINE_integer('wass_bpt', 0, """Backprop through T matrix? """) +tf.app.flags.DEFINE_integer('varsel', 0, """Whether the first layer performs variable selection. """) +tf.app.flags.DEFINE_string('outdir', '../results/tfnet_topic/alpha_sweep_22_d100/', """Output directory. """) +tf.app.flags.DEFINE_string('datadir', '../data/topic/csv/', """Data directory. """) +tf.app.flags.DEFINE_string('dataform', 'topic_dmean_seed_%d.csv', """Training data filename form. """) +tf.app.flags.DEFINE_string('data_test', '', """Test data filename form. """) +tf.app.flags.DEFINE_integer('sparse', 0, """Whether data is stored in sparse format (.x, .y). """) +tf.app.flags.DEFINE_integer('seed', 1, """Seed. """) +tf.app.flags.DEFINE_integer('repetitions', 1, """Repetitions with different seed.""") +tf.app.flags.DEFINE_integer('use_p_correction', 1, """Whether to use population size p(t) in mmd/disc/wass.""") +tf.app.flags.DEFINE_string('optimizer', 'RMSProp', """Which optimizer to use. (RMSProp/Adagrad/GradientDescent/Adam)""") +tf.app.flags.DEFINE_string('imb_fun', 'mmd_lin', """Which imbalance penalty to use (mmd_lin/mmd_rbf/mmd2_lin/mmd2_rbf/lindisc/wass). """) +tf.app.flags.DEFINE_integer('output_csv',0,"""Whether to save a CSV file with the results""") +tf.app.flags.DEFINE_integer('output_delay', 100, """Number of iterations between log/loss outputs. """) +tf.app.flags.DEFINE_integer('pred_output_delay', -1, """Number of iterations between prediction outputs. (-1 gives no intermediate output). """) +tf.app.flags.DEFINE_integer('debug', 0, """Debug mode. """) +tf.app.flags.DEFINE_integer('save_rep', 0, """Save representations after training. """) +tf.app.flags.DEFINE_float('val_part', 0, """Validation part. """) +tf.app.flags.DEFINE_boolean('split_output', 0, """Whether to split output layers between treated and control. """) +tf.app.flags.DEFINE_boolean('reweight_sample', 1, """Whether to reweight sample for prediction loss with average treatment probability. """) + +if FLAGS.sparse: + import scipy.sparse as sparse + +NUM_ITERATIONS_PER_DECAY = 100 + +__DEBUG__ = False +if FLAGS.debug: + __DEBUG__ = True + +def train(CFR, sess, train_step, D, I_valid, D_test, logfile, i_exp): + """ Trains a CFR model on supplied data """ + + ''' Train/validation split ''' + n = D['x'].shape[0] + I = range(n); I_train = list(set(I)-set(I_valid)) + n_train = len(I_train) + + ''' Compute treatment probability''' + p_treated = np.mean(D['t'][I_train,:]) + + ''' Set up loss feed_dicts''' + dict_factual = {CFR.x: D['x'][I_train,:], CFR.t: D['t'][I_train,:], CFR.y_: D['yf'][I_train,:], \ + CFR.do_in: 1.0, CFR.do_out: 1.0, CFR.r_alpha: FLAGS.p_alpha, \ + CFR.r_lambda: FLAGS.p_lambda, CFR.p_t: p_treated} + + if FLAGS.val_part > 0: + dict_valid = {CFR.x: D['x'][I_valid,:], CFR.t: D['t'][I_valid,:], CFR.y_: D['yf'][I_valid,:], \ + CFR.do_in: 1.0, CFR.do_out: 1.0, CFR.r_alpha: FLAGS.p_alpha, \ + CFR.r_lambda: FLAGS.p_lambda, CFR.p_t: p_treated} + + if D['HAVE_TRUTH']: + dict_cfactual = {CFR.x: D['x'][I_train,:], CFR.t: 1-D['t'][I_train,:], CFR.y_: D['ycf'][I_train,:], \ + CFR.do_in: 1.0, CFR.do_out: 1.0} + + ''' Initialize TensorFlow variables ''' + sess.run(tf.global_variables_initializer()) + + ''' Set up for storing predictions ''' + preds_train = [] + preds_test = [] + + ''' Compute losses ''' + losses = [] + obj_loss, f_error, imb_err = sess.run([CFR.tot_loss, CFR.pred_loss, CFR.imb_dist],\ + feed_dict=dict_factual) + + cf_error = np.nan + if D['HAVE_TRUTH']: + cf_error = sess.run(CFR.pred_loss, feed_dict=dict_cfactual) + + valid_obj = np.nan; valid_imb = np.nan; valid_f_error = np.nan; + if FLAGS.val_part > 0: + valid_obj, valid_f_error, valid_imb = sess.run([CFR.tot_loss, CFR.pred_loss, CFR.imb_dist],\ + feed_dict=dict_valid) + + losses.append([obj_loss, f_error, cf_error, imb_err, valid_f_error, valid_imb, valid_obj]) + + objnan = False + + reps = [] + reps_test = [] + + ''' Train for multiple iterations ''' + for i in range(FLAGS.iterations): + + ''' Fetch sample ''' + I = random.sample(range(0, n_train), FLAGS.batch_size) + x_batch = D['x'][I_train,:][I,:] + t_batch = D['t'][I_train,:][I] + y_batch = D['yf'][I_train,:][I] + + if __DEBUG__: + M = sess.run(cfr.pop_dist(CFR.x, CFR.t), feed_dict={CFR.x: x_batch, CFR.t: t_batch}) + log(logfile, 'Median: %.4g, Mean: %.4f, Max: %.4f' % (np.median(M.tolist()), np.mean(M.tolist()), np.amax(M.tolist()))) + + ''' Do one step of gradient descent ''' + if not objnan: + sess.run(train_step, feed_dict={CFR.x: x_batch, CFR.t: t_batch, \ + CFR.y_: y_batch, CFR.do_in: FLAGS.dropout_in, CFR.do_out: FLAGS.dropout_out, \ + CFR.r_alpha: FLAGS.p_alpha, CFR.r_lambda: FLAGS.p_lambda, CFR.p_t: p_treated}) + + ''' Project variable selection weights ''' + if FLAGS.varsel: + wip = simplex_project(sess.run(CFR.weights_in[0]), 1) + sess.run(CFR.projection, feed_dict={CFR.w_proj: wip}) + + ''' Compute loss every N iterations ''' + if i % FLAGS.output_delay == 0 or i==FLAGS.iterations-1: + obj_loss,f_error,imb_err = sess.run([CFR.tot_loss, CFR.pred_loss, CFR.imb_dist], + feed_dict=dict_factual) + + rep = sess.run(CFR.h_rep_norm, feed_dict={CFR.x: D['x'], CFR.do_in: 1.0}) + rep_norm = np.mean(np.sqrt(np.sum(np.square(rep), 1))) + + cf_error = np.nan + if D['HAVE_TRUTH']: + cf_error = sess.run(CFR.pred_loss, feed_dict=dict_cfactual) + + valid_obj = np.nan; valid_imb = np.nan; valid_f_error = np.nan; + if FLAGS.val_part > 0: + valid_obj, valid_f_error, valid_imb = sess.run([CFR.tot_loss, CFR.pred_loss, CFR.imb_dist], feed_dict=dict_valid) + + losses.append([obj_loss, f_error, cf_error, imb_err, valid_f_error, valid_imb, valid_obj]) + loss_str = str(i) + '\tObj: %.3f,\tF: %.3f,\tCf: %.3f,\tImb: %.1g,\tVal: %.3f,\tValImb: %.1g,\tValObj: %.2f,\tRepNrm: %.3f' \ + % (obj_loss, f_error, cf_error, imb_err, valid_f_error, valid_imb, valid_obj, rep_norm) + + if FLAGS.loss == 'log': + y_pred = sess.run(CFR.output, feed_dict={CFR.x: x_batch, \ + CFR.t: t_batch, CFR.do_in: 1.0, CFR.do_out: 1.0}) + y_pred = 1.0*(y_pred > 0.5) + acc = 100*(1 - np.mean(np.abs(y_batch - y_pred))) + loss_str += ',\tAcc: %.2f%%' % acc + + log(logfile, loss_str) + + if np.isnan(obj_loss): + log(logfile,'Experiment %d: Objective is NaN. Skipping.' % i_exp) + objnan = True + + ''' Compute predictions every M iterations ''' + if (FLAGS.pred_output_delay > 0 and i % FLAGS.pred_output_delay == 0) or i==FLAGS.iterations-1: + + y_pred_f = sess.run(CFR.output, feed_dict={CFR.x: D['x'], \ + CFR.t: D['t'], CFR.do_in: 1.0, CFR.do_out: 1.0}) + y_pred_cf = sess.run(CFR.output, feed_dict={CFR.x: D['x'], \ + CFR.t: 1-D['t'], CFR.do_in: 1.0, CFR.do_out: 1.0}) + preds_train.append(np.concatenate((y_pred_f, y_pred_cf),axis=1)) + + if D_test is not None: + y_pred_f_test = sess.run(CFR.output, feed_dict={CFR.x: D_test['x'], \ + CFR.t: D_test['t'], CFR.do_in: 1.0, CFR.do_out: 1.0}) + y_pred_cf_test = sess.run(CFR.output, feed_dict={CFR.x: D_test['x'], \ + CFR.t: 1-D_test['t'], CFR.do_in: 1.0, CFR.do_out: 1.0}) + preds_test.append(np.concatenate((y_pred_f_test, y_pred_cf_test),axis=1)) + + if FLAGS.save_rep and i_exp == 1: + reps_i = sess.run([CFR.h_rep], feed_dict={CFR.x: D['x'], \ + CFR.do_in: 1.0, CFR.do_out: 0.0}) + reps.append(reps_i) + + if D_test is not None: + reps_test_i = sess.run([CFR.h_rep], feed_dict={CFR.x: D_test['x'], \ + CFR.do_in: 1.0, CFR.do_out: 0.0}) + reps_test.append(reps_test_i) + + return losses, preds_train, preds_test, reps, reps_test + +def run(outdir): + """ Runs an experiment and stores result in outdir """ + + ''' Set up paths and start log ''' + npzfile = outdir+'result' + npzfile_test = outdir+'result.test' + repfile = outdir+'reps' + repfile_test = outdir+'reps.test' + outform = outdir+'y_pred' + outform_test = outdir+'y_pred.test' + lossform = outdir+'loss' + logfile = outdir+'log.txt' + f = open(logfile,'w') + f.close() + dataform = FLAGS.datadir + FLAGS.dataform + + has_test = False + if not FLAGS.data_test == '': # if test set supplied + has_test = True + dataform_test = FLAGS.datadir + FLAGS.data_test + + ''' Set random seeds ''' + random.seed(FLAGS.seed) + tf.set_random_seed(FLAGS.seed) + np.random.seed(FLAGS.seed) + + ''' Save parameters ''' + save_config(outdir+'config.txt') + + log(logfile, 'Training with hyperparameters: alpha=%.2g, lambda=%.2g' % (FLAGS.p_alpha,FLAGS.p_lambda)) + + ''' Load Data ''' + npz_input = False + if dataform[-3:] == 'npz': + npz_input = True + if npz_input: + datapath = dataform + if has_test: + datapath_test = dataform_test + else: + datapath = dataform % 1 + if has_test: + datapath_test = dataform_test % 1 + + log(logfile, 'Training data: ' + datapath) + if has_test: + log(logfile, 'Test data: ' + datapath_test) + D = load_data(datapath) + D_test = None + if has_test: + D_test = load_data(datapath_test) + + log(logfile, 'Loaded data with shape [%d,%d]' % (D['n'], D['dim'])) + + ''' Start Session ''' + sess = tf.Session() + + ''' Initialize input placeholders ''' + x = tf.placeholder("float", shape=[None, D['dim']], name='x') # Features + t = tf.placeholder("float", shape=[None, 1], name='t') # Treatent + y_ = tf.placeholder("float", shape=[None, 1], name='y_') # Outcome + + ''' Parameter placeholders ''' + r_alpha = tf.placeholder("float", name='r_alpha') + r_lambda = tf.placeholder("float", name='r_lambda') + do_in = tf.placeholder("float", name='dropout_in') + do_out = tf.placeholder("float", name='dropout_out') + p = tf.placeholder("float", name='p_treated') + + ''' Define model graph ''' + log(logfile, 'Defining graph...\n') + dims = [D['dim'], FLAGS.dim_in, FLAGS.dim_out] + CFR = cfr.cfr_net(x, t, y_, p, FLAGS, r_alpha, r_lambda, do_in, do_out, dims) + + ''' Set up optimizer ''' + global_step = tf.Variable(0, trainable=False) + lr = tf.train.exponential_decay(FLAGS.lrate, global_step, \ + NUM_ITERATIONS_PER_DECAY, FLAGS.lrate_decay, staircase=True) + + opt = None + if FLAGS.optimizer == 'Adagrad': + opt = tf.train.AdagradOptimizer(lr) + elif FLAGS.optimizer == 'GradientDescent': + opt = tf.train.GradientDescentOptimizer(lr) + elif FLAGS.optimizer == 'Adam': + opt = tf.train.AdamOptimizer(lr) + else: + opt = tf.train.RMSPropOptimizer(lr, FLAGS.decay) + + ''' Unused gradient clipping ''' + #gvs = opt.compute_gradients(CFR.tot_loss) + #capped_gvs = [(tf.clip_by_value(grad, -1.0, 1.0), var) for grad, var in gvs] + #train_step = opt.apply_gradients(capped_gvs, global_step=global_step) + + train_step = opt.minimize(CFR.tot_loss,global_step=global_step) + + ''' Set up for saving variables ''' + all_losses = [] + all_preds_train = [] + all_preds_test = [] + all_valid = [] + if FLAGS.varsel: + all_weights = None + all_beta = None + + all_preds_test = [] + + ''' Handle repetitions ''' + n_experiments = FLAGS.experiments + if FLAGS.repetitions>1: + if FLAGS.experiments>1: + log(logfile, 'ERROR: Use of both repetitions and multiple experiments is currently not supported.') + sys.exit(1) + n_experiments = FLAGS.repetitions + + ''' Run for all repeated experiments ''' + for i_exp in range(1,n_experiments+1): + + if FLAGS.repetitions>1: + log(logfile, 'Training on repeated initialization %d/%d...' % (i_exp, FLAGS.repetitions)) + else: + log(logfile, 'Training on experiment %d/%d...' % (i_exp, n_experiments)) + + ''' Load Data (if multiple repetitions, reuse first set)''' + + if i_exp==1 or FLAGS.experiments>1: + D_exp_test = None + if npz_input: + D_exp = {} + D_exp['x'] = D['x'][:,:,i_exp-1] + D_exp['t'] = D['t'][:,i_exp-1:i_exp] + D_exp['yf'] = D['yf'][:,i_exp-1:i_exp] + if D['HAVE_TRUTH']: + D_exp['ycf'] = D['ycf'][:,i_exp-1:i_exp] + else: + D_exp['ycf'] = None + + if has_test: + D_exp_test = {} + D_exp_test['x'] = D_test['x'][:,:,i_exp-1] + D_exp_test['t'] = D_test['t'][:,i_exp-1:i_exp] + D_exp_test['yf'] = D_test['yf'][:,i_exp-1:i_exp] + if D_test['HAVE_TRUTH']: + D_exp_test['ycf'] = D_test['ycf'][:,i_exp-1:i_exp] + else: + D_exp_test['ycf'] = None + else: + datapath = dataform % i_exp + D_exp = load_data(datapath) + if has_test: + datapath_test = dataform_test % i_exp + D_exp_test = load_data(datapath_test) + + D_exp['HAVE_TRUTH'] = D['HAVE_TRUTH'] + if has_test: + D_exp_test['HAVE_TRUTH'] = D_test['HAVE_TRUTH'] + + ''' Split into training and validation sets ''' + I_train, I_valid = validation_split(D_exp, FLAGS.val_part) + + ''' Run training loop ''' + losses, preds_train, preds_test, reps, reps_test = \ + train(CFR, sess, train_step, D_exp, I_valid, \ + D_exp_test, logfile, i_exp) + + ''' Collect all reps ''' + all_preds_train.append(preds_train) + all_preds_test.append(preds_test) + all_losses.append(losses) + + ''' Fix shape for output (n_units, dim, n_reps, n_outputs) ''' + out_preds_train = np.swapaxes(np.swapaxes(all_preds_train,1,3),0,2) + if has_test: + out_preds_test = np.swapaxes(np.swapaxes(all_preds_test,1,3),0,2) + out_losses = np.swapaxes(np.swapaxes(all_losses,0,2),0,1) + + ''' Store predictions ''' + log(logfile, 'Saving result to %s...\n' % outdir) + if FLAGS.output_csv: + np.savetxt('%s_%d.csv' % (outform,i_exp), preds_train[-1], delimiter=',') + np.savetxt('%s_%d.csv' % (outform_test,i_exp), preds_test[-1], delimiter=',') + np.savetxt('%s_%d.csv' % (lossform,i_exp), losses, delimiter=',') + + ''' Compute weights if doing variable selection ''' + if FLAGS.varsel: + if i_exp == 1: + all_weights = sess.run(CFR.weights_in[0]) + all_beta = sess.run(CFR.weights_pred) + else: + all_weights = np.dstack((all_weights, sess.run(CFR.weights_in[0]))) + all_beta = np.dstack((all_beta, sess.run(CFR.weights_pred))) + + ''' Save results and predictions ''' + all_valid.append(I_valid) + if FLAGS.varsel: + np.savez(npzfile, pred=out_preds_train, loss=out_losses, w=all_weights, beta=all_beta, val=np.array(all_valid)) + else: + np.savez(npzfile, pred=out_preds_train, loss=out_losses, val=np.array(all_valid)) + + if has_test: + np.savez(npzfile_test, pred=out_preds_test) + + ''' Save representations ''' + if FLAGS.save_rep and i_exp == 1: + np.savez(repfile, rep=reps) + + if has_test: + np.savez(repfile_test, rep=reps_test) + +def main(argv=None): # pylint: disable=unused-argument + """ Main entry point """ + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S-%f") + outdir = FLAGS.outdir+'/results_'+timestamp+'/' + os.mkdir(outdir) + + try: + run(outdir) + except Exception as e: + with open(outdir+'error.txt','w') as errfile: + errfile.write(''.join(traceback.format_exception(*sys.exc_info()))) + raise + +if __name__ == '__main__': + tf.app.run() diff --git a/cfr_param_search.py b/cfr_param_search.py new file mode 100644 index 0000000..0eff053 --- /dev/null +++ b/cfr_param_search.py @@ -0,0 +1,82 @@ +import sys +import os +import numpy as np +from subprocess import call + +def load_config(cfg_file): + cfg = {} + + with open(cfg_file,'r') as f: + for l in f: + l = l.strip() + if len(l)>0 and not l[0] == '#': + vs = l.split('=') + if len(vs)>0: + k,v = (vs[0], eval(vs[1])) + if not isinstance(v,list): + v = [v] + cfg[k] = v + return cfg + +def sample_config(configs): + cfg_sample = {} + for k in configs.keys(): + opts = configs[k] + c = np.random.choice(len(opts),1)[0] + cfg_sample[k] = opts[c] + return cfg_sample + +def cfg_string(cfg): + ks = sorted(cfg.keys()) + cfg_str = ','.join(['%s:%s' % (k, str(cfg[k])) for k in ks]) + return cfg_str.lower() + +def is_used_cfg(cfg, used_cfg_file): + cfg_str = cfg_string(cfg) + used_cfgs = read_used_cfgs(used_cfg_file) + return cfg_str in used_cfgs + +def read_used_cfgs(used_cfg_file): + used_cfgs = set() + with open(used_cfg_file, 'r') as f: + for l in f: + used_cfgs.add(l.strip()) + + return used_cfgs + +def save_used_cfg(cfg, used_cfg_file): + with open(used_cfg_file, 'a') as f: + cfg_str = cfg_string(cfg) + f.write('%s\n' % cfg_str) + +def run(cfg_file, num_runs): + configs = load_config(cfg_file) + + outdir = configs['outdir'][0] + used_cfg_file = '%s/used_configs.txt' % outdir + + if not os.path.isfile(used_cfg_file): + f = open(used_cfg_file, 'w') + f.close() + + for i in range(num_runs): + cfg = sample_config(configs) + if is_used_cfg(cfg, used_cfg_file): + print 'Configuration used, skipping' + continue + + save_used_cfg(cfg, used_cfg_file) + + print '------------------------------' + print 'Run %d of %d:' % (i+1, num_runs) + print '------------------------------' + print '\n'.join(['%s: %s' % (str(k), str(v)) for k,v in cfg.iteritems() if len(configs[k])>1]) + + flags = ' '.join('--%s %s' % (k,str(v)) for k,v in cfg.iteritems()) + call('python cfr_net_train.py %s' % flags, shell=True) + +if __name__ == "__main__": + if len(sys.argv) < 3: + print 'Usage: python evaluate.py ' + else: + run(sys.argv[1], int(sys.argv[2])) diff --git a/configs/example_ihdp.txt b/configs/example_ihdp.txt new file mode 100755 index 0000000..7d590a5 --- /dev/null +++ b/configs/example_ihdp.txt @@ -0,0 +1,37 @@ +p_alpha=[0, 1e-4, 3e-4, 1e-3, 3e-4, 1e-2, 3e-2, 1e-1, 3e-1, 1, 3, 10, 30] +p_lambda=[1e-3] +n_in=[2] +n_out=[2] +dropout_in=1.0 +dropout_out=[1.0] +nonlin='elu' +lrate=[1e-2] +lrate_decay=0.97 +decay=0.3 +optimizer='Adam' +batch_size=[100] +dim_in=[100] +dim_out=[50] +batch_norm=[1] +normalization=['bn_fixed'] +rbf_sigma=0.01 +imb_fun=['mmd2_lin', 'wass'] +wass_lambda=10.0 +wass_iterations=10 +wass_bpt=1 +use_p_correction=0 +reweight_sample=1 +experiments=1000 +iterations=4000 +weight_init=[0.1] +outdir='../results/ihdp_search_2-2' +datadir='../data/ihdp/' +dataform='ihdp_npci_1-1000.train.npz' +data_test='ihdp_npci_1-1000.test.npz' +pred_output_delay=200 +loss='l2' +sparse=0 +varsel=0 +repetitions=1 +val_part=0.3 +split_output=[1] diff --git a/data/ihdp_npci_1-100.test.npz b/data/ihdp_npci_1-100.test.npz new file mode 100644 index 0000000..c807538 Binary files /dev/null and b/data/ihdp_npci_1-100.test.npz differ diff --git a/data/ihdp_npci_1-100.train.npz b/data/ihdp_npci_1-100.train.npz new file mode 100644 index 0000000..d3ea83b Binary files /dev/null and b/data/ihdp_npci_1-100.train.npz differ diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..e757ee8 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,117 @@ +import sys +import os + +import cPickle as pickle + +from cfr.logger import Logger as Log +Log.VERBOSE = True + +import cfr.evaluation as evaluation +from cfr.plotting import * + +def sort_by_config(results, configs, key): + vals = np.array([cfg[key] for cfg in configs]) + I_vals = np.argsort(vals) + + for k in results['train'].keys(): + results['train'][k] = results['train'][k][I_vals,] + results['valid'][k] = results['valid'][k][I_vals,] + + if k in results['test']: + results['test'][k] = results['test'][k][I_vals,] + + configs_sorted = [] + for i in I_vals: + configs_sorted.append(configs[i]) + + return results, configs_sorted + +def evaluate(path, dataset, overwrite=False, filters=None): + if not os.path.isdir(path): + raise Exception('Could not find output at path: %s' % path) + + output_dir = path + + if dataset==1: + data_train = '../data/LaLonde/jobs_dw_bin.train.npz' + data_test = '../data/LaLonde/jobs_dw_bin.test.npz' + binary = True + elif dataset==2: + data_train = '../data/LaLonde/jobs_dw_bin.new.10.train.npz' + data_test = '../data/LaLonde/jobs_dw_bin.new.10.test.npz' + binary = True + elif dataset==3: + data_train = '../data/LaLonde/jobs_DW_bin.bias.married.10.train.npz' + data_test = '../data/LaLonde/jobs_DW_bin.bias.married.10.test.npz' + binary = True + elif dataset==4: + data_train = '../data/LaLonde/jobs_DW_bin.bias.nodegr.10.train.npz' + data_test = '../data/LaLonde/jobs_DW_bin.bias.nodegr.10.test.npz' + binary = True + elif dataset==5: + data_train = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_1_1-500.npz' + data_test = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_1_1-500.npz' + binary = False + elif dataset==6: + data_train = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_2_1-500.npz' + data_test = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_2_1-500.npz' + binary = False + elif dataset==7: + data_train = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_3_1-500.npz' + data_test = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_3_1-500.npz' + binary = False + elif dataset==8: + data_train = '../data/LaLonde/jobs_DW_bin.bias.nodegr.25.train.npz' + data_test = '../data/LaLonde/jobs_DW_bin.bias.nodegr.25.test.npz' + binary = True + else: + data_train = '../data/ihdp/ihdp_npci_1-1000.train.npz' + data_test = '../data/ihdp/ihdp_npci_1-1000.test.npz' + binary = False + + + # Evaluate results + eval_path = '%s/evaluation.npz' % output_dir + if overwrite or (not os.path.isfile(eval_path)): + eval_results, configs = evaluation.evaluate(output_dir, + data_path_train=data_train, + data_path_test=data_test, + binary=binary) + # Save evaluation + pickle.dump((eval_results, configs), open(eval_path, "wb")) + else: + if Log.VERBOSE: + print 'Loading evaluation results from %s...' % eval_path + # Load evaluation + eval_results, configs = pickle.load(open(eval_path, "rb")) + + # Sort by alpha + #eval_results, configs = sort_by_config(eval_results, configs, 'p_alpha') + + # Print evaluation results + if binary: + plot_evaluation_bin(eval_results, configs, output_dir, data_train, data_test, filters) + else: + plot_evaluation_cont(eval_results, configs, output_dir, data_train, data_test, filters) + + # Plot evaluation + #if configs[0]['loss'] == 'log': + # plot_cfr_evaluation_bin(eval_results, configs, output_dir) + #else: + # plot_cfr_evaluation_cont(eval_results, configs, output_dir) + +if __name__ == "__main__": + if len(sys.argv) < 3: + print 'Usage: python evaluate.py ' + else: + dataset = int(sys.argv[2]) + + overwrite = False + if len(sys.argv)>3 and sys.argv[3] == '1': + overwrite = True + + filters = None + if len(sys.argv)>4: + filters = eval(sys.argv[4]) + + evaluate(sys.argv[1], dataset, overwrite, filters=filters) diff --git a/example_ihdp.sh b/example_ihdp.sh new file mode 100644 index 0000000..e757ee8 --- /dev/null +++ b/example_ihdp.sh @@ -0,0 +1,117 @@ +import sys +import os + +import cPickle as pickle + +from cfr.logger import Logger as Log +Log.VERBOSE = True + +import cfr.evaluation as evaluation +from cfr.plotting import * + +def sort_by_config(results, configs, key): + vals = np.array([cfg[key] for cfg in configs]) + I_vals = np.argsort(vals) + + for k in results['train'].keys(): + results['train'][k] = results['train'][k][I_vals,] + results['valid'][k] = results['valid'][k][I_vals,] + + if k in results['test']: + results['test'][k] = results['test'][k][I_vals,] + + configs_sorted = [] + for i in I_vals: + configs_sorted.append(configs[i]) + + return results, configs_sorted + +def evaluate(path, dataset, overwrite=False, filters=None): + if not os.path.isdir(path): + raise Exception('Could not find output at path: %s' % path) + + output_dir = path + + if dataset==1: + data_train = '../data/LaLonde/jobs_dw_bin.train.npz' + data_test = '../data/LaLonde/jobs_dw_bin.test.npz' + binary = True + elif dataset==2: + data_train = '../data/LaLonde/jobs_dw_bin.new.10.train.npz' + data_test = '../data/LaLonde/jobs_dw_bin.new.10.test.npz' + binary = True + elif dataset==3: + data_train = '../data/LaLonde/jobs_DW_bin.bias.married.10.train.npz' + data_test = '../data/LaLonde/jobs_DW_bin.bias.married.10.test.npz' + binary = True + elif dataset==4: + data_train = '../data/LaLonde/jobs_DW_bin.bias.nodegr.10.train.npz' + data_test = '../data/LaLonde/jobs_DW_bin.bias.nodegr.10.test.npz' + binary = True + elif dataset==5: + data_train = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_1_1-500.npz' + data_test = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_1_1-500.npz' + binary = False + elif dataset==6: + data_train = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_2_1-500.npz' + data_test = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_2_1-500.npz' + binary = False + elif dataset==7: + data_train = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_3_1-500.npz' + data_test = '../data/ihdp/ihdp_imb/ihdp_imb_p0_400_3_1-500.npz' + binary = False + elif dataset==8: + data_train = '../data/LaLonde/jobs_DW_bin.bias.nodegr.25.train.npz' + data_test = '../data/LaLonde/jobs_DW_bin.bias.nodegr.25.test.npz' + binary = True + else: + data_train = '../data/ihdp/ihdp_npci_1-1000.train.npz' + data_test = '../data/ihdp/ihdp_npci_1-1000.test.npz' + binary = False + + + # Evaluate results + eval_path = '%s/evaluation.npz' % output_dir + if overwrite or (not os.path.isfile(eval_path)): + eval_results, configs = evaluation.evaluate(output_dir, + data_path_train=data_train, + data_path_test=data_test, + binary=binary) + # Save evaluation + pickle.dump((eval_results, configs), open(eval_path, "wb")) + else: + if Log.VERBOSE: + print 'Loading evaluation results from %s...' % eval_path + # Load evaluation + eval_results, configs = pickle.load(open(eval_path, "rb")) + + # Sort by alpha + #eval_results, configs = sort_by_config(eval_results, configs, 'p_alpha') + + # Print evaluation results + if binary: + plot_evaluation_bin(eval_results, configs, output_dir, data_train, data_test, filters) + else: + plot_evaluation_cont(eval_results, configs, output_dir, data_train, data_test, filters) + + # Plot evaluation + #if configs[0]['loss'] == 'log': + # plot_cfr_evaluation_bin(eval_results, configs, output_dir) + #else: + # plot_cfr_evaluation_cont(eval_results, configs, output_dir) + +if __name__ == "__main__": + if len(sys.argv) < 3: + print 'Usage: python evaluate.py ' + else: + dataset = int(sys.argv[2]) + + overwrite = False + if len(sys.argv)>3 and sys.argv[3] == '1': + overwrite = True + + filters = None + if len(sys.argv)>4: + filters = eval(sys.argv[4]) + + evaluate(sys.argv[1], dataset, overwrite, filters=filters)