"""Contains all the functions for calculations in the simulation."""

from __future__ import division
from pylab import *
import matplotlib.pyplot as plt
import mpmath as mp


##############################################################
#####  FIXED POINTS FUNCTIONS
##############################################################

def gen_fp_timedep(data):
    """Calculates the time-dependent version of the fixed points."""
    zeta = data['zeta']
    t = data['t']
    n_end = data['n_end']
    N = data['N']
    Rpnlist = range(n_end)
    Rmnlist = range(n_end)
    for n in range(0, n_end):
        Rp = (-zeta + mp.sqrt(zeta**2+48*t*n/N))/(24*t)
#        Rp = (-zeta + mp.sqrt(zeta**2+48*t*n/N))/(24*t*(n/N))
        Rpnlist[n] = Rp
        Rm = (-zeta - mp.sqrt(zeta**2+48*t*(n/N)))/(24*t)
#        Rm = (-zeta - mp.sqrt(zeta**2+48*t*(n/N)))/(24*t*(n/N))
        Rmnlist[n] = Rm
    data['Rpnlist'] = Rpnlist
    data['Rmnlist'] = Rmnlist
    return data

##############################################################
#####  INITIAL CONDITIONS FUNCTIONS
##############################################################

def gen_ic_standard(data):
    """Sets initial conditions for the combinatorial problem. """
    data = x1(data)
    S0 = 0
    R0 = data['x1']
    data['S0'] = S0
    data['R0'] = R0
    return data

def x1(data):
    """Calculates 2nd initial condition for combinatorial problem."""
    N = data['N']
    c2 = moment(2, N, data)
    c0 = moment(0, N, data)
    x1 = c2/c0
    data['x1'] = x1
    return data

def moment(j, N, data):
    t = data['t']
    zeta = data['zeta']
    
    # the integrand: x^j times the exponential measure
    f = lambda x: mp.exp(-N*(zeta/2*x**2 + t*x**4))*x**j
    r = mp.quad(f, [-mp.inf, mp.inf])
    return r

def ic_declare(data):
    """Sets as initial conditions that specified in main file."""
    return data

##############################################################
#####  DYNAMICS FUNCTIONS
##############################################################

def gen_standard_Rs(data):
    """Generates list of x_n's recursively from initial conditions."""
    t = data['t']
    zeta = data['zeta']
    N = data['N']
    n_end = data['n_end']
    x0 = 0
    x1 = data['R0']
    num = [0, 1]
    Rlist = [x0, x1]

    xp = x1
    xpp = x0    
    for n in range(2, n_end+1):
        xn = (n-1) / (4*N*t*xp) - zeta/(4*t) - xp - xpp
        # The N should probably be N-1 to really match up,
        # but then the case N=1 case will throw an error
        # instead of being our case to check accuracy against.
        num.append(n)
        Rlist.append(xn)
        xpp = xp
        xp = xn

    data['Rlist'] = Rlist
    data['ns'] = num
    return data

def gen_standard(data):
    """Sets lists of x- and y-coordinates of orbit points."""
    data = gen_standard_Rs(data)
    Rlist = data['Rlist'][1:]
    Slist = data['Rlist'][:-1]   
    ns = data['ns'][1:] 
    data['Slist'] = Slist
    data['Rlist'] = Rlist
    data['ns'] = ns
    return data

##############################################################
#####  VERIFYING THE NUMERICS
##############################################################

def gen_fnofxy(zeta, t, n, N, x, y):
    """Evaluates Hamiltonian at point in time and space."""
    return x*y*(zeta + 4*t*x + 4*t*y) - (n/N)*x - (n/N)*y

def gen_energy_diff(data):
    """Creates Diagnostic plot from integrability #2."""
    dec = data['dec']
    zeta = data['zeta']
    t = data['t']
    N = data['N']
    ns = data['ns']
    Rlist = data['Rlist']

    diffn = []
    for n in ns[1:-2]:
        cn = gen_fnofxy(zeta, t, n+1, N, Rlist[n+1], Rlist[n])
        dn = gen_fnofxy(zeta, t, n+1, N, Rlist[n], Rlist[n-1])
        diffn.append(abs(cn - dn))

    figure()
    plot(ns[1:-2], diffn, 'b-')
    yscale('log')
    xlabel('$n$', fontsize=16)
    ylabel('$|f_n(x_{n+1}, x_n) - f_n(x_n, x_{n-1})|$',
        fontsize=16)
    ex1 = 'Plot of absolute error between energy evaluated at'
    ex2 = 'adjacent recursively-generated orbit points.'
    title(ex1 + '\n' + ex2 + '\n' + 'dps = ' + str(dec) + ', $\zeta = $' + str(zeta) + ', $t = $' + str(t) + ', $N = $' + str(N), 
        fontsize='large')
    draw()
    tight_layout()
    return data
            
def gen_energy_eval(data):
    """Creates list of outputs from Hamiltonian."""
    t = data['t']
    zeta = data['zeta']
    N = int(data['N'])
    Rlist = data['Rlist']
    Slist = data['Slist']
    nlist = data['ns']

    energies = []
    for n in nlist[:-1]:
        x = Rlist[n]
        y = Slist[n]
        e = gen_fnofxy(zeta, t, n, N, x, y)
        energies.append(e)
    data['energies'] = energies
    return data

def gen_energy_error(data):
    """Creates Diagnostic plot from integrability #1."""
    zeta = data['zeta']
    t = data['t']
    N = data['N']
    dec = data['dec']
    e = data['energies']
    ns = data['ns']
    Rlist = data['Rlist']

    e_diff = [abs(e[n] - e[n-1]) for n in ns[1:-2]]
    figure()
    semilogy(range(len(e_diff)), e_diff, 'b-', 
    label='$|f_n(x_{n+1}, x_n) - f_{n-1}(x_n, x_{n-1})|$')

    R_diff = [(abs(Rlist[n] + Rlist[n-1]))/N for n in ns[1:-2]]
    semilogy(range(len(R_diff)), R_diff, 'r-',
    label='$|x_n + x_{n-1}|/N$')

    legend(loc=0)
    xlabel('$n$', fontsize=16)
    ylabel('energy differences, see legend',
        fontsize=16)
    ex1 = 'Plot comparing the energy and its theoretical value,'
    ex2 = 'based on the calculation that $f_n(x_{n+1}, x_n) - f_{n-1}(x_n, x_{n-1}) = -(x_n + x_{n-1})$.'
    title(ex1 + '\n' + ex2 + '\n' + 'dps = ' + str(dec) + ', $\zeta = $' + str(zeta) + ', $t = $' + str(t) + ', $N = $' + str(N), 
        fontsize='large')
    draw()
    tight_layout()
    return data
