|
/* |
|
* Copyright 2016 Xavier Mendez |
|
* All rights reserved. |
|
* |
|
* Redistribution and use in source and binary forms, with or without |
|
* modification, are permitted provided that the following conditions |
|
* are met: |
|
* 1. Redistributions of source code must retain the above copyright |
|
* notice, this list of conditions and the following disclaimer. |
|
* 2. Redistributions in binary form must reproduce the above copyright |
|
* notice, this list of conditions and the following disclaimer in the |
|
* documentation and/or other materials provided with the distribution. |
|
* |
|
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR |
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES |
|
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. |
|
* IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, |
|
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT |
|
* NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF |
|
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
*/ |
|
|
|
#include <stdio.h> |
|
#include <string.h> |
|
#include <stddef.h> |
|
#include <stdlib.h> |
|
#include <unistd.h> |
|
#include <poll.h> |
|
#include <ctype.h> |
|
#include <errno.h> |
|
#include <assert.h> |
|
#include <sys/types.h> |
|
#include <sys/socket.h> |
|
#include <sys/uio.h> |
|
#include <netdb.h> |
|
#include <arpa/inet.h> |
|
#include <netinet/in.h> |
|
|
|
char xdigitvalue(char x) { |
|
if (x > '9') x += 9; |
|
return x & 0xF; |
|
} |
|
|
|
void percent_decode(char *p) { |
|
while (*p) { |
|
if (!(*(p++) == '%' && isxdigit(p[0]) && isxdigit(p[1]))) continue; |
|
char c = (xdigitvalue(p[0]) << 4) | xdigitvalue(p[1]); |
|
if (!c) continue; |
|
p[-1] = c; |
|
memmove(p, p+2, strlen(p+2)); |
|
} |
|
} |
|
|
|
struct uri_parts { |
|
char *scheme; |
|
char *auth; |
|
char *host; |
|
int port; |
|
char *path; |
|
char *fragment; |
|
}; |
|
|
|
/* Scans the URI in name, sets each part if present. |
|
Note: this will destroy `name` contents. |
|
Note: host is not validated. |
|
Note: fields are not percent-decoded. |
|
*/ |
|
int scan_uri(char *name, struct uri_parts *parts) { |
|
char *p; |
|
memset(parts, 0, sizeof(struct uri_parts)); |
|
|
|
if ((p = strchr(name, '#')) != NULL) { |
|
*p++ = '\0'; |
|
parts->fragment = p; |
|
} |
|
|
|
if ((p = strchr(name, ' ')) != NULL) *p++ = '\0'; |
|
|
|
for (p = name; *p; p++) { |
|
if (isspace((int) *p)) { |
|
char *orig = p, *dest = p+1; |
|
while ((*orig++ = *dest++)); |
|
p = p-1; |
|
} |
|
if (*p == '/' || *p == '#' || *p == '?') |
|
break; |
|
if (*p == ':') { |
|
*p = 0; |
|
parts->scheme = name; |
|
name = p+1; |
|
break; |
|
} |
|
} |
|
|
|
p = name; |
|
if (p[0] != '/' || p[1] != '/') return -1; |
|
parts->host = p+2; |
|
*p = 0; |
|
p = strchr(parts->host, '/'); |
|
if (p) { |
|
*p = 0; |
|
parts->path = p+1; |
|
} |
|
|
|
p = parts->host + strlen(parts->host); |
|
while (p > parts->host && isdigit(*(--p))); |
|
if (*p == ':') { |
|
*p++ = 0; |
|
if (*p) { |
|
parts->port = atoi(p); |
|
if (!(parts->port > 0 && parts->port < 65536)) return -1; |
|
} |
|
} |
|
|
|
p = parts->host; |
|
while (isalnum(*p) || strchr("-._~!$&'()*+,;=%:", *p) != NULL) p++; |
|
if (*p == '@') { |
|
*p++ = 0; |
|
parts->auth = parts->host; |
|
parts->host = p; |
|
} |
|
|
|
if (!(parts->scheme && *parts->scheme && *parts->host)) return -1; |
|
return 0; |
|
} |
|
|
|
int connect_to(const char *hostname, int port) { |
|
struct addrinfo hints; |
|
struct addrinfo *result, *rp; |
|
int r, sock, connect_errno = -1; |
|
char service [8]; |
|
|
|
assert(port > 0 && port < 65536); |
|
r = sprintf(service, "%d", port); |
|
assert(r >= 0); |
|
|
|
memset(&hints, 0, sizeof(struct addrinfo)); |
|
hints.ai_family = AF_UNSPEC; |
|
hints.ai_socktype = SOCK_STREAM; |
|
|
|
r = getaddrinfo(hostname, service, &hints, &result); |
|
if (r != 0) { |
|
fprintf(stderr, "ssh: couldn't resolve %s: %s\n", hostname, gai_strerror(r)); |
|
exit(EXIT_FAILURE); |
|
} |
|
|
|
for (rp = result; rp != NULL; rp = rp->ai_next) { |
|
sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); |
|
if (sock == -1) |
|
continue; |
|
|
|
if (connect(sock, rp->ai_addr, rp->ai_addrlen) != -1) |
|
break; |
|
connect_errno = errno; |
|
close(sock); |
|
} |
|
|
|
freeaddrinfo(result); |
|
|
|
if (rp == NULL) { |
|
fprintf(stderr, "ssh: couldn't connect to %s port %d: %s\n", hostname, port, strerror(connect_errno)); |
|
return -1; |
|
} |
|
|
|
return sock; |
|
} |
|
|
|
/* Borrowed from ssh source code */ |
|
int mm_send_fd(int sock, int fd) { |
|
struct msghdr msg; |
|
union { |
|
struct cmsghdr hdr; |
|
char buf[CMSG_SPACE(sizeof(int))]; |
|
} cmsgbuf; |
|
struct cmsghdr *cmsg; |
|
struct iovec vec; |
|
char ch = '\0'; |
|
ssize_t n; |
|
struct pollfd pfd; |
|
|
|
memset(&msg, 0, sizeof(msg)); |
|
memset(&cmsgbuf, 0, sizeof(cmsgbuf)); |
|
msg.msg_control = (caddr_t)&cmsgbuf.buf; |
|
msg.msg_controllen = sizeof(cmsgbuf.buf); |
|
cmsg = CMSG_FIRSTHDR(&msg); |
|
cmsg->cmsg_len = CMSG_LEN(sizeof(int)); |
|
cmsg->cmsg_level = SOL_SOCKET; |
|
cmsg->cmsg_type = SCM_RIGHTS; |
|
*(int *)CMSG_DATA(cmsg) = fd; |
|
|
|
vec.iov_base = &ch; |
|
vec.iov_len = 1; |
|
msg.msg_iov = &vec; |
|
msg.msg_iovlen = 1; |
|
|
|
pfd.fd = sock; |
|
pfd.events = POLLOUT; |
|
while ((n = sendmsg(sock, &msg, 0)) == -1 && |
|
(errno == EAGAIN || errno == EINTR)) { |
|
(void)poll(&pfd, 1, -1); |
|
} |
|
if (n == -1) { |
|
fprintf(stderr, "%s: sendmsg(%d): %s", __func__, fd, strerror(errno)); |
|
return -1; |
|
} |
|
|
|
if (n != 1) { |
|
fprintf(stderr, "%s: sendmsg: expected sent 1 got %zd", __func__, n); |
|
return -1; |
|
} |
|
return 0; |
|
} |
|
|
|
// TODO: implement authentication |
|
int dial_http(int sock, const char *proxy_auth, const char *hostname, int port) { |
|
int r, terminator_len; |
|
char terminator_str [24]; |
|
struct iovec iov[3]; |
|
char code_str [4]; |
|
size_t code_len = 0; |
|
char c; |
|
|
|
assert(port > 0 && port < 65536); |
|
terminator_len = sprintf(terminator_str, ":%d HTTP/1.1\r\n\r\n", port); |
|
assert(terminator_len >= 0); |
|
|
|
iov[0].iov_base = "CONNECT "; |
|
iov[0].iov_len = sizeof("CONNECT ")-1; |
|
iov[1].iov_base = (char *) hostname; |
|
iov[1].iov_len = strlen(hostname); |
|
iov[2].iov_base = terminator_str; |
|
iov[2].iov_len = terminator_len; |
|
if ((r = writev(sock, iov, 3)) < 0) { |
|
fprintf(stderr, "ssh: couldn't send request to proxy: %s\n", strerror(errno)); |
|
return 1; |
|
} |
|
|
|
while (1) { |
|
if ((r = read(sock, &c, 1)) != 1) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
if (!isdigit(c) || code_len >= 3) |
|
break; |
|
code_str[code_len++] = c; |
|
} |
|
code_str[code_len] = 0; |
|
if (code_len != 3 || (c != ' ' && c != '\r')) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
|
|
int line_empty = 0; |
|
while (1) { |
|
while (c != '\r') { |
|
line_empty = 0; |
|
if (c == '\n' || (r = read(sock, &c, 1)) != 1) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
} |
|
if ((r = read(sock, &c, 1)) != 1 || c != '\n') { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
if (line_empty) break; |
|
if ((r = read(sock, &c, 1)) != 1) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
line_empty = 1; |
|
} |
|
|
|
if (strcmp(code_str, "200") != 0) { |
|
fprintf(stderr, "ssh: proxy rejected connection (%s response)\n", code_str); |
|
return 1; |
|
} |
|
return 0; |
|
} |
|
|
|
int dial_socks4(int sock, const char *proxy_auth, const char *hostname, int port) { |
|
int r; |
|
char command [9], reply [8]; |
|
struct iovec iov [3]; |
|
in_addr_t ip = inet_addr(hostname); |
|
uint16_t nport = htons(port); |
|
|
|
if (ip == INADDR_NONE) ip = inet_addr("0.0.0.1"); |
|
else hostname = NULL; |
|
|
|
if (proxy_auth) |
|
fprintf(stderr, "ssh: warning: authentication not supported with SOCKS4\n"); |
|
|
|
/* Send connect command */ |
|
command[0] = 4; |
|
command[1] = 1; |
|
memcpy(&command[2], &nport, 2); |
|
memcpy(&command[4], &ip, 4); |
|
command[8] = 0; |
|
|
|
iov[0].iov_base = command; |
|
iov[0].iov_len = sizeof(command); |
|
if (hostname) { |
|
iov[1].iov_base = (char *) hostname; |
|
iov[1].iov_len = strlen(hostname) + 1; |
|
r = writev(sock, iov, 2); |
|
} else { |
|
r = writev(sock, iov, 1); |
|
} |
|
if (r < 0) { |
|
fprintf(stderr, "ssh: couldn't send request to proxy: %s\n", strerror(errno)); |
|
return 1; |
|
} |
|
|
|
/* Receive response */ |
|
if ((r = read(sock, reply, sizeof(reply))) != sizeof(reply)) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
if (reply[0] != 0 || reply[1] < 0x5A || reply[1] > 0x5D) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
if (reply[1] != 0x5A) { |
|
fprintf(stderr, "ssh: connection rejected by proxy (0x%02X)\n", reply[1]); |
|
return 1; |
|
} |
|
|
|
return 0; |
|
} |
|
|
|
int dial_socks5(int sock, const char *proxy_auth, const char *hostname, int port) { |
|
static const char* SOCKS5_MESSAGES [] = { |
|
"OK", |
|
"general failure", |
|
"connection not allowed by ruleset", |
|
"network unreachable", |
|
"host unreachable", |
|
"connection refused by destination host", |
|
"TTL expired", |
|
"command not supported / protocol error", |
|
"address type not supported", |
|
}; |
|
|
|
int r, to_discard; |
|
char greeting [3], server_greeting [2], command [5], reply [5]; |
|
struct iovec iov [5]; |
|
in_addr_t ip = inet_addr(hostname); |
|
uint16_t nport = htons(port); |
|
char auth_method = proxy_auth ? 0x02 : 0x00; |
|
|
|
/* Send greeting */ |
|
greeting[0] = 5; |
|
greeting[1] = 1; |
|
greeting[2] = auth_method; |
|
if ((r = write(sock, greeting, sizeof(greeting))) != sizeof(greeting)) { |
|
fprintf(stderr, "ssh: couldn't send greeting to proxy\n"); |
|
return 1; |
|
} |
|
|
|
/* Receive response */ |
|
if ((r = read(sock, server_greeting, sizeof(server_greeting))) != sizeof(server_greeting)) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
if (server_greeting[0] != 5 || (server_greeting[1] != auth_method && server_greeting[1] != -1)) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
if (server_greeting[1] == -1) { |
|
fprintf(stderr, "ssh: authentication method not supported by SOCKS5 proxy\n"); |
|
return 1; |
|
} |
|
|
|
/* Authenticate */ |
|
if (auth_method) { |
|
char version = 1, username_length, password_length; |
|
char auth_reply [2]; |
|
char *username, *password; |
|
|
|
if ((username = strdup(proxy_auth)) == NULL) abort(); |
|
for (password = username; *password; password++) |
|
if (*password == ':') { |
|
*password++ = 0; |
|
break; |
|
} |
|
username_length = strlen(username); |
|
password_length = strlen(password); |
|
|
|
iov[0].iov_base = &version; |
|
iov[0].iov_len = sizeof(version); |
|
iov[1].iov_base = &username_length; |
|
iov[1].iov_len = sizeof(username_length); |
|
iov[2].iov_base = username; |
|
iov[2].iov_len = username_length; |
|
iov[3].iov_base = &password_length; |
|
iov[3].iov_len = sizeof(password_length); |
|
iov[4].iov_base = password; |
|
iov[4].iov_len = password_length; |
|
if ((r = writev(sock, iov, 5)) < 0) { |
|
fprintf(stderr, "ssh: couldn't send authentication data to proxy\n"); |
|
free(username); |
|
return 1; |
|
} |
|
|
|
if ((r = read(sock, auth_reply, sizeof(auth_reply))) != sizeof(auth_reply) || auth_reply[0] != version) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
free(username); |
|
return 1; |
|
} |
|
if (auth_reply[1] != 0) { |
|
fprintf(stderr, "ssh: proxy authentication failed\n"); |
|
free(username); |
|
return 1; |
|
} |
|
|
|
free(username); |
|
} |
|
|
|
/* Send connect request */ |
|
command[0] = 5; |
|
command[1] = 1; |
|
command[2] = 0; |
|
iov[0].iov_base = command; |
|
iov[0].iov_len = 4; |
|
|
|
// FIXME: implement IPv6 case |
|
if (ip != INADDR_NONE) { |
|
command[3] = 1; |
|
iov[1].iov_base = &ip; |
|
iov[1].iov_len = sizeof(ip); |
|
} else { |
|
command[3] = 3; |
|
command[4] = strlen(hostname); |
|
iov[0].iov_len = 5; |
|
iov[1].iov_base = (char *) hostname; |
|
iov[1].iov_len = strlen(hostname); |
|
} |
|
|
|
iov[2].iov_base = &nport; |
|
iov[2].iov_len = sizeof(nport); |
|
|
|
if ((r = writev(sock, iov, 3)) < 0) { |
|
fprintf(stderr, "ssh: couldn't send request to proxy\n"); |
|
return 1; |
|
} |
|
|
|
/* Receive response */ |
|
if ((r = read(sock, reply, sizeof(reply))) != sizeof(reply)) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
if (reply[0] != 5 || reply[1] < 0 || reply[1] > 8 || reply[2] != 0) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
to_discard = 2; |
|
switch (reply[3]) { |
|
case 1: |
|
to_discard += 3; |
|
break; |
|
case 3: |
|
to_discard += reply[4]; |
|
break; |
|
case 4: |
|
to_discard += 15; |
|
break; |
|
default: |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
while (to_discard > 0) { |
|
char c; |
|
if ((r = read(sock, &c, 1)) != 1) { |
|
fprintf(stderr, "ssh: malformed response\n"); |
|
return 1; |
|
} |
|
--to_discard; |
|
} |
|
|
|
if (reply[1] != 0) { |
|
fprintf(stderr, "ssh: tunnel failed (0x%02X): %s\n", reply[1], SOCKS5_MESSAGES[reply[1]]); |
|
return 1; |
|
} |
|
|
|
return 0; |
|
} |
|
|
|
struct proxy_protocol { |
|
const char *scheme; |
|
int default_port; |
|
int (*dialer)(int sock, const char *proxy_auth, const char *hostname, int port); |
|
}; |
|
|
|
struct proxy_protocol proxy_protocols[] = { |
|
{ "http", 80, dial_http }, |
|
{ "socks4", 1080, dial_socks4 }, |
|
{ "socks", 1080, dial_socks5 }, |
|
{ "socks5", 1080, dial_socks5 }, |
|
{ NULL } |
|
}; |
|
|
|
const struct proxy_protocol* match_scheme(const char *scheme) { |
|
int i; |
|
for (i = 0; proxy_protocols[i].scheme; i++) |
|
if (strcasecmp(proxy_protocols[i].scheme, scheme) == 0) |
|
return &proxy_protocols[i]; |
|
return NULL; |
|
} |
|
|
|
int dial_proxy(const char *proxy_url, const char *hostname, const char *port_str) { |
|
int r, sock, port; |
|
char *end_ptr, *tmpbuffer; |
|
struct uri_parts proxy; |
|
const struct proxy_protocol *protocol; |
|
|
|
/* Parse port, validate hostname */ |
|
port = strtol(port_str, &end_ptr, 10); |
|
if (!(*port_str && !(*end_ptr) && port > 0 && port < 65536)) { |
|
fprintf(stderr, "Invalid port given\n"); |
|
return 1; |
|
} |
|
|
|
/* Parse proxy URL */ |
|
if (!proxy_url) { |
|
fprintf(stderr, "No proxy URL set, please check your ssh config\n"); |
|
return 1; |
|
} |
|
if ((tmpbuffer = strdup(proxy_url)) == NULL) abort(); |
|
if ((r = scan_uri(tmpbuffer, &proxy))) { |
|
fprintf(stderr, "ssh: Invalid proxy URL: \"%s\"\n", proxy_url); |
|
return 1; |
|
} |
|
|
|
/* Match scheme */ |
|
if ((protocol = match_scheme(proxy.scheme)) == NULL) { |
|
fprintf(stderr, "ssh: Unknown scheme \"%s\" in proxy URL: \"%s\"\n", proxy.scheme, proxy_url); |
|
return 1; |
|
} |
|
if (!proxy.port) |
|
proxy.port = protocol->default_port; |
|
|
|
/* Connect to proxy */ |
|
if ((sock = connect_to(proxy.host, proxy.port)) == -1) |
|
return 1; |
|
|
|
/* Make connection */ |
|
if ((r = protocol->dialer(sock, proxy.auth, hostname, port))) |
|
return 1; |
|
|
|
/* Pass connected FD back to ssh */ |
|
if ((r = mm_send_fd(1, sock))) |
|
return 1; |
|
|
|
free(tmpbuffer); |
|
return 0; |
|
} |
|
|
|
int main(int argc, char **argv) { |
|
char *proxy_url = getenv("ssh_proxy"); |
|
if (proxy_url == NULL) |
|
proxy_url = getenv("all_proxy"); |
|
|
|
if (argc == 2 && strcmp(argv[1], "test") == 0) |
|
return (proxy_url && *proxy_url) ? 0 : 1; |
|
|
|
if (argc == 4 && strcmp(argv[1], "dial") == 0) |
|
return dial_proxy(proxy_url, argv[2], argv[3]); |
|
|
|
fprintf(stderr, "Usage: %1$s test\n %1$s dial <host> <port>\n\n" |
|
"Helper program to be invoked by ssh(1) to open a proxied connection\n" |
|
"if $ssh_proxy or $all_proxy are set. To use, put something like:\n" |
|
"\n" |
|
" Match exec \"%1$s test\"\n" |
|
" ProxyCommand %1$s dial '%%h' '%%p'\n" |
|
" ProxyUseFdpass yes\n" |
|
"\n" |
|
"In your ssh config file.\n\n", |
|
argv[0]); |
|
return 1; |
|
} |