#include #include #include #include "BanditWorld.h" #define PI 3.1415926535897932384626 #define SQ(x) (x*x) #define POWER(x,y) (exp(log(x)*y)) #define True 1 #define False 0 typedef struct normal_s { double mean; double var; } normal_t; /* Approximation for the max of two independant Normal random variables. */ void nmax (normal_t X, normal_t Y, normal_t *Res) { double a = X.var + Y.var; double b = X.mean - Y.mean; double c = -0.5 * SQ(b) / a; double d = c * ((8.0 * (PI-3.0)) / (3.0 * PI * (PI-4.0))); double p = 0.5 * ((b>0)? 1.0 + sqrt (1.0 - exp (c * (d + (4.0/PI)) / (d + 1.0))) : 1.0 - sqrt (1.0 - exp (c * (d + (4.0/PI)) / (d + 1.0)))); double np = 1.0 - p; double q = sqrt(a * (1.0/(2.0*PI))) * exp(c); double m = p * X.mean + np * Y.mean + q; double sq_m_plus_v = p * (SQ(X.mean) + X.var) + np * (SQ(Y.mean) + Y.var) + q * (X.mean + Y.mean); Res->mean = m; Res->var = sq_m_plus_v - SQ(m); } typedef struct state_eval_s { double n; /* number of visits to that state */ normal_t Exp_reward; /* estimator of expected reward */ normal_t Exp_return; /* estimator of expected return */ struct state_eval_s ** _succ; /* pointers to evaluation structures of successors to that state */ /**** fields for estimating the expected reward : ****/ double nr; double sum_r; /* sum of reward */ double sum_sq_r; /* sum of squared rewards */ } state_eval_t; int state_count; int succ_count; /* number of successors for each state */ double gama, sq_gama; /* discount factor and its square */ double horizon; /* 1/(1-gama) */ double exploration_need; state_eval_t * _state; /* array of state evaluation structures */ /* state_eval_t *unknown_state; */ state_eval_t ** _history; /* queue (looping array) of pointers to evaluators of last visited states */ state_eval_t **h_ptr; /* points to the cell pointing to the last visited state */ state_eval_t **h_end; /* points to the last cell of history array */ int h_full; /* tells if history array has been entirely filled up (t >= length(history)) => must we loop when reaching the array boundaries */ FILE *file; state_t init (int argc, char *argv[]) { int i; state_eval_t *s; int h_length; char filename[100]; uint32_t seed = strtol (argv[2], NULL, 10); state_count = (argc>3)? strtol (argv[3], NULL, 10) : 1000; succ_count = (argc>4)? strtol (argv[4], NULL, 10) : 2; gama = 1.0 - 1.0 / sqrt((float) state_count); sq_gama = SQ(gama); horizon = 1.0/(1.0-gama); exploration_need = POWER((double)succ_count, horizon); _state = malloc (state_count * sizeof(state_eval_t)); h_length = (int) (log(0.05)/log(gama)); _history = malloc (h_length * sizeof(state_eval_t *)); h_end = _history + h_length - 1; h_ptr = _history; h_full = False; for (i=0, s=_state; in = 0.0; s->Exp_reward.mean = 0.5; s->Exp_reward.var = SQ(0.5); s->nr = 2.0; s->sum_r = s->nr * s->Exp_reward.mean; s->sum_sq_r = s->nr * (s->nr * s->Exp_reward.var + SQ(s->Exp_reward.mean)); s->Exp_return.mean = 0.5 * horizon; s->Exp_return.var = ((1.0 / 3.0) - SQ(0.5)) * horizon; s->_succ = malloc (succ_count * sizeof(state_eval_t *)); } sprintf (filename, "bandit_world_%d_%d_%d.eval.dat", seed, state_count, succ_count); file = fopen (filename, "w"); fprintf (file, "# state mean var \n"); return make_MDP (seed, state_count, succ_count); } void _exit() { int i; state_eval_t *s; close_MDP(); for (s = _state, i=0; iExp_return.mean, s->Exp_return.var); /* WHY does the following cause 'double free or corruption' error?? */ /* free (s->_succ); */ } fclose (file); free (_state); free (_history); } #define UCB_NORMAL(V,eta) V.mean + sqrt (16.0 * V.var * eta); action_t pi (state_eval_t *s) { double eta = log (exploration_need * s->n); int i, argmax = 0; double x, max = UCB_NORMAL(s->_succ[0]->Exp_return, eta); for (i=1; i_succ[i]->Exp_return, eta); if (x>max) {max = x; argmax = i;} } return argmax; } int main (int argc, char *argv[]) { state_t s_id = init (argc, argv); state_eval_t *s, *hs, **sp; action_t u; reward_t r; int i, t, T = strtol (argv[1], NULL, 10); double x; normal_t Vmax; *h_ptr = s = _state + s_id; for (i=0; i_succ[i] = _state + successor(i); for (t=0; tn += 1.0; act (u=pi(s), &s_id, &r); /* add new state to history */ if (h_ptr == h_end) {h_ptr = _history; h_full = True;} else h_ptr++; *h_ptr = s = _state + s_id; if (s->n == 0) for (i=0; i_succ[i] = _state + successor(i); /* update expected reward estimator */ x = 1.0 / (s->nr += 1.0); s->Exp_reward.mean = (s->sum_r += r) * x; s->Exp_reward.var = ((s->sum_sq_r += SQ(r)) * x - SQ(s->Exp_reward.mean)) * x; /* backward-chained update of state values along history */ sp = h_ptr; while (True) { hs = *sp; Vmax=hs->_succ[0]->Exp_return; for (i=1; i_succ[i]->Exp_return, &Vmax); hs->Exp_return.mean = hs->Exp_reward.mean + gama * Vmax.mean; hs->Exp_return.var = hs->Exp_reward.var + sq_gama * Vmax.var; if (sp == _history) if (h_full) sp = h_end; else break; else sp--; if (sp == h_ptr) break; } } _exit(); return 0; }