/*
 * Decompiled with CFR 0.152.
 */
package hic.tools.utils.norm.scale;

import hic.HiCGlobals;
import hic.tools.utils.bigarray.BigContactList;
import hic.tools.utils.largelists.BigFloatsArray;
import hic.tools.utils.largelists.BigIntsArray;
import java.util.Arrays;
import javastraw.reader.datastructures.ListOfFloatArrays;
import javastraw.reader.datastructures.ListOfIntArrays;

public class FinalScale {
    private static final short S1 = 1;
    private static final short S0 = 0;
    private static final float maxPercentile = 10.0f;
    private static final float tolerance = 1.0E-4f;
    private static final int maxIter = 500;
    private static final int totalIterations = 1500;
    private static final float minErrorThreshold = 0.02f;
    private static final float del = 0.05f;
    private static final float rsError = 0.05f;
    private static final float erez = 1.0E-5f;

    public static ListOfFloatArrays scaleToTargetVector(BigContactList ba, long matrixSize, BigFloatsArray initialGuess, String stem) {
        BigFloatsArray col;
        long startTime = System.nanoTime();
        int matrixSizeI = (int)matrixSize;
        BigIntsArray bad = new BigIntsArray(matrixSize);
        BigIntsArray zTargetVector = new BigIntsArray(matrixSize, 1);
        BigFloatsArray calculatedVectorB = new BigFloatsArray(matrixSize);
        BigIntsArray one = new BigIntsArray(matrixSize, 1);
        double[] reportErrorForIteration = new double[1503];
        int[] numItersForAllIterations = new int[1503];
        ListOfIntArrays numNonZero = ba.getNumNonZeroInRows();
        int[] simpleNumNonZero = new int[matrixSizeI];
        int n0 = 0;
        for (long p = 0L; p < matrixSize; ++p) {
            int a = numNonZero.get(p);
            if (a <= 0) continue;
            simpleNumNonZero[n0++] = a;
        }
        Arrays.sort(simpleNumNonZero, 0, n0);
        int upperBound = simpleNumNonZero[(int)((double)(10.0f * (float)n0) / 100.0)];
        int lowCutoff = 1;
        for (long p = 0L; p < matrixSize; ++p) {
            if (numNonZero.get(p) >= lowCutoff) continue;
            FinalScale.excludeBadRow(p, bad, zTargetVector, one);
        }
        BigFloatsArray row = ba.parSparseMultiplyAcrossLists(one, matrixSize);
        for (long p = 0L; p < matrixSize; ++p) {
            one.set(p, (short)(1 - bad.get(p)));
        }
        BigFloatsArray dr = initialGuess == null ? one.deepConvertedClone() : initialGuess;
        BigFloatsArray dc = dr.deepClone();
        BigFloatsArray current = dr.deepClone();
        row.parMultiplyBy(dr);
        boolean conv = false;
        boolean div = false;
        int low_conv = 1000;
        int low_div = 0;
        float[] b_conv = new float[matrixSizeI];
        float[] b0 = new float[matrixSizeI];
        int[] bad_conv = new int[matrixSizeI];
        double ber_conv = 10.0;
        boolean yes = true;
        double convergeError = 10.000999999974738;
        int iter = 0;
        int allItersI = 0;
        int realIters = 0;
        while (convergeError > (double)1.0E-4f && iter < 500 && allItersI < 1500) {
            long p;
            int p2;
            double temp1;
            ++iter;
            ++allItersI;
            ++realIters;
            col = FinalScale.update(matrixSize, bad, row, zTargetVector, dr, ba);
            col.parMultiplyBy(dc);
            row = FinalScale.update(matrixSize, bad, col, zTargetVector, dc, ba);
            row.parMultiplyBy(dr);
            calculatedVectorB.parSetToGeoMean(dr, dc);
            convergeError = BigFloatsArray.parCalculateConvergenceError(calculatedVectorB, current, bad);
            int numBad = 0;
            for (long p3 = 0L; p3 < matrixSize; ++p3) {
                if (bad.get(p3) == 1 || !((temp1 = (double)Math.abs((calculatedVectorB.get(p3) - current.get(p3)) / (calculatedVectorB.get(p3) + current.get(p3)))) > (double)1.0E-4f)) continue;
                ++numBad;
            }
            for (p2 = 0; p2 < matrixSizeI; ++p2) {
                b0[p2] = current.get(p2);
            }
            reportErrorForIteration[allItersI - 1] = convergeError;
            numItersForAllIterations[allItersI - 1] = iter;
            current.parSetTo(calculatedVectorB);
            if (convergeError < (double)1.0E-4f) {
                long p4;
                yes = true;
                if (lowCutoff == 1) break;
                conv = true;
                for (p2 = 0; p2 < matrixSizeI; ++p2) {
                    b_conv[p2] = calculatedVectorB.get(p2);
                }
                for (p2 = 0; p2 < matrixSizeI; ++p2) {
                    bad_conv[p2] = bad.get(p2);
                }
                ber_conv = convergeError;
                low_conv = lowCutoff;
                if (div) {
                    if (low_conv - low_div <= 1) break;
                    lowCutoff = (low_conv + low_div) / 2;
                } else {
                    lowCutoff = low_conv / 2;
                }
                for (p4 = 0L; p4 < matrixSize; ++p4) {
                    one.set(p4, (short)1);
                    bad.set(p4, (short)0);
                }
                for (p4 = 0L; p4 < matrixSize; ++p4) {
                    if (numNonZero.get(p4) >= lowCutoff) continue;
                    bad.set(p4, (short)1);
                    one.set(p4, (short)0);
                }
                convergeError = 10.0;
                iter = 0;
                for (p4 = 0L; p4 < matrixSize; ++p4) {
                    dr.set(p4, 1 - bad.get(p4));
                }
                for (p4 = 0L; p4 < matrixSize; ++p4) {
                    dc.set(p4, 1 - bad.get(p4));
                }
                row = ba.parSparseMultiplyAcrossLists(dc, matrixSize);
                row.parMultiplyBy(dr);
                continue;
            }
            if (iter <= 5 || reportErrorForIteration[allItersI - 1] * 1.050000000745058 < reportErrorForIteration[allItersI - 6] && iter < 500) continue;
            div = true;
            low_div = lowCutoff;
            if (conv) {
                if (low_conv - low_div <= 1) {
                    for (long p5 = 0L; p5 < matrixSize; ++p5) {
                        calculatedVectorB.set(p5, b_conv[(int)p5]);
                    }
                    for (int p6 = 0; p6 < matrixSizeI; ++p6) {
                        bad.set(p6, (short)bad_conv[p6]);
                    }
                    convergeError = ber_conv;
                    break;
                }
                if ((double)numBad / (double)n0 < (double)1.0E-5f && yes) {
                    long p7;
                    for (p2 = 0; p2 < matrixSizeI; ++p2) {
                        if (bad.get(p2) == 1 || !((temp1 = (double)Math.abs((calculatedVectorB.get(p2) - b0[p2]) / (calculatedVectorB.get(p2) + b0[p2]))) > (double)1.0E-4f)) continue;
                        bad.set(p2, (short)1);
                        one.set(p2, (short)0);
                    }
                    yes = false;
                    convergeError = 10.0;
                    iter = 0;
                    for (p7 = 0L; p7 < matrixSize; ++p7) {
                        dr.set(p7, 1.0f - (float)bad.get(p7));
                    }
                    for (p7 = 0L; p7 < matrixSize; ++p7) {
                        dc.set(p7, 1.0f - (float)bad.get(p7));
                    }
                    row = ba.parSparseMultiplyAcrossLists(dc, matrixSize);
                    row.parMultiplyBy(dr);
                    if (lowCutoff <= upperBound && allItersI <= 1500) continue;
                    break;
                }
                lowCutoff = (low_div + low_conv) / 2;
                yes = true;
            } else {
                if ((double)numBad / (double)n0 < (double)1.0E-5f && yes) {
                    long p8;
                    for (p2 = 0; p2 < matrixSizeI; ++p2) {
                        if (bad.get(p2) == 1 || !((temp1 = (double)Math.abs((calculatedVectorB.get(p2) - b0[p2]) / (calculatedVectorB.get(p2) + b0[p2]))) > (double)1.0E-4f)) continue;
                        bad.set(p2, (short)1);
                        one.set(p2, (short)0);
                    }
                    yes = false;
                    convergeError = 10.0;
                    iter = 0;
                    for (p8 = 0L; p8 < matrixSize; ++p8) {
                        dr.set(p8, 1.0f - (float)bad.get(p8));
                    }
                    for (p8 = 0L; p8 < matrixSize; ++p8) {
                        dc.set(p8, 1.0f - (float)bad.get(p8));
                    }
                    row = ba.parSparseMultiplyAcrossLists(dc, matrixSize);
                    row.parMultiplyBy(dr);
                    if (lowCutoff <= upperBound && allItersI <= 1500) continue;
                    break;
                }
                lowCutoff = 2 * lowCutoff;
                yes = true;
            }
            for (p = 0L; p < matrixSize; ++p) {
                bad.set(p, (short)0);
                one.set(p, (short)1);
            }
            for (p = 0L; p < matrixSize; ++p) {
                if (numNonZero.get(p) >= lowCutoff) continue;
                bad.set(p, (short)1);
                one.set(p, (short)0);
            }
            convergeError = 10.0;
            iter = 0;
            for (p = 0L; p < matrixSize; ++p) {
                dr.set(p, 1.0f - (float)bad.get(p));
            }
            for (p = 0L; p < matrixSize; ++p) {
                dc.set(p, 1.0f - (float)bad.get(p));
            }
            row = ba.parSparseMultiplyAcrossLists(dc, matrixSize);
            row.parMultiplyBy(dr);
            if (lowCutoff <= upperBound && allItersI <= 1500) continue;
            break;
        }
        col = ba.parSparseMultiplyAcrossLists(calculatedVectorB, matrixSize);
        double rowSumError = BigFloatsArray.parCalculateError(col, calculatedVectorB, zTargetVector, bad);
        if (HiCGlobals.printVerboseComments) {
            System.out.println("Total iters " + realIters + " \nRow Sums Error " + rowSumError);
            System.out.println("Convergence error " + convergeError);
            System.out.println("Remove rows with less than " + lowCutoff + " nonzeros");
            reportErrorForIteration[allItersI + 1] = convergeError;
            reportErrorForIteration[allItersI + 2] = rowSumError;
        }
        if (convergeError > (double)1.0E-4f || rowSumError > (double)0.05f || lowCutoff > upperBound) {
            if (HiCGlobals.printVerboseComments) {
                System.out.println("Setting vector to null (not converged)");
            }
            calculatedVectorB.clear();
            return null;
        }
        for (long p = 0L; p < matrixSize; ++p) {
            if (bad.get(p) != 1) continue;
            calculatedVectorB.set(p, Float.NaN);
        }
        bad.clear();
        zTargetVector.clear();
        one.clear();
        numNonZero.clear();
        row.clear();
        dr.clear();
        dc.clear();
        current.clear();
        col.clear();
        if (HiCGlobals.printVerboseComments) {
            long endTime = System.nanoTime();
            long timeInSecs = (long)((double)(endTime - startTime) * 1.0E-9);
            System.out.println(stem + " took " + timeInSecs + " seconds");
            System.out.println("Final error in scaling vector is " + reportErrorForIteration[allItersI + 1] + " and in row sums is " + reportErrorForIteration[allItersI + 2]);
        }
        ListOfFloatArrays answer = calculatedVectorB.convertToRegular();
        calculatedVectorB.clear();
        return answer;
    }

    private static void excludeBadRow(long index, BigIntsArray bad, BigIntsArray target, BigIntsArray one) {
        bad.set(index, (short)1);
        one.set(index, (short)0);
        target.set(index, (short)0);
    }

    private static BigFloatsArray update(long matrixSize, BigIntsArray bad, BigFloatsArray vector, BigIntsArray target, BigFloatsArray dVector, BigContactList ba) {
        for (long p = 0L; p < matrixSize; ++p) {
            if (bad.get(p) != 1) continue;
            vector.set(p, 1.0f);
        }
        dVector.parScaleByRatio(target, vector);
        return ba.parSparseMultiplyAcrossLists(dVector, matrixSize);
    }
}

