/* 
 * unsharp.c  -- This code is derived from the GIMP unsharp mask plug-in 
 * by Winston Chang
, September 2001
 *
 * Copyright (C) 2001 by Winston Chang
 <winston@stdout.org>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
 */

#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <math.h>
#include <malloc.h>

#include "unsharp.h"

//#define ROUND(x) ((int) ((x) + 0.5))
#define ROUND(x) (floor((x) + 0.5))
#define new(ctype, ccount)	calloc(ccount, sizeof(ctype))

/* to show both pretty unoptimized code and ugly optimized code blocks
   There's really no reason to define this, unless you want to see how
   much pointer aritmetic can speed things up.  I find that it is about
   45% faster with the optimized code. */
/*  #define READABLE_CODE 
  */


void get_col(unsigned char *d_region, unsigned char *d_col, int c_col, int wdth, int hgt, int channels);
void set_col(unsigned char *d_region, unsigned char *d_col, int c_col, int wdth, int hgt, int channels);


/* -------------------------- Unsharp Mask ------------------------- */


/* perform an unsharp mask on the region, given a source region, dest.
   region, width and height of the regions.
  */
/* src_region = image data; size=x*y*bytes */
/* channels = number of 8-bit channels- usually 3 RGB */

void
 unsharp_mask(unsigned char   *src_region, unsigned char  *dest_region,
		int   width,
  int   height,
  int  channels,    
		UnsharpMaskParams unsharp_params)
{
    unsigned char   *cur_col;
    unsigned char   *dest_col;
    unsigned char   *cur_row;          /* current row before blurring */
    unsigned char   *dest_row;         /* current row after blurring */
    double          *cmatrix = NULL;      /* the convolution matrix */
    unsigned int     cmatrix_length;
    double *ctable;              /* lookup table for the convolution matrix */
    int row, col;  /* these are counters for loops */
  /* these are used for the merging step */
    unsigned int threshold;
    int diff, value;
    unsigned long int total_bytes, i;
    unsigned int widch;
    
  /* generate convolution matrix and make sure it's smaller than each dimension */
    cmatrix_length = gen_convolve_matrix(unsharp_params.radius, &cmatrix);
  /* generate lookup table */
    ctable = gen_lookup_table(cmatrix, cmatrix_length);

    widch = width * channels;

    /*  allocate row buffers  */
    cur_row  = new(unsigned char, widch);
    dest_row = new(unsigned char, widch);

    /* allocate column buffers */
    cur_col  = new(unsigned char, height * channels);
    dest_col = new(unsigned char, height * channels);

    total_bytes = width * height * channels;
    memset(dest_region, 0, total_bytes); 

    /* blur the rows */
    for (row = 0; row < height; row++)
    {
	memcpy(cur_row, (src_region + widch*row), widch);
	memcpy(dest_row, (dest_region + widch*row), widch);

	blur_line(ctable, cmatrix, cmatrix_length, cur_row, dest_row, width, channels);

	memcpy((dest_region + widch*row), dest_row, widch);
    }

    /* blur the cols */
    for (col = 0; col < width; col++)
    {
	get_col(dest_region, cur_col, col, width, height, channels);
	get_col(dest_region, dest_col, col, width, height, channels);

	blur_line(ctable, cmatrix, cmatrix_length, cur_col, dest_col, height, channels);

	set_col(dest_region, dest_col, col, width, height, channels);
    }


    /*-------------------- Merging stage ---------------------*/

    /* find integer value of threshold */
    threshold = unsharp_params.threshold;

    /* merge the source and destination (currently containing
 the blurred version) images */
    for (i = 0; i< total_bytes; i++)
    {

	diff = src_region[i] - dest_region[i];
	/* do thresholding  */
	if (abs (2 * diff) < threshold)
    	    diff = 0;
      
	/* The unsharped value */
	value = src_region[i] + (unsharp_params.amount * diff);

	/* Check for over/underflow */
	if      (value < 0)   dest_region[i] = 0;
	else if (value > 255) dest_region[i] = 255;
	else                  dest_region[i] = value;
    }
    /* free the memory we took */
    free(cur_row);
      free(dest_row);
    free(cur_col);
      free(dest_col);
    free(cmatrix);
      free(ctable);

}
  /*  unsharp mask  */


/* this function is written as if it is blurring a column at a time,
   even though it can operate on rows, too.  There is no difference
   in the processing of the lines, at least to the blur_line function. */
/* static?? */
inline void
blur_line (double *ctable,
	   double *cmatrix,
	   int    cmatrix_length,
	   unsigned char   *cur_col,
	   unsigned char   *dest_col,
	   int    y,
	   long   bytes)
{

#ifdef READABLE_CODE
/* ------------- semi-readable code ------------------- */
  double scale;
  double sum;
  int i,j;
  int row;

  /* this is to take care cases in which the matrix can go over
   * both edges at once.  It's not efficient, but this can only
   * happen in small pictures anyway.
   */
  if (cmatrix_length > y)
    {
      for (row = 0; row < y ; row++)
	{
	  scale = 0;
	  /* find the scale factor */
	  for (j = 0; j < y ; j++)
	    {
	      /* if the index is in bounds, add it to the scale counter */
	      if ((j + cmatrix_length/2 - row >= 0) &&
		  (j + cmatrix_length/2 - row < cmatrix_length))
		scale += cmatrix[j + cmatrix_length/2 - row];
	    }
	  for (i = 0; i<bytes; i++)
	    {
	      sum = 0;
	      for (j = 0; j < y; j++)
		{
		  if ((j >= row - cmatrix_length/2) &&
		      (j <= row + cmatrix_length/2))
		    sum += cur_col[j*bytes + i] * cmatrix[j];
		}
	      dest_col[row*bytes + i] = (unsigned char) ROUND(sum / scale);
	    }
	}
    }
  else
    {  /* when the cmatrix is smaller than row length */
      /* for the edge condition, we only use available info, and scale to one */
      for (row = 0; row < cmatrix_length/2; row++)
	{
	  /* find scale factor */
	  scale=0;
	  for (j = cmatrix_length/2 - row; j<cmatrix_length; j++)
	    scale += cmatrix[j];

	  for (i = 0; i<bytes; i++)
	    {
	      sum = 0;
	      for (j = cmatrix_length/2 - row; j<cmatrix_length; j++)
		{
		  sum += cur_col[(row + j-cmatrix_length/2)*bytes + i] *
		    cmatrix[j];
		}
	      dest_col[row*bytes + i] = (unsigned char) ROUND(sum / scale);
	    }
	}
      /* go through each pixel in each col */
      for (; row < y-cmatrix_length/2; row++)
	{
	  for (i = 0; i<bytes; i++)
	    {
	      sum = 0;
	      for (j = 0; j<cmatrix_length; j++)
		{
		  sum += cur_col[(row + j-cmatrix_length/2)*bytes + i] * cmatrix[j];
		}
	      dest_col[row*bytes + i] = (unsigned char) ROUND(sum);
	    }
	}
      /* for the edge condition , we only use available info, and scale to one */
      for (; row < y; row++)
	{
	  /* find scale factor */
	  scale=0;
	  for (j = 0; j< y-row + cmatrix_length/2; j++)
	    scale += cmatrix[j];

	  for (i = 0; i<bytes; i++)
	    {
	      sum = 0;
	      for (j = 0; j<y-row + cmatrix_length/2; j++)
		{
		  sum += cur_col[(row + j-cmatrix_length/2)*bytes + i] * cmatrix[j];
		}
	      dest_col[row*bytes + i] = (unsigned char) ROUND(sum / scale);
	    }
	}
    }
#endif

#ifndef READABLE_CODE
  /* --------------- optimized, unreadable code -------------------*/
  double scale;
  double sum;
  int i=0, j=0;
  int row;
  int cmatrix_middle = cmatrix_length/2;

  double *cmatrix_p;
  unsigned char  *cur_col_p;
  unsigned char  *cur_col_p1;
  unsigned char  *dest_col_p;
  double *ctable_p;

  /* this first block is the same as the non-optimized version --
   * it is only used for very small pictures, so speed isn't a
   * big concern.
   */
  if (cmatrix_length > y)
    {
      for (row = 0; row < y ; row++)
	{
	  scale=0;
	  /* find the scale factor */
	  for (j = 0; j < y ; j++)
	    {
	      /* if the index is in bounds, add it to the scale counter */
	      if ((j + cmatrix_length/2 - row >= 0) &&
		  (j + cmatrix_length/2 - row < cmatrix_length))
		scale += cmatrix[j + cmatrix_length/2 - row];
	    }
	  for (i = 0; i<bytes; i++)
	    {
	      sum = 0;
	      for (j = 0; j < y; j++)
		{
		  if ((j >= row - cmatrix_length/2) &&
		      (j <= row + cmatrix_length/2))
		    sum += cur_col[j*bytes + i] * cmatrix[j];
		}
	      dest_col[row*bytes + i] = (unsigned char) ROUND(sum / scale);

	    }
	}
    }
  else
    {
      /* for the edge condition, we only use available info and scale to one */
      for (row = 0; row < cmatrix_middle; row++)
	{
	  /* find scale factor */
	  scale=0;
	  for (j = cmatrix_middle - row; j<cmatrix_length; j++)
	    scale += cmatrix[j];
	  for (i = 0; i<bytes; i++)
	    {
	      sum = 0;
	      for (j = cmatrix_middle - row; j<cmatrix_length; j++)
		{
		  sum += cur_col[(row + j-cmatrix_middle)*bytes + i] * cmatrix[j];
		}
	      dest_col[row*bytes + i] = (unsigned char) ROUND(sum / scale);
	    }
	}
      /* go through each pixel in each col */
      dest_col_p = dest_col + row*bytes;
      for (; row < y-cmatrix_middle; row++)
	{
	  cur_col_p = (row - cmatrix_middle) * bytes + cur_col;
	  for (i = 0; i<bytes; i++)
	    {
	      sum = 0;
	      cmatrix_p = cmatrix;
	      cur_col_p1 = cur_col_p;
	      ctable_p = ctable;
	      for (j = cmatrix_length; j>0; j--)
		{
		  sum += *(ctable_p + *cur_col_p1);
		  cur_col_p1 += bytes;
		  ctable_p += 256;
		}
	      cur_col_p++;
	      *(dest_col_p++) = (unsigned char) ROUND(sum);
	    }
	}
	
      /* for the edge condition , we only use available info, and scale to one */
      for (; row < y; row++)
	{
	  /* find scale factor */
	  scale=0;
	  for (j = 0; j< y-row + cmatrix_middle; j++)
	    scale += cmatrix[j];
	  for (i = 0; i<bytes; i++)
	    {
	      sum = 0;
	      for (j = 0; j<y-row + cmatrix_middle; j++)
		{
		  sum += cur_col[(row + j-cmatrix_middle)*bytes + i] * cmatrix[j];
		}
	      dest_col[row*bytes + i] = (unsigned char) ROUND(sum / scale);
	    }
	}
    }
#endif

}  /*  blur line  */

/* generates a 1-D convolution matrix to be used for each pass of 
 * a two-pass gaussian blur.  Returns the length of the matrix.
*/

/* static??  */
int
 gen_convolve_matrix(double radius,
 double **cmatrix_p)
{
  int matrix_length;
  int matrix_midpoint;
  double *cmatrix;
  int i,j;
  double std_dev;
  double sum, base_x;
	
  /* we want to generate a matrix that goes out a certain radius
   * from the center, so we have to go out ceil(rad-0.5) pixels,
   * inlcuding the center pixel.  Of course, that's only in one direction,
   * so we have to go the same amount in the other direction, but not count
   * the center pixel again.  So we double the previous result and subtract
   * one.
   * The radius parameter that is passed to this function is used as
   * the standard deviation, and the radius of effect is the
   * standard deviation * 2.  It's a little confusing.
   */

  radius = fabs(radius) + 1.0;
  std_dev = radius;
  radius = std_dev * 2;

  /* go out 'radius' in each direction */
  matrix_length = 2 * ceil(radius - 0.5) + 1;
  if ( matrix_length <= 0 ) 
     matrix_length = 1;
    
  matrix_midpoint = matrix_length/2 + 1;
  *cmatrix_p = new(double, matrix_length);
  cmatrix = *cmatrix_p;

  /*  Now we fill the matrix by doing a numeric integration approximation
   * from -2*std_dev to 2*std_dev, sampling 50 points per pixel.
   * We do the bottom half, mirror it to the top half, then compute the
   * center point.  Otherwise asymmetric quantization errors will occur.
   *  The formula to integrate is e^-(x^2/2s^2).
   */

  /* first we do the top (right) half of matrix */

  for (i = matrix_length/2 + 1; i < matrix_length; i++)
  {
      base_x = i - floor(matrix_length/2) - 0.5;
      sum = 0;
      for (j=1; j<=50; j++)
      {
	  if( (base_x + (0.02*j)) <= radius ) 
             sum += exp(-(base_x + (0.02*j))*(base_x + (0.02*j))/(2*
std_dev*std_dev));
 
      }
      cmatrix[i] = sum/50;
  }

  /* mirror the thing to the bottom half */
  for(i=0; i<=matrix_length/2; i++) 
  {
    cmatrix[i] = cmatrix[matrix_length - 1 - i];
  }
	
  /* find center val -- calculate an odd number of quanta to make it symmetric,
   * even if the center point is weighted slightly higher than others. */
  sum = 0;
  for(j=0; j<=50; j++)
  {
     sum += exp( -(0.5 + (0.02*j))*(0.5 + (0.02*j))/(2*std_dev*std_dev));
  
  }
 
  cmatrix[matrix_length/2] = sum/51;
	
  /* normalize the distribution by scaling the total sum to one */
  sum=0;
  for(i=0; i<matrix_length; i++) 
       sum += cmatrix[i];
  for(i=0; i<matrix_length; i++) 
       cmatrix[i] = cmatrix[i] / sum;

  return matrix_length;

} /* gen_convolve_matrix */


/* ----------------------- gen_lookup_table ----------------------- */
/* generates a lookup table for every possible product of 0-255 and
   each value in the convolution matrix.  The returned array is
   indexed first by matrix position, then by input multiplicand (?)
   value.
*/
/*  static */
double *
gen_lookup_table (double *cmatrix,
		  int     cmatrix_length)
{
  int i, j;
  double* lookup_table = new(double, cmatrix_length * 256);

#ifdef READABLE_CODE
  for (i=0; i<cmatrix_length; i++)
    {
      for (j=0; j<256; j++)
	{
	  lookup_table[i*256 + j] = cmatrix[i] * (double)j;
	}
    }
#endif

#ifndef READABLE_CODE
  double* lookup_table_p = lookup_table;
  double* cmatrix_p      = cmatrix;

  for (i=0; i<cmatrix_length; i++)
    {
      for (j=0; j<256; j++)
	{
	  *(lookup_table_p++) = *cmatrix_p * (double)j;
	}
      cmatrix_p++;
    }
#endif

  return lookup_table;
}

/********************************/

void get_col(unsigned char *d_region, unsigned char *d_col, int c_col, int wdth, int hgt, int channels)
{
  int y;
  unsigned char* src_p  = d_region + c_col*channels;
  unsigned char* dest_p = d_col;
  int   skipbytes = wdth * channels;
  int   i;

  for (y=0; y<hgt; y++)
  {
        for (i=0; i<channels; i++)
        {
            *(dest_p++) = *(src_p++);
        }
        src_p += skipbytes;
  }
}

void set_col(unsigned char *d_region, unsigned char *d_col, int c_col, int wdth, int hgt, int channels)
{
  int y;
  unsigned char* dest_p = d_region + c_col*channels;
  unsigned char* src_p  = d_col;
  int   skipbytes = wdth * channels;
  int   i;

  for (y=0; y<hgt; y++)
  {
        for (i=0; i<channels; i++)
        {
            *(dest_p++) = *(src_p++);
        }
        dest_p += skipbytes;
  }

}
