Created
March 16, 2011 17:44
-
-
Save mlimotte/872918 to your computer and use it in GitHub Desktop.
A Cascalog function to join a small file that can fit in memory, map-side.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package foo.cascalog; | |
import cascading.flow.FlowProcess; | |
import cascading.flow.hadoop.HadoopFlowProcess; | |
import cascading.operation.FunctionCall; | |
import cascading.operation.OperationCall; | |
import cascading.tuple.Tuple; | |
import cascading.tuple.TupleEntry; | |
import cascalog.CascalogFunction; | |
import org.apache.hadoop.conf.Configuration; | |
import org.apache.hadoop.filecache.DistributedCache; | |
import org.apache.hadoop.fs.FileSystem; | |
import org.apache.hadoop.fs.Path; | |
import java.io.*; | |
import java.util.HashMap; | |
import java.util.Map; | |
public class MemoryJoin extends CascalogFunction { | |
private Map<String,String> theMap = new HashMap<String,String>(); | |
private String name; | |
private int keyIndex; | |
private int valueIndex; | |
private int splitLimit; | |
private String defaultValue=null; | |
private boolean caseSensitive; | |
private boolean warnOnNoJoin; | |
private String delim; | |
/** | |
* @param name | |
* @param keyIndex | |
* @param valueIndex | |
* @param defaultValue | |
* @param caseSensitive - only set false if you need case-insensitivity, false is slightly worse performance | |
*/ | |
public MemoryJoin(String name, String delim, int keyIndex, int valueIndex, String defaultValue, | |
boolean caseSensitive, boolean warnOnNoJoin) { | |
this.name = name; | |
this.delim = delim; | |
this.keyIndex = keyIndex; | |
this.valueIndex = valueIndex; | |
int maxIdx = (keyIndex > valueIndex)?keyIndex:valueIndex; | |
this.splitLimit = maxIdx + 2; | |
this.defaultValue = defaultValue; | |
this.caseSensitive = caseSensitive; | |
this.warnOnNoJoin = warnOnNoJoin; | |
} | |
public MemoryJoin(String name, int keyIndex, int valueIndex, String defaultValue, | |
boolean caseSensitive, boolean warnOnNoJoin) { | |
this(name,"\t",keyIndex,valueIndex,defaultValue,caseSensitive,warnOnNoJoin); | |
} | |
public MemoryJoin(String name, String delim, int keyIndex, int valueIndex, String defaultValue) { | |
this(name,delim,keyIndex,valueIndex,defaultValue,true,true); | |
} | |
public MemoryJoin(String name, int keyIndex, int valueIndex, String defaultValue) { | |
this(name,"\t",keyIndex,valueIndex,defaultValue,true,true); | |
} | |
private Path getFileFor(HadoopFlowProcess hfp, String name) throws IOException { | |
Path[] files = DistributedCache.getLocalCacheFiles(hfp.getJobConf()); | |
for (Path cachePath : files) { | |
if (cachePath.getName().equals(name)) { | |
return cachePath; | |
} | |
} | |
// error, if we haven't returned by now | |
throw new RuntimeException("Did not find " + name + " in local cache files (distributedCache contains " | |
+ files.length + " files)."); | |
} | |
@Override | |
public void prepare(FlowProcess flowProcess, OperationCall operationCall) { | |
super.prepare(flowProcess, operationCall); | |
Path f=null; | |
try { | |
HadoopFlowProcess hfp = (HadoopFlowProcess) flowProcess; | |
f = getFileFor(hfp,name); | |
} catch (IOException e) { | |
new RuntimeException("Error getting files from distributedCache for " + name, e); | |
} | |
try { | |
FileSystem fs = FileSystem.getLocal(new Configuration()); | |
InputStream in = fs.open(f,1000000); | |
InputStreamReader inr = new InputStreamReader(in); | |
BufferedReader r = new BufferedReader(inr); | |
String line; | |
while ((line = r.readLine()) != null) { | |
String[] flds = line.split(delim,splitLimit); | |
if (flds.length >= (splitLimit - 1)) { | |
String key = caseSensitive?flds[keyIndex]:flds[keyIndex].toLowerCase(); | |
theMap.put(key,flds[valueIndex]); | |
} | |
} | |
r.close(); | |
} catch (IOException e) { | |
new RuntimeException("Error reading file " + f.toString() + " from distributedCache",e); | |
} | |
System.out.printf("MemoryJoin.prepare() for %s(mappings=%d, path=%s)%n", name, theMap.size(), f.toString()); | |
} | |
@Override | |
public void operate(FlowProcess flowProcess, FunctionCall functionCall) { | |
Tuple result = new Tuple(); | |
TupleEntry arguments = functionCall.getArguments(); | |
String key = arguments.getString(0); | |
String s = ""; | |
if (! "".equals(key)) { | |
key = caseSensitive?key:key.toLowerCase(); | |
s = theMap.get(key); | |
if (s == null) { | |
if (warnOnNoJoin) { | |
System.err.println("MemoryJoin " + name + ": no join found for key \'" + key + "\'"); | |
theMap.put(key,defaultValue); // so we only warn on join once per JVM | |
} | |
s = defaultValue; | |
} | |
} | |
result.add(s); | |
functionCall.getOutputCollector().add(result); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The files that you want to join will need to exist in HDFS, and then you can use the distributed cache to pass them around to each tasktracker. To do this, add to mapred.cache.files in your job-conf. E.g.:
(with-job-conf { "mapred.cache.files" csv-of-paths-in-hdfs}
Where csv-of-paths-in-hdfs is a comma-separated list of paths in HDFS where the files to be joined exist.
In clojure, you can use this like so: