-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpropdiff_1.py.lp
174 lines (127 loc) · 5.89 KB
/
propdiff_1.py.lp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
% Solving cost functions with Clingo + Python using propagators
% Version 0.1
% Copyright (c) 2018 Matthias Nickles
%
% Licensed under the Apache License, Version 2.0 (the "License");
% you may not use this file except in compliance with the License.
% You may obtain a copy of the License at
%
% http://www.apache.org/licenses/LICENSE-2.0
%
% Unless required by applicable law or agreed to in writing, software
% distributed under the License is distributed on an "AS IS" BASIS,
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
% See the License for the specific language governing permissions and
% limitations under the License.
#script(python)
import clingo
import re
import random
import sys
import ad # ad-1.3.2 (for automatic differentiation, https://pypi.org/project/ad/#description)
from ad import adnumber
import copy
branch_param_lit = sys.maxint
param_atoms = {}
__inner_costs = {}
h = 0.0001
# for psi and max_trials see below
class Propagator:
solver_lits = {}
def init(self, init):
global param_atoms
# Initializing propagator...
for atomproglit, atom in param_atoms.iteritems():
solver_lit = init.solver_literal(atomproglit)
Propagator.solver_lits[atomproglit] = solver_lit
init.add_watch(solver_lit)
init.add_watch(-solver_lit)
init.check_mode = clingo.PropagatorCheckMode.Fixpoint
def propagate(self, control, changes):
global branch_param_lit
global param_atoms
if branch_param_lit != sys.maxint:
if branch_param_lit > 0:
control.add_clause([Propagator.solver_lits[branch_param_lit]], True)
else:
control.add_clause([-Propagator.solver_lits[abs(branch_param_lit)]], True)
def main(prg):
global current_cost
global branch_param_lit
global param_atoms
global __inner_costs
psi = 1.0E-4 # cost function convergence threshold (accuracy)
max_trials = 1000
prg.configuration.solver.sign_def = "rnd"
#prg.configuration.solver.rand_freq = 0.7
prg.configuration.solver.seed = random.randint(0, sys.maxint)
models = []
def __cost(freqs):
return __cost_ad(freqs, "")
def __cost_ad(freqs, atom_x): # using automatic differentiation (except if atom_x=="", then evaluate cost expression)
ps = [ (adnumber(freqs['a']), 'a'), (adnumber(freqs['b']), 'b') ]
c = ( (1*(0.2-ps[0][0])**2) + (1*(0.6-ps[1][0])**2)) / 2 # example MSE-shaped cost function (0.2 = target probability of a,
# 0.6 = target probability of b)
if atom_x == "":
return c
else:
pxi = next(i for i,v in enumerate(ps) if v[1] == atom_x)
return c.d(ps[pxi][0])
def __cost_upd(freqs, upd_atom): # using single point numerical differentiation
global h
if len(freqs) == 0:
return sys.float_info.max
else:
freqs_upd = freqs.copy()
freqs_upd[upd_atom] += h
return __cost(freqs_upd)
def update_cost():
global branch_param_lit
global param_atoms
global h
use_auto_diff = False # choose numerical or automatic differentiation (using the ad package)
freqs = {}
for atomlit, atom in param_atoms.iteritems():
if len(models) == 0:
freqs[atom] = 0
else:
freqs[atom] = float(sum(ms.find(" " + atom + " ") != -1 for ms in models)) / float(len(models))
# linear-time freqs updates possible too (straightforward)
current_cost = __cost(freqs)
print("Current cost = " + str(current_cost))
min_diff_lit = (param_atoms.keys()[0], sys.float_info.max)
for atomlit, atom in param_atoms.iteritems(): # we search for the minimum partial derivative
diff_p = __cost_ad(freqs, atom) if use_auto_diff else (__cost_upd(freqs, atom) - current_cost) / h
if diff_p < min_diff_lit[1]:
min_diff_lit = (atomlit, diff_p)
diff_n = -diff_p
if diff_n < min_diff_lit[1]:
min_diff_lit = (-atomlit, diff_n)
branch_param_lit = min_diff_lit[0]
return current_cost
def on_model(m):
posAtoms = m.symbols(atoms=True)
ms = " ".join(map(str, m.symbols(atoms=True)))
models.append(" " + ms + " ") # we add the set of atoms in m as a new model to the sample multiset
prg.configuration.solve.models = 1
prg.register_propagator(Propagator()) # see class Propagator
prg.ground([("base", [])])
for atom in prg.symbolic_atoms:
atom_str = str(atom.symbol)
if atom_str in [ "b", "a" ]:
param_atoms[atom.literal] = atom_str
i = 1
while i <= max_trials:
print("Model sampling iteration: " + str(i) + "...")
ret = prg.solve(on_model=on_model)
print(ret)
print("#models sampled so far: " + str(len(models)))
current_cost = update_cost()
if current_cost <= psi: # target accuracy reached
i = sys.maxint
else:
i += 1
if i != sys.maxint:
print("Warning: No convergence with psi = " + str(psi) + " and max_trials = " + str(max_trials))
# return 'models', or compute further models st. current_cost<=psi, or sample uniformly from 'models'
#end.