#include #include #include #include #include #include #include static struct machine m[1] = {{ .metrics = "round,hash,pc,excited,bitrate,delta,wrong,error,mips", .rounds = LONG_MAX, }}; static struct fann_train_data *D; static mpz_t x, y, d, e, u; static unsigned char *b; static long a, n; int main(int argc, char *argv[]) { sigset_t mask; FILE *stream; char *path; long c, r; /* Parse command line arguments. */ if (options(m, argc, argv) < 0) goto failure; /* Print usage statement and exit if so configured. */ if (m->flags & HELP) { help(); goto success; } /* Initialize vectors. */ mpz_inits(x, y, d, e, u, 0); /* Prepare execution context. */ if ((n = initial(y, x, m, argc - optind, argv + optind)) < 0) goto failure; /* Allocate buffer for portable cryptographic hashes. */ if ((b = calloc(4 + n / 8, sizeof(unsigned char))) == 0) goto failure; /* Display output. */ if (status(r = 0, x, y, u, d, e, b, a = 0, m) < 0) goto failure; /* Initial transient integration loop. */ for (r = 1, mpz_set(x, y); r < MIN(3, m->rounds); r++, mpz_swap(x, y)) { /* Execute one superstep. */ if ((c = integrate(y, x, m, m->mode)) == 0) goto failure; /* Compute difference. */ mpz_xor(d, y, x); /* Display output. */ if (status(r, x, y, u, d, e, b, a = 0, m) < 0) goto failure; /* Stop if child process requests exit. */ if (c < 0) goto success; } if (m->excitations) mpz_set(e, m->excitations); if (m->data) { D = m->data; } else { D = calloc(1, sizeof(struct fann_train_data)); } /* Master integrator loop. */ for ( ; r < m->rounds; r++, mpz_swap(x, y)) { /* Issue prediction. */ predict(u, &m->model, D, x, e); /* Execute one superstep. */ if ((c = integrate(y, x, m, m->mode)) == 0) goto failure; /* Check for correct prediction. */ if (mpz_cmp(y, u) == 0) a++; /* Compute difference. */ mpz_xor(d, y, x); /* Display output. */ if (status(r, x, y, u, d, e, b, a, m) < 0) goto failure; /* Stop if child process requests exit. */ if (c < 0) break; /* Compute excitations. */ mpz_ior(e, e, d); m->Bits += mpz_popcount(d); /* Train on-line weak learner. */ if (update(m->model, D, x, d, u, e) < 0) goto failure; } if (sigemptyset(&mask) < 0) { diagnostic("sigemptyset"); goto failure; } if (sigaddset(&mask, SIGINT) < 0) { diagnostic("sigaddset"); goto failure; } /* Mask interrupts. */ if (sigprocmask(SIG_BLOCK, &mask, 0) < 0) { diagnostic("sigprocmask"); goto failure; } /* Write out the excitation vector. */ if (mpz_popcount(e) > 0) { if (asprintf(&path, "%s.excite", argv[optind]) < 0) goto failure; if ((stream = fopen(path, "w")) == 0) { diagnostic("fopen"); goto failure; } if (mpz_out_raw(stream, e) == 0) { diagnostic("mpz_out_raw"); goto failure; } if (fclose(stream)) { diagnostic("fclose"); goto failure; } } /* Write out the training set. */ if (D->num_data > 0) { if (asprintf(&path, "%s.train", argv[optind]) < 0) goto failure; if (fann_save_train(D, path) < 0) goto failure; } /* Write out the weak learner. */ if (m->model && fann_get_learning_rate(m->model) > 0) { if (asprintf(&path, "%s.net", argv[optind]) < 0) goto failure; if (fann_save(m->model, path) < 0) goto failure; } /* Unmask interrupts. */ if (sigprocmask(SIG_UNBLOCK, &mask, 0) < 0) { diagnostic("sigprocmask"); goto failure; } success: return 0; failure: return 1; }