/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.hadoop.als;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.Closeable;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.hadoop.als.ALSUtils;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
import org.apache.mahout.common.mapreduce.TransposeMapper;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelALSFactorizationJob
extends AbstractJob {
    private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJob.class);
    static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures";
    static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda";
    static final String ALPHA = ParallelALSFactorizationJob.class.getName() + ".alpha";
    static final String FEATURE_MATRIX = ParallelALSFactorizationJob.class.getName() + ".featureMatrix";
    private boolean implicitFeedback;
    private int numIterations;
    private int numFeatures;
    private double lambda;
    private double alpha;

    public static void main(String[] args) throws Exception {
        ToolRunner.run((Tool)new ParallelALSFactorizationJob(), (String[])args);
    }

    public int run(String[] args) throws Exception {
        this.addInputOption();
        this.addOutputOption();
        this.addOption("lambda", null, "regularization parameter", true);
        this.addOption("implicitFeedback", null, "data consists of implicit feedback?", String.valueOf(false));
        this.addOption("alpha", null, "confidence parameter (only used on implicit feedback)", String.valueOf(40));
        this.addOption("numFeatures", null, "dimension of the feature space", true);
        this.addOption("numIterations", null, "number of iterations", true);
        Map<String, List<String>> parsedArgs = this.parseArguments(args);
        if (parsedArgs == null) {
            return -1;
        }
        this.numFeatures = Integer.parseInt(this.getOption("numFeatures"));
        this.numIterations = Integer.parseInt(this.getOption("numIterations"));
        this.lambda = Double.parseDouble(this.getOption("lambda"));
        this.alpha = Double.parseDouble(this.getOption("alpha"));
        this.implicitFeedback = Boolean.parseBoolean(this.getOption("implicitFeedback"));
        Job itemRatings = this.prepareJob(this.getInputPath(), this.pathToItemRatings(), TextInputFormat.class, ItemRatingVectorsMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
        itemRatings.setCombinerClass(VectorSumReducer.class);
        boolean succeeded = itemRatings.waitForCompletion(true);
        if (!succeeded) {
            return -1;
        }
        Job userRatings = this.prepareJob(this.pathToItemRatings(), this.pathToUserRatings(), TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
        userRatings.setCombinerClass(MergeVectorsCombiner.class);
        succeeded = userRatings.waitForCompletion(true);
        if (!succeeded) {
            return -1;
        }
        Job averageItemRatings = this.prepareJob(this.pathToItemRatings(), this.getTempPath("averageRatings"), AverageRatingMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
        averageItemRatings.setCombinerClass(MergeVectorsCombiner.class);
        succeeded = averageItemRatings.waitForCompletion(true);
        if (!succeeded) {
            return -1;
        }
        Vector averageRatings = ALSUtils.readFirstRow(this.getTempPath("averageRatings"), this.getConf());
        this.initializeM(averageRatings);
        for (int currentIteration = 0; currentIteration < this.numIterations; ++currentIteration) {
            log.info("Recomputing U (iteration {}/{})", (Object)currentIteration, (Object)this.numIterations);
            this.runSolver(this.pathToUserRatings(), this.pathToU(currentIteration), this.pathToM(currentIteration - 1));
            log.info("Recomputing M (iteration {}/{})", (Object)currentIteration, (Object)this.numIterations);
            this.runSolver(this.pathToItemRatings(), this.pathToM(currentIteration), this.pathToU(currentIteration));
        }
        return 0;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void initializeM(Vector averageRatings) throws IOException {
        Random random = RandomUtils.getRandom();
        FileSystem fs = FileSystem.get((URI)this.pathToM(-1).toUri(), (Configuration)this.getConf());
        SequenceFile.Writer writer = null;
        try {
            writer = new SequenceFile.Writer(fs, this.getConf(), new Path(this.pathToM(-1), "part-m-00000"), IntWritable.class, VectorWritable.class);
            Iterator averages = averageRatings.iterateNonZero();
            while (averages.hasNext()) {
                Vector.Element e = (Vector.Element)averages.next();
                DenseVector row = new DenseVector(this.numFeatures);
                row.setQuick(0, e.get());
                for (int m = 1; m < this.numFeatures; ++m) {
                    row.setQuick(m, random.nextDouble());
                }
                writer.append((Writable)new IntWritable(e.index()), (Writable)new VectorWritable((Vector)row));
            }
        }
        catch (Throwable throwable) {
            Closeables.closeQuietly(writer);
            throw throwable;
        }
        Closeables.closeQuietly((Closeable)writer);
    }

    private void runSolver(Path ratings, Path output, Path pathToUorI) throws ClassNotFoundException, IOException, InterruptedException {
        Class solverMapper = this.implicitFeedback ? SolveImplicitFeedbackMapper.class : SolveExplicitFeedbackMapper.class;
        Job solverForUorI = this.prepareJob(ratings, output, SequenceFileInputFormat.class, solverMapper, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
        Configuration solverConf = solverForUorI.getConfiguration();
        solverConf.set(LAMBDA, String.valueOf(this.lambda));
        solverConf.set(ALPHA, String.valueOf(this.alpha));
        solverConf.setInt(NUM_FEATURES, this.numFeatures);
        solverConf.set(FEATURE_MATRIX, pathToUorI.toString());
        boolean succeeded = solverForUorI.waitForCompletion(true);
        if (!succeeded) {
            throw new IllegalStateException("Job failed!");
        }
    }

    private Path pathToM(int iteration) {
        return iteration == this.numIterations - 1 ? this.getOutputPath("M") : this.getTempPath("M-" + iteration);
    }

    private Path pathToU(int iteration) {
        return iteration == this.numIterations - 1 ? this.getOutputPath("U") : this.getTempPath("U-" + iteration);
    }

    private Path pathToItemRatings() {
        return this.getTempPath("itemRatings");
    }

    private Path pathToUserRatings() {
        return this.getOutputPath("userRatings");
    }

    static class AverageRatingMapper
    extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        AverageRatingMapper() {
        }

        protected void map(IntWritable r, VectorWritable v, Mapper.Context ctx) throws IOException, InterruptedException {
            FullRunningAverage avg = new FullRunningAverage();
            Iterator elements = v.get().iterateNonZero();
            while (elements.hasNext()) {
                avg.addDatum(((Vector.Element)elements.next()).get());
            }
            RandomAccessSparseVector vector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
            vector.setQuick(r.get(), avg.getAverage());
            ctx.write((Object)new IntWritable(0), (Object)new VectorWritable((Vector)vector));
        }
    }

    static class SolveImplicitFeedbackMapper
    extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private ImplicitFeedbackAlternatingLeastSquaresSolver solver;

        SolveImplicitFeedbackMapper() {
        }

        protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
            double lambda = Double.parseDouble(ctx.getConfiguration().get(LAMBDA));
            double alpha = Double.parseDouble(ctx.getConfiguration().get(ALPHA));
            int numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
            Path YPath = new Path(ctx.getConfiguration().get(FEATURE_MATRIX));
            OpenIntObjectHashMap<Vector> Y = ALSUtils.readMatrixByRows(YPath, ctx.getConfiguration());
            this.solver = new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, Y);
            Preconditions.checkArgument(numFeatures > 0, "numFeatures was not set correctly!");
        }

        protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Mapper.Context ctx) throws IOException, InterruptedException {
            SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(ratingsWritable.get());
            Vector uiOrmj = this.solver.solve((Vector)ratings);
            ctx.write((Object)userOrItemID, (Object)new VectorWritable(uiOrmj));
        }
    }

    static class SolveExplicitFeedbackMapper
    extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private double lambda;
        private int numFeatures;
        private OpenIntObjectHashMap<Vector> UorM;
        private AlternatingLeastSquaresSolver solver;

        SolveExplicitFeedbackMapper() {
        }

        protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
            this.lambda = Double.parseDouble(ctx.getConfiguration().get(LAMBDA));
            this.numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
            this.solver = new AlternatingLeastSquaresSolver();
            Path UOrIPath = new Path(ctx.getConfiguration().get(FEATURE_MATRIX));
            this.UorM = ALSUtils.readMatrixByRows(UOrIPath, ctx.getConfiguration());
            Preconditions.checkArgument(this.numFeatures > 0, "numFeatures was not set correctly!");
        }

        protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Mapper.Context ctx) throws IOException, InterruptedException {
            SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(ratingsWritable.get());
            ArrayList<Object> featureVectors = Lists.newArrayList();
            Iterator interactions = ratings.iterateNonZero();
            while (interactions.hasNext()) {
                int index = ((Vector.Element)interactions.next()).index();
                featureVectors.add(this.UorM.get(index));
            }
            Vector uiOrmj = this.solver.solve(featureVectors, (Vector)ratings, this.lambda, this.numFeatures);
            ctx.write((Object)userOrItemID, (Object)new VectorWritable(uiOrmj));
        }
    }

    static class ItemRatingVectorsMapper
    extends Mapper<LongWritable, Text, IntWritable, VectorWritable> {
        ItemRatingVectorsMapper() {
        }

        protected void map(LongWritable offset, Text line, Mapper.Context ctx) throws IOException, InterruptedException {
            String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
            int userID = Integer.parseInt(tokens[0]);
            int itemID = Integer.parseInt(tokens[1]);
            float rating = Float.parseFloat(tokens[2]);
            RandomAccessSparseVector ratings = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
            ratings.set(userID, (double)rating);
            ctx.write((Object)new IntWritable(itemID), (Object)new VectorWritable((Vector)ratings, true));
        }
    }
}

