"""Code for nice-looking plots, with no titles."""

from __future__ import division
from pylab import *
import matplotlib.pyplot as plt
import mpmath as mp
rcParams['text.usetex']=True
rcParams['font.family']='serif'

##############################################################
#####  PLOTTING FUNCTIONS
##############################################################

def latexify(s):
    if 'e' in str(s):
        exp = log10(s)
        return r'$10^{{{:1.0f}}}$'.format(exp)
    return r'${}$'.format(s)

def prettiness(ax):
    def latexify(s):
        if 'e' in str(s):
            exp = log10(s)
            return r'$10^{{{:1.0f}}}$'.format(exp)
        return r'${}$'.format(s)

    ax.set_xticklabels(list(map(latexify, ax.get_xticks())), fontsize=16)
    ax.set_yticklabels(list(map(latexify, ax.get_yticks())), fontsize=16)
    ax.tick_params('both', length=7, width=1, which='major')
    ax.tick_params('both', length=4, width=1, which='minor')

def plot_gens(data):
    dec = data['dec']
    t = data['t']
    zeta = data['zeta']
    N = data['N']
    nlist = data['ns']
    n_end = data['n_end']
    col1 = data['col1']
    mar1 = data['mar1']

    Rpnlist = map(float, data['Rpnlist'])
    Rlist = map(float, data['Rlist'])
    
    figure()
    scatter(nlist, Rlist, color=col1, marker=mar1, label='$x_n$')
    plot(nlist, Rpnlist, 'r-', label='$x(n)$')

    ax = plt.gca()
    ax = prettiness(ax)

    xlabel(r'$n$', fontsize=24)
#    s = r'$\mathrm{Plot~of}~x_n~\mathrm{approaching}~x(n)$'
#    title(s, fontsize=24)
    legend(loc=0, fontsize=20)
    tight_layout()
    plt.savefig('zeta{}_t{}_n{}_N{}_dps{}_asymptotics.pdf'.format(zeta, t, n_end, N, dec))
    return data

def plot_3d(data):
    dec = data['dec']
    zeta = data['zeta']
    t = data['t']
    N = data['N']
    n_end = data['n_end']

    nlist = data['ns']
    Rlist = map(float, data['Rlist'])
    Slist = map(float, data['Slist'])
    diagram = plt.figure()
    ax = diagram.add_subplot(111, projection='3d')
    ax.set_xlabel('$n$', fontsize=24)
    ax.set_ylabel('$x_n$', fontsize=24, rotation='horizontal')
    ax.set_zlabel('$x_{n-1}$', fontsize=24, rotation='horizontal')
    ax.plot(nlist, Rlist, Slist, 'b.')
#    s = r'$\mathrm{Orbit~plot~of~}(n, x_n, x_{n-1})$'
#    ax.set_title(s, fontsize=24)

    def latexify(s):
        return r'${:1.1f}$'.format(s)
    ax.set_xticklabels(list(r'${:1.0f}$'.format(s) for s in ax.get_xticks()), fontsize=16)
    ax.set_yticklabels(list(map(latexify, ax.get_yticks())), fontsize=16)
    ax.set_zticklabels(list(map(latexify, ax.get_zticks())), fontsize=16)
    ax.tick_params('both', length=7, width=1, which='major')
    
    plt.savefig('zeta{}_t{}_n{}_N{}_dps{}_3dim.pdf'.format(zeta, t, n_end, N, dec))
    return data

def plot_2dprojection(data):
    col = data['col1']
    mar = data['mar1']
    dec = data['dec']
    zeta = data['zeta']
    t = data['t']
    N = data['N']
    n_end = data['n_end']

    Slist = map(float, data['Slist'])
    Rlist = map(float, data['Rlist'])
    S = data['S0']
    R = data['R0']

    fig, ax = plt.subplots()
    ax.scatter(R, S, c='red')
    ax.scatter(Rlist, Slist, c=col, marker=mar, edgecolor=col)
    
    prettiness(ax)
    xlabel('$x_n$', fontsize='24')
    ylabel('$x_{n-1}$', fontsize='24', rotation='horizontal')
    legend(fontsize=20, loc=0)

    Rr = float(R)
    Ss = float(S)
    plt.text(R+.1, S, r'$({:1.2f},{:1.2f})$'.format(Rr,Ss), fontsize=16)

#    s = r'$\mathrm{Orbit~plot~of~}(x_n, x_{n-1})$'
#    title(s, fontsize=24)
    scatter(R, S, c='red', edgecolor='red')
    tight_layout()
    plt.savefig('zeta{}_t{}_n{}_N{}_dps{}_2dim.pdf'.format(zeta, t, n_end, N, dec))
    return data

##############################################################
#####  VERIFYING THE NUMERICS
##############################################################

def gen_fnofxy(zeta, t, n, N, x, y):
    """Evaluates Hamiltonian at point in time and space."""
    return x*y*(zeta + 4*t*x + 4*t*y) - (n/N)*x - (n/N)*y

def gen_energy_diff(data):
    """Creates Diagnostic plot from integrability #2."""
    dec = data['dec']
    zeta = data['zeta']
    t = data['t']
    N = data['N']
    ns = data['ns']
    Rlist = data['Rlist']
    n_end = data['n_end']

    diffn = []
    for n in ns[1:-2]:
        cn = gen_fnofxy(zeta, t, n+1, N, Rlist[n+1], Rlist[n])
        dn = gen_fnofxy(zeta, t, n+1, N, Rlist[n], Rlist[n-1])
        diffn.append(abs(cn - dn))

    figure()
    plot(ns[1:-2], diffn, 'b-')
    yscale('log')

    ax = plt.gca()
    ax = prettiness(ax)
    
    plt.axhline(y=10**(-dec), xmin=0, xmax=n_end, color='r')

    xlabel('$n$', fontsize=24)
    ylabel(r'$\mathrm{abs~error}$', fontsize=24)

#    ylabel('$|f_n(x_{n+1}, x_n) - f_n(x_n, x_{n-1})|$',
#        fontsize=24)
#    s = 'Rounding error diagnostic: \n absolute error between~energy \n at adjacent orbit points'

#    s = r'$\mathrm{Rounding~error~diagnostic:}$' + '\n' + r'$\mathrm{absolute~error~between~energy}$' + '\n' + r'$\mathrm{at~adjacent~orbit~points}$'
#    title(s, fontsize=24)
    tight_layout()
    plt.savefig('zeta{}_t{}_n{}_N{}_dps{}_abserror.pdf'.format(zeta, t, n_end, N, dec))
    return data
            
def gen_energy_eval(data):
    """Creates list of outputs from Hamiltonian."""
    t = data['t']
    zeta = data['zeta']
    N = int(data['N'])
    Rlist = data['Rlist']
    Slist = data['Slist']
    nlist = data['ns']

    energies = []
    for n in nlist[:-1]:
        x = Rlist[n]
        y = Slist[n]
        e = gen_fnofxy(zeta, t, n, N, x, y)
        energies.append(e)
    data['energies'] = energies
    return data

def gen_energy_error(data):
    """Creates Diagnostic plot from integrability #1."""
    zeta = data['zeta']
    t = data['t']
    N = data['N']
    dec = data['dec']
    e = data['energies']
    ns = data['ns']
    Rlist = data['Rlist']
    n_end = data['n_end']

    e_diff = [abs(e[n] - e[n-1]) for n in ns[1:-2]]
    figure()
    semilogy(range(len(e_diff)), e_diff, 'b-', 
    label=r'$\displaystyle |f_n(x_{n+1}, x_n) - f_{n-1}(x_n, x_{n-1})|$')

    R_diff = [(abs(Rlist[n] + Rlist[n-1]))/N for n in ns[1:-2]]
    semilogy(range(len(R_diff)), R_diff, 'r-',
    label=r'$\displaystyle\frac{|x_n + x_{n-1}|}{N}$')

    ax = plt.gca()
    ax = prettiness(ax)

    legend(loc=7)
    xlabel('$n$', fontsize=24)

#    ylabel('$\mathrm{energy~differences}$',
#        fontsize=24)
#    ex1 = 'Plot comparing the energy and its theoretical value,'
#    ex2 = 'based on the calculation that $f_n(x_{n+1}, x_n) - f_{n-1}(x_n, x_{n-1}) = -(x_n + x_{n-1})$.'
#    title(ex1 + '\n' + ex2 + '\n' + 'dps = ' + str(dec) + ', $\zeta = $' + str(zeta) + ', $t = $' + str(t) + ', $N = $' + str(N), 
#        fontsize=24)
#    s = 'Energy accumulation diagnostic: \n time-dependent energy differences \n at adjacent orbit points'

#    s = r'$\mathrm{Energy~accumulation~diagnostic:}$' + '\n' + r'$\mathrm{time~dependent~energy~differences}$' + '\n' + r'$\mathrm{at~adjacent~orbit~points}$'
#    title(s, fontsize=24)
    tight_layout()
    plt.savefig('zeta{}_t{}_n{}_N{}_dps{}_energyerror.pdf'.format(zeta, t, n_end, N, dec))
    return data
