#!/usr/bin/env python
from scipy import *
from scipy import interpolate
from scipy import integrate
from scipy import optimize
from pylab import *
from pylab import *
import sys

def n_bose(x,beta):
    if x*beta>100:
        return 0.0
    elif x*beta<-100:
        return -1.0
    else:
        return 1./(exp(x*beta)-1.)
def ferm(x,beta):
    if x*beta>100:
        return 0.0
    elif x*beta<-100:
        return 1.0
    else:
        return 1./(exp(x*beta)+1.)
def ferma(x,beta):
    ff=1./(exp(x*beta)+1.)
    for i in range(len(x)):
        if abs(x[i]*beta)>100:
            if x[i]>0 : ff[i]=0.0
            else: ff[i]=1.0
    return ff

def Sigma_x(EF, k, lam, eps):
    kF = sqrt(EF)
    pp = 0
    if lam>0:
        pp = lam/kF*(arctan((k+kF)/lam)-arctan((k-kF)/lam))
    qq = 1 - pp - (lam**2+kF**2-k**2)/(4*k*kF)*log((lam**2+(k-kF)**2)/(lam**2+(k+kF)**2))
    return -2*kF/(pi*eps)*qq

def EpsEx(rs,lam, eps):
    fx=1.
    if lam>0 :
        x = (9*pi/4)**(1./3.) *1./(lam*rs)
        fx = 1 - 1./(6*x**2) - 4*arctan(2*x)/(3*x) + (1.+1./(12*x**2))*log(1+4*x**2)/(2*x**2)
    return -(3./(2.*pi*eps))*(9*pi/4.)**(1./3.)/rs*fx

def ReadMesh(fmesh):
    dat = loadtxt(fmesh).transpose()
    fi = open(fmesh,'r')
    firstline = fi.readline()[1:].split()
    
    N_w, N_k, nom_low = map(int, firstline[:3])
    (kF, beta, lam, eps) = map(float, firstline[3:7])
    (PhiC,EpotC0,EC_LDA) = map(float, firstline[7:10])
    fi.close()
    nw = dat[:N_w]
    km = dat[N_w:]
    return (nw, km, nom_low, kF, beta, lam, eps,PhiC,EpotC0,EC_LDA)

def Occup(mu, k, SigC, nom_low, kF, lam, eps):
    # exchange
    Sx = Sigma_x(kF**2, k, lam, eps)
    # exchange-correlation
    Sxc = Sx + SigC
    
    # Complete Matsubara mesh with many more points than computed above
    midpoint = int(0.5*(nw[nom_low-1]+nw[nom_low]))
    mwl = arange(nw[0],       midpoint)
    mwh = arange(midpoint, nw[-1])
    # Interpolation
    Srl = interpolate.UnivariateSpline(nw[:nom_low], real(Sxc[:nom_low]), s=0)
    Sil = interpolate.UnivariateSpline(nw[:nom_low], imag(Sxc[:nom_low]), s=0)
    Sxcl = Srl(mwl)+Sil(mwl)*1j
    
    Srh = interpolate.UnivariateSpline(nw[nom_low:], real(Sxc[nom_low:]), s=0)
    Sih = interpolate.UnivariateSpline(nw[nom_low:], imag(Sxc[nom_low:]), s=0)
    Sxch = Srh(mwh)+Sih(mwh)*1j

    _mw_ = hstack( (mwl, mwh) )
    _Sxc_ = hstack( (Sxcl,Sxch) )
    
    wnc = (2*_mw_+1)*pi/beta
    # interacting Green's function
    Gc  = 1/(wnc*1j + mu - k**2 - _Sxc_)
    # Corresponding HF (non-interacting) Green's function
    Gc0 = 1/(wnc*1j + mu - k**2 - Sx)
    # Finally sum over all Matsubara points + analytically computed corrections
    nocc = sum((Gc-Gc0).real)*2/beta + ferm(k**2+Sx-mu,beta)
    #print k, 'nocc=', nocc, 'Sx=', Sx
    return nocc

def ComputeChemicalPotential(km, SigC_all, nom_low, kF, lam, eps):
    
    def dN(mu, km, SigC_all, nom_low, kF, lam, eps):
        nk = array([Occup(mu, km[ik], SigC_all[:,ik], nom_low, kF, lam, eps) for ik in range(len(km))])
        N_over_N0 = integrate.simps(nk * km**2, x=km)*(3/kF**3)
        #print 'N_over_N0=', N_over_N0
        return N_over_N0 - 1.0

    Sx0 = Sigma_x(kF**2, kF-1e-6, lam, eps)  # S_exchange(k=kF)
    mu = kF**2+Sx0
    deps = kF**2/10.
    dn = dN(mu, km, SigC_all, nom_low, kF, lam, eps)
    # bracketing zero
    sgn = sign(dn)
    while(dn*sgn > 0):
        mu -= deps*sgn
        dn = dN(mu, km, SigC_all, nom_low, kF, lam, eps)
        (a,b) = (mu+deps*sgn, mu)
    
    #(a,b) = (kF**2+Sx0-kF**2/5., kF**2+Sx0+kF**2/5.)
    mu = optimize.brentq(dN, a, b, args=(km, SigC_all, nom_low, kF, lam, eps))
    return mu

def dTrElnG(mu, k, SigC, nom_low, kF, lam, eps):
    "Computes the change of potential energy due to G->G0 and TrLog(G/G0)"
    def Free0(E,beta):
        if beta*E<-100:
            return E
        elif beta*E>100:
            return 0
        return -1./beta * log(1.+exp(-beta*E))

    # exchange
    Sx = Sigma_x(kF**2, k, lam, eps)
    # exchange-correlation
    Sxc = SigC + Sx
    
    # Complete Matsubara mesh with many more points than computed above
    midpoint = int(0.5*(nw[nom_low-1]+nw[nom_low]))
    mwl = arange(nw[0],       midpoint)
    mwh = arange(midpoint, nw[-1])
    # Interpolation for low frequency
    Srl = interpolate.UnivariateSpline(nw[:nom_low], real(Sxc[:nom_low]), s=0)
    Sil = interpolate.UnivariateSpline(nw[:nom_low], imag(Sxc[:nom_low]), s=0)
    Sxcl = Srl(mwl)+Sil(mwl)*1j
    # Interpolation for high frequency
    Srh = interpolate.UnivariateSpline(nw[nom_low:], real(Sxc[nom_low:]), s=0)
    Sih = interpolate.UnivariateSpline(nw[nom_low:], imag(Sxc[nom_low:]), s=0)
    Sxch = Srh(mwh)+Sih(mwh)*1j
    # entire frequency mesh
    _mw_ = hstack( (mwl, mwh) )
    # entire self-energy
    _Sxc_ = hstack( (Sxcl,Sxch) )
    Sc = _Sxc_-Sx # The correlation part only
    
    wnc = (2*_mw_+1)*pi/beta
    # interacting Green's function
    Gc  = 1/(wnc*1j + mu - k**2 - _Sxc_)
    lnGc  = -log(-(wnc*1j + mu - k**2 - _Sxc_))
    # Corresponding non-interacting Green's function
    Gc00 = 1/(wnc*1j + kF**2 - k**2)
    lnGc0 = -log(-(wnc*1j + mu - k**2 - Sx))
    
    # Finally sum over all Matsubara points
    dGG0k = sum(((Gc-Gc00)*Sc).real)*2/beta  # Tr((G-G0)*Sigma_c)
    F0 = Free0(k**2+Sx-mu,beta)  # free energy of HF
    F00 = Free0(k**2-kF**2,beta) # free energy of non-interacting
    trLnG = sum((lnGc-lnGc0).real)*2/beta + F0 - F00 # Tr(log(G/G0))=Tr(log(G/G_HF)+log(G_HF/G0))
    return (dGG0k,trLnG)

if __name__ == '__main__':

    rsx=[0.5,1,2,3,4,5,6,7,8]
    xEtot=[]
    xFtot=[]
    xEc_LDA=[]
    eps0=1; lam0=2.5
    for rs in rsx:
        end = '_rs_'+str(rs)+'_eps_'+str(eps0)+'_lam_'+str(lam0)+'.dat'
        # Read correlation self-energy data from the file
        (nw, km, nom_low, kF, beta, lam, eps, PhiC, EpotC0, EC_LDA) = ReadMesh('mesh'+end)
        SigC_all = loadtxt('SigC'+end)
        SigC_all = SigC_all[::2,:] + SigC_all[1::2,:]*1j
    
        # Interpolate on larger k-mesh
        Np=3  # how much more points in k-mesh
        _km_ = hstack( ([1e-10], linspace(0.01,0.9*kF, 10*Np), linspace(0.9*kF,1.2333*kF, 20*Np)[1:], linspace(1.2333*kF, 2.1*kF, 5*Np)[1:]) )
        _SigC_all_ = zeros((len(nw),len(_km_)),dtype=complex)
        for iw in range(len(nw)):
            _Sr_=interpolate.UnivariateSpline(km,real(SigC_all[iw,:]),s=0)
            _Si_=interpolate.UnivariateSpline(km,imag(SigC_all[iw,:]),s=0)
            _SigC_all_[iw,:] = _Sr_(_km_) + _Si_(_km_)*1j
        km = _km_
        SigC_all = _SigC_all_
    
        # Compute the chemical potential using this self-energy
        mu = ComputeChemicalPotential(km, SigC_all, nom_low, kF, lam, eps)
    
        # Compute occupancy n(k)==n_k
        nk=zeros(len(km))
        for ik in range(len(km)):
            nk[ik] = Occup(mu, km[ik], SigC_all[:,ik], nom_low, kF, lam, eps)
    
        # Check N/N0=1
        N_over_N0 = integrate.simps(nk * km**2, x=km)*(3/kF**3)
        Ek_over_Ek0 = integrate.simps(nk * km**4, x=km)*(5/kF**5)
        Ek0=kF**5/(5*pi**2)  # kinetic energy of the non-interacting system
        rho=kF**3/(3*pi**2)  # density
    
        print 'mu=', mu, 'N/N0=', N_over_N0, 'Ek/Ek0', Ek_over_Ek0
        # difference in kinetic energy between interacting and non-interacting system : 
        # dEk=(Ek-Ek0)/rho
        dEk = (Ek_over_Ek0*Ek0-Ek0)/rho  
    
        # Ex0 = 1/2*Tr(Sigma_x*G0)/rho
        cc=[Sigma_x(kF**2, k, lam, eps)*ferm(k**2-kF**2,beta)*0.5*k**2 for (ik,k) in enumerate(km)]
        Ex0 = integrate.simps(cc,x=km)/pi**2/ rho
    
        # dEpotx = 1/2*Tr((G-G0)*Sigma_x)/rho
        cc=[Sigma_x(kF**2, k, lam, eps)*(nk[ik]-ferm(k**2-kF**2,beta))*0.5*k**2 for (ik,k) in enumerate(km)]
        dEpotx = integrate.simps(cc,x=km)/pi**2/rho
    
        # dEpotc = 1/2*Tr((G-G0)*Sigma_x)/rho
        # TrLogGoG0 = Tr(log(G/G0))/rho
        Scdn = zeros(len(km))
        trlng = zeros(len(km))
        for ik in range(len(km)):
            (dGG0k,trLnG) = dTrElnG(mu, km[ik], SigC_all[:,ik], nom_low, kF, lam, eps)
            Scdn[ik] = 0.5*dGG0k * km[ik]**2
            trlng[ik] = trLnG * km[ik]**2
        dEpotc  = integrate.simps(Scdn,x=km)/pi**2/rho
        TrLogGoG0 = integrate.simps(trlng,x=km)/pi**2/rho
    
        # dEtot = Etot-Ek0-Ex= (Ek-Ek0) + 1/2*Tr(Sig_c*G0) + 1/2*Tr(Sig_c*(G-G0)) + 1/2*Tr(Sig_x*(G-G0))
        dEtot = dEk + EpotC0 + dEpotc + dEpotx
        # dFtot = Tr(log(G/G0)) - Tr(Sig_c*G0) - Tr(Sig_c*(G-G0)) - Tr(Sig_x*G0) - Tr(Sig_x*(G-G0)) + Phi^c + mu-EF
        dFtot = TrLogGoG0 - 2*(EpotC0+dEpotc) -2*(Ex0+dEpotx)+ PhiC  + mu-kF**2
        xEtot.append(dEtot)
        xFtot.append(dFtot)
        xEc_LDA.append(EC_LDA)
        print 'Ec_LDA=', EC_LDA
        print 'dEtot=', dEtot
        print 'dFtot=', dFtot
    
    savetxt('ec_eps_'+str(eps)+'_lam_'+str(lam)+'.dat', vstack((rsx,xEtot)).transpose())
    plot(rsx, xEtot, 'o-', label='Etot_corr')
    plot(rsx, xFtot, 'o-', label='Ftot_corr')
    plot(rsx, xEc_LDA, 'o-', label='Ec_LDA')
    legend(loc='best')
    show()
