Last active
November 10, 2016 07:25
-
-
Save hallkk/da5adbeb8df88774864bc0e1b1ac4151 to your computer and use it in GitHub Desktop.
hive UDTF样例,将列转化为多行
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
编写自己的UDTF: | |
1.继承org.apache.Hadoop.hive.ql.udf.generic.GenericUDTF。 | |
2.实现initialize(),process(),close()三个方法。 | |
3.UDTF首先会调用initialize()方法,此方法返回UDTF的返回行的信息(返回个数,类型)。 | |
4.初始化完成后会调用process()方法,对传入的参数进行处理,可以通过forward()方法把结果返回。 | |
5.最后调用close()对需要清理的方法进行清理。 | |
**/ | |
@Description(name = "convert_nplus_freq", | |
value = "convert_nplus_freq(nPlusFreqInfo,[freqSeparator]) - convert n+ freq info to n freq") | |
public class ConvertNPlusFreqUdf extends GenericUDTF { | |
private static final String DEFAULT_SEPARATOR = ","; | |
@Override | |
public StructObjectInspector initialize(ObjectInspector[] args) | |
throws UDFArgumentException { | |
if (args.length < 1) { | |
throw new UDFArgumentLengthException("convert_nplus_freq takes require arguments"); | |
} | |
if (args[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { | |
throw new UDFArgumentException("convert_nplus_freq takes string as parameter"); | |
} | |
List<String> fieldNames = new ArrayList(); | |
fieldNames.add("freq"); | |
fieldNames.add("count"); | |
List<ObjectInspector> fieldOIs = new ArrayList(); | |
fieldOIs.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector); | |
fieldOIs.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector); | |
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); | |
} | |
@Override | |
public void process(Object[] args) throws HiveException { | |
String freqInfo = args[0].toString(); | |
String freqSeparator = DEFAULT_SEPARATOR; | |
if (args.length > 1) { | |
freqSeparator = args[1].toString(); | |
} | |
for (Object[] row : this.processRecord(freqInfo, freqSeparator)) { | |
forward(row); | |
} | |
} | |
public List<Object[]> processRecord(String nPlusFreqInfo, String freqSeparator) { | |
List<Integer> nPlusFreqList = this.parseIntList(nPlusFreqInfo, freqSeparator); | |
List<Object[]> resultRows = new ArrayList(); | |
for (int index = 0; index < nPlusFreqList.size(); index++) { | |
//遍历到N+频次为0,停止计算 | |
if (nPlusFreqList.get(index) == 0) { | |
break; | |
} | |
int freq = index + 1; | |
//遍历到结尾,直接输出 | |
if (freq == nPlusFreqList.size()) { | |
resultRows.add(new Object[]{freq, nPlusFreqList.get(index)}); | |
} | |
//计算freq的值 | |
else { | |
resultRows.add(new Object[]{freq, nPlusFreqList.get(index) - nPlusFreqList.get(index + 1)}); | |
} | |
} | |
return resultRows; | |
} | |
@Override | |
public void close() throws HiveException { | |
} | |
/** | |
* 根据输入n+频次信息和分隔符,转化为n+频次数组。 | |
* 如果在转化过程中发生异常,返回空集合。 | |
* | |
* @param nPlusFreqInfo | |
* @param freqSeparator | |
* @return | |
*/ | |
private List<Integer> parseIntList(String nPlusFreqInfo, String freqSeparator) { | |
List<Integer> freqList = Lists.newArrayList(); | |
for (String freq : nPlusFreqInfo.split(freqSeparator)) { | |
try { | |
freqList.add(Integer.parseInt(freq)); | |
} catch (NumberFormatException ex) { | |
return Lists.newArrayList(); | |
} | |
} | |
return freqList; | |
} | |
} | |
~ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* 对UDTF进行单元测试。 | |
**/ | |
public class ConvertNPlusFreqUdfTest { | |
@Test | |
public void testconvertion() throws UDFArgumentException { | |
ConvertNPlusFreqUdf udf = new ConvertNPlusFreqUdf(); | |
ObjectInspector[] inputOI = {PrimitiveObjectInspectorFactory.javaStringObjectInspector}; | |
// the value exists | |
try { | |
udf.initialize(inputOI); | |
} catch (Exception ex) { | |
throw ex; | |
} | |
// 目标方法应为private方法,通过反射进行单元测试 | |
List<Object[]> results = udf.processRecord("3,2,1,0,0", ","); | |
Assert.assertEquals(3, results.size()); | |
Assert.assertEquals(1, results.get(0)[0]); | |
Assert.assertEquals(1, results.get(0)[1]); | |
Assert.assertEquals(2, results.get(1)[0]); | |
Assert.assertEquals(1, results.get(1)[1]); | |
Assert.assertEquals(3, results.get(2)[0]); | |
Assert.assertEquals(1, results.get(2)[1]); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment