Skip to content

Instantly share code, notes, and snippets.

@seraphy
Last active August 9, 2019 01:49
Show Gist options
  • Save seraphy/a7ce1db3dadb9ef3cf66f5fd64646661 to your computer and use it in GitHub Desktop.
Save seraphy/a7ce1db3dadb9ef3cf66f5fd64646661 to your computer and use it in GitHub Desktop.
ソケットを直接使って、簡単なHTTP/1.1と、WebSocketのサーバーの実装例。
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<title>WebSocket Example</title>
<script type="text/javascript" src="index.js"></script>
</head>
<body>
<div>
<input id="btnConnect" type="button" onclick="doWebSocket()" value="Do WebSocket" />
<input id="btnClose" type="button" onclick="closeWebSocket()" value="Close" disabled/>
</div>
<div>
<input id="txtSend" type="text" value="" />
<input id="btnSend" type="submit" onclick="sendWebSocket()" value="send" disabled/>
</div>
<div>
<textarea id="txtRecv" cols="40" rows="20"></textarea>
</div>
<a href="/shutdown">shutdown</a>
</body>
</html>
var connection;
function doWebSocket() {
if (connection != null) {
return;
}
connection = new WebSocket('ws://localhost/websocket');
connection.onerror = function (evt) {
alert('error');
};
connection.onopen = function (evt) {
document.getElementById('btnSend').disabled = false;
document.getElementById('btnClose').disabled = false;
document.getElementById('btnConnect').disabled = true;
connection.send('hello world!' + new Date());
};
connection.onclose = function (evt) {
connection = null;
var text = '*CLOSED* code=' + evt.code;
document.getElementById('txtRecv').value += text + '\r\n';
document.getElementById('btnSend').disabled = true;
document.getElementById('btnClose').disabled = true;
document.getElementById('btnConnect').disabled = false;
};
connection.onmessage = function (evt) {
var text = evt.data;
document.getElementById('txtRecv').value += text + '\r\n';
};
}
function closeWebSocket() {
if (connection != null) {
connection.close();
connection = null;
}
}
function sendWebSocket() {
if (connection != null) {
var text = document.getElementById('txtSend').value;
connection.send(text);
}
}
package jp.seraphyware.example.java8learn.websocket;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* ソケットを直接使って、簡単なHTTP/1.1と、WebSocketのサーバーの実装例。
*/
public class WebSocketExample {
/**
* HTTPヘッダの区切りの判定用
*/
private static final byte[] headerDelimitter = "\r\n\r\n".getBytes();
/**
* サーバー終了フラグ
*/
private static AtomicBoolean shutdownFlag = new AtomicBoolean();
/**
* エントリ
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
shutdownFlag.set(false);
// 現在接続中のクライアントのリスト。
// サーバーの終了時に強制的に切断するため
ConcurrentLinkedDeque<Socket> clientSockets = new ConcurrentLinkedDeque<>();
// サーバーの開始
try (ServerSocket server = new ServerSocket(80)) {
// クライアントの接続数のカウント用
AtomicInteger clientCounter = new AtomicInteger();
AtomicInteger aliveThread = new AtomicInteger();
// サーバーの終了フラグが立つまでクライアントの待ち受けをループする
while (!shutdownFlag.get()) {
// クライアントからの接続を待機する。
// ※ ServerSocketをcloseすると、ここで例外が発生して待ち受けを終える
Socket client = server.accept();
// クライアントごとにスレッドを作成して処理する
new Thread(() -> {
clientSockets.add(client);
String clientNo = "(" + clientCounter.incrementAndGet() + ":" + aliveThread.incrementAndGet() + ")";
System.out.println("*connected " + clientNo);
try {
InputStream in = client.getInputStream();
OutputStream out = client.getOutputStream();
// KeepAlive(HTTP/1.1のデフォルト)の場合は接続を切らずに使い回す
for (;;) {
// クライアントからヘッダ部を読み込むか、終端に達するまで読み取る
String headerStr = readHeader(in);
if (headerStr.length() == 0) {
// ヘッダ部が空 = クライアントからの接続が切れて即時終端の場合
System.out.println("*client end " + clientNo);
break;
}
System.out.println("[" + clientNo + "] header:" + headerStr);
Map<String, String> header = parseHeader(headerStr);
if (!dispatch(header, in, out)) {
// Keep-Aliveでない場合はコネクションをcloseする
break;
}
}
clientSockets.remove(client);
client.close();
System.out.println("*closed " + clientNo);
if (shutdownFlag.get()) {
// サーバソケットを閉じて待ち受けを終了する
// (ただし、現在すでに接続されているKeep-Aliveの接続は維持されている)
server.close();
}
} catch (Exception e) {
e.printStackTrace();
}
System.out.println("残存スレッド数: " + aliveThread.decrementAndGet());
}).start();
}
} catch (IOException ex) {
ex.printStackTrace();
}
System.out.println("*server closed");
// 残存するクライアントを閉じる
clientSockets.forEach(client -> {
try {
client.close();
} catch (IOException ex) {
ex.printStackTrace();
}
});
}
/**
* ヘッダの文字列をマップに変換する。
* ステータス行は、METHOD, PATH, PROTOCOLの3つに分解される。
* @param header
* @return
*/
private static Map<String, String> parseHeader(String header) {
Map<String, String> headerKeyValue = new LinkedHashMap<>();
String[] lines = header.split("\r\n");
String firstLine = lines.length > 0 ? lines[0] : "";
String[] tokens = firstLine.trim().split(" ");
String method = tokens[0];
String path = tokens[1];
String protocol = tokens[2];
headerKeyValue.put("METHOD", method);
headerKeyValue.put("PATH", path);
headerKeyValue.put("PROTOCOL", protocol);
for (int idx = 1; idx < lines.length; idx++) {
String line = lines[idx].trim();
if (line.length() == 0) {
continue;
}
int pos = line.indexOf(":");
if (pos > 0) {
String key = line.substring(0, pos).trim();
String val = line.substring(pos + 1).trim();
headerKeyValue.put(key.toUpperCase(), val);
} else {
System.err.println("*unknown: " + line);
}
}
return headerKeyValue;
}
/**
* 入力ストリームからHTTPヘッダを読み取る。
* @param is
* @return
* @throws IOException
*/
private static String readHeader(InputStream is) throws IOException {
int idx = 0;
ByteArrayOutputStream bos = new ByteArrayOutputStream();
int ch;
while ((ch = is.read()) >= 0) {
bos.write(ch);
if (ch == headerDelimitter[idx]) {
idx++;
if (idx == headerDelimitter.length) {
break;
}
} else {
idx = 0;
}
}
return new String(bos.toByteArray(), StandardCharsets.UTF_8);
}
/**
* HTTPヘッダから、Keep-Aliveであるか判定する。
* HTTP/1.0の場合でConnectionがKeep-Aliveでない場合、
* もしくは、HTTP/1.1でConnectionがcloseの場合のみ、Keep-Aliveでないと判定される。
* @param reqHeaders
* @return
*/
private static boolean isKeepAlive(Map<String, String> reqHeaders) {
String protocol = reqHeaders.get("PROTOCOL");
boolean keepAlive;
if (protocol.equalsIgnoreCase("HTTP/1.0")) {
keepAlive = false;
} else {
// http/1.1はデフォルトでkeep-alive
keepAlive = true;
}
String connection = reqHeaders.get("CONNECTION");
if (connection != null) {
if (connection.equalsIgnoreCase("close")) {
// 明示的にcloseの指定がある場合
keepAlive = false;
} else {
// http/1.0ではclose以外を指定した場合にkeep-alive
// http/1.1では互換性のために"keep-alive"を送信するが、なくてもkeep-aliveである
keepAlive = true;
}
}
return keepAlive;
}
/**
* リクエストの処理を振り分ける
* @param reqHeader
* @param is
* @param os
* @return
* @throws IOException
*/
private static boolean dispatch(Map<String, String> reqHeader, InputStream is, OutputStream os) throws IOException {
String path = reqHeader.get("PATH");
// /shutdownへのアクセスがあったらサーバを終了する
if (path.startsWith("/shutdown")) {
return handleShutdown(reqHeader, os);
}
// websocketへのアクセスがあればwebsocketのコネクションにアップグレードする
if (path.startsWith("/websocket")) {
return handleWebsocket(reqHeader, is, os);
}
// それ以外はリソースを返す (なければ404)
return sendResource(reqHeader, os);
}
/**
* shutdownの要求があった場合にサーバーを停止させるための処理。
* @param reqHeaders
* @param os
* @return
* @throws IOException
*/
private static boolean handleShutdown(Map<String, String> reqHeaders, OutputStream os) throws IOException {
String body = "Shutdown::" + reqHeaders.get("PATH") + "::" + new Date();
String protocol = reqHeaders.get("PROTOCOL");
String header = protocol + " 200 OK\r\nContent-Type: text/plain\r\nContent-Length: " + body.length() + "\r\n";
header += "Connection: close\r\n";
os.write((header + "\r\n").getBytes(StandardCharsets.UTF_8));
os.write(body.getBytes(StandardCharsets.UTF_8));
shutdownFlag.set(true);
return false;
}
/**
* WebSocketのSec-WebSocket-Acceptを生成するためのマジックナンバー(RFCの定義)
* https://developer.mozilla.org/ja/docs/WebSockets-840092-dup/Writing_WebSocket_servers
*/
private static final String WEBSOCKET_ACCEPT_MAGIC = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
/**
* ウェブソケットにアップグレードして、エコーを処理する。
* @param reqHeaders
* @param is
* @param os
* @return
* @throws IOException
*/
private static boolean handleWebsocket(Map<String, String> reqHeaders, InputStream is, OutputStream os)
throws IOException {
String protocol = reqHeaders.get("PROTOCOL");
String connection = reqHeaders.get("CONNECTION");
String upgrade = reqHeaders.get("UPGRADE");
String webSocketKey = reqHeaders.get("SEC-WEBSOCKET-KEY");
if (upgrade == null || connection == null || !upgrade.equalsIgnoreCase("websocket")
|| !connection.equalsIgnoreCase("Upgrade") || webSocketKey == null || webSocketKey.length() == 0) {
// websocketのupgrade要求でない場合
String header = protocol + " 403 Forbidden\r\nContent-Length: 0\r\nConnection: close\r\n\r\n";
os.write(header.getBytes());
return false;
}
String rawKey = webSocketKey + WEBSOCKET_ACCEPT_MAGIC;
MessageDigest sha1;
try {
sha1 = MessageDigest.getInstance("sha1");
} catch (NoSuchAlgorithmException ex) {
throw new RuntimeException(ex);
}
byte[] hashedKey = sha1.digest(rawKey.getBytes());
String secWebSocketAccept = Base64.getEncoder().encodeToString(hashedKey);
System.out.println("secWebSocketAccept=" + secWebSocketAccept);
StringBuilder buf = new StringBuilder();
buf.append(protocol + " 101 Switching Protocols\r\n");
buf.append("Upgrade: websocket\r\n");
buf.append("Connection: Upgrade\r\n");
buf.append("Sec-WebSocket-Accept: " + secWebSocketAccept + "\r\n");
//buf.append("Sec-WebSocket-Protocol: chat\r\n");
buf.append("\r\n");
System.out.println(buf.toString());
os.write(buf.toString().getBytes());
// クライアントからのエコーループに入る
// クライアントからquitの送信、またはクローズ通知、もしくはソケットが閉じられた場合にループを終了する。
recvLoop(is, os);
// 接続終了
return false;
}
/**
* クライアントからメッセージを受信し、それをエコーとして返す。
* quitを受信するか、クライアントから終了通知があるか、ソケットが閉じられた場合に終了する。
* 終了する際にはクライアントにも終了通知を出す。ただし、ピアの強制終了通知(1001 going away)か、
* クライアントソケット終了の場合は、クライアントには終了通知を送信しない。
* @param is
* @param os
* @throws IOException
*/
private static void recvLoop(InputStream is, OutputStream os) throws IOException {
// https://developer.mozilla.org/ja/docs/WebSockets-840092-dup/Writing_WebSocket_servers
// https://triple-underscore.github.io/RFC6455-ja.html#section-11.3.4
for (;;) {
ByteArrayOutputStream recvBuf = new ByteArrayOutputStream();
int firstOpe = 0;
int recvError = 0;
for (;;) {
int b1 = is.read();
if (b1 < 0) {
// クライアント側から強制切断、もしくはネットワークが何らかの原因でクローズされた場合
// (なお、ブラウザやページをユーザが閉じた場合はソケットが閉じる前に1001クローズ送信が送信される。)
System.out.println("★ disconnected websocket.");
recvError = 1001; // Going Away (ピアの切断済み)
break;
}
b1 &= 0xff; // 8bitに整形する
System.out.println("*recv");
System.out.println("b1=" + Integer.toHexString(b1));
boolean fin = ((b1 >> 7) & 0xff) != 0;
int ope = b1 & 0x0f;
System.out.println(" fin=" + fin + ", ope=" + ope);
// 最初のopeコードを記憶する。
// finフラグが立っておらず継続している場合はopeコードを上書きしない。
// (継続fin != falseの場合はope == 0となるはず)
if (firstOpe == 0 && ope != 0) {
firstOpe = ope;
}
int b2 = is.read() & 0xff; // 8bitに整形する
System.out.println("b2=" + Integer.toHexString(b2));
boolean mask = ((b2 >> 7) & 0xff) != 0;
int len = b2 & 0x7f; // 125バイトまでの長さの判定
if (len == 126 || len == 127) {
DataInputStream dis = new DataInputStream(is);
if (len == 126) {
// 継続の16bitが真の長さ
len = dis.readShort();
} else if (len == 127) {
// 後続の64bitが真の長さ
len = (int) dis.readLong();
}
System.out.println(" long-len=" + len);
}
// マスクビットが立っている場合、後続の4バイトはxorするためのマスク値
// (クライアントから送信される場合は必ずランダム値でマスクされる。)
// (これはプロキシなどでキャッシュが返されないようにするための措置)
byte[] maskingKey = new byte[4];
if (mask) {
is.read(maskingKey);
}
// ※ クライアントから送信されたデータにmaskフラグがない場合は、本来は
// サーバ側はclose operationをエラーコード1002のボディとともに送信して閉じる必要がある。
// https://tools.ietf.org/html/rfc6455#section-5.1
// https://developer.mozilla.org/ja/docs/Web/API/CloseEvent ← エラー一覧
if (!mask) {
recvError = 1002;
}
System.out.println(" mask=" + mask + ", len=" + len);
// データ部の受信
byte[] recv = new byte[len];
is.read(recv);
// マスク値でxorをとる。(マスクがない場合は、そのまま)
for (int idx = 0; idx < len; idx++) {
recv[idx] = (byte) (recv[idx] ^ maskingKey[idx % 4]);
}
// 復号された受信データをfinがくるまで貯めておく
recvBuf.write(recv);
if (fin) {
// finの場合はメッセージ完了
break;
}
}
// opeコードが8の場合はクローズ通知である
if (firstOpe == 8) {
// close opecode
byte[] bytes = recvBuf.toByteArray();
String code;
if (bytes.length < 2) {
// JavaScriptクライアントから明示的にcloseされた場合などはコードがない
recvError = 1000; // Normal Closure
code = Integer.toString(recvError);
} else {
// コードが2バイト以上の場合
// (Closeフレームはコードは2バイトで、追加の情報が付与されていることもありえる)
recvError = ((bytes[1] & 0xff) | (bytes[0] & 0xff) << 8) & 0xffff;
code = Integer.toString(recvError);
code += "@" + IntStream.range(0, bytes.length)
.mapToObj(idx -> Integer.toHexString(bytes[idx] & 0xff)).collect(Collectors.joining(":"));
}
System.out.println("☆ websocket client close operation. code=" + code);
}
// エラーがない場合は受信したテキストを表示して、
// そのままクライアントにエコーバックする
if (recvError == 0) {
// 受信したテキスト
String msg = new String(recvBuf.toByteArray(), StandardCharsets.UTF_8);
System.out.println(">>" + msg);
// エラーでない場合
// 受信したテキストをエコー送信する
sendWebsocket(os, msg);
// quitメッセージを受信したらノーマルエンドのCloseを送信して接続を閉じる
if ("quit".equals(msg)) {
recvError = 1000; // Normal Closure
}
}
if (recvError != 0) {
// エラーが発生している場合、もしくはクライアントがCloseした場合、もしくは終了要求の場合、
// サーバからもCloseを送信して接続を閉じる
if (recvError != 1001) {
// ただし、1001 Going Away の場合は、すでに接続が切られているのでCloseの送信はしない。
// (クライアントのソケットがクローズされている場合は1001と同じように扱う)
sendCloseOpe(os, (short) recvError);
}
// ループ終了
break;
}
}
}
/**
* クローズフレームを送信する。
* ただし、送信に失敗しても例外は返さない。(終了通知なので、このあと送受信することはない。)
*
* 定義済みエラーコードは以下のとおり。
* https://developer.mozilla.org/ja/docs/Web/API/CloseEvent
* @param os
* @param reason エラーコード
*/
private static void sendCloseOpe(OutputStream os, short reason) {
try {
DataOutputStream dos = new DataOutputStream(os);
int s1 = 0x88; // FIN, CLOSE
dos.write(s1);
int s2 = 2 & 0x7f; // NO MASK + 2Bytes
dos.write(s2);
dos.writeShort(reason);
} catch (IOException ex) {
ex.printStackTrace();
}
}
private static void sendWebsocket(OutputStream os, String msg) throws IOException {
byte[] data = msg.getBytes(StandardCharsets.UTF_8);
int len = data.length;
DataOutputStream dos = new DataOutputStream(os);
int s1 = 0x81; // FIN, TEXT
dos.write(s1);
if (len <= 125) {
// 125以下の場合
int s2 = len & 0x7f; // NO MASK
dos.write(s2);
} else if (len <= 65535) {
int s2 = 126 & 0x7f; // NO MASK
dos.write(s2);
dos.writeShort(len);
} else {
int s2 = 127 & 0x7f; // NO MASK
dos.write(s2);
dos.writeLong(len);
}
dos.write(data);
}
/**
* リソースを返す。なければ404
* @param reqHeaders
* @param os
* @return
* @throws IOException
*/
private static boolean sendResource(Map<String, String> reqHeaders, OutputStream os) throws IOException {
String protocol = reqHeaders.get("PROTOCOL");
// パスの取得
String path = reqHeaders.get("PATH");
if (path != null) {
path = path.substring(1); // 先頭の/は消す(手抜き)
}
if ("".equals(path)) {
// 空はindex.htmlと見なす
path = "index.html";
}
// クエリとパスの分離
int qpos = path.indexOf("?");
String query;
if (qpos >= 0) {
query = path.substring(qpos + 1);
path = path.substring(0, qpos);
} else {
query = "";
}
System.out.println("request path: " + path);
System.out.println("request query: " + query);
String header;
byte[] body;
if (path == null || path.contains("/") || path.contains("\\")) {
// 手抜きのため、トップディレクトリ以外の探索はしない
body = ("403 Forbidden: " + path).getBytes(StandardCharsets.UTF_8);
header = protocol + " 403 Forbidden\r\nContent-Type: text/plain\r\nContent-Length: " + body.length + "\r\n";
} else {
// このクラスと同じパッケージ上のリソースを探索する
ByteArrayOutputStream tmp = null;
try (InputStream is = WebSocketExample.class.getResourceAsStream(path)) {
if (is != null) {
tmp = new ByteArrayOutputStream();
byte[] buf = new byte[4096];
for (;;) {
int rd = is.read(buf);
if (rd < 0) {
break;
}
tmp.write(buf, 0, rd);
}
}
}
if (tmp == null) {
body = ("404 Not Found: " + path).getBytes(StandardCharsets.UTF_8);
header = protocol + " 404 Not Found\r\nContent-Type: text/plain\r\nContent-Length: " + body.length
+ "\r\n";
} else {
body = tmp.toByteArray();
String contentType;
if (path.endsWith(".html") || path.endsWith(".htm")) {
contentType = "text/html";
} else if (path.endsWith(".js")) {
contentType = "text/javascript";
} else if (path.endsWith(".png")) {
contentType = "image/png";
} else if (path.endsWith(".jpeg") || path.endsWith(".jpg")) {
contentType = "image/jpeg";
} else {
contentType = "text/plain";
}
header = protocol + " 200 OK\r\nContent-Type: " + contentType + "\r\nContent-Length: " + body.length
+ "\r\n";
}
}
boolean keepAlive = isKeepAlive(reqHeaders);
if (!keepAlive) {
header += "Connection: close\r\n";
}
os.write((header + "\r\n").getBytes(StandardCharsets.UTF_8));
os.write(body);
return keepAlive;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment