#include <iostream>
#include <stdio.h>
#include <math.h>
#include <time.h>

#include <cufft.h>

using namespace std;

#include "../include/shorteningData.h"

__global__ void complex_conj(int count, int fft_len, Complex *in, Complex *out);
__global__ void complex_point_mult(int count, int fft_len, Complex *in1, Complex *in2, Complex *out);
__global__ void complex_point_mult_scale(int count, int fft_len, Complex *in1, Complex *in2, Complex *out, double scale);

__global__ void calc_wighted_resp(int Lg, int fft_len, double *gk, double *window, double *out_abs, double *out_sign);
__global__ void calc_pot(int Lg, int fft_len, double *in, double *out, double pot, double *tmp);

__global__ void reduce2(double *g_idata, double *g_odata, double *res);
__global__ void calc_b(int Lg, int fft_len, double Sagd, double *sign_gd, double *wd, double *Bgd, double *b);

__global__ void apply_gradient(int Lh, int fft_len, double mue, double *F1, double *F2, double *h);

void do_shortening(double *c, int c_len, double *wd, double *wu, int wd_len, double *h, int h_len, int Lg, int fft_len, double *ret_re, double *ret_im){



    shorteningData sd(c_len, wd_len, h_len ,fft_len);

    int sf_len = fft_len/2+1;

    //cout << "helper: start" << endl;
    //cout.flush();

    //copy data from matlab

    for (int i=0;i<c_len;i++){
        sd.c[i] = c[i];
    }
    for (int i=c_len;i<fft_len;i++){
        sd.c[i] = 0;
    }

    for (int i=0;i<wd_len;i++){
        sd.wd[i] = wd[i];
    }
    for (int i=wd_len;i<fft_len;i++){
        sd.wd[i] = 0;
    }

    for (int i=0;i<wd_len;i++){
        sd.wu[i] = wu[i];
    }
    for (int i=wd_len;i<fft_len;i++){
        sd.wu[i] = 0;
    }
    for (int i=0;i<h_len;i++){
        sd.h[i] = h[i];
    }
    for (int i=h_len;i<fft_len;i++){
        sd.h[i] = 0;
    }

    cout << "h_len: " << h_len << endl;
    cout << "fft_len: " << fft_len << endl;


     //printf("time %g \n",((double)end-(double)start)/((double)(CLOCKS_PER_SEC)));
    //    cudaThreadSynchronize();


    sd.copyToDevice();
    shorteningData *dev_sd;
    cudaMalloc( (void**)&dev_sd, sizeof(shorteningData));
    cudaMemcpy(dev_sd, &sd, sizeof(shorteningData), cudaMemcpyHostToDevice);

    cufftHandle plan_DZ;
    cufftHandle plan_ZD;
    cufftPlan1d(&plan_DZ, fft_len, CUFFT_D2Z, 1);
    cufftPlan1d(&plan_ZD, fft_len, CUFFT_Z2D, 1);

    //init algorithm

    cufftExecD2Z(plan_DZ, (cufftDoubleReal *)sd.dev_c, (cufftDoubleComplex *)sd.Ck1_f);
    complex_conj<<<128,128>>>(sf_len, fft_len, sd.Ck1_f, sd.Ck_f);

    int pd = 10;
    int pu = 20;
    double mue = 1e-11;

//    cudaThreadSynchronize();
    //main loop
    for (int iter=0;iter<5000;iter++){
        //gk=real(ifft(Ck1.*fft(h,Ln)));
        cufftExecD2Z(plan_DZ, (cufftDoubleReal *)sd.dev_h, (cufftDoubleComplex *)sd.dev_tmp);
        complex_point_mult_scale<<<64,64>>>(sf_len, fft_len, sd.dev_tmp, sd.Ck1_f,sd.dev_tmp2,1.0/fft_len);
        cufftExecZ2D(plan_ZD, (cufftDoubleComplex *)sd.dev_tmp2, (cufftDoubleReal *)sd.dev_gk);

        calc_wighted_resp<<<64,64>>>(Lg, fft_len, sd.dev_gk, sd.dev_wd, sd.dev_abs_gd,sd.dev_sign_gd);
        calc_wighted_resp<<<64,64>>>(Lg, fft_len, sd.dev_gk, sd.dev_wu, sd.dev_abs_gu,sd.dev_sign_gu);

        calc_pot<<<64,64>>>(Lg, fft_len, sd.dev_abs_gd, sd.dev_Bgd, pd-1, sd.dev_t1);
        calc_pot<<<64,64>>>(Lg, fft_len, sd.dev_abs_gu, sd.dev_Bgu, pu-1, sd.dev_t2);


        int smemSize = 64 * sizeof(double);

        reduce2<<< 64, 64, smemSize >>>(sd.dev_t1, sd.dev_t3, sd.dev_aDouble);
        reduce2<<< 64, 64, smemSize >>>(sd.dev_t3, sd.dev_t3, sd.dev_aDouble);
        sd.copyADouble();
        double SAgd = 1./sd.aDouble;

        reduce2<<< 64, 64, smemSize >>>(sd.dev_t2, sd.dev_t3, sd.dev_aDouble);
        reduce2<<< 64, 64, smemSize >>>(sd.dev_t3, sd.dev_t3, sd.dev_aDouble);
        sd.copyADouble();
        double SAgu = 1./sd.aDouble;

        //bu = SAgu*sign(gu).*wu.*Bgu;
        calc_b<<<64,64>>>(Lg, fft_len, SAgd, sd.dev_sign_gd, sd.dev_wd, sd.dev_Bgd, sd.dev_bd);
        calc_b<<<64,64>>>(Lg, fft_len, SAgu, sd.dev_sign_gu, sd.dev_wu, sd.dev_Bgu, sd.dev_bu);


        cufftExecD2Z(plan_DZ, (cufftDoubleReal *)sd.dev_bu, (cufftDoubleComplex *)sd.dev_tmp);
        complex_point_mult_scale<<<64,64>>>(sf_len, fft_len, sd.dev_tmp, sd.Ck_f,sd.dev_tmp2,1.0/fft_len);
        cufftExecZ2D(plan_ZD, (cufftDoubleComplex *)sd.dev_tmp2, (cufftDoubleReal *)sd.dev_F1);

        cufftExecD2Z(plan_DZ, (cufftDoubleReal *)sd.dev_bd, (cufftDoubleComplex *)sd.dev_tmp);
        complex_point_mult_scale<<<64,64>>>(sf_len, fft_len, sd.dev_tmp, sd.Ck_f,sd.dev_tmp2,1.0/fft_len);
        cufftExecZ2D(plan_ZD, (cufftDoubleComplex *)sd.dev_tmp2, (cufftDoubleReal *)sd.dev_F2);

        apply_gradient<<<64,64>>>(sd.h_len, fft_len, mue, sd.dev_F1, sd.dev_F2, sd.dev_h);

    }




    sd.copyToHost();


    for (int i=0;i<sd.fft_len;i++) {
        //ret_re[i] = sd.tmp[i].x;
        ret_re[i] = sd.tmp[i];
        //ret_im[i] = sd.tmp[i].y;
        ret_im[i] = 0;
    }

    cudaFree(dev_sd);
    cufftDestroy(plan_DZ);
    cufftDestroy(plan_ZD);
}

//void do_the_fft_test(double *re1, int len, double *ret_re, double *ret_im){
//    //cout << "Helper reached" <<endl;
//    Complex* h_signal = (Complex*)malloc(sizeof(Complex) * len);
//    for (int ii = 0; ii< len; ii++) {
//        h_signal[ii].x = re1[ii];
//        h_signal[ii].y = 0;
//    }
//    int mem_size = sizeof(Complex) * len;
//    Complex* d_signal;
//    cudaMalloc((void**)&d_signal, mem_size);
//    cudaMemcpy(d_signal, h_signal, mem_size, cudaMemcpyHostToDevice);
//
//
//    cufftHandle plan;
////    cufftPlan1d(&plan, len, CUFFT_C2C, 1);
//    cufftPlan1d(&plan, len, CUFFT_Z2Z, 1);
//
////   cufftExecC2C(plan, (cufftComplex *)d_signal, (cufftComplex *)d_signal, CUFFT_FORWARD);
////   cufftExecC2C(plan, (cufftComplex *)d_signal, (cufftComplex *)d_signal, CUFFT_INVERSE);
//    cufftExecZ2Z(plan, (cufftDoubleComplex *)d_signal, (cufftDoubleComplex *)d_signal, CUFFT_FORWARD);
////    cufftExecZ2Z(plan, (cufftDoubleComplex *)d_signal, (cufftDoubleComplex *)d_signal, CUFFT_INVERSE);
//
//
//    cudaMemcpy(h_signal, d_signal, mem_size, cudaMemcpyDeviceToHost);
//
//    h_signal[0].x = 0.5;
//    h_signal[0].y = 0.75;
//
//    //cout << "fft finished" << endl;
//
//    for (int ii = 0; ii< len; ii++) {
//        ret_re[ii] = h_signal[ii].x;
//        ret_im[ii] = h_signal[ii].y;
//    }
//
//
//}




