Commit aa55cf29 authored by tpetaja1's avatar tpetaja1
Browse files

small fixes to files

parent 3ac184a8
Loading
Loading
Loading
Loading
+24 −14
Original line number Diff line number Diff line

from TVGL import TVGL
from SerialTVGL import SerialTVGL
import numpy as np
import multiprocessing
import mp_workers_async as mp


class AsyncProTVGL(TVGL):
class AsyncProTVGL(SerialTVGL):

    def __init__(self, filename, blocks=10,
                 lambd=30, beta=4, processes=2):
        super(AsyncProTVGL, self).__init__(filename, blocks,
                                           lambd, beta, processes)

    def init_algorithm(self):
        self.pool = multiprocessing.Pool(self.processes)

    def terminate_pools(self):
        self.pool.close()
        self.pool.join()

    """
    def theta_update(self):
        for i in range(self.blocks):
            a = (self.z0s[i] + self.z1s[i] + self.z2s[i] -
@@ -24,23 +32,24 @@ class AsyncProTVGL(TVGL):
            diagonal = np.diag(d) + np.diag(sqrt_matrix)
            self.thetas[i] = np.real(
                self.nju/2*np.dot(np.dot(q, diagonal), qt))
    """

    def z_update(self):
        pool = multiprocessing.Pool(self.processes)
        res_z0s = pool.apply_async(mp.z0_update,
        res_z0s = self.pool.apply_async(mp.z0_update,
                                        (self.thetas, self.z0s,
                                         self.u0s, self.lambd,
                                         self.rho, self.blocks))
        res_z1z2s = pool.apply_async(mp.z1_z2_update,
        res_z1z2s = self.pool.apply_async(mp.z1_z2_update,
                                          (self.thetas, self.z1s, self.z2s,
                                           self.u1s, self.u2s, self.beta,
                                           self.rho, self.blocks))
        self.z0s = res_z0s.get()
        z1s_z2s = res_z1z2s.get()
        self.z0s = res_z0s.get(timeout=1)
        z1s_z2s = res_z1z2s.get(timeout=1)
        self.z1s = z1s_z2s[0]
        self.z2s = z1s_z2s[1]
        pool.close()
        #pool.close()

    """
    def u_update(self):
        pool = multiprocessing.Pool(self.processes)
        res_u0s = pool.apply_async(mp.u0_update,
@@ -56,3 +65,4 @@ class AsyncProTVGL(TVGL):
        self.u1s = res_u1s.get()
        self.u2s = res_u2s.get()
        pool.close()
    """
+2 −2
Original line number Diff line number Diff line
@@ -15,9 +15,9 @@ def z1_z2_update(thetas, z1s, z2s, u1s, u2s, beta, rho, blocks):
            a = thetas[i] - thetas[i-1] + u2s[i] - u1s[i-1]
            e = pf.group_lasso_penalty(a, 2*beta/rho)
            z1s[i-1] = 0.5*(thetas[i-1] + thetas[i]
                            + u1s[i] + u2s[i]) - 0.5*e
                            + u1s[i-i] + u2s[i]) - 0.5*e
            z2s[i] = 0.5*(thetas[i-1] + thetas[i]
                          + u1s[i] + u2s[i]) + 0.5*e
                          + u1s[i-i] + u2s[i]) + 0.5*e
    except Exception as e:
        traceback.print_exc()
        raise e
+2 −2
Original line number Diff line number Diff line
@@ -35,8 +35,8 @@ def z1_z2_update((theta, theta_pre, u1, u1_pre, u2, beta, rho)):
    try:
        a = theta - theta_pre + u2 - u1_pre
        e = pf.group_lasso_penalty(a, 2*beta/rho)
        z1 = 0.5*(theta_pre + theta + u1 + u2) - 0.5*e
        z2 = 0.5*(theta_pre + theta + u1 + u2) + 0.5*e
        z1 = 0.5*(theta_pre + theta + u1_pre + u2) - 0.5*e
        z2 = 0.5*(theta_pre + theta + u1_pre + u2) + 0.5*e
    except Exception as e:
        traceback.print_exc()
        raise e