#!/usr/bin/env python
# Author: Kristjan Haule, Jan 2008


import os,sys,re
from bisect import bisect

if 'ROOT' not in globals().keys():
    ROOT = os.environ.get('LMTO_DMFT_ROOT')
    sys.path.append( ROOT )
    sys.path.append( ROOT + '/fort')

from ld_base import *
import optics

# MPI start
from mpi4py import MPI
Master = 0
# for the version 0.5
MPI.rprint = MPI._rprint
MPI.rank = MPI.WORLD_RANK
MPI.size = MPI.WORLD_SIZE
MPI.rprint('mpi.(root,size) =  %d %d'%(Master,MPI.size), MPI.COMM_WORLD, Master)

def read_velocity(lmt):
    """ Reads velocities outputed by LmtArt.

    Output:
            matrix: lmt.vel[3,ndim,ndim,nkp]

    """
    #------------------- info ---------------------------------------
    lmt.fh_info.write("-"*80+"\n")
    lmt.fh_info.write("READ_VEL:   Velocities\n\n")
    #------------------- info ---------------------------------------
    file_vel  = lmt.inpdir + "/vel.dat"
    #-------------------- Reading Velocities ---------------
    lmt.fh_info.write("Reading velocities from "+file_vel+"\n")
    # fortran read
    lmt.vel = fort_read.read_v(file_vel, 3, lmt.ndim, lmt.ndim, lmt.nkp, 1, 1)


    # from Rydbergs to eV
    lmt.fh_info.write("Assuming velocities are in Rydberg->Changing to electron volts\n")
    Ry2eV = 13.6058
    lmt.vel *= Ry2eV*2j/lmt.latcon # due to SS definition of velocities
    

def read_OutsideSig(filename):
    """ Frequency Om at which optics sigma(Om) will be computed! """
    Om=[]
    fos = open(filename)
    for line in fos:
        if re.match('^\s*#', line) is not None: fos.next()
        Om.append(float(line.split()[0]))
    return Om


def Interpolate(x, om, Sigc):
    """ given mesh points om and some floating point number x, finds the linear interpolation of Sigc(x)"""
    ix = bisect(om, x)-1
    if ix>len(om)-2 : ix=len(om)-2
    if ix<0 : ix=0
    #Sigc=array(Sigc)
    return Sigc[ix] + (Sigc[ix+1]-Sigc[ix])*(x-om[ix])/(om[ix+1]-om[ix])
    

def Optics_Frequencies(Om, om, small=1e-12):
    """ Construcs a 2D mesh of all points which need to be computed for optics"""
    epsi=[] # all mesh points
    for Omega in Om:
        teps=[]
        for eps in om:
            if eps>=0 and eps<=Omega : teps.append(eps)
        if len(teps)==0 or teps[-1]<Omega-small : teps.append(Omega)
        if (teps[0]>small) : teps.insert(0, 0.0)
        epsi.append(teps)

    # Integration limits over internal frequency for each point in mesh
    ab=[]
    for iOm in range(len(epsi)):
        eps = epsi[iOm]
        tab=[]
        tab.append([eps[0], 0.5*(eps[1]+eps[0])])
        for ie in range(1,len(epsi[iOm])-1):
            tab.append([0.5*(eps[ie]+eps[ie-1]), 0.5*(eps[ie]+eps[ie+1])])
        tab.append([0.5*(eps[-2]+eps[-1]), eps[-1]])
        ab.append(tab)

    return (epsi, ab)


def pr_limits(start, end, MPI_size):
    """ Returns a 2D-list of points pnts[nprocessors][:] describing which points are to be computed on each processor"""
    pnts = [ [] for i in range(MPI_size)]
    n=start    
    while n<end:
        for i in range(MPI_size):
            pnts[i].append(n)
            n+=1
            if n>=end : break
    return pnts

def GatherAndPrint(tops, Om, indom, pnts, outdir):
    """ Combines results from all processors and prints """

    data = MPI.WORLD[Master].Gather(tops)
    #data = MPI.WORLD.Allgather(tops)
    if MPI.rank == Master:
        copt=zeros(len(Om))
        for proc in range(MPI.size):
            for i in range(len(data[proc])):
                (iOm, ie) = indom[pnts[proc][i]]
                copt[iOm] += data[proc][i]
        fh_optics = open(outdir+'/optics.dat', 'w')
        for iOm in range(len(Om)): print >> fh_optics, Om[iOm], copt[iOm]
        fh_optics.close()
    MPI.WORLD.Barrier()


def cmp_Optics(lmt, Om, epsi, ab, om, Sigc, alphaV, outdir, fh_info, gamma=[0.0,0.0], how_often=10):
    """ Actually computes optical conductivity"""

    alphaV1 = zeros((3,3), order='F')
    for i in range(3): alphaV1[i,i] = alphaV[i]

    indom=[] # index array with all frequencies (outside and inside)
    for iOm in range(len(Om)):
        for ie in range(len(epsi[iOm])):
            indom.append((iOm,ie))

    # Parallel run!
    pnts = pr_limits(0, len(indom), MPI.size) # Which frequencies need to be computed by each processor

    tops=[]
    for ii,i in enumerate(pnts[MPI.rank]):
        iOm = indom[i][0]; ie = indom[i][1]
        Omega = Om[iOm]
        epsilon = epsi[iOm][ie]
 
        #sig_p = CreateSelfEnergyMatrix(Interpolate(epsilon,       om, Sigc), lmt.Sigind, gamma)
        #sig_m = CreateSelfEnergyMatrix(Interpolate(epsilon-Omega, om, Sigc), lmt.Sigind, gamma)
        sig_p = Interpolate(epsilon,       om, Sigc)
        sig_m = Interpolate(epsilon-Omega, om, Sigc)
       
        opt = optics.cmp_opt3(Omega, lmt.mu, epsilon, ab[iOm][ie][0], ab[iOm][ie][1], sig_p, sig_m, lmt.vel, lmt.hamf, lmt.olap, alphaV1, lmt.latcon, lmt.unit_cell_volume, lmt.wk)
        tops.append(opt/Omega)

        print >> fh_info, Omega, opt/Omega
        fh_info.flush()

        if (ii+1) % how_often == 0:
            GatherAndPrint(tops, Om, indom, pnts, outdir)

    GatherAndPrint(tops, Om, indom, pnts, outdir)

def cmp_Dos_Gloc_Delta(lmt, omega, Sigma, pmax=1000, max_metropolis_steps=0):
    """
    Computes partial DOS, gloc, correlated gf and Delta
    Input:
       lmt                     -- lmtart class
       omega                   -- frequency (real or matsubara)
       Sigma[ndim,ndim]        -- self-energy matrix
       max_metropolis_steps    -- metropolis steps in sorting bands
       ---------------------------------------------
          The following members of class lmt are used:
             mu                -- the chemical potential
             olap[ndim,ndim]   -- overlap matrix
             hamf[ndim,ndim]   -- hamiltonian matrix
             itt[5,ntet]       -- index of tetrahedron mesh
             T2C[ndim,ndim]    -- matrix which rotates from spheric to cubis/relativistic coordinates
             bndindx[ndim,4]   -- index from orbital to atom and l
             ndim              -- size of the ham matrix
             Sigind[ndim,ndim] -- correlated index matrix
    Output:
       tDOS                    -- total DOS
       pDOS[natom,nmax]        -- partial DOS
       gc                      -- columns of the Green's function of the correlated orbitals 
       Delta                   -- hybridization function of the correlated orbitals
    """
    z = omega + lmt.mu

    # Actually computes local Green's function - sum over irreducible BZ
    gloc = cmp_Gloc(z, lmt.olap, lmt.hamf, Sigma, lmt.itt, max_metropolis_steps)
    # Array has to be converted to Fortran type array to symmetrize
    gloc = array(gloc, order='F')
    # Symmetrization replaces summation over full 1-st BZ
    sym.symmetrize_cubic(gloc, lmt.T2C)

    MDelta = z*identity(lmt.ndim) - linalg.inv(matrix(gloc)) - matrix(Sigma)
    
    # Partial DOS is computed for each atom and each l
    # We need information about lmax and natom to do that
    # Information is obtained from lmt.bndindx
    lmax  = int(max([x[1] for x in lmt.bndindx]))
    natom = int(max([x[0] for x in lmt.bndindx]))
    # Partial DOS is stored below
    dos_tot=0
    pdos = zeros((natom,lmax+1),dtype=float)
    for p in range(lmt.ndim):
        wdos = -gloc[p,p].imag/pi
        iatom = int(lmt.bndindx[p][0]-1)
        il    = int(lmt.bndindx[p][1])
        pdos[iatom, il] += wdos
        if (p<pmax) : dos_tot += wdos
    # Green's function of the correlated block is stored below
    # We need information about the nonzero entries in the matrix
    sind = lmt.Sigind.flatten()
    cols = utl.union(sind)
    deg = utl.repeat(sind,cols)
    
    gc = zeros(max(cols),dtype=complex)
    Delta = zeros(max(cols),dtype=complex)
    for p in range(lmt.ndim):
        for q in range(lmt.ndim):
            if (lmt.Sigind[p,q]>0):
                ii = lmt.Sigind[p,q]-1
                gc[ii] += gloc[p,q]
                Delta[ii] += MDelta[p,q]
                
    for p in range(len(gc)):
        gc[p] /= deg[p+1]
        Delta[p] /= deg[p+1]

    return (dos_tot, pdos, gc, Delta)


def GatherAndPrintDos(gcDeDo, om, pnts, outdir):
    data = MPI.WORLD[Master].Gather(gcDeDo)
    if MPI.rank == Master:
        # put results from different processors in proper order
        gcDeDoN=[[] for i in range(len(om))]
        for proc in range(MPI.size):
            for i in range(len(data[proc])):
                iom = pnts[proc][i]
                gcDeDoN[iom] = data[proc][i]

        # print out the results
        fh_dos = open(outdir+'/dos.out','w')
        fh_gf  = open(outdir+'/gf.out', 'w')
        fh_delta = open(outdir+'/Delta.out', 'w')

        for iom in range(len(om)):
            if len(gcDeDoN[iom])!=0: # Some points might not be computed yet!
                (tdos, pdos, gc, Delta) = gcDeDoN[iom]
                
                fh_dos.write("%f \t%f " % (om[iom], tdos))
                for iatom in range(shape(pdos)[0]):
                    for il in range(shape(pdos)[1]):
                        fh_dos.write("\t%f " % pdos[iatom,il] )
                fh_dos.write("\n")
                
                fh_gf.write("%f " % (om[iom]))
                for g in gc:
                    fh_gf.write("\t%f  %f " % (g.real, g.imag))
                fh_gf.write("\n")
                
                fh_delta.write("%f " % (om[iom]))
                for g in Delta:
                    fh_delta.write("\t%f  %f " % (g.real, g.imag))
                fh_delta.write("\n")
            
        fh_dos.close()
        fh_gf.close()
        fh_delta.close()
    MPI.WORLD.Barrier()
    

def Print_Dos_Gloc_Delta(lmt, om, Sigc, outdir, fh_info, pmax, max_metropolis_steps=0, how_often=10):
    """
    Computes and prints three quantities: DOS, Gf, Delta
    This function Reads self-energy from the input file
    at inpdir/sig.inp, calls cmp_Dos_Gloc_Delta compute the
    three quantities and prints to three files:
    outdir/dos.out, outdir/gf.out and outdir/Delta.out
    """

    # Parallel run!
    pnts = pr_limits(0, len(om), MPI.size) # Which frequencies need to be computed by each processor

    gcDeDo=[] # contains (gc, Delta, dos_tot, pdos)
    for ii,iom in enumerate(pnts[MPI.rank]):
        gcDeDo.append( cmp_Dos_Gloc_Delta(lmt, om[iom], Sigc[iom], pmax, max_metropolis_steps) )
        print >> fh_info, om[iom], gcDeDo[-1][0]
        fh_info.flush()
        if (ii+1) % how_often == 0: GatherAndPrintDos(gcDeDo, om, pnts, outdir)

    GatherAndPrintDos(gcDeDo, om, pnts, outdir)
    

#-------------- for upfolding -------------------------
def strip_comments(data):
    newdata=[]
    for i in range(len(data)):
        stripped  = re.sub(r'(.*)\#.*', r"\1",  data[i].strip())
        stripped = stripped.split()
        if (len(stripped)>0):
            newdata.append(stripped)
    return newdata

def pars_index(ndim, lines, i0):
    """ reading the rules to transform
    columns of the sigma file to matrix"""
    sind=zeros((ndim,ndim), dtype=int)
    coef = zeros((ndim,ndim), dtype=complex)
    for p in range(ndim):
        ph = lines[i0]
        del lines[i0]

        for q in range(ndim):
            cols = re.findall(r'\$(\d+)', ph[q])
            if (cols):
                sind[p,q] = int(cols[0])

                tsc = re.sub(r'\$'+cols[0], '1', ph[q])
                coef[p,q] = eval(tsc)

        mess=[]
    return (sind, coef)


def pars_coeff(sc):
    """ Given a list string such as '$1$1 + 1./4.*$6$1 + (1.+1j)*$3$3*2'
        returns list of column numbers [(1,1),(6,1),(3,3)] and coefficients
        which correspond to each column  [1, 0.25, 2+2j]
    """
    cols = re.findall(r'\$(\d+)\$(\d+)', sc)
    scols = re.findall(r'(\$\d+\$\d+)', sc)
    cols = [(int(x[0]),int(x[1])) for x in cols]

    coef=[]
    for ic in range(len(cols)):
        tsc = sc
        for iq in range(len(cols)):
            rg = r'\$'+str(cols[iq][0])+'\$'+str(cols[iq][1])
            if (cols[ic]==cols[iq]):
                tsc = re.sub(rg, '1', tsc)
            else:
                tsc = re.sub(rg, '0', tsc)

        coef.append(eval(tsc))

    return (scols, cols, coef)

def read_sind(fh_sind, ndim):
    """ Reads index array to make a matrix of self-energy from columns of the file
    """
    f = open(fh_sind, 'r')

    data = f.readlines()
    lines = strip_comments(data)

    # Transformation from momentum to real space base
    Uk = identity(ndim)
    if (['transformation'] in lines):
        i0 = lines.index(['transformation'])
        del lines[i0]

        # parsing transformation
        Uk0=[]
        for p in range(ndim):
            Uk0.append(map(complex,lines[i0]))
            del lines[i0]
        Uk = array(Uk0)


    corind = ones((ndim,ndim),dtype=int)
    if (['correlated_blocks'] in lines):
        i0 = lines.index(['correlated_blocks'])
        del lines[i0]

        ones_ = ones(ndim, dtype=int)
        # parsing the shape of correlated blocks
        corind=[]
        for p in range(ndim):
            ind = map(int, lines[i0]) - ones_
            corind.append(ind)
            del lines[i0]
        corind = array(corind)

    NMcoef = ones((ndim,ndim))
    NMsind = ones((ndim,ndim),dtype=int)
    if (['NM_structure'] in lines):
        i0 = lines.index(['NM_structure'])
        del lines[i0]
        # reading the rules to transform
        # columns of the sigma file to matrix
        (NMsind, NMcoef) = pars_index(ndim, lines, i0)

    return (Uk, corind, NMsind, NMcoef)

def CreateSelfEnergyMatrixUpfold(Sig, Uk, sind, coef, gamma=0.01):
    # Here we create a matrix of self-energy Sig. The most important array in this class
    sig = matrix(zeros((len(sind),len(sind)),dtype=complex))
    #print shape(sig), shape(Sig), shape(Uk), shape(sind), shape(coef)
    for p in range(len(sig)):
        for q in range(len(sig)):
            ind = sind[p,q]-1
            #print p, q, ind
            if (ind>=0):
                if (ind<shape(Sig)[0]):
                    sig[p,q] = Sig[ind]*coef[p,q]
                else:
                    sig[p,q] = -gamma*1j

    return matrix(Uk) * matrix(sig) * matrix(conjugate(Uk.transpose()))
#-------------- for upfolding -------------------------



if __name__ == '__main__':

    max_metropolis_steps=0 #50000
    cmp_dos = False
    cmp_opt = False
    inpdir = '.'
    outdir = '.'
    sig = 'sig.inp'
    soutside = 'sig_out'
    gamma = [0.01,0.01]
    alphaV = [1/3.,1/3.,1/3.] # Sets the direction for conductivity
    how_often = 10
    pmax = 10000
    
    execfile('PARAMS.opt') # File containing all parameters!


    fh_info = open(inpdir+'/opt.info.'+str(MPI.rank), 'w')
    print >> fh_info, '-'*80
    print >> fh_info, '*'*20, 'Optics & DOS calculation'
    print >> fh_info, '-'*80, '\n'

    
    if upfold:
        # Reads self-energy from file
        (om,Sigc) = ReadSelfEnergyColumns(inpdir+"/"+sig)
    else:
        (om,Sigc,Sigma_oo_IMPv,Edc_IMP) = ReadSelfEnergyColumns(inpdir+"/"+sig, contains_DC=True)
        

    # Reads LMTO data
    lmt = Lmtart(inpdir, fh_info)
    lmt.read_cix()    
    trn.transform2(lmt.olap, lmt.hamf, lmt.Utk)


    dcorr = utl.flatten(lmt.vSigind)
    icorr = [lmt.Sigind[ic][ic] for ic in dcorr]
    if sum(sort(icorr)-array(range(1,len(icorr)+1)))==0: # icor contains right seuqence of integers
        corr = [dcorr[i-1] for i in icorr]               # kcix.dat contains  sequence of orbitals
        print 'reordering corr'
    else:
        corr = dcorr
    print 'corr = ', corr

    if upfold:
        s_ndim = len(corr)
        (Uk, corind, NMsind, NMcoef) = read_sind(f_sind, s_ndim)
        
        SigOm=[]
        for iom in range(len(om)):
            sig = CreateSelfEnergyMatrixUpfold(Sigc[iom], Uk, NMsind, NMcoef, gamma[0])
            
            Sig = zeros((lmt.ndim,lmt.ndim), dtype=complex)
            for ip in range(len(Sig)): Sig[ip,ip] = -gamma[1]*1j
            
            for iq0,q0 in enumerate(corr):
                for iq1,q1 in enumerate(corr):
                    Sig[q0,q1] = sig[iq0,iq1]
            SigOm.append(Sig)
        lmt.mu += dmu
        
    else:
        SigOm=[]
        for iom in range(len(om)):
            Sigma_oo = CreateSelfEnergyMatrix(array(Sigma_oo_IMPv)-array(Edc_IMP), lmt.Sigind)
            SigOm.append(CreateSelfEnergyMatrix(Sigc[iom], lmt.Sigind, gamma) + Sigma_oo)
        lmt.mu = mu
            
    print >> fh_info, 'Self-energy created'  
    # Computes DOS
    if cmp_dos:
        Print_Dos_Gloc_Delta(lmt, om, SigOm, outdir, fh_info, pmax, max_metropolis_steps, how_often)

    # Computes optics
    if cmp_opt:
    
        # Reads mesh for optical conductivity calculation
        Om = read_OutsideSig(inpdir+"/"+soutside)
        
        # Reads velocities
        read_velocity(lmt)
        trn.transform3(lmt.vel, lmt.Utk) # transforms velocities to DMFT base
        
        
        # Construcs a 2D mesh of all points which need to be computed for optics
        (epsi, ab) = Optics_Frequencies(Om, om)
        
        print >> lmt.fh_info, 'Total number of frequency points to compute: ', sum(map(len,epsi))

        # Parallel execution of optics calculation
        cmp_Optics(lmt, Om, epsi, ab, om, SigOm, alphaV, outdir, fh_info, gamma, how_often)
    
