From a7893ba85c04c1a53762458fde59d367c15847aa Mon Sep 17 00:00:00 2001 From: PJ v M Date: Thu, 14 Sep 2023 13:36:40 +0000 Subject: [PATCH] adapt genann_train to general activation function Note: requires the user to specify a "differential expression" for the activation function, by which I mean its derivative in terms of its function value. Thus limited to strictly increasing, differentiable functions. --- genann.c | 11 +++++++++-- genann.h | 6 ++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/genann.c b/genann.c index b05fa4f..b871f71 100644 --- a/genann.c +++ b/genann.c @@ -73,6 +73,10 @@ double genann_act_sigmoid(const genann *ann unused, double a) { return 1.0 / (1 + exp(-a)); } +double genann_act_diffexpr_sigmoid(const genann * ann unused, double y) { + return y*(1.0-y); +} + void genann_init_sigmoid_lookup(const genann *ann) { const double f = (sigmoid_dom_max - sigmoid_dom_min) / LOOKUP_SIZE; int i; @@ -143,6 +147,9 @@ genann *genann_init(int inputs, int hidden_layers, int hidden, int outputs) { genann_init_sigmoid_lookup(ret); + ret->diffexpr_activation_hidden = genann_act_diffexpr_sigmoid; + ret->diffexpr_activation_output = genann_act_diffexpr_sigmoid; + return ret; } @@ -296,7 +303,7 @@ void genann_train(genann const *ann, double const *inputs, double const *desired } } else { for (j = 0; j < ann->outputs; ++j) { - *d++ = (*t - *o) * *o * (1.0 - *o); + *d++ = (*t - *o) * ann->diffexpr_activation_output(ann, *o); ++o; ++t; } } @@ -328,7 +335,7 @@ void genann_train(genann const *ann, double const *inputs, double const *desired delta += forward_delta * forward_weight; } - *d = *o * (1.0-*o) * delta; + *d = ann->diffexpr_activation_hidden(ann, *o) * delta; ++d; ++o; } } diff --git a/genann.h b/genann.h index e4b7383..8159de5 100644 --- a/genann.h +++ b/genann.h @@ -53,6 +53,11 @@ typedef struct genann { /* Which activation function to use for output. Default: gennann_act_sigmoid_cached*/ genann_actfun activation_output; + /* Derivative of the activation function, expressed in terms of the function value; i.e. f'(f_inverse(y)) + * Used for backpropagation. Default: y(1-y), corresponding to the sigmoid. */ + genann_actfun diffexpr_activation_hidden; + genann_actfun diffexpr_activation_output; + /* Total number of weights, and size of weights buffer. */ int total_weights; @@ -97,6 +102,7 @@ void genann_write(genann const *ann, FILE *out); void genann_init_sigmoid_lookup(const genann *ann); double genann_act_sigmoid(const genann *ann, double a); double genann_act_sigmoid_cached(const genann *ann, double a); +double genann_act_diffexpr_sigmoid(const genann *ann, double y); double genann_act_threshold(const genann *ann, double a); double genann_act_linear(const genann *ann, double a);