Created
March 24, 2020 12:30
-
-
Save qxj/59fc422a5d2dd2700eb5f372ee3cb1e6 to your computer and use it in GitHub Desktop.
Generate tfrecord in MapReduce
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package net.jqian.tutorial.tfrecord; | |
import org.apache.commons.logging.Log; | |
import org.apache.commons.logging.LogFactory; | |
import org.apache.hadoop.conf.Configuration; | |
import org.apache.hadoop.mapreduce.Mapper; | |
import org.tensorflow.example.*; | |
import java.util.*; | |
/** | |
* Created by jqian on 17/10/10. | |
*/ | |
public class DecodeProcessor { | |
private static final Log LOG = LogFactory.getLog(DecodeProcessor.class); | |
protected Map<Integer, Map<String, Long>> features = new HashMap<>(); | |
protected List<Integer> contFea = new ArrayList<>(); | |
protected List<Integer> deleteFea = new ArrayList<>(); | |
protected List<Integer> lineFea = new ArrayList<>(); | |
protected List<Integer> sparseFea = new ArrayList<>(); | |
private Configuration conf; | |
public void init(Configuration conf) { | |
LOG.info("start init decode processor"); | |
String dictFilePath = conf.get("dict").trim(); | |
this.conf = conf; | |
String[] linears = conf.getStrings("linear_fea"); | |
if (linears != null) { | |
for (String linear : linears) { | |
lineFea.add(Integer.parseInt(linear.trim())); | |
} | |
} | |
String[] deletes = conf.getStrings("delete_fea"); | |
if (deletes != null) { | |
for (String delete : deletes) { | |
deleteFea.add(Integer.parseInt(delete)); | |
} | |
} | |
String[] conts = conf.getStrings("cont_fea"); | |
if (conts != null) { | |
for (String cont : conts) { | |
contFea.add(Integer.parseInt(cont)); | |
} | |
} | |
LOG.info("load dict success"); | |
String fieldPath = conf.get("field"); | |
} | |
public void setConf(Configuration conf) { | |
this.conf = conf; | |
} | |
public Example processLine(String line, Mapper.Context context) { | |
String[] contents = line.trim().split(" "); | |
Map<Integer, LinkedList<Long>> x1_0 = new HashMap<Integer, LinkedList<Long>>(); | |
Map<Integer, List<Float>> x2 = new HashMap<>(); | |
Map<Integer, LinkedList<Long>> x3_0 = new HashMap<Integer, LinkedList<Long>>(); | |
for (int i = 1; i < contents.length; i++) { | |
String content = contents[i]; | |
String[] idsStr = content.split(":"); | |
Integer fieldId = Integer.parseInt(idsStr[0]); | |
String featureId = idsStr[1]; | |
String val = idsStr[2]; | |
//Float num = Float.parseFloat(idsStr[2]); | |
//idsStr[2].split(";"); | |
if (this.features.keySet().contains(fieldId)) { | |
long featureValue; | |
try { | |
featureValue = this.features.get(fieldId).get(featureId); | |
} catch (Exception e) { | |
LOG.debug(String.format("use rare for field[%d] feature[%s]", fieldId, featureId)); | |
featureValue = 0l; | |
} | |
if (this.contFea.size() > 0 && this.contFea.contains(fieldId)) { | |
if (!val.trim().isEmpty()) { | |
List<Float> values = new ArrayList<>(); | |
for (String v : val.trim().split(";")) | |
values.add(Float.parseFloat(v)); | |
x2.put(fieldId, values); | |
} | |
} else if (this.lineFea.size() > 0 && this.lineFea.contains(fieldId)) { | |
putValue(x3_0, fieldId, featureValue); | |
} else if (this.sparseFea.contains(fieldId)) { | |
putValue(x1_0, fieldId, featureValue); | |
} | |
} else { | |
context.getCounter("counter", "feature not exist " + fieldId).increment(1); | |
} | |
} | |
putRares(x1_0,sparseFea); | |
//putRaresIndex(x1_0, index, x1_1, sparseFea); | |
if (contFea.size() > 0) { | |
for (Integer fieldId : contFea) { | |
if (!x2.containsKey(fieldId)) { | |
//填写默认值 -10000 | |
List<Float> arr = new ArrayList<>(); | |
for (int idx = 0; idx < globalFieldInfo.getField(fieldId).contFeaVecSize; idx++) | |
arr.add(-10000.0f); | |
x2.put(fieldId, arr); | |
} | |
} | |
} | |
putRares(x3_0,lineFea); | |
//putRaresIndex(x3_0, index, x3_1, lineFea); | |
Features.Builder featuresBuilder = Features.newBuilder(); | |
/*for(Map.Entry<Integer,LinkedList<Long>> entry:x1.entrySet()){ | |
String keyId = "x1_"+entry.getKey(); | |
featuresBuilder.putFeature(keyId,createLongListFeature(entry.getValue())); | |
//saveFloatCounter(context, entry, keyId); | |
}*/ | |
for (Map.Entry<Integer, LinkedList<Long>> entry : x1_0.entrySet()) { | |
String keyId = "x_" + entry.getKey(); | |
featuresBuilder.putFeature(keyId, createLongListFeature(entry.getValue())); | |
saveLongCounter(context, entry, keyId); | |
} | |
/*for (Map.Entry<Integer, LinkedList<Long>> entry : x1_1.entrySet()) { | |
String keyId = "x1_" + entry.getKey() + "_1"; | |
featuresBuilder.putFeature(keyId, createLongListFeature(entry.getValue())); | |
saveLongCounter(context, entry, keyId); | |
}*/ | |
if (x2.size() > 0) { | |
for (Map.Entry<Integer, List<Float>> entry : x2.entrySet()) { | |
String keyId = "x2_" + entry.getKey(); | |
//ArrayList<Float> values = new ArrayList<>(); | |
//values.add(entry.getValue()); | |
featuresBuilder.putFeature(keyId, createFloatListFeature(entry.getValue())); | |
} | |
} | |
/*for(Map.Entry<Integer,LinkedList<Long>> entry:x3.entrySet()){ | |
String keyId = "x3_"+entry.getKey(); | |
featuresBuilder.putFeature(keyId,createLongListFeature(entry.getValue())); | |
saveLongCounter(context, entry, keyId); | |
}*/ | |
for (Map.Entry<Integer, LinkedList<Long>> entry : x3_0.entrySet()) { | |
String keyId = "x_" + entry.getKey(); | |
featuresBuilder.putFeature(keyId, createLongListFeature(entry.getValue())); | |
saveLongCounter(context, entry, keyId); | |
} | |
/*for (Map.Entry<Integer, LinkedList<Long>> entry : x3_1.entrySet()) { | |
String keyId = "x3_" + entry.getKey() + "_1"; | |
featuresBuilder.putFeature(keyId, createLongListFeature(entry.getValue())); | |
saveLongCounter(context, entry, keyId); | |
}*/ | |
if(contents[0].split(",").length==2) { | |
int clickNums = Integer.parseInt(contents[0].split(",")[0]); | |
int orderNums = Integer.parseInt(contents[0].split(",")[1]); | |
int y_ctr = 0; | |
int y_cvr = 0; | |
if (clickNums > 0) { | |
if (orderNums > 0) { | |
y_ctr = 1; | |
y_cvr = 1; | |
} else { | |
y_ctr = 1; | |
y_cvr = 0; | |
} | |
} | |
featuresBuilder.putFeature("y_ctr", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(y_ctr).build()).build()); | |
featuresBuilder.putFeature("y_cvr", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(y_cvr).build()).build()); | |
featuresBuilder.putFeature("clickNums", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(clickNums).build()).build()); | |
featuresBuilder.putFeature("orderNums", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(orderNums).build()).build()); | |
}else{ | |
int y = Integer.parseInt(contents[0]); | |
if (y <= 0) { | |
y = 0; | |
} else { | |
y = 1; | |
} | |
featuresBuilder.putFeature("y", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(y).build()).build()); | |
} | |
Example example = Example.newBuilder().setFeatures(featuresBuilder.build()).build(); | |
return example; | |
} | |
private void saveLongCounter(Mapper.Context context, Map.Entry<Integer, LinkedList<Long>> entry, String keyId) { | |
if (context != null) { | |
if (entry.getValue().size() > 1) { | |
context.getCounter("field_size_max_than_one", keyId).increment(1l); | |
} else { | |
context.getCounter("field_size_eq_one", keyId).increment(1l); | |
} | |
} | |
} | |
private void putValue(Map<Integer, LinkedList<Long>> x3, Integer fieldId, long featureValue) { | |
if (x3.containsKey(fieldId)) { | |
x3.get(fieldId).add(featureValue); | |
} else { | |
LinkedList<Long> feaValues = new LinkedList<>(); | |
feaValues.add(featureValue); | |
x3.put(fieldId, feaValues); | |
} | |
} | |
private void putRares(Map<Integer, LinkedList<Long>> x3, List<Integer> features) { | |
for (Integer fieldId : features) { | |
if (!x3.containsKey(fieldId)) { | |
LinkedList<Long> feaValues = new LinkedList<>(); | |
feaValues.add(0l); | |
x3.put(fieldId, feaValues); | |
} | |
} | |
} | |
private void putRaresIndex(Map<Integer, LinkedList<Long>> x1, int index, Map<Integer, LinkedList<Long>> x2, List<Integer> features) { | |
for (Integer fieldId : features) { | |
if (!x2.containsKey(fieldId)) { | |
LinkedList<Long> x2Values = new LinkedList<>(); | |
x2Values.add(0l); | |
x2.put(fieldId, x2Values); | |
LinkedList<Long> x1Values = new LinkedList<>(); | |
x1Values.add((long) index); | |
x1.put(fieldId, x1Values); | |
} | |
} | |
} | |
private Feature createFloatListFeature(Iterable<? extends Float> values) { | |
FloatList floatList = FloatList.newBuilder().addAllValue(values).build(); | |
return Feature.newBuilder().setFloatList(floatList).build(); | |
} | |
private Feature createLongListFeature(Iterable<? extends Long> values) { | |
Int64List floatList = Int64List.newBuilder().addAllValue(values).build(); | |
return Feature.newBuilder().setInt64List(floatList).build(); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package net.jqian.tutorial.tfrecord; | |
import org.apache.commons.logging.Log; | |
import org.apache.commons.logging.LogFactory; | |
import org.apache.hadoop.conf.Configuration; | |
import org.apache.hadoop.fs.FileSystem; | |
import org.apache.hadoop.fs.Path; | |
import org.apache.hadoop.io.*; | |
import org.apache.hadoop.mapreduce.*; | |
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; | |
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; | |
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; | |
import org.apache.hadoop.util.GenericOptionsParser; | |
import org.tensorflow.example.Example; | |
import org.tensorflow.hadoop.io.TFRecordFileOutputFormat; | |
import java.io.IOException; | |
import java.util.Random; | |
public class TFPreprocess { | |
private static Log LOG = LogFactory.getLog(TFPreprocess.class); | |
static class ToTFRecordMapper extends Mapper<LongWritable, Text, Text, BytesWritable> { | |
ToTFRecordMapper() { | |
} | |
private DecodeProcessor processor; | |
private Random random = new Random(System.currentTimeMillis()); | |
@Override | |
protected void setup(Context context) throws IOException, InterruptedException { | |
processor = new DecodeProcessor(); | |
processor.init(context.getConfiguration()); | |
String feaStr = ""; | |
for (Integer fea: processor.sparseFea) { | |
feaStr += fea; | |
feaStr += ";"; | |
} | |
context.getCounter("counter", "sparse fea: " + feaStr).increment(1); | |
feaStr = ""; | |
for (Integer fea: processor.contFea) { | |
feaStr += fea; | |
feaStr += ";"; | |
} | |
context.getCounter("counter", "cont fea: " + feaStr).increment(1); | |
feaStr = ""; | |
for (Integer fea: processor.deleteFea) { | |
feaStr += fea; | |
feaStr += ";"; | |
} | |
context.getCounter("counter", "delete fea: " + feaStr).increment(1); | |
feaStr = ""; | |
for (Integer fea: processor.lineFea) { | |
feaStr += fea; | |
feaStr += ";"; | |
} | |
context.getCounter("counter", "linear fea: " + feaStr).increment(1); | |
} | |
@Override | |
protected void map(LongWritable key, Text value, | |
Context context) throws IOException, InterruptedException { | |
Example example = processor.processLine(value.toString(), context); | |
String outKey = String.valueOf(random.nextInt()); | |
context.write(new Text(outKey), new BytesWritable(example.toByteArray())); | |
} | |
} | |
public static class ToTFRecordReducer extends Reducer<Text, BytesWritable, BytesWritable, NullWritable> { | |
protected void reduce(Text key, Iterable<BytesWritable> values, Context context) | |
throws IOException, InterruptedException { | |
for (BytesWritable v: values) { | |
context.write(v, NullWritable.get()); | |
} | |
} | |
} | |
public static boolean convert(String jobName, Configuration conf, | |
Class<? extends Mapper> mapperClass, | |
Class<? extends Reducer> reducerClass, | |
Class<? extends Writable> mapKeyClass, | |
Class<? extends Writable> mapValueClass, | |
Class<? extends Writable> outputKeyClass, | |
Class<? extends Writable> outputValueClass, | |
Class<? extends InputFormat> inFormatClass, | |
Class<? extends OutputFormat> outFormatClass) throws InterruptedException, IOException, ClassNotFoundException { | |
Job job = Job.getInstance(conf, jobName); | |
job.setJarByClass(mapperClass); | |
job.setMapperClass(mapperClass); | |
job.setReducerClass(reducerClass); | |
job.setNumReduceTasks(conf.getInt("rednum", 128)); | |
job.setInputFormatClass(inFormatClass); | |
job.setOutputFormatClass(outFormatClass); | |
job.setMapOutputKeyClass(mapKeyClass); | |
job.setMapOutputValueClass(mapValueClass); | |
job.setOutputKeyClass(outputKeyClass); | |
job.setOutputValueClass(outputValueClass); | |
//并发数 | |
job.getConfiguration().setLong("mapred.min.split.size",536870912); | |
job.getConfiguration().setInt("mapreduce.map.memory.mb", 4096); | |
job.getConfiguration().set("mapreduce.map.java.opts", "-Xmx4096m"); | |
String inputPath = conf.get("in"); | |
String outputPath = conf.get("out"); | |
String day = conf.get("day"); | |
FileSystem fileSystem = FileSystem.get(conf); | |
FileInputFormat.addInputPath(job, new Path(inputPath + "/" + day)); | |
LOG.info("input:" + inputPath + "/" + day); | |
FileInputFormat.setInputDirRecursive(job, true); | |
Path out = new Path(outputPath + "/" + day); | |
fileSystem.delete(out, true); | |
FileOutputFormat.setOutputPath(job, out); | |
return job.waitForCompletion(true); | |
} | |
public static void main(String[] args) throws Exception { | |
GenericOptionsParser parser = new GenericOptionsParser(new Configuration(), args); | |
Configuration conf = parser.getConfiguration(); | |
boolean ret = convert("ToTFR", conf, ToTFRecordMapper.class, ToTFRecordReducer.class, | |
Text.class, BytesWritable.class, | |
BytesWritable.class, NullWritable.class, | |
TextInputFormat.class, TFRecordFileOutputFormat.class); | |
System.exit(ret ? 0 : 1); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment