First Steps with Apache Mahout

I’ve been digging into a home rolled machine learning classification system at $work recently, so decided to noodle around with Mahout to get some perspective on a different toolchain. Mahout is not the most approachable library in the world, so I picked up Mahout in Action and skipped right to the classifiers section. The book helped (a lot!) but it is still kind of tricky, so, here are some notes to help anyone else get started.

First, the different classifiers seem mostly unrelated to each other. The book talked mostly about the SGD implementation, so I focused on that. It seems to be the easiest to work with sans hadoop cluster, so that is convenient.

Creating Vectors

The first problem is that you want to write out vectors somewhere so you can read them back in as you noddle on training. For my experiment I did the stereotypical two-category spam classifier (easy access to a data corpus). My corpus consists of mostly mixed plain text and html-ish samples, so to normalize them I ran each sample through jTidy and then used Tika’s very forgiving text extractor to get just the character data. Finally, I ran it through Lucene’s standard analyzer and downcase filter. I didn’t bother with stemming as the corpus is mixed language and I didn’t want to get into that for simple experimentation. The code isn’t optimized at all, is just simple path to the output I want.

Preconditions.checkArgument(title != null, "title may not be null");
Preconditions.checkArgument(description != null, "description may not be null");

try {
    final List<String> fields = Lists.newArrayList();
    for (String raw : Arrays.asList(description, title)) {
        Tidy t = new Tidy();
        t.setErrout(new PrintWriter(new ByteArrayOutputStream()));
        StringWriter out = new StringWriter();
        t.parse(new StringReader(raw), out);

        AutoDetectParser p = new AutoDetectParser();
        p.parse(new ByteArrayInputStream(out.getBuffer().toString().getBytes()), 
                                         new TextContentHandler(new DefaultHandler()
        {
            @Override
            public void characters(char[] ch, int start, int length) throws SAXException
            {
                CharBuffer buf = CharBuffer.wrap(ch, start, length);
                String s = buf.toString();
                fields.add(s);
            }
        }), new Metadata());
    }

    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_34);
    StringReader in = new StringReader(Joiner.on(" ")
                                             .join(fields)
                                             .replaceAll("\\s+", " "));
    TokenStream ts = analyzer.tokenStream("content", in);
    ts = new LowerCaseFilter(Version.LUCENE_34, ts);

    CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
    List<String> words = Lists.newArrayList();
    while (ts.incrementToken()) {
        char[] termBuffer = termAtt.buffer();
        int termLen = termAtt.length();
        String w = new String(termBuffer, 0, termLen);
        words.add(w);
    }
    this.scrubbedWords = words;
}
catch (Exception e) {
    throw new RuntimeException(e);
}

SGD works on identically sized vectors, so we need to convert that bag of words into a vector, and as this is training data, associate the value of the target variable (spam or ham) with the vector when we store it. We cannot encode the target variable in the vector, it would be used for training, but it turns out that Mahout has a useful NamedVector implementation. This could be used to encode an id or such for external lookup, but I just used it to encode that target variable.

Mahout offers a lot of ways to vectorize text and no practical guidance on which way works. I settled on using the AdaptiveWordValueEncoder for each word in the bag of text, in order to get some kind of useful frequency weighting:

public Vector encode(String type, TopicData data) throws IOException
{
    FeatureVectorEncoder content_encoder = new AdaptiveWordValueEncoder("content");
    content_encoder.setProbes(2);

    FeatureVectorEncoder type_encoder = new StaticWordValueEncoder("type");
    type_encoder.setProbes(2);

    Vector v = new RandomAccessSparseVector(VECTOR_WIDTH);
    type_encoder.addToVector(type, v);

    for (String word : data.getWords()) {
        content_encoder.addToVector(word, v);
    }
    return new NamedVector(v, label);
}

To break this down, I am encoding two things, a type which is basically the type of content this is, and the actual bag of words from the TopicData. I am using the hashing vector muckery in order to keep fixed-width vectors, and using two probes. The number of probes, and the vector width (100) I use are total swags, I have no real reasoning for the choice. The label argument to the NamedVector is either “spam” or “ham” – the aforementioned target variable.

Pushing Data Around

Now, some mechanics. My corpus is sitting in a database, and I want to vectorize it and save it to local disk to noodle with. The book steers you towards Hadoop SequenceFile instances for saving stuff, and you can store them on local disk as easily as a hadoop cluster so I did so. Sadly, the book, the wiki, and everything else kind of hides the details of this. So, my implementation:

URI path = URI.create("file:/tmp/spammery.seq");
Configuration hadoop_conf = new Configuration();
FileSystem fs = FileSystem.get(path, hadoop_conf);
SequenceFile.Writer writer = SequenceFile.createWriter(fs, hadoop_conf,
                                                       new Path(path), 
                                                       LongWritable.class,
                                                       VectorWritable.class);
final VectorWriter vw = new SequenceFileVectorWriter(writer);

Nice and succinct, but took some trial and error to get there. We start with a URI pointing to the local filesystem, do a hadoop dance, and make a sequence file writer. The trial and error came in with figuring out the key and value classes to give the writer, LongWriteable and VectorWritable respectively. After that it all worked nicely.

The VectorWriter is a nice convenience in Mahout for splatting stuff out to the sequence file.

Once we have our writer, the easiest path is to pass a Iterable<Vector> to the writer. Luckily for me, my corpus is in a database, and jDBI Query instances implement Iterable, so that amounted to grabbing two database connections, and fetching spam and ham:

// fetch size, etc, is to make results stream out of mysql
// *sigh*
Query<Vector> spam_q = h.createQuery("select spam...")
                        .setFetchDirection2(ResultSet.FETCH_FORWARD)
                        .setFetchSize(Integer.MIN_VALUE)
                        .map(new Vectorizer("spam"));

Query<Vector> ham_q = h2.createQuery("select ham...")
                        .setFetchDirection2(ResultSet.FETCH_FORWARD)
                        .setFetchSize(Integer.MIN_VALUE)
                        .map(new Vectorizer("ham"));

vw.write(new CollatingIterable<>(ham_q, spam_q));

Two things to note, which I found out the hard way. The first is that you need about the same amount of samples for each category. I used 50,000 of each to keep times short. Secondly, you need to intersperse the samples. The first time through I wrote all the spam, then all the ham. When running the resulting classifier everything was classified as ham. Oops. CollatingIterable just takes a bunch of iterables and collates them lazily, so it goes back and forth.

When we’re done close() the vector writer and sequence file.

Reading data back in is almost the same:

Configuration hconf = new Configuration();
FileSystem fs = FileSystem.get(path, hconf);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(path), hconf);

LongWritable key = new LongWritable();
VectorWritable value = new VectorWritable();

while (reader.next(key, value)) {
    NamedVector v = (NamedVector) value.get();
    // do stuff with v
}

The key and value instances are basically containers for pulling info out. I don’t bother with the key, all I care about is the vector. Sadly, we have to downcast it to NamedVector, but the cast works.

Training

So we can write our vectors, and read our vectors. Let’s train a classifier! First pass I tried to use the OnlineLogisticRegression and looked at CrossFoldLearner and got very sad. Then I saw reference to the handy-dandy AdaptiveLogisticRegression and got happy again. Basically, it is a self-tuning magical learning thing that figures out the parameters for you. It seems to work.

Using it is straightforward:

AdaptiveLogisticRegression reg = new AdaptiveLogisticRegression(2, VECTOR_WIDTH, new L1());
while (reader.next(key, value)) {
    NamedVector v = (NamedVector) value.get();
    reg.train("spam".equals(v.getName()) ? 1 : 0, v);
}
reg.close();

Voila, we trained it! If you use too small a dataset, or don’t close it, it gets cranky, just to note. The arguments to the ALR are the number of categories (2 in this case), the vector width, and a prior function. I tried L1 and L2. L2 didn’t work. I have no idea why. When calling train() you need to tell it the category for the target variable, encoded as a 0 offset int. It works fine.

The ALR trains up a bunch of classifiers for you, it looks like twenty by default. For each vector it feeds every trainer but one, and uses that vector to test the one, and so on. From this it keeps track of the best performing until you are finished, then you can ask it for the best. I am not certain its criteria for best, but I trust it works.

You can then save off the model for later use:

ModelSerializer.writeBinary(model_path, 
                            reg.getBest().getPayload().getLearner());

Model path is a path on the filesystem. It dumps it out, life goes on.

Using the Classifier

The first thing I did was run the classifier against the training data, as it should do well against that data (I hope). We start by loading our model back off of disk:

InputStream in = new FileInputStream(model_path);
CrossFoldLearner best = ModelSerializer.readBinary(in, CrossFoldLearner.class);
in.close();

I then reopened the sequence file (not going to show the code again, look earlier) and traversed it asking it to score each vector:

int total = 0;
while (reader.next(key, value)) {
    total++;
    NamedVector v = (NamedVector) value.get();
    int expected = "spam".equals(v.getName()) ? 1 : 0;
    
    Vector p = new DenseVector(2);
    best.classifyFull(p, v);
    
    int cat = p.maxValueIndex();
    if (cat == expected) { correct++;}
}
double cd = correct;
double td = total;
System.out.println(cd / td);

The p vector is for holding the results of the classification. Basically, it puts the calculated likelihood that v falls into a particular category (0 or 1 in this case) into that entry in p. For instance, if a given vector has a likelihood of being ham (category 0) of 0.97 and a likelihood of being spam (category 1) of 0.03, then the p vector will contain {0: 0.97, 1: 0.03}. We just ask for the index with the maximum value and report that as the classification. For a real spam system we’d probably pull out likelihood of it being spam and report that.

In this case we just keep track of the correct classification rate and print it out at the end.

Real Data

Once I got it working, I hooked the classifier up to a realtime content stream at $work and watched it for a while, visually inspecting for accuracy. It worked! This made me happy.

Conclusions

While it worked, and the code is not bad at the end, it was harder to get to this point than I would expect. Mahout in Action was instrumental in helping me understand the pieces and point me in the right direction. However, it really only talks about the SGD classifiers, and it mostly talks about working with a specially prepared dataset and the command line tools, not the APIs. The Mahout Wiki is, well, a wiki. The users mailing list seems to have about ten questions for each answer. So, the library is not super approachable, but it seems to work and be pleasant code to work with once figured out.

There seem to be a number of other classification algorithms, but it is not obvious how to use them – particularly the naive-bayes variants.

I’ll probably keep working with Mahout, and probably try to contribute some practice-oriented documentation. With any luck I can track down some Mahout-ites at ApacheCon in a couple weeks and learn more.