Skip to content

Instantly share code, notes, and snippets.

@CarsonSlovoka
Created August 13, 2024 11:16
Show Gist options
  • Save CarsonSlovoka/8aede9fa1d715d34cb048fb3ab9ec8e2 to your computer and use it in GitHub Desktop.
Save CarsonSlovoka/8aede9fa1d715d34cb048fb3ab9ec8e2 to your computer and use it in GitHub Desktop.
ssh server
package main
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/pem"
"fmt"
"golang.org/x/crypto/ssh"
"io"
"log"
"net"
"os"
"strings"
"time"
)
type User struct {
Name string
IPublicKey ssh.PublicKey
}
var users = map[string]*User{
"alice": {
Name: "alice",
IPublicKey: nil,
},
}
func init() {
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
log.Fatal(err)
}
pemBlock, err := ssh.MarshalPrivateKey(crypto.PrivateKey(privateKey), "")
bsPrivateKeyPem := pem.EncodeToMemory(pemBlock)
_ = os.WriteFile("alice_rsa", bsPrivateKeyPem, os.ModePerm) // 保存到 ~/.ssh/alice_rsa
sshPublicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
publicKeyStr := "ssh-rsa" + " " + base64.StdEncoding.EncodeToString(sshPublicKey.Marshal())
_ = os.WriteFile("alice_rsa.pub", []byte(publicKeyStr), os.ModePerm) // 這個其實不用寫,只是方便查看
// 伺服器會爬取使用者上傳的公鑰
iPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publicKeyStr))
if err != nil {
log.Fatal(err)
}
users["alice"].IPublicKey = iPublicKey
fmt.Printf("Alice's public key: %s\n", ssh.MarshalAuthorizedKey(sshPublicKey))
}
func main() {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatal(err)
}
// 打印出私鑰與公鑰的內容
pemBlock, err := ssh.MarshalPrivateKey(crypto.PrivateKey(privateKey), "") // 將私鑰編碼成PEM的格式
bsPrivateKeyPem := pem.EncodeToMemory(pemBlock)
fmt.Printf("server私鑰:\n%s\n", string(bsPrivateKeyPem))
sshPublicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
publicKeyStr := base64.StdEncoding.EncodeToString(sshPublicKey.Marshal())
fmt.Printf("server公鑰:\n%s\n", publicKeyStr) // known_hosts上要填入的伺服器公鑰內容
// 獲取server上的公鑰其指紋
hasher := sha256.New()
hasher.Write(sshPublicKey.Marshal())
// fingerprint := base64.StdEncoding.EncodeToString(hasher.Sum(nil)) // 這個會填充,如果不是4的倍數會用=填充到滿足為止
fingerprint := base64.RawStdEncoding.EncodeToString(hasher.Sum(nil)) // 用的是不填充才是對的
fmt.Printf("RSA key fingerprint is SHA256:%s\n", fingerprint)
// 配置 SSH 服務器
config := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
log.Println(base64.StdEncoding.EncodeToString(key.Marshal())) // 此key是alice的公鑰key,也就是使用者放到服務器上的公鑰
user, ok := users[conn.User()]
if !ok {
return nil, fmt.Errorf("unknown user %q", conn.User())
}
if user.IPublicKey.Type() != key.Type() || // Type指的是ssh-rsa, ssh-rd25519, ...這類的東西
!bytes.Equal(user.IPublicKey.Marshal(), key.Marshal()) {
return nil, fmt.Errorf("invalid key for %q", conn.User())
}
return &ssh.Permissions{
Extensions: map[string]string{
"user": user.Name,
},
}, nil
},
}
signer, err := ssh.ParsePrivateKey(bsPrivateKeyPem)
if err != nil {
log.Fatal(err)
}
config.AddHostKey(signer)
// 啟動監聽
listener, err := net.Listen("tcp", "127.0.0.1:2222")
if err != nil {
log.Fatal(err)
}
fmt.Println("SSH server listening on port 2222...")
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Failed to accept connection: %v", err)
continue
}
go handleConnection(conn, config)
}
}
func handleConnection(conn net.Conn, config *ssh.ServerConfig) {
sshConn, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil {
log.Printf("Failed to handshake: %v", err)
return
}
defer func() {
_ = sshConn.Close()
}()
log.Printf("New SSH connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion())
go ssh.DiscardRequests(reqs)
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
_ = newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Printf("Could not accept channel: %v", err)
continue
}
_, _ = io.WriteString(channel, "handleChannel(channel, requests)\r\n")
go handleChannel(channel, requests)
}
}
func handleChannel(channel ssh.Channel, requests <-chan *ssh.Request) {
defer func() {
_, _ = io.WriteString(channel, "Connection closed.\r\n")
_ = channel.Close()
}()
for req := range requests {
switch req.Type {
case "exec": // ssh -i ~/.ssh/alice_rsa -p 2222 alice@127.0.0.1 xxx
if len(req.Payload) < 4 {
_ = req.Reply(false, nil)
continue
}
cmd := string(req.Payload[4:])
_ = req.Reply(true, nil)
executeCommand(channel, cmd)
_, _ = channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
return
case "shell": // ssh -i ~/.ssh/alice_rsa -p 2222 alice@127.0.0.1
_ = req.Reply(true, nil)
_, _ = io.WriteString(channel, "Welcome to the Go SSH server!\r\n")
_, _ = io.WriteString(channel, "Type 'exit' to disconnect.\r\n")
handleShell(channel)
return
}
}
}
func executeCommand(channel ssh.Channel, cmd string) {
switch cmd {
case "list":
_, _ = io.WriteString(channel, "Available commands: list, hello, time\r\n")
case "hello":
_, _ = io.WriteString(channel, "Hello, SSH user!\r\n")
case "time":
_, _ = io.WriteString(channel, fmt.Sprintf("Current server time: %s\r\n", time.Now().Format(time.RFC3339)))
default:
_, _ = io.WriteString(channel, fmt.Sprintf("Unknown command: %s\r\n", cmd))
}
}
func handleShell(channel ssh.Channel) {
for {
_, _ = io.WriteString(channel, "> ")
var cmd string
_, _ = fmt.Fscanf(channel, "%s\n", &cmd)
cmd = strings.TrimSpace(cmd)
if cmd == "exit" {
return
}
executeCommand(channel, cmd)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment