
/*
 * Copyright (c) 1997 Massachusetts Institute of Technology
 *
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files (the
 * "Software"), to use, copy, modify, and distribute the Software without
 * restriction, provided the Software, including any modified copies made
 * under this license, is not distributed for a fee, subject to
 * the following conditions:
 * 
 * The above copyright notice and this permission notice shall be
 * included in all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE MASSACHUSETTS INSTITUTE OF TECHNOLOGY BE LIABLE
 * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
 * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 * 
 * Except as contained in this notice, the name of the Massachusetts
 * Institute of Technology shall not be used in advertising or otherwise
 * to promote the sale, use or other dealings in this Software without
 * prior written authorization from the Massachusetts Institute of
 * Technology.
 *  
 */

#include <stdlib.h>

#include <fftw_threads.h>

/* Prototypes for functions used internally in this file: */

static void fftw2d_out_of_place_aux_threads(int nthreads, 
					    fftwnd_plan p, int howmany,
				FFTW_COMPLEX *in, int istride, int idist,
			      FFTW_COMPLEX *out, int ostride, int odist);
static void fftw3d_out_of_place_aux_threads(int nthreads, 
					    fftwnd_plan p, int howmany,
				FFTW_COMPLEX *in, int istride, int idist,
			      FFTW_COMPLEX *out, int ostride, int odist);
static void fftwnd_out_of_place_aux_threads(int nthreads, 
					    fftwnd_plan p, int howmany,
				FFTW_COMPLEX *in, int istride, int idist,
			      FFTW_COMPLEX *out, int ostride, int odist);

static void fftw2d_in_place_aux_threads(int nthreads, 
					fftwnd_plan p, int howmany,
			   FFTW_COMPLEX *in_out, int istride, int idist);
static void fftw3d_in_place_aux_threads(int nthreads, 
					fftwnd_plan p, int howmany,
			   FFTW_COMPLEX *in_out, int istride, int idist);
static void fftwnd_in_place_aux_threads(int nthreads, 
					fftwnd_plan p, int howmany,
			   FFTW_COMPLEX *in_out, int istride, int idist);

/************** Computing the N-Dimensional FFT **************/

void fftwnd_threads(int nthreads,
		    fftwnd_plan plan, int howmany,
		    FFTW_COMPLEX *in, int istride, int idist,
		    FFTW_COMPLEX *out, int ostride, int odist)
{
     if (plan->is_in_place)	/* fft is in-place */
	  switch (plan->rank) {
	      case 0:
		   break;
	      case 1:
		   fftw_threads(nthreads,
				plan->plans[0], howmany, in, istride, idist,
				0, 0, 0);
		   break;
	      case 2:
		   fftw2d_in_place_aux_threads(nthreads,
					       plan, howmany, in, 
					       istride, idist);
		   break;
	      case 3:
		   fftw3d_in_place_aux_threads(nthreads, plan, howmany, in, 
					       istride, idist);
		   break;
	      default:
		   fftwnd_in_place_aux_threads(nthreads, plan, howmany, in, 
					       istride, idist);
     } else {
	  if (in == out || out == 0)
	       fftw_die("Illegal attempt to perform in-place FFT!\n");
	  switch (plan->rank) {
	      case 0:
		   break;
	      case 1:
		   fftw_threads(nthreads,
				plan->plans[0], howmany, in, istride, idist,
				out, ostride, odist);
		   break;
	      case 2:
		   fftw2d_out_of_place_aux_threads(nthreads, 
						   plan, howmany, in, istride,
						   idist, out, ostride, odist);
		   break;
	      case 3:
		   fftw3d_out_of_place_aux_threads(nthreads,
						   plan, howmany, in, istride,
						   idist, out, ostride, odist);
		   break;
	      default:
		   fftwnd_out_of_place_aux_threads(nthreads, 
						   plan, howmany, in, istride,
						   idist, out, ostride, odist);
	  }
     }
}

typedef struct {
     int min, max;
     int distance;
     fftw_plan plan;
     int howmany;
     FFTW_COMPLEX *in;
     int istride, idist;
     FFTW_COMPLEX *tmp;
} fftw_many_data;

static void *fftw_many_thread(fftw_many_data *d)
{
     int min, max;
     int distance;
     fftw_plan plan;
     int howmany;
     FFTW_COMPLEX *in;
     int istride, idist;
     FFTW_COMPLEX *tmp;

     min = d->min; max = d->max;
     distance = d->distance;
     plan = d->plan;
     howmany = d->howmany;
     in = d->in;
     istride = d->istride;
     idist = d->idist;
     tmp = d->tmp;

     for (; min < max; ++min)
	  fftw(plan,howmany,in + min*distance,istride,idist,tmp,0,0);

     return 0;
}

static void fftw_many_threads(int nthreads, int ntimes, int distance,
			      fftw_plan plan, 
			      int howmany, 
			      FFTW_COMPLEX *in, int is, int id)
{
     int k;
     FFTW_COMPLEX *tmp;

     if (nthreads <= 1) {
	  tmp = (FFTW_COMPLEX *) fftw_malloc(plan->n * sizeof(FFTW_COMPLEX));
	  for (k = 0; k < ntimes; ++k)
	       fftw(plan,howmany,in+k*distance,is,id,tmp,0,0);
     }
     else {
	  fftw_many_data *d;
	  int n = plan->n;

	  if (nthreads > ntimes)
	       nthreads = ntimes;

	  tmp = (FFTW_COMPLEX *) fftw_malloc(nthreads * n 
					     * sizeof(FFTW_COMPLEX));

	  d = alloca(nthreads * sizeof(fftw_many_data));
	  d->distance = distance;
	  d->plan = plan;
	  d->howmany = howmany;
	  d->in = in;
	  d->istride = is;
	  d->idist = id;
	  d->tmp = tmp;
	  for (k = 1; k < nthreads; ++k) {
	       d[k] = *d;
	       d[k].tmp = tmp + k*n;
	  }

	  fftw_thread_spawn_loop(ntimes,nthreads,fftw_many_thread,d);
     }
     fftw_free(tmp);
}

static void fftw2d_out_of_place_aux_threads(int nthreads, 
					    fftwnd_plan p, int howmany,
				FFTW_COMPLEX *in, int istride, int idist,
			       FFTW_COMPLEX *out, int ostride, int odist)
{
     int fft_iter;
     fftw_plan p0, p1;
     int n0, n1;

     p0 = p->plans[0];
     p1 = p->plans[1];
     n0 = p->n[0];
     n1 = p->n[1];

     for (fft_iter = 0; fft_iter < howmany; ++fft_iter) {
	  /* FFT y dimension (out-of-place): */
	  fftw_threads(nthreads, p1, n0,
	       in + fft_iter * idist, istride, n1 * istride,
	       out + fft_iter * odist, ostride, n1 * ostride);
	  /* FFT x dimension (in-place): */
	  fftw_threads(nthreads, p0, n1,
	       out + fft_iter * odist, n1 * ostride, ostride,
	       0, 0, 0);
     }
}

static void fftw3d_out_of_place_aux_threads(int nthreads,
					    fftwnd_plan p, int howmany,
				FFTW_COMPLEX *in, int istride, int idist,
			       FFTW_COMPLEX *out, int ostride, int odist)
{
     int fft_iter;
     int i;
     fftw_plan p0, p1, p2;
     int n0, n1, n2;

     p0 = p->plans[0];
     p1 = p->plans[1];
     p2 = p->plans[2];
     n0 = p->n[0];
     n1 = p->n[1];
     n2 = p->n[2];

     for (fft_iter = 0; fft_iter < howmany; ++fft_iter) {
	  /* FFT z dimension (out-of-place): */
	  fftw_threads(nthreads, p2, n0 * n1,
	       in + fft_iter * idist, istride, n2 * istride,
	       out + fft_iter * odist, ostride, n2 * ostride);
	  /* FFT y dimension (in-place): */
	  fftw_many_threads(nthreads,n0,n1*n2*ostride,
			    p1,n2,out + fft_iter * odist, n2*ostride, ostride);
	  /* FFT x dimension (in-place): */
	  fftw_threads(nthreads, p0, n1 * n2,
	       out + fft_iter * odist, n1 * n2 * ostride, ostride,
	       0, 0, 0);
     }
}

static void fftwnd_out_of_place_aux_threads(int nthreads,
					    fftwnd_plan p, int howmany,
				FFTW_COMPLEX *in, int istride, int idist,
			       FFTW_COMPLEX *out, int ostride, int odist)
{
     int fft_iter;
     int j, i;

     /* Do FFT for rank > 3: */

     for (fft_iter = 0; fft_iter < howmany; ++fft_iter) {
	  /* do last dimension (out-of-place): */
	  fftw_threads(nthreads, 
		       p->plans[p->rank - 1], p->n_before[p->rank - 1],
	     in + fft_iter * idist, istride, p->n[p->rank - 1] * istride,
	   out + fft_iter * odist, ostride, p->n[p->rank - 1] * ostride);

	  /* do first dimension (in-place): */
	  fftw_threads(nthreads, p->plans[0], p->n_after[0],
	       out + fft_iter * odist, p->n_after[0] * ostride, ostride,
	       0, 0, 0);

	  /* do other dimensions (in-place): */
	  for (j = 1; j < p->rank - 1; ++j)
	       fftw_many_threads(nthreads,
				 p->n_before[j],ostride*p->n[j]*p->n_after[j],
				 p->plans[j], p->n_after[j],
				 out + fft_iter * odist,
				 p->n_after[j] * ostride, ostride);
     }
}

static void fftw2d_in_place_aux_threads(int nthreads, 
					fftwnd_plan p, int howmany,
			    FFTW_COMPLEX *in_out, int istride, int idist)
{
     int fft_iter;
     fftw_plan p0, p1;
     int n0, n1;

     p0 = p->plans[0];
     p1 = p->plans[1];
     n0 = p->n[0];
     n1 = p->n[1];

     for (fft_iter = 0; fft_iter < howmany; ++fft_iter) {
	  /* FFT y dimension: */
	  fftw_threads(nthreads, p1, n0,
	       in_out + fft_iter * idist, istride, istride * n1,
	       0, 0, 0);
	  /* FFT x dimension: */
	  fftw_threads(nthreads, p0, n1,
	       in_out + fft_iter * idist, istride * n1, istride,
	       0, 0, 0);
     }
}

static void fftw3d_in_place_aux_threads(int nthreads, 
					fftwnd_plan p, int howmany,
			    FFTW_COMPLEX *in_out, int istride, int idist)
{
     int i;
     int fft_iter;
     fftw_plan p0, p1, p2;
     int n0, n1, n2;

     p0 = p->plans[0];
     p1 = p->plans[1];
     p2 = p->plans[2];
     n0 = p->n[0];
     n1 = p->n[1];
     n2 = p->n[2];

     for (fft_iter = 0; fft_iter < howmany; ++fft_iter) {
	  /* FFT z dimension: */
	  fftw_threads(nthreads, p2, n0 * n1,
	       in_out + fft_iter * idist, istride, n2 * istride,
	       0, 0, 0);
	  /* FFT y dimension: */
	  fftw_many_threads(nthreads,n0,n1*n2*istride,
			    p1,n2,
			    in_out + fft_iter * idist, n2*istride, istride);
	  /* FFT x dimension: */
	  fftw_threads(nthreads, p0, n1 * n2,
	       in_out + fft_iter * idist, n1 * n2 * istride, istride,
	       0, 0, 0);
     }
}

static void fftwnd_in_place_aux_threads(int nthreads, fftwnd_plan p,
					int howmany,
			    FFTW_COMPLEX *in_out, int istride, int idist)
/* Do FFT for rank > 3: */
{
     int fft_iter;
     int j, i;

     for (fft_iter = 0; fft_iter < howmany; ++fft_iter) {
	  /* do last dimension: */
	  fftw_threads(nthreads, 
		       p->plans[p->rank - 1], p->n_before[p->rank - 1],
	  in_out + fft_iter * idist, istride, p->n[p->rank - 1] * istride,
	       0, 0, 0);

	  /* do first dimension: */
	  fftw_threads(nthreads, p->plans[0], p->n_after[0],
	     in_out + fft_iter * idist, p->n_after[0] * istride, istride,
	       0, 0, 0);

	  /* do other dimensions: */
	  for (j = 1; j < p->rank - 1; ++j)
	       fftw_many_threads(nthreads,
				 p->n_before[j],istride*p->n[j]*p->n_after[j],
				 p->plans[j], p->n_after[j],
				 in_out + fft_iter * idist,
				 p->n_after[j] * istride, istride);
     }
}
