/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning.boosting;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.boosting.RBWeakRanker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;

public class RankBoost
extends Ranker {
    public static int nIteration = 300;
    public static int nThreshold = 10;
    protected double[][][] sweight = null;
    protected double[][] potential = null;
    protected List<List<int[]>> sortedSamples = new ArrayList<List<int[]>>();
    protected double[][] thresholds = null;
    protected int[][] tSortedIdx = null;
    protected List<RBWeakRanker> wRankers = null;
    protected List<Double> rWeight = null;
    protected List<RBWeakRanker> bestModelRankers = new ArrayList<RBWeakRanker>();
    protected List<Double> bestModelWeights = new ArrayList<Double>();
    private double R_t = 0.0;
    private double Z_t = 1.0;
    private int totalCorrectPairs = 0;

    public RankBoost() {
    }

    public RankBoost(List<RankList> samples, int[] features, MetricScorer scorer) {
        super(samples, features, scorer);
    }

    private int[] reorder(RankList rl, int fid) {
        double[] score = new double[rl.size()];
        for (int i = 0; i < rl.size(); ++i) {
            score[i] = rl.get(i).getFeatureValue(fid);
        }
        return MergeSorter.sort(score, false);
    }

    private void updatePotential() {
        for (int i = 0; i < this.samples.size(); ++i) {
            RankList rl = (RankList)this.samples.get(i);
            for (int j = 0; j < rl.size(); ++j) {
                int k;
                double p = 0.0;
                for (k = j + 1; k < rl.size(); ++k) {
                    p += this.sweight[i][j][k];
                }
                for (k = 0; k < j; ++k) {
                    p -= this.sweight[i][k][j];
                }
                this.potential[i][j] = p;
            }
        }
    }

    private RBWeakRanker learnWeakRanker() {
        int bestFid = -1;
        double maxR = -10.0;
        double bestThreshold = -1.0;
        for (int i = 0; i < this.features.length; ++i) {
            List<int[]> sSortedIndex = this.sortedSamples.get(i);
            int[] idx = this.tSortedIdx[i];
            int[] last = new int[this.samples.size()];
            for (int j = 0; j < this.samples.size(); ++j) {
                last[j] = -1;
            }
            double r = 0.0;
            for (int j = 0; j < idx.length; ++j) {
                double t = this.thresholds[i][idx[j]];
                for (int k = 0; k < this.samples.size(); ++k) {
                    DataPoint p;
                    RankList rl = (RankList)this.samples.get(k);
                    int[] sk = sSortedIndex.get(k);
                    int l = last[k] + 1;
                    while (l < rl.size() && (double)(p = rl.get(sk[l])).getFeatureValue(this.features[i]) > t) {
                        r += this.potential[k][sk[l]];
                        last[k] = l++;
                    }
                }
                if (!(r > maxR)) continue;
                maxR = r;
                bestThreshold = t;
                bestFid = this.features[i];
            }
        }
        if (bestFid == -1) {
            return null;
        }
        this.R_t = this.Z_t * maxR;
        return new RBWeakRanker(bestFid, bestThreshold);
    }

    @Override
    public void init() {
        int i;
        int k;
        int j;
        int i2;
        int k2;
        int j2;
        RankList rl;
        int i3;
        this.PRINT("Initializing... ");
        this.wRankers = new ArrayList<RBWeakRanker>();
        this.rWeight = new ArrayList<Double>();
        this.totalCorrectPairs = 0;
        for (i3 = 0; i3 < this.samples.size(); ++i3) {
            this.samples.set(i3, ((RankList)this.samples.get(i3)).getCorrectRanking());
            rl = (RankList)this.samples.get(i3);
            for (j2 = 0; j2 < rl.size() - 1; ++j2) {
                for (k2 = rl.size() - 1; k2 >= j2 + 1 && rl.get(j2).getLabel() > rl.get(k2).getLabel(); --k2) {
                    ++this.totalCorrectPairs;
                }
            }
        }
        this.sweight = new double[this.samples.size()][][];
        for (i3 = 0; i3 < this.samples.size(); ++i3) {
            rl = (RankList)this.samples.get(i3);
            this.sweight[i3] = new double[rl.size()][];
            for (j2 = 0; j2 < rl.size() - 1; ++j2) {
                this.sweight[i3][j2] = new double[rl.size()];
                for (k2 = j2 + 1; k2 < rl.size(); ++k2) {
                    this.sweight[i3][j2][k2] = rl.get(j2).getLabel() > rl.get(k2).getLabel() ? 1.0 / (double)this.totalCorrectPairs : 0.0;
                }
            }
        }
        this.potential = new double[this.samples.size()][];
        for (i3 = 0; i3 < this.samples.size(); ++i3) {
            this.potential[i3] = new double[((RankList)this.samples.get(i3)).size()];
        }
        if (nThreshold <= 0) {
            int i4;
            int count = 0;
            for (i4 = 0; i4 < this.samples.size(); ++i4) {
                count += ((RankList)this.samples.get(i4)).size();
            }
            this.thresholds = new double[this.features.length][];
            for (i4 = 0; i4 < this.features.length; ++i4) {
                this.thresholds[i4] = new double[count];
            }
            int c = 0;
            for (i2 = 0; i2 < this.samples.size(); ++i2) {
                RankList rl2 = (RankList)this.samples.get(i2);
                for (j = 0; j < rl2.size(); ++j) {
                    for (k = 0; k < this.features.length; ++k) {
                        this.thresholds[k][c] = rl2.get(j).getFeatureValue(this.features[k]);
                    }
                    ++c;
                }
            }
        } else {
            double[] fmax = new double[this.features.length];
            double[] fmin = new double[this.features.length];
            for (i2 = 0; i2 < this.features.length; ++i2) {
                fmax[i2] = -1000000.0;
                fmin[i2] = 1000000.0;
            }
            for (i2 = 0; i2 < this.samples.size(); ++i2) {
                RankList rl3 = (RankList)this.samples.get(i2);
                for (j = 0; j < rl3.size(); ++j) {
                    for (k = 0; k < this.features.length; ++k) {
                        double f = rl3.get(j).getFeatureValue(this.features[k]);
                        if (f > fmax[k]) {
                            fmax[k] = f;
                        }
                        if (!(f < fmin[k])) continue;
                        fmin[k] = f;
                    }
                }
            }
            this.thresholds = new double[this.features.length][];
            for (i2 = 0; i2 < this.features.length; ++i2) {
                double step = Math.abs(fmax[i2] - fmin[i2]) / (double)nThreshold;
                this.thresholds[i2] = new double[nThreshold + 1];
                this.thresholds[i2][0] = fmax[i2];
                for (int j3 = 1; j3 < nThreshold; ++j3) {
                    this.thresholds[i2][j3] = this.thresholds[i2][j3 - 1] - step;
                }
                this.thresholds[i2][RankBoost.nThreshold] = fmin[i2] - 1.0E8;
            }
        }
        this.tSortedIdx = new int[this.features.length][];
        for (i = 0; i < this.features.length; ++i) {
            this.tSortedIdx[i] = MergeSorter.sort(this.thresholds[i], false);
        }
        for (i = 0; i < this.features.length; ++i) {
            ArrayList<int[]> idx = new ArrayList<int[]>();
            for (j2 = 0; j2 < this.samples.size(); ++j2) {
                idx.add(this.reorder((RankList)this.samples.get(j2), this.features[i]));
            }
            this.sortedSamples.add(idx);
        }
        this.PRINTLN("[Done]");
    }

    @Override
    public void learn() {
        this.PRINTLN("------------------------------------------");
        this.PRINTLN("Training starts...");
        this.PRINTLN("--------------------------------------------------------------------");
        this.PRINTLN(new int[]{7, 8, 9, 9, 9, 9}, new String[]{"#iter", "Sel. F.", "Threshold", "Error", this.scorer.name() + "-T", this.scorer.name() + "-V"});
        this.PRINTLN("--------------------------------------------------------------------");
        for (int t = 1; t <= nIteration; ++t) {
            RankList rl;
            int i;
            this.updatePotential();
            RBWeakRanker wr = this.learnWeakRanker();
            if (wr == null) break;
            double alpha_t = 0.5 * SimpleMath.ln((this.Z_t + this.R_t) / (this.Z_t - this.R_t));
            this.wRankers.add(wr);
            this.rWeight.add(alpha_t);
            this.Z_t = 0.0;
            for (i = 0; i < this.samples.size(); ++i) {
                rl = (RankList)this.samples.get(i);
                double[][] D_t = new double[rl.size()][];
                for (int j = 0; j < rl.size() - 1; ++j) {
                    D_t[j] = new double[rl.size()];
                    for (int k = j + 1; k < rl.size(); ++k) {
                        D_t[j][k] = this.sweight[i][j][k] * Math.exp(alpha_t * (double)(wr.score(rl.get(k)) - wr.score(rl.get(j))));
                        this.Z_t += D_t[j][k];
                    }
                }
                this.sweight[i] = D_t;
            }
            this.PRINT(new int[]{7, 8, 9, 9}, new String[]{t + "", wr.getFid() + "", SimpleMath.round(wr.getThreshold(), 4) + "", SimpleMath.round(this.R_t, 4) + ""});
            if (t % 1 == 0) {
                this.PRINT(new int[]{9}, new String[]{SimpleMath.round(this.scorer.score(this.rank(this.samples)), 4) + ""});
                if (this.validationSamples != null) {
                    double score = this.scorer.score(this.rank(this.validationSamples));
                    if (score > this.bestScoreOnValidationData) {
                        this.bestScoreOnValidationData = score;
                        this.bestModelRankers.clear();
                        this.bestModelRankers.addAll(this.wRankers);
                        this.bestModelWeights.clear();
                        this.bestModelWeights.addAll(this.rWeight);
                    }
                    this.PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) + ""});
                }
            }
            this.PRINTLN("");
            for (i = 0; i < this.samples.size(); ++i) {
                rl = (RankList)this.samples.get(i);
                for (int j = 0; j < rl.size() - 1; ++j) {
                    int k = j + 1;
                    while (k < rl.size()) {
                        double[] dArray = this.sweight[i][j];
                        int n = k++;
                        dArray[n] = dArray[n] / this.Z_t;
                    }
                }
            }
            System.gc();
        }
        if (this.validationSamples != null && this.bestModelRankers.size() > 0) {
            this.wRankers.clear();
            this.rWeight.clear();
            this.wRankers.addAll(this.bestModelRankers);
            this.rWeight.addAll(this.bestModelWeights);
        }
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(this.rank(this.samples)), 4);
        this.PRINTLN("--------------------------------------------------------------------");
        this.PRINTLN("Finished sucessfully.");
        this.PRINTLN(this.scorer.name() + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(this.rank(this.validationSamples));
            this.PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        this.PRINTLN("---------------------------------");
    }

    @Override
    public double eval(DataPoint p) {
        double score = 0.0;
        for (int j = 0; j < this.wRankers.size(); ++j) {
            score += this.rWeight.get(j) * (double)this.wRankers.get(j).score(p);
        }
        return score;
    }

    @Override
    public Ranker createNew() {
        return new RankBoost();
    }

    @Override
    public String toString() {
        String output = "";
        for (int i = 0; i < this.wRankers.size(); ++i) {
            output = output + this.wRankers.get(i).toString() + ":" + this.rWeight.get(i) + (i == this.rWeight.size() - 1 ? "" : " ");
        }
        return output;
    }

    @Override
    public String model() {
        String output = "## " + this.name() + "\n";
        output = output + "## Iteration = " + nIteration + "\n";
        output = output + "## No. of threshold candidates = " + nThreshold + "\n";
        output = output + this.toString();
        return output;
    }

    @Override
    public void loadFromString(String fullText) {
        try {
            int i;
            String content = "";
            BufferedReader in = new BufferedReader(new StringReader(fullText));
            while ((content = in.readLine()) != null && ((content = content.trim()).length() == 0 || content.indexOf("##") == 0)) {
            }
            in.close();
            this.rWeight = new ArrayList<Double>();
            this.wRankers = new ArrayList<RBWeakRanker>();
            int idx = content.lastIndexOf("#");
            if (idx != -1) {
                content = content.substring(0, idx).trim();
            }
            String[] fs = content.split(" ");
            for (i = 0; i < fs.length; ++i) {
                fs[i] = fs[i].trim();
                if (fs[i].compareTo("") == 0) continue;
                String[] strs = fs[i].split(":");
                int fid = Integer.parseInt(strs[0]);
                double threshold = Double.parseDouble(strs[1]);
                double weight = Double.parseDouble(strs[2]);
                this.rWeight.add(weight);
                this.wRankers.add(new RBWeakRanker(fid, threshold));
            }
            this.features = new int[this.rWeight.size()];
            for (i = 0; i < this.rWeight.size(); ++i) {
                this.features[i] = this.wRankers.get(i).getFid();
            }
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in RankBoost::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        this.PRINTLN("No. of rounds: " + nIteration);
        this.PRINTLN("No. of threshold candidates: " + nThreshold);
    }

    @Override
    public String name() {
        return "RankBoost";
    }
}

