Skip to content

Instantly share code, notes, and snippets.

@ceefour
Created January 14, 2017 08:38
Show Gist options
  • Save ceefour/5ca03f696dd57361c55abecac1907def to your computer and use it in GitHub Desktop.
Save ceefour/5ca03f696dd57361c55abecac1907def to your computer and use it in GitHub Desktop.
Nyobain pake HMM untuk mengenali kata dalam chat (termasuk yang alay), awalnya sih agak lumayan dengan training sample sedikit.......... giliran training samplenya ditambah, malah ngaco :'( kesimpulan sementara ga bisa pake HMM dengan cara ini ha ha ha.. (menyedihkan)
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.hendyirawan</groupId>
<artifactId>postagger</artifactId>
<version>1.0-SNAPSHOT</version>
<dependencies>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.1.7</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.21</version>
</dependency>
<dependency>
<groupId>org.apache.mahout</groupId>
<artifactId>mahout-mr</artifactId>
<version>0.12.2</version>
<exclusions>
<exclusion>
<artifactId>slf4j-log4j12</artifactId>
<groupId>org.slf4j</groupId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.mahout</groupId>
<artifactId>mahout-math</artifactId>
<version>0.12.2</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.5</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
import com.google.common.collect.ImmutableList;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmEvaluator;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmModel;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmTrainer;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
/**
* Created by ceefour on 14/01/2017.
*/
public class SentenceTagger {
private static Logger log = LoggerFactory.getLogger(SentenceTagger.class);
public enum Tag {
_,
DOT,
COMMA,
THX,
GREET,
I
}
static class Sample {
Tag hidden;
String text;
public Sample(Tag hidden, String text) {
this.hidden = hidden;
this.text = text;
}
int[] toHiddenSeq() {
return fill(text.length(), hidden.ordinal());
}
int[] toTextSeq() {
return textToSeq(text);
}
}
private List<Sample> trainSamples = new ArrayList<>();
public static int[] textToSeq(String text) {
final String lower = text.toLowerCase();
int[] result = new int[lower.length()];
for (int i = 0; i < lower.length(); i++) {
result[i] = lower.codePointAt(i);
}
return result;
}
public static int[] fill(int count, int body) {
int[] result = new int[count];
Arrays.fill(result, body);
return result;
}
public void addTrainSample(Tag hidden, String text) {
trainSamples.add(new Sample(hidden, text));
}
public void run() {
// hidden states:
// 0: UNKNOWN, 1: THANK_YOU, 2: GREETING
// observed states
// lowercased. 7-bit only. total 128 states
// TRAIN
addTrainSample(Tag._, " ");
addTrainSample(Tag.DOT, ".");
addTrainSample(Tag.COMMA, ",");
addTrainSample(Tag.I, "aku");
// addTrainSample(Tag.I, "I");
addTrainSample(Tag.THX, "thx");
// addTrainSample(Tag.THX, "thank you");
// addTrainSample(Tag.THX, "terima kasih");
// addTrainSample(Tag.THX, "makasih");
// addTrainSample(Tag.THX, "thanks");
// addTrainSample(Tag.THX, "thanx");
addTrainSample(Tag.GREET, "pagi");
// addTrainSample(Tag.GREET, "met pagi");
// addTrainSample(Tag.GREET, "met siang");
// addTrainSample(Tag.GREET, "selamat pagi");
// addTrainSample(Tag.GREET, "selamat siang");
// addTrainSample(Tag.GREET, "good morning");
// addTrainSample(Tag.GREET, "good day");
// addTrainSample(Tag.GREET, "good afternoon");
final int hiddenStateCount = Tag.values().length;
final HmmModel hmm = HmmTrainer.trainSupervisedSequence(hiddenStateCount, 128,
trainSamples.stream().map(Sample::toHiddenSeq).collect(Collectors.toList()),
trainSamples.stream().map(Sample::toTextSeq).collect(Collectors.toList()),
0.05);
hmm.registerHiddenStateNames(ImmutableList.copyOf(Tag.values()).stream().map(Tag::name).toArray(String[]::new));
String[] outputStateNames = new String[128];
for (int i = 0; i < 128; i++) {
outputStateNames[i] = new String(new int[] { i }, 0, 1);
}
hmm.registerOutputStateNames(outputStateNames);
// TEST
final String testText = "aku thxxxx pagi thx p4gi";
log.info("Test text: {}", testText);
final int[] decoded = HmmEvaluator.decode(hmm, textToSeq(testText), false);
final List<String> decodedNames = HmmUtils.decodeStateSequence(hmm, decoded, false, "");
log.info("Decoded: {}", (Object) decoded);
log.info("Decoded: {}", decodedNames);
}
public static void main(String... args) {
new SentenceTagger().run();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment