/*
 * Decompiled with CFR 0.152.
 */
package keel.Algorithms.Instance_Generation.LVQ;

import java.util.ArrayList;
import java.util.HashMap;
import keel.Algorithms.Instance_Generation.Basic.Prototype;
import keel.Algorithms.Instance_Generation.Basic.PrototypeGenerationAlgorithm;
import keel.Algorithms.Instance_Generation.Basic.PrototypeSet;
import keel.Algorithms.Instance_Generation.LVQ.LVQ1;
import keel.Algorithms.Instance_Generation.utilities.Debug;
import keel.Algorithms.Instance_Generation.utilities.KNN.KNN;
import keel.Algorithms.Instance_Generation.utilities.Pair;
import keel.Algorithms.Instance_Generation.utilities.Parameters;

public class LVQTC
extends LVQ1 {
    private double alpha_r = ALPHA_DEFAULT_VALUE;
    private double alpha_w = ALPHA_DEFAULT_VALUE;
    private int epoches = 4;
    private int retentionThreshold = 3;
    private static ArrayList<Double> posibleClasses = null;
    private HashMap<Prototype, HashMap<Double, Integer>> counter = null;
    private HashMap<Prototype, Integer> sumCounter = null;
    private HashMap<Prototype, PrototypeSet> wrong = null;

    public LVQTC(PrototypeSet traDataSet, Parameters parameters) {
        super(traDataSet, parameters);
        this.algorithmName = "LVQTC";
        this.alpha_r = this.alpha_0;
        this.alpha_w = parameters.getNextAsDouble();
        this.retentionThreshold = parameters.getNextAsInt();
        this.epoches = parameters.getNextAsInt();
        posibleClasses = traDataSet.getPosibleValuesOfOutput();
        this.counter = new HashMap();
        this.sumCounter = new HashMap();
        this.wrong = new HashMap();
    }

    public LVQTC(PrototypeSet traDataSet, int it, double percProts, double alpha_r, double alpha_w, int T, int epoches) {
        super(traDataSet, it, percProts, alpha_r);
        this.algorithmName = "LVQTC";
        this.alpha_r = alpha_r;
        this.alpha_w = alpha_w;
        this.retentionThreshold = T;
        this.epoches = epoches;
        posibleClasses = traDataSet.getPosibleValuesOfOutput();
        this.counter = new HashMap();
        this.sumCounter = new HashMap();
        this.wrong = new HashMap();
    }

    protected void initCounterOf(Prototype i) {
        this.counter.put(i, new HashMap());
        for (Double d : posibleClasses) {
            this.counter.get(i).put(d, 0);
        }
        this.sumCounter.put(i, -1);
    }

    private void reset(PrototypeSet data) {
        for (Prototype p : data) {
            this.initCounterOf(p);
            this.wrong.put(p, new PrototypeSet());
        }
    }

    private int sum(HashMap<Double, Integer> v) {
        ArrayList<Integer> values = new ArrayList<Integer>(v.values());
        int acc = 0;
        for (Integer i : values) {
            acc += i.intValue();
        }
        return acc;
    }

    private int sumOfCounterOf(Prototype p) {
        int value = 0;
        Debug.force(this.sumCounter.containsKey(p), "ERROR en sumOfCounter");
        if (this.sumCounter.get(p) == -1) {
            int _sum = this.sum(this.counter.get(p));
            this.sumCounter.put(p, _sum);
            value = _sum;
        } else {
            value = this.sumCounter.get(p);
        }
        return value;
    }

    private Pair<Boolean, Double> maximumWrongClassCounter(Prototype p) {
        HashMap<Double, Integer> h = this.counter.get(p);
        ArrayList<Double> list = new ArrayList<Double>(h.keySet());
        double classWrong = p.label();
        int max = this.retentionThreshold;
        boolean found = false;
        for (Double klass : list) {
            if (klass.doubleValue() == p.assignedClass() || h.get(klass) <= max) continue;
            classWrong = klass;
            max = h.get(klass);
            found = true;
        }
        return new Pair<Boolean, Double>(found, classWrong);
    }

    private void incrementCounterOf(Prototype i, double _class) {
        Debug.force(this.counter.containsKey(i), "No contiene la clave");
        int oldValue = this.counter.get(i).get(_class);
        this.counter.get(i).put(_class, oldValue + 1);
    }

    @Override
    protected void reward(Prototype m, Prototype x) {
        int q_i = this.sumOfCounterOf(m);
        Debug.force(q_i > 0, "CERAPIO en reward");
        m.set(m.add(x.sub(m).mul(this.alpha_r / (double)q_i)));
    }

    @Override
    protected void penalize(Prototype m, Prototype x) {
        int q_i = this.sumOfCounterOf(m);
        Debug.force(q_i > 0, "CERAPIO en penalize");
        m.set(m.sub(x.sub(m).mul(this.alpha_w / (double)q_i)));
    }

    void updateCentroidOfWrongClass(Prototype p, Prototype newWrong) {
        PrototypeSet oldSet = this.wrong.get(p);
        oldSet.add(newWrong);
        this.wrong.put(p, oldSet);
    }

    @Override
    protected void correct(Prototype i, PrototypeSet tData) {
        Prototype nearest = KNN._1nn(i, tData);
        double i_label = i.label();
        this.incrementCounterOf(nearest, i_label);
        double nearest_prot_label = nearest.label();
        if (i_label != nearest_prot_label) {
            this.penalize(nearest, i);
            this.updateCentroidOfWrongClass(nearest, i);
        } else {
            this.reward(nearest, i);
        }
    }

    protected PrototypeSet neuronPruning(PrototypeSet data) {
        PrototypeSet edited = new PrototypeSet();
        Prototype pMC = null;
        int mc = 0;
        for (Prototype p : data) {
            int currentCounter = this.sum(this.counter.get(p));
            if (currentCounter >= this.retentionThreshold) {
                edited.add(p);
            }
            if (mc >= currentCounter) continue;
            mc = currentCounter;
            pMC = p;
        }
        if (edited.size() == 0) {
            edited.add(pMC);
        }
        return edited;
    }

    protected PrototypeSet neuronCreation(PrototypeSet data) {
        PrototypeSet newPrototypes = new PrototypeSet();
        for (Prototype p : data) {
            Pair<Boolean, Double> isWrong = this.maximumWrongClassCounter(p);
            if (!isWrong.first().booleanValue()) continue;
            Prototype w = this.wrong.get(p).avg();
            w.setLabel(isWrong.second());
            newPrototypes.add(w);
        }
        for (Prototype newP : newPrototypes) {
            data.add(newP);
        }
        return data;
    }

    protected PrototypeSet doEpoche(PrototypeSet outputDataSet) {
        for (int it = 0; it < this.iterations; ++it) {
            Prototype instance = this.extract(this.trainingDataSet);
            this.correct(instance, outputDataSet);
        }
        return outputDataSet;
    }

    @Override
    public PrototypeSet reduceSet() {
        PrototypeSet outputDataSet = this.initDataSet();
        for (int e = 0; e < this.epoches; ++e) {
            this.reset(outputDataSet);
            outputDataSet = this.doEpoche(outputDataSet);
            outputDataSet = this.neuronPruning(outputDataSet);
            outputDataSet = this.neuronCreation(outputDataSet);
        }
        return outputDataSet;
    }

    public static void main(String[] args) {
        Parameters.setUse("LVQTC", "<seed> <iterations per epoch> <% of prots> <alpha_r> <alpha_w> <retention threshold> <number of epoches>");
        Parameters.assertBasicArgs(args);
        Debug.setStdDebugMode(false);
        PrototypeSet training = PrototypeGenerationAlgorithm.readPrototypeSet(args[0]);
        PrototypeSet test = PrototypeGenerationAlgorithm.readPrototypeSet(args[1]);
        long seed = Parameters.assertExtendedArgAsInt(args, 2, "seed", 0.0, 9.223372036854776E18);
        int iter = Parameters.assertExtendedArgAsInt(args, 3, "number of iterations per epoch", 1.0, 2.147483647E9);
        double pcProt = Parameters.assertExtendedArgAsDouble(args, 4, "% of prototypes", 0.0, 100.0);
        double alphaR = Parameters.assertExtendedArgAsDouble(args, 5, "alpha_r", 0.0, 1.0);
        double alphaW = Parameters.assertExtendedArgAsDouble(args, 6, "alpha_w", 0.0, 1.0);
        int Q = Parameters.assertExtendedArgAsInt(args, 7, "retention threshold (Q)", 1.0, 2.147483647E9);
        int epoches = Parameters.assertExtendedArgAsInt(args, 8, "number of epoches of the algorithm", 1.0, 2.147483647E9);
        LVQTC.setSeed(seed);
        LVQTC generator = new LVQTC(training, iter, pcProt, alphaR, alphaW, Q, epoches);
        PrototypeSet resultingSet = generator.execute();
        int accuracy1NN = KNN.classficationAccuracy(resultingSet, test);
        generator.showResultsOfAccuracy(Parameters.getFileName(), accuracy1NN, test);
    }
}

