(*
 *
 * This is a primitive irrational-base discrete weighted transform taken
 * from pseudocode found in: "Prime Numbers: A Computational Perspective"
 *
 * The code was originally written by Rick Lavoie and translated to ATS
 * by Hongwei Xi (November, 2005)
 *
 * Absolutely no optimization has gone into this code; written for clarity,
 * not speed.
 *
 *)

(*
 *
 * The code is ported to ATS/Geizella by Hongwei Xi (July 2007)
 *
 * test: 2 ^ 44497 - 1 is a prime! (BITSIZE = 16)
 * time: 357.479u 0.234s 5:59.57 99.4%	0+0k 0+0io 0pf+0w
 *
 * test: 2 ^ 110503 - 1 is a prime! (BITSIZE = 16) 
 * time: 2171.782u 2.094s 36:18.67 99.7% 0+0k 0+0io 0pf+0w
 *
 *)

//

staload "libc/SATS/math.sats"
staload "libc/SATS/stdio.sats"

%{^

#include "prelude/CATS/array.cats"
#include "libc/CATS/math.cats"
#include "libc/CATS/stdio.cats"

static inline
ats_double_type
double_of_double (ats_double_type x) { return x ; }

%}

//

macdef PI = 3.1415926535897932384626

//

typedef real = double

extern fun real_of_double : double -<> real = "double_of_double"
#define d2r real_of_double

// This is an experimental effort to differentiate imaginary from real

abst@ype imag = $extype "ats_double_type"

extern fun imag_of_double : double -<> imag
  = "double_of_double"

#define d2i imag_of_double

extern fun add_imag_imag : (imag, imag) -> imag
  = "atspre_add_double_double"
overload + with add_imag_imag

extern fun sub_imag_imag : (imag, imag) -> imag
  = "atspre_sub_double_double"
overload - with sub_imag_imag

extern fun mul_imag_imag : (imag, imag) -> real
  = "atspre_mul_double_double"

extern fun mul_real_imag : (real, imag) -> imag
  = "atspre_mul_double_double"

extern fun mul_imag_real : (imag, real) -> imag
  = "atspre_mul_double_double"

overload * with mul_imag_imag
overload * with mul_real_imag
overload * with mul_imag_real

//

#define i2d double_of_int

//

fn compute_signal_sum {n:nat} {l:addr}
  (pf: !array_v (real, n, l) | A: ptr l, n: int n): real =
  let
     fun loop {i:nat | i <= n} .<n-i>.
       (pf: !array_v (real, n, l) | i: int i, res: real): real =
       if i < n then loop (pf | i+1, res + A[i]) else res
  in
     loop (pf | 0, d2r 0.0)
  end

fn convolve_signal {n:nat} {l_r,l_i:addr}
  (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) |
   R: ptr l_r, I: ptr l_i, n: int n): void =
  let
     fun loop {i:nat | i <= n} .<n-i>.
       (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) |
        i: int i): void =
       if i < n then
         let
           val t_r = R[i] and t_i = I[i]
           val t_rr = t_r * t_r and t_ii = t_i * t_i and t_ri = t_r * t_i
         in
           R[i] := t_rr - t_ii; I[i] := t_ri + t_ri; loop (pf_r, pf_i | i+1)
         end
       else ()
  in
     loop (pf_r, pf_i | 0)
  end

fn round_signal {n:nat} {l_r,l_i:addr}
  (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) |
   R: ptr l_r, I: ptr l_i, n: int n): void =
  let
     fun loop {i:nat | i <= n} .<n-i>.
       (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) |
        i: int i): void =
       if i < n then begin
         R[i] := floor (R[i] + d2r 0.5);
         I[i] := d2i (0.0);
         loop (pf_r, pf_i | i+1)
       end
  in
     loop (pf_r, pf_i | 0)
  end

//

fun forward_fft {n:nat} {l_r, l_i, l_w:addr}
   (pf_r: !array_v (real, n, l_r) ,
    pf_i: !array_v (imag, n, l_i) ,
    pf_w: !array_v (real, n, l_w) |
    R: ptr l_r, I: ptr l_i, W: ptr l_w, n: int n): real =
  let
     fun aux0 {i:nat | i <= n} .<n-i>.
       (pf_r: !array_v (real, n, l_r), pf_w: !array_v (real, n, l_w) |
        i: int i): void =
       if i < n then (R[i] := R[i] * W[i]; aux0 (pf_r, pf_w | i+1)) else ()

     fun aux1 {i,m:nat}
       (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) |
        i: int i, m: int m, a_r: real, a_i: imag): void =
       let
          val im = i + m
       in
          if im < n then
            let
               val t1_r = R[im] and t1_i = I[im]
               val t2_r = R[i] and t2_i = I[i]
               val () = R[i] := t2_r + t1_r
               and () = I[i] := t2_i + t1_i
               val t1_r = t2_r - t1_r and t1_i = t2_i - t1_i
               val () = R[im] := a_r * t1_r - a_i * t1_i
               val () = I[im] := a_r * t1_i + a_i * t1_r
            in
               aux1 (pf_r, pf_i | im + m, m, a_r, a_i)
            end
          else ()
       end

     fun aux2 {j,m:nat}
       (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) |
        j: int j, m: int m): void =
       if j < m then
         let
            val theta = d2r (PI * (i2d j) / (i2d m))
         in
            aux1 (pf_r, pf_i | j, m, d2r (cos theta), d2i ~(sin theta));
            aux2 (pf_r, pf_i | j+1, m)
         end

     fun aux3 {m:nat}
       (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) | m: int m): void =
       if m >= 1 then (aux2 (pf_r, pf_i | 0, m); aux3 (pf_r, pf_i | nhalf m))

     val () = aux0 (pf_r, pf_w | 0)
     val sum = compute_signal_sum (pf_r | R, n)
  in
     aux3 (pf_r, pf_i | nhalf n); sum
  end

//

fun inverse_fft {n:nat} {l_r, l_i, l_w:addr}
   (pf_r: !array_v (real, n, l_r) ,
    pf_i: !array_v (imag, n, l_i) ,
    pf_w: !array_v (real, n, l_w) |
    R: ptr l_r, I: ptr l_i, W: ptr l_w, n: int n): real =
  let
     fun aux1 {i,m:nat}
       (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) |
       i: int i, m: int m, a_r: real, a_i: imag): void =
       let
          val im = i + m
       in
          if im < n then
            let
               val c_r = R[im] and c_i = I[im]
               val t1_r = c_r * a_r - c_i * a_i
               and t1_i = c_r * a_i + c_i * a_r
               val t2_r = R[i] and t2_i = I[i]
               val () = R[i] := t2_r + t1_r and () = I[i] := t2_i + t1_i
               val () = R[im] := t2_r - t1_r and () = I[im] := t2_i - t1_i
            in
               aux1 (pf_r, pf_i | im + m, m, a_r, a_i)
            end
       end

     fun aux2 {j,m:nat}
       (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) |
        j: int j, m: int m): void =
       if j < m then
         let
            val theta = d2r (PI * (i2d j) / (i2d m))
         in
            aux1 (pf_r, pf_i | j, m, d2r (cos theta), d2i (sin theta));
            aux2 (pf_r, pf_i | j+1, m)
         end

     fun aux3 {m:nat}
       (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) |
        m: int m): void =
       if m < n then (aux2 (pf_r, pf_i | 0, m); aux3 (pf_r, pf_i | m+m))

     fun aux4 {i:nat | i <= n}
       (pf_r: !array_v (real, n, l_r) | i: int i, nf: real): void =
       if i < n then (R[i] := R[i] / nf; aux4 (pf_r | i+1, nf))

     fun aux5 {i:nat | i <= n}
       (pf_r: !array_v (real, n, l_r), pf_w: !array_v (real, n, l_w) |
        i: int i): void =
       if i < n then (R[i] := R[i] / W[i]; aux5 (pf_r, pf_w | i+1))

     val () = aux3 (pf_r, pf_i | 1)
     val () = aux4 (pf_r | 0, d2r (i2d n))
     val sum = compute_signal_sum (pf_r | R, n)
     val () = aux5 (pf_r, pf_w | 0)
  in
     sum
  end

//

fun modular_reduce {n:nat} {l_r,l_b:addr}
  (pf_r: !array_v (real, n, l_r), pf_b: !array_v (real, n, l_b) |
   R: ptr l_r, B: ptr l_b, n: int n): void =
  let
     fun aux1 {i:nat | i <= n}
       (pf_r: !array_v (real, n, l_r), pf_b: !array_v (real, n, l_b) |
        i: int i, carry: real): real =
       if i < n then
         let
            val temp = R[i] + carry
            val bi = B[i]
            val () = R[i] := fmod (temp, bi)
         in
            aux1 (pf_r, pf_b | i+1, floor (temp / bi))
         end
       else carry

     fun aux2 {i:nat | i <= n}
       (pf_r: !array_v (real, n, l_r), pf_b: !array_v (real, n, l_b) |
        i: int i, carry: real): void =
       if i < n then
         if carry = d2r 0.0 then ()
         else let
           val temp = R[i] + carry
           val bi = B[i]
           val () = R[i] := fmod (temp, bi)
         in
            aux2 (pf_r, pf_b | i+1, floor (temp / bi))
         end
       else aux2 (pf_r, pf_b | 0, carry)

     val carry = aux1 (pf_r, pf_b | 0, d2r 0.0)
  in
     aux2 (pf_r, pf_b | 0, carry)
  end

//

fun subtract2 {n:nat} {l_r,l_b:addr}
  (pf_r: !array_v (real, n, l_r), pf_b: !array_v (real, n, l_b) |
   R: ptr l_r, B: ptr l_b, n: int n): void =
  let
     fun aux {i:nat | i <= n}
       (pf_r: !array_v (real, n, l_r), pf_b: !array_v (real, n, l_b) |
        i: int i, carry: real): void =
       if i < n then
         if carry = d2r 0.0 then ()
         else let
           val ri = R[i] and bi = B[i]
         in
           if carry > ri then
             (R[i] := bi - carry + ri; aux (pf_r, pf_b | i+1, d2r 1.0))
           else let
             val t = ri - carry
           in
             R[i] := fmod (t, bi);
             aux (pf_r, pf_b | i+1, floor (t / bi))
           end
         end
       else ()
  in
     aux (pf_r, pf_b | 0, d2r 2.0)
  end

//

#define BITSIZE 16

fun compute_optimal_signal_size (exponent: Nat): intGt 0 =
  let
     fun aux (i: Nat, p: intGt 0): intGt 0 =
       if i < BITSIZE then p else aux (nhalf i, p + p)
  in
     aux (exponent, 1)
  end

//

fun compute_base_signal {n:nat} {l_b:addr}
  (pf_b: !array_v (real, n, l_b) | B: ptr l_b, exponent: Nat, n: int n): void =
  let
    fun aux {i:nat | i <= n}
      (pf_b: !array_v (real, n, l_b) | i: int i): void =
      if i < n then
        let
          val e = i2d exponent
          val en = e / (i2d n)
          val eni = en * (i2d i)
        in
          B[i] := pow (d2r 2.0, d2r (ceil (eni + en) - ceil (eni)));
          aux (pf_b | i+1)
        end
  in
    aux (pf_b | 0)
  end

//

fun compute_weight_signal {n:nat} {l_w:addr}
  (pf_w: !array_v (real, n, l_w) | W: ptr l_w, exponent: Nat, n: int n): void =
  let
     fun aux {i:nat | i <= n} .<n-i>.
       (pf_w: !array_v (real, n, l_w) | i: int i): void =
       if i < n then
         let
           val e = i2d exponent
           val en = e / (i2d n)
           val eni = en * (i2d i)
         in
           W[i] := pow (d2r 2.0, d2r (ceil eni - eni));
           aux (pf_w | i+1)
         end
  in
     aux (pf_w | 0)
  end

//

fun check_zero_signal {n:nat} {l:addr}
   (pf: !array_v (real, n, l) | A: ptr l, n: int n): bool =
  let
    fun aux {i:nat | i <= n} .<n-i>.
      (pf: !array_v (real, n, l) | i: int i): bool =
      if i < n then (if A[i] = d2r 0.0 then aux (pf | i+1) else false)
      else true
  in
    aux (pf | 0)
  end

//

fun do_iteration {n:nat} {l_r,l_i,l_b,l_w:addr}
   (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) , 
    pf_b: !array_v (real, n, l_b), pf_w: !array_v (real, n, l_w) |
    R: ptr l_r, I: ptr l_i, B: ptr l_b, W: ptr l_w, n: int n): void =
  let
     val sum_in = forward_fft (pf_r, pf_i, pf_w | R, I, W, n)
     val () = convolve_signal (pf_r, pf_i | R, I, n)
     val sum_out = inverse_fft (pf_r, pf_i, pf_w | R, I, W, n)
     val diff = abs (sum_in * sum_in - sum_out)
  in
     if diff > d2r 1.0 then begin
       print "Possible error: difference = "; print diff; print_newline ()
     end;
     round_signal (pf_r, pf_i | R, I, n);
     modular_reduce (pf_r, pf_b | R, B, n);
     subtract2 (pf_r, pf_b | R, B, n)
  end

#define STRIDE 100

fun main_loop {n:nat} {l_r,l_i,l_b,l_w:addr}
   (pf_r: !array_v (real, n, l_r), pf_i: !array_v (imag, n, l_i) , 
    pf_b: !array_v (real, n, l_b), pf_w: !array_v (real, n, l_w) |
    R: ptr l_r, I: ptr l_i, B: ptr l_b, W: ptr l_w, n: int n,
    i: Nat, limit: Nat): void =
  if i < limit then begin
    if (i mod STRIDE = 0) then (print '.'; fflush_stdout ());
    do_iteration (pf_r, pf_i, pf_b, pf_w | R, I, B, W, n);
    main_loop (pf_r, pf_i, pf_b, pf_w | R, I, B, W, n, i+1, limit)
  end

// Here are the first few Mersenne primes:
// 2, 3, 5, 7, 13, 17, 19, 31, 61, 89, 107, 127, 521, 607,
// 1279, 2203, 2281, 3217, 4253, 4423, 9689, 9941, 11213,
// 19937, 21701, 23209, 44497, 86243, 110503, 132049, 216091,
// 756839, 859433, 1257787, 1398269, 2976221, 3021377, 6972593,
// 13466917, ...

// The following function tests whether 2^exponent-1 is a prime:
fun is_mersenne_prime (exponent: intGte 2): bool =
  let
     val n = compute_optimal_signal_size (exponent)

     val (pf_r | R) =
       array_ptr_make_fun_tsz {real} (n, lam (x, _) =<> x := d2r 0.0, sizeof<real>)
     val (pf_i | I) =
       array_ptr_make_fun_tsz {imag} (n, lam (x, _) =<> x := d2i 0.0, sizeof<imag>)
     val (pf_b | B) =
       array_ptr_make_fun_tsz {real} (n, lam (x, _) =<> x := d2r 0.0, sizeof<real>)
     val (pf_w | W) =
       array_ptr_make_fun_tsz {real} (n, lam (x, _) =<> x := d2r 0.0, sizeof<real>)

     val () = compute_base_signal (pf_b | B, exponent, n)
     val () = compute_weight_signal (pf_w | W, exponent, n)
     val () = R[0] := d2r 4.0
     val () = main_loop (pf_r, pf_i, pf_b, pf_w | R, I, B, W, n, 0, exponent - 2)
     val () = print_newline ()
     val ans = check_zero_signal (pf_r | R, n)
  in
     array_ptr_free {real} (pf_r | R);
     array_ptr_free {imag} (pf_i | I);
     array_ptr_free {real} (pf_b | B);
     array_ptr_free {real} (pf_w | W);
     ans
  end

//

fn usage (): void = begin
  print "A natural number argument >= 2 is needed.";
  print_newline ()
end

implement main (argc, argv) = let
  val () = if argc <= 1 then (usage (); exit {void} (1))
  val () = assert (argc > 1)
  val exp = int1_of (argv.[1])
  val () = if exp <= 1 then (usage (); exit {void} (1))
  val () =  assert (exp > 1)

in
  if is_mersenne_prime exp then begin
    printf ("2 ^ %i - 1 is a prime!\n", @(exp))
  end else begin
    printf ("2 ^ %i - 1 is a composite.\n", @(exp))
  end
end

////

/* FFT.c
Very primitive implementation of the irrational-base discrete weighted transform.
Written by Rick Lavoie
Taken from pseudocode found in "Prime Numbers: A Computational Perspective"

Absolutely no optimization has gone into this code. It's written for clarity, not speed.
*/
#include <stdio.h>
#include <math.h>
#include <memory.h>
#include <stdlib.h>

#define PI 3.1415926535897932384626

// Displays the contents of a signal
int display_signal(double* real, double* imag, unsigned int signal_size)
{
	unsigned int i;

	if (imag == NULL) { // Pass NULL as the second parameter if we just want the real part
		for (i = 0; i < signal_size; i++)
			printf("%.2f ", real[i]);
	} else {
		for (i = 0; i < signal_size; i++) 
			printf("%.2f+%.2fi ", real[i], imag[i]);
	}

	printf("\n");
	return 0;
}

// Adds together the values of a signal. Used for error-checking
double compute_signal_sum(double* real, unsigned int signal_size)
{
	unsigned int i;
	double running_total = 0.0;

	for (i = 0; i < signal_size; i++) 
		running_total += real[i]; 
	
	return running_total;
}

// Computes the forward fast fourier transform of a complex signal. Uses Gentleman-Sande decimation
// in frequency variant. weight is a weight signal, and signal_size is the size of the signal. real
// and imag are the signal to be transformed. Bitscrambling is not done.
double forward_fast_fourier(double* real, double* imag, double* weight, unsigned int signal_size)
{
	unsigned int m, j, i;
	double a_real, a_imag, temp1_real, temp1_imag, temp2_real, temp2_imag;
	double sum;
	
	for (m = 0; m < signal_size; m++) // Initially multiply the input by the weights
		real[m] *= weight[m];
	
	sum = compute_signal_sum(real, signal_size);

	// Taken straight from "Prime Numbers"
	for (m = signal_size/2; m >= 1; m = m/2) {
		for (j = 0; j < m; j++) {
			a_real = cos((2.0*PI*j*signal_size/(2*m))/signal_size);
			a_imag = -sin((2.0*PI*j*signal_size/(2*m))/signal_size);
			for (i = j; i < signal_size; i += 2*m) {
				temp1_real = real[i+m];
				temp1_imag = imag[i+m];
				temp2_real = real[i];
				temp2_imag = imag[i];
				real[i] = temp2_real + temp1_real;
				imag[i] = temp2_imag + temp1_imag;
				temp1_real = temp2_real - temp1_real;
				temp1_imag = temp2_imag - temp1_imag;
				real[i+m] = (a_real * temp1_real) - (a_imag * temp1_imag);
				imag[i+m] = (a_real * temp1_imag) + (a_imag * temp1_real);
			}
		}
	}

	return sum;
}

// Computes the inverse fast fourier transform. Uses Cooley-Turkey decimation in time. Parameters
// are the same as the forward transform. Bitscrambling is not done, because the use of the two
// different decimations cancels them out.
double inverse_fast_fourier(double* real, double* imag, double* weight, unsigned int signal_size)
{
	unsigned int m, j, i;
	double a_real, a_imag, temp1_real, temp1_imag, temp2_real, temp2_imag;
	double sum;
		
	for (m = 1; m < signal_size; m = 2*m) {
		for (j = 0; j < m; j++) {
			a_real = cos((2.0*PI*j*signal_size/(2*m))/signal_size);
			a_imag = sin((2.0*PI*j*signal_size/(2*m))/signal_size);
			for (i = j; i < signal_size; i += 2*m) {
				temp1_real = (real[i+m] * a_real) - (imag[i+m] * a_imag);
				temp1_imag = (real[i+m] * a_imag) + (imag[i+m] * a_real);
				temp2_real = real[i];
				temp2_imag = imag[i];
				real[i] = temp2_real + temp1_real;
				imag[i] = temp2_imag + temp1_imag;
				real[i+m] = temp2_real - temp1_real;
				imag[i+m] = temp2_imag - temp1_imag;
			}
		}
	}

	for (m = 0; m < signal_size; m++) // Divide by the signal_size to re-normalize the signal
		real[m] = real[m] / signal_size;
	
	sum = compute_signal_sum(real, signal_size);
	
	for (m = 0; m < signal_size; m++)  // Adjust by the weights
		real[m] = real[m] / weight[m];
	
	return sum;
}

// Convulves a transformed signal. IE, square the signal dyadically.
int convulve_signal(double* real, double* imag, unsigned int signal_size)
{
	unsigned int i;
	double temp_real, temp_imag;

	for (i = 0; i < signal_size; i++) {
		temp_real = real[i];
		temp_imag = imag[i];
		real[i] = (temp_real * temp_real) - (temp_imag * temp_imag); // Complex multiplication
		imag[i] = (temp_real * temp_imag) + (temp_real * temp_imag);
	}

	return 0;
}

// Rounds the floating point signal to closest integers.
int round_signal(double* real, double* imag, unsigned int signal_size)
{
	unsigned int i;

	// I'll eventually add a check to see if we're not rounding too far
	for (i = 0; i < signal_size; i++) 
		real[i] = floor(real[i] + 0.5);
	
	return 0;
}

// Reduces the signal into their appropriate bases. base is the base signal. 
int modular_reduce(double* real, double* base, unsigned int signal_size)
{
	unsigned int i;
	double carry = 0;
	double temp;

	for (i = 0; i < signal_size; i++) {
		temp = real[i] + carry;
		real[i] = fmod(temp, base[i]);
		carry = floor(temp/base[i]);
	}
	
	while (carry != 0) {
		for (i = 0; i < signal_size; i++) {
			if (carry == 0)
				break;
			temp = real[i] + carry;
			real[i] = fmod(temp, base[i]);
			carry = floor(temp/base[i]);
		}
	}
	
	return 0;
}

// Subtracts 2 from the signal. This is part of the Lucas-Lehmer test. This code won't work
// if the base is less than 2.
int subtract_2(double* real, double* base, unsigned int signal_size)
{
	unsigned int i;
	double carry = 2; // Subtract 2
	double temp;

	for (i = 0; i < signal_size; i++) {
		if (carry == 0) // No more carry, so we can exit early
			break;
		else if (carry > real[i]) { // If real[i] is less than the carry
			real[i] = base[i] - (carry - real[i]);
			carry = 1;
		} else {
			temp = real[i] - carry; // Subtract the carry
			real[i] = fmod(temp, base[i]); // Modular adjust
			carry = floor(temp/base[i]); // Get the carry
		}
	}

	return 0;
}

// Computes the optimal size of a signal, given an exponent
unsigned int compute_optimal_signal_size(unsigned int exponent)
{
	unsigned int temp = 1;
	// "Prime Numbers" says to find a signal_length such that Floor[2^(exponent/temp)] is
	// an "appropriate" size;
	while (1) { 
		if ((double)exponent/temp < 16) // I have no clue how to decide what's an ideal bitsize
			break; // so I'm just using 16 bits and hope it's adequate
		temp *= 2;
	}

	return temp;
}

// Computes the base signal for an exponent
int compute_base_signal(double* base, unsigned int exponent, unsigned int signal_size)
{
	unsigned int i;
	// "Prime Numbers" says base[i-1] = 2^(Ceil[exponent*i/signal_size] - Ceil[exponent*(i-1)/signal_size]);
	// Running i from 1 to signal_size;
	for (i = 1; i <= signal_size; i++) 
		base[i-1] = pow(2, ceil((double)exponent*i/signal_size)-ceil((double)exponent*(i-1)/signal_size));

	return 0;
}

// Computes the weight signal for an exponent
int compute_weight_signal(double* weight, unsigned int exponent, unsigned int signal_size)
{
	unsigned int i;
	double temp;
	// "Prime Numbers" says weight[i] = 2^(Ceil[exponent*i/signal_size]-(exponent*i/signal_size))
	for (i = 0; i < signal_size; i++) {
		temp = ((double)exponent*i)/signal_size;
		weight[i] = pow(2, ceil(temp) - temp);
	}

	return 0;
}

int check_zero_signal(double* signal, unsigned int signal_size)
{
	int is_zero = 1;
	unsigned int i;

	for (i = 0; i < signal_size; i++) {
		if (signal[i] != 0) {
			is_zero = 0;
			break;
		}
	}

	return is_zero;
}

// Performs one Lucas-Lehmer iteration using the IBDWT algorithm
int do_iteration(double* real, double* imag, double* weight, double* base, unsigned int signal_size)
{
	double sum_in, sum_out, difference;
		
	// I know that you can "embed" the real signal into the complex, but for clarity I'm avoiding that
	sum_in = forward_fast_fourier(real, imag, weight, signal_size);

	// Square the elements
	convulve_signal(real, imag, signal_size);

	// Invert the signal
	sum_out = inverse_fast_fourier(real, imag, weight, signal_size);

	// What's the difference between sum_in^2 and sum_out
	difference = fabs((sum_in * sum_in) - sum_out);
	if (difference > 10)  // Uh oh, an error, they're not matching
		printf("**POSSIBLE ERROR**: SUMIN and SUMOUT differ by %.2f\n", difference);
		
	// Round the signal to integers
	round_signal(real, imag, signal_size);
	memset(imag, 0, signal_size*sizeof(double));

	// Modular reduce them to their variable base
	modular_reduce(real, base, signal_size); 
	// Subtract 2
	subtract_2(real, base, signal_size);

	return 0;
}

int main(int argc, char* argv[])
{

  double* real; // Real part of the signal
  double* imag; // Imaginary part of the signal
  unsigned int signal_size; // Signal length
  unsigned int exponent; // Exponent for the LL test (Change as needed)
  double* base; // Variable base of the signal
  double* weight; // Weight signal
  unsigned int i;

  exponent = atoi(argv[1]);

  // Find out the appropriate size for this exponent
  signal_size = compute_optimal_signal_size(exponent);

  // Allocate memory
  base = malloc(sizeof(double)*signal_size);
  weight = malloc(sizeof(double)*signal_size);
  real = malloc(sizeof(double)*signal_size);
  imag = malloc(sizeof(double)*signal_size);

  // Clear memory and compute the signals
  memset(real, 0, sizeof(double) * signal_size);
  memset(imag, 0, sizeof(double) * signal_size);
  compute_base_signal(base, exponent, signal_size);
  compute_weight_signal(weight, exponent, signal_size);

  real[0] = 4; // LL test starts with 4
  for (i = 0; i < exponent-2; i++) {
    do_iteration(real, imag, weight, base, signal_size);
    if (i % 100 == 0)
      printf("%i\r\n", i);
  }

  if (check_zero_signal(real, signal_size) == 1)
    printf("PRIME!!!\n");
  else
    printf("Not Prime\n");

  // Clear memory
  free(base);
  free(weight);
  free(real);
  free(imag);
  return 0;
}