Commit f41a1f83 authored by tanmoy.7989's avatar tanmoy.7989
Browse files

vectorized in parts and made changes as suggested by evoyiatzis

parent bbb0f574
Loading
Loading
Loading
Loading
+24 −23
Original line number Diff line number Diff line
@@ -38,11 +38,11 @@ StringIO (or io if in Python 3.x)



import os, sys, numpy as np, argparse, time, pickle
import os, numpy as np, argparse, time, pickle
from scipy.special import logsumexp
from mpi4py import MPI

from tqdm import tqdm, trange
from tqdm import tqdm
import gzip, bz2
try:
    # python-2
@@ -78,12 +78,10 @@ def _get_nearest_temp(temps, query_temp):
    """
    
    if isinstance(temps, list): temps = np.array(temps)
    idx = np.argmin(abs(temps - query_temp))
    out_temp = temps[idx]
    return out_temp
    return temps[np.argmin(np.abs(temps-query_temp))]


def readwrite(trajfn, mode = "rb"):
def readwrite(trajfn, mode):
    """ 
    Helper function for input/output LAMMPS traj files.
    Trajectories may be plain text, .gz or .bz2 compressed.
@@ -96,11 +94,14 @@ def readwrite(trajfn, mode = "rb"):
    """
    
    if trajfn.endswith(".gz"):
        return gzip.GzipFile(trajfn, mode)
        of = gzip.open(trajfn, mode)
        #return gzip.GzipFile(trajfn, mode)
    elif trajfn.endswith(".bz2"):
        return bz2.BZ2File(trajfn, mode)
        of = bz2.open(trajfn, mode)
        #return bz2.BZ2File(trajfn, mode)
    else:
        return file(trajfn, mode)
        of = open(trajfn, mode)
    return of


def get_replica_frames(logfn, temps, nswap, writefreq):
@@ -163,7 +164,7 @@ def get_byte_index(rep_inds, byteindfns, intrajfns):
        if os.path.isfile(byteindfns[n]): continue
        
        # extract bytes
        fobj = readwrite(intrajfns[n]) 
        fobj = readwrite(intrajfns[n], "rb") 
        byteinds = [ [0,0] ]
        
        # place file pointer at first line
@@ -243,7 +244,7 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
    for n in temp_inds:
        # open string-buffer and file
        buf = IOBuffer()
        of = readwrite(outtrajfns[n], mode = "wb")
        of = readwrite(outtrajfns[n], "wb")
        
        # get frames
        abs_temp_ind = np.argmin( abs(temps - outtemps[n]) )
@@ -281,7 +282,7 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
        
        
def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
                       kB = 0.001987):
                       kB):
    """
    Gets configurational log-weights (logw) for each frame and at each temp.
    from the REMD simulation. ONLY WRITTEN FOR THE CANONICAL (NVT) ensemble.
@@ -348,25 +349,25 @@ def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
    #3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*)
    u_kln = np.zeros([ntemps, ntemps, nframes], float)
    for k in range(ntemps):
        for l in range(ntemps):
            u_kln[ k, l, 0:nframes_k[k] ] = beta_k[l] * u_kn[k, 0:nframes_k[k]]
        u_kln[k] = np.outer(beta_k, u_kn[k])
        
    # run pymbar and extract the free energies
    print("\nRunning pymbar...")
    mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose = True)
    f_k = mbar.f_k
    f_k = mbar.f_k # (1 x k array)

    # calculate the log-weights
    print("\nExtracting log-weights...")
    log_nframes = np.log(nframes)
    logw = dict( (k, np.zeros([ntemps, nframes], float)) for k in range(ntemps) )
    for l in range(ntemps):
    # get log-weights to reweight to this temp.
    for k in range(ntemps):
        for n in range(nframes):
            num = -beta_k[k] * u_kn[k,n]
            denom = f_k - beta_k[k] * u_kn[k,n]
            for l in range(ntemps):
                logw[l][k,n] = num - logsumexp(denom) - log_nframes
            
    return logw


@@ -515,7 +516,7 @@ if __name__ == "__main__":
    comm.barrier()
    
    # open all replica files for reading
    infobjs = [readwrite(i) for i in intrajfns]
    infobjs = [readwrite(i, "rb") for i in intrajfns]
    
    # open all byteindex files
    byte_inds = dict( (i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns) )