Commit Logs

Caching predictive models using Guava in Scala

In a previous post, we talked about caching Spark models in memory for a web service so that the prediction latency is reduced. As with any web applications, caching strategy can get very interesting, but the patterns of caching a machine learning model are relatively straightforward, since the model is likely to be static unless there are updates. In particular, I find Guava provides some handy in-memory cache solutions for our use case.

In the rest of this post, I am going to walk you through some basic caching patterns using Guava. Note that, this is inspired by, but not limited to caching predictive models. As a reference example, we assume the goal is to serve a machine learning model, which is updated daily, in a web application built by the Play Framework.

Timed eviction

To start with, a simple caching pattern is to load the model in-memory and evict it after a given time period (daily in our case). In our particular case, we will use CacheLoader, since there is a default function (the machine learning model) to load associated with a key (model identifier); otherwise, you will need to pass a Callable into a get call.

With dependency injection, you could create a CacheProvider for caching.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
package controllers

import java.util.concurrent.TimeUnit
import com.google.common.cache.{CacheBuilder, CacheLoader}

trait CacheProvider {

val modelCache = CacheBuilder.newBuilder()
.maximumSize(2)
.expireAfterWrite(24, TimeUnit.HOURS)
.build(
new CacheLoader[String, Model]{
def load(path: String): Model = {
Model.load(path)
}
}
)

def getModel: Model = {
modelCache.get("path/to/model")
}
}

In this way, the cached model is evicted after 24 hours. For the immediate next query after eviction, the service will hang there until the model is loaded again so a higher latency is expected.

Timed refresh

For timed eviction, if things went wrong during reloading, the service won’t be able to return anything because the old model is already evicted. This is of course is not ideal and may cause serious problems.

Instead, a better solution maybe timed refresh. The difference is that the old model (if any) is still returned while the key is being refreshed. Therefore, even if an exception is thrown while refreshing, the service is still able to return results from the old model, while the exception is logged and swallowed.

The change to switch from timed eviction to timed refresh is minimal - you just need to replace expireAfterWrite with refreshAfterWrite.

Timed asynchronous refresh

The defauled refresh loads value synchronously. That means, the service will still hang there waiting for the new model to be loaded. This makes queries to have high latency during refresh and, thus, bad user experience.

Good news is that there is a way to set up the CacheBuilder such that refresh happens asynchronously. Specifically, you need to overwrite the reload method to be asynchronous.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
package controllers

import java.util.concurrent.{Callable, Executors, TimeUnit}
import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.google.common.util.concurrent.{ListenableFuture, ListenableFutureTask}

trait CacheProvider {

val executor = Executors.newFixedThreadPool(10)

val modelCache = CacheBuilder.newBuilder()
.maximumSize(2)
.refreshAfterWrite(24, TimeUnit.HOURS)
.build(
new CacheLoader[String, Model]() {
def load(path: String): Model = {
Model.load(path)
}

// override reload makes refresh asynchronous
override def reload(
path: String,
prevModel: Model
): ListenableFuture[Model] = {
val task = ListenableFutureTask.create(
new Callable[Model]() {
def call(): Model = {
Model.load(path)
}
}
)

executor.execute(task)
task
}
}
)

def getModel: Model = {
modelCache.get("path/to/model")
}
}

Summary

Caching is one of the most interesting problems in web applications. Here I only talked about some most basic in-memory caching patterns, but they, especially the timed asynchronous refresh, seem to work well with predictive models, which is relatively static compared to other content.

As always, I would really appreciate your thoughts/comments. Feel free to leave them following this post or tweet me @_LeiG.