Skip to content

Instantly share code, notes, and snippets.

@gangtao
Last active September 21, 2017 06:14
Show Gist options
  • Save gangtao/c04d4d404f153956b6de19ee0a8cb5c0 to your computer and use it in GitHub Desktop.
Save gangtao/c04d4d404f153956b6de19ee0a8cb5c0 to your computer and use it in GitHub Desktop.
ML Explained Naive Bayes
<div class="container">
<div class="row">
<div class="col-sm-6">
<div id="chart"></div>
</div>
<div class="col-sm-6">
<button type="button" class="btn btn-xs" id="new_category_button">New Category Data</button>
<span id="category_info"></span>
<button type="button" class="btn btn-xs" id="clean_button">Clear</button>
<label><input type="checkbox" id="predict_check"> Predict</label>
</div>
</div>
</div>
var size = 400; //The size of the canvas
var margin_size = 50;
var point_size = 8;
var colors = d3.scaleOrdinal(d3.schemeCategory10);
var domain_max = 100;
var data = [];
var current_category = undefined;
var xScale, yScale;
var nb = undefined; // model
function drawCircle(container, p, r, color) {
var circle = container
.append("circle")
.attr("cx", p.x)
.attr("cy", p.y)
.attr("r", r)
.classed("circle", true);
if (color) {
circle.style("fill", color);
}
return circle;
}
function drawLine(container, p1, p2) {
var line = container
.append("line")
.attr("x1", p1.x)
.attr("y1", p1.y)
.attr("x2", p2.x)
.attr("y2", p2.y)
.classed("line", true);
return line;
}
function gen_new_category() {
if (current_category == undefined) {
current_category = 0;
} else {
current_category = current_category + 1;
}
d3
.select("#category_info")
.text("category" + current_category)
.style("color", colors(current_category));
}
function clean() {
data = [];
current_category = undefined;
total_category = 0;
d3.select("#category_info").text("");
d3.select("#train_info").text("");
$(".circle").remove();
$(".line").remove();
}
function train() {
var train_data = [];
var labels = [];
data.map(function(d) {
train_data.push([d.x, d.y]);
labels.push(d.category);
});
nb = new ML.SL.NaiveBayes();
nb.train(train_data, labels);
console.log("training complete");
}
function drawPredict(model, container) {
var predict_data = [];
for (var i = 1; i <= domain_max; i++) {
for (var j = 1; j <= domain_max; j++) {
predict_data.push([i, j]);
}
}
var label = nb.predict(predict_data);
var length = predict_data.length;
for (var i = 0; i < length; i++) {
var point = {};
point.x = predict_data[i][0];
point.y = predict_data[i][1];
var predict_point = drawCircle(
container,
{ x: xScale(point.x), y: yScale(point.y) },
1,
colors(label[i])
);
predict_point.classed("predict", true).style("stroke", colors(label[i]));
}
}
$(function() {
var margin = {
top: margin_size,
right: margin_size,
bottom: margin_size,
left: margin_size
},
width = size - margin.left - margin.right,
height = size - margin.top - margin.bottom;
var root = d3
.select("#chart")
.append("svg")
.attr("width", size)
.attr("height", size);
var g = root
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
var layer1 = root.append("g").attr("id", "layer1");
//Draw Axis
xScale = d3.scaleLinear().rangeRound([0, width]);
yScale = d3.scaleLinear().rangeRound([height, 0]);
xScale.domain([0, domain_max]);
yScale.domain([0, domain_max]);
g
.append("g")
.attr("transform", "translate(0," + height + ")")
.call(d3.axisBottom(xScale));
g.append("g").call(d3.axisLeft(yScale));
// Generate Data
$("#new_category_button").click(function() {
gen_new_category();
root.on("click", function() {
var coords = d3.mouse(this);
var mapped_coords = [coords[0] - margin_size, coords[1] - margin_size];
// convert to data domain
var newData = {
x: Math.round(xScale.invert(mapped_coords[0])), // Takes the pixel number to convert to number
y: Math.round(yScale.invert(mapped_coords[1]))
};
newData.coordinates = { x: coords[0], y: coords[1] };
if (current_category !== undefined) {
drawCircle(
g,
{ x: mapped_coords[0], y: mapped_coords[1] },
point_size,
colors(current_category)
);
newData.category = current_category;
data.push(newData);
}
});
});
$("#clean_button").click(function() {
clean();
});
$("#predict_check").change(function() {
train();
// this will contain a reference to the checkbox
if (this.checked) {
drawPredict(nb, g);
} else {
$(".predict").remove();
}
});
$("#predict_button").click(function() {
root.on("mousemove", function(d) {
var coords = d3.mouse(this);
var mapped_coords = [coords[0] - margin_size, coords[1] - margin_size];
var current_data_point = [];
current_data_point.push(Math.round(xScale.invert(mapped_coords[0])));
current_data_point.push(Math.round(yScale.invert(mapped_coords[1])));
var current_point = {};
current_point.x = mapped_coords[0];
current_point.y = mapped_coords[1];
var label = nb.predict([current_data_point]);
var predict_point = drawCircle(g, current_point, 8, colors(label[0]));
predict_point
.transition()
.duration(1000)
.style("fill-opacity", 0)
.style("stroke-width", "0px");
});
});
});
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/4.9.1/d3.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
<script src="https://www.lactame.com/lib/ml/2.0.0/ml.min.js"></script>
body {
background-color: #555;
margin-top: 10px;
}
#chart {
background-color: #555;
}
.row {
padding-top: 2px;
}
.circle {
stroke: #000;
stroke-width: 1px;
fill: #fa6900;
fill-opacity: 1;
}
.line {
stroke: #ccc;
stroke-width: 3px;
}
.predict {
opacity: 0.2;
}
<link href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/css/bootstrap.min.css" rel="stylesheet" />
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment