/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.tools;

import au.com.bytecode.opencsv.CSVReader;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;

public class PredictCsv {
    private String modelName;
    private String inputCSVFileName;
    private String outputCSVFileName;
    private boolean useDecimalOutput = false;
    private EasyPredictModelWrapper model;

    public static void main(String[] args) {
        PredictCsv main = new PredictCsv();
        main.parseArgs(args);
        try {
            main.run();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(2);
        }
        System.exit(0);
    }

    private static RowData formatDataRow(String[] splitLine, String[] inputColumnNames) {
        RowData row = new RowData();
        int maxI = Math.min(inputColumnNames.length, splitLine.length);
        block9: for (int i = 0; i < maxI; ++i) {
            String cellData;
            String columnName = inputColumnNames[i];
            switch (cellData = splitLine[i]) {
                case "": 
                case "NA": 
                case "N/A": 
                case "-": {
                    continue block9;
                }
                default: {
                    row.put(columnName, cellData);
                }
            }
        }
        return row;
    }

    private String myDoubleToString(double d) {
        if (Double.isNaN(d)) {
            return "NA";
        }
        return this.useDecimalOutput ? Double.toString(d) : Double.toHexString(d);
    }

    private void run() throws Exception {
        ModelCategory category = this.model.getModelCategory();
        CSVReader reader = new CSVReader(new FileReader(this.inputCSVFileName));
        BufferedWriter output = new BufferedWriter(new FileWriter(this.outputCSVFileName));
        switch (category) {
            case AutoEncoder: {
                output.write(this.model.getHeader());
                break;
            }
            case Binomial: 
            case Multinomial: {
                String[] responseDomainValues;
                output.write("predict");
                for (String s : responseDomainValues = this.model.getResponseDomainValues()) {
                    output.write(",");
                    output.write(s);
                }
                break;
            }
            case Clustering: {
                output.write("cluster");
                break;
            }
            case Regression: {
                output.write("predict");
                break;
            }
            default: {
                throw new Exception("Unknown model category " + (Object)((Object)category));
            }
        }
        output.write("\n");
        int lineNum = 0;
        try {
            String[] splitLine;
            String[] inputColumnNames = null;
            while ((splitLine = reader.readNext()) != null) {
                if (++lineNum == 1) {
                    inputColumnNames = splitLine;
                    continue;
                }
                RowData row = PredictCsv.formatDataRow(splitLine, inputColumnNames);
                switch (category) {
                    case AutoEncoder: {
                        throw new UnsupportedOperationException();
                    }
                    case Binomial: {
                        int i;
                        AbstractPrediction p = this.model.predictBinomial(row);
                        output.write(p.label);
                        output.write(",");
                        for (i = 0; i < p.classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(p.classProbabilities[i]));
                        }
                        break;
                    }
                    case Multinomial: {
                        int i;
                        AbstractPrediction p = this.model.predictMultinomial(row);
                        output.write(((MultinomialModelPrediction)p).label);
                        output.write(",");
                        for (i = 0; i < ((MultinomialModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((MultinomialModelPrediction)p).classProbabilities[i]));
                        }
                        break;
                    }
                    case Clustering: {
                        AbstractPrediction p = this.model.predictClustering(row);
                        output.write(this.myDoubleToString(((ClusteringModelPrediction)p).cluster));
                        break;
                    }
                    case Regression: {
                        AbstractPrediction p = this.model.predictRegression(row);
                        output.write(this.myDoubleToString(((RegressionModelPrediction)p).value));
                        break;
                    }
                    default: {
                        throw new Exception("Unknown model category " + (Object)((Object)category));
                    }
                }
                output.write("\n");
            }
        }
        catch (Exception e) {
            System.out.println("Caught exception on line " + lineNum);
            System.out.println("");
            e.printStackTrace();
            System.exit(1);
        }
        output.close();
        reader.close();
    }

    private void loadModel(String modelName) throws Exception {
        try {
            this.loadMojo(modelName);
        }
        catch (IOException e) {
            this.loadPojo(modelName);
        }
    }

    private void loadPojo(String className) throws Exception {
        GenModel genModel = (GenModel)Class.forName(className).newInstance();
        this.model = new EasyPredictModelWrapper(genModel);
    }

    private void loadMojo(String modelName) throws IOException {
        MojoModel genModel = MojoModel.load(modelName);
        this.model = new EasyPredictModelWrapper(genModel);
    }

    private static void usage() {
        System.out.println("");
        System.out.println("Usage:  java [...java args...] hex.genmodel.tools.PredictCsv --mojo mojoName");
        System.out.println("             --pojo pojoName --input inputFile --output outputFile --decimal");
        System.out.println("");
        System.out.println("     --mojo    Name of the zip file containing model's MOJO.");
        System.out.println("     --pojo    Name of the java class containing the model's POJO. Either this ");
        System.out.println("               parameter or --model must be specified.");
        System.out.println("     --input   CSV file containing the test data set to score.");
        System.out.println("     --output  Name of the output CSV file with computed predictions.");
        System.out.println("     --decimal Use decimal numbers in the output (default is to use hexademical).");
        System.out.println("");
        System.exit(1);
    }

    private void parseArgs(String[] args) {
        try {
            block16: for (int i = 0; i < args.length; ++i) {
                String s = args[i];
                if (s.equals("--header")) continue;
                if (s.equals("--decimal")) {
                    this.useDecimalOutput = true;
                    continue;
                }
                if (++i >= args.length) {
                    PredictCsv.usage();
                }
                String sarg = args[i];
                switch (s) {
                    case "--model": {
                        this.loadModel(sarg);
                        continue block16;
                    }
                    case "--mojo": {
                        this.loadMojo(sarg);
                        continue block16;
                    }
                    case "--pojo": {
                        this.loadPojo(sarg);
                        continue block16;
                    }
                    case "--input": {
                        this.inputCSVFileName = sarg;
                        continue block16;
                    }
                    case "--output": {
                        this.outputCSVFileName = sarg;
                        continue block16;
                    }
                    default: {
                        System.out.println("ERROR: Unknown command line argument: " + s);
                        PredictCsv.usage();
                    }
                }
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            PredictCsv.usage();
        }
    }
}

