"""
Name: func.py
Author: Yuxiang LI (li.yuxiang.nj@gmail.com)
Data: 25/03/2016

Description: utility functions
"""

import numpy as np
from PIL import Image
import scipy.ndimage.filters as fi


def read_image(path):
    """Read a grayscale image and return a 2D numpy array."""

    img = Image.open(path) # open color image
    img = img.convert('L') # convert image to grayscale from 0 to 255, integer
    return np.array(img)


def save_image(array, path):
    """Turn a numpy array to RGB image"""
	
    array[array < 0] = 0
    array[array > 255] = 255
    image = Image.fromarray(array)
    image.convert('RGB').save(path) 
    return 0


def psnr(noise, clean, dynamic=255.0):
    """
    Calculate the PSNR value between noise and clean image.

    :param noise: numpy array (no < 0 or > dynamic values are allowed)
    :param clean: numpy array without noise
    :param dynamic: the scale (max value) of the image
    :return: PSNR value
    """
    peak = dynamic * dynamic
    mse = np.square(noise.astype(float) - clean.astype(float)).mean()
    return 10 * np.log10(peak / mse)


def gaussian_kernel(size, sigma=1):
    """
    Create a 2-dimensional Gaussian kernel

    :param size: size of the kernel
    :param sigma: standard variance of the kernel
    :return: Gaussian kernel array
    """
    inp = np.zeros((size, size))
    inp[size / 2, size / 2] = 1  # set element at the middle to one, a dirac delta
    return fi.gaussian_filter(inp, sigma)  # gaussian-smooth the dirac, resulting in a gaussian filter mask


def int_sqrt(n):
    """
    Get the largest integer whose square does not exceed n using Newton method

    :param n: integer
    :return: integer square root
    """
    x = n
    y = (x + 1) / 2
    while y < x:
    	x = y
    	y = (x + n / x) / 2
    return x


def denoise(image, model, step=3, sigma=2):
    """
    Framing window technique for image denoising

    :param image: noise image, a numpy array
    :param model: a list of weights for denoising
    :param step: stride between consecutive patches
    :param sigma: shape of the Gaussian window
    :return: noise-free image
    """

    patch_in = model['patch_in']
    patch_out = model['patch_out']

    row, col = image.shape
    patch_in_radius = patch_in / 2  # a patch size must be odd
    patch_out_radius = patch_out / 2

    kernel = gaussian_kernel(size=patch_out, sigma=sigma)
    output = np.zeros(image.shape)
    kernel_sum = np.zeros(image.shape)

    # denoise patches one by one
    for i in range(patch_in_radius, row - patch_in_radius + step - 1, step):
        x = min(i, row - patch_in_radius - 1)
        for j in range(patch_in_radius, col - patch_in_radius + step - 1, step):
            y = min(j, col - patch_in_radius - 1)
            box = (image[x - patch_in_radius: x + patch_in_radius + 1, y - patch_in_radius: y + patch_in_radius + 1]
                   .reshape(1, patch_in * patch_in) / 255.0 - 0.5) / 0.2

            treated = model['predict'](box)  # denoise a small patch (te be customized in different methods)
        
            output[x - patch_out_radius: x + patch_out_radius + 1, y - patch_out_radius: y + patch_out_radius + 1] \
                += np.multiply(kernel, (treated.reshape((patch_out, patch_out)) * 0.2 + 0.5) * 255.0)
            kernel_sum[x - patch_out_radius: x + patch_out_radius + 1, y - patch_out_radius: y + patch_out_radius + 1] += kernel

    output = output / kernel_sum  # smooth over the overlapped region

    output[output < 0] = 0
    output[output > 255] = 255  # remove outliers

    return output


