Skip to content

Instantly share code, notes, and snippets.

@yujp
Created April 13, 2021 14:43
Show Gist options
  • Save yujp/cc1743955f1ced31b5ab7b73c99f2b0b to your computer and use it in GitHub Desktop.
Save yujp/cc1743955f1ced31b5ab7b73c99f2b0b to your computer and use it in GitHub Desktop.
package main
import (
"context"
"flag"
"fmt"
"log"
"net"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
)
func main() {
var (
ifName = flag.String("ifname", "", "network-interface-name")
)
flag.Parse()
args := flag.Args()
if len(args) < 1 || *ifName == "" {
log.Println("Usage: drr -ifname <network-interface-name> <domain1> <domain2>...")
return
}
var ifIP string
{
ifr, err := net.InterfaceByName(*ifName)
if err != nil {
log.Fatalf("failed to get the interface(%s): %v", *ifName, err)
}
addrs, err := ifr.Addrs()
if err != nil || len(addrs) < 1 {
log.Fatalf("failed to get the addresses of the interface(%s): %v", *ifName, err)
}
ip, _, err := net.ParseCIDR(addrs[0].String())
if err != nil {
log.Fatalf("failed to get the address of the interface(%s): %v", *ifName, err)
}
ifIP = ip.String()
// log.Printf("LocalIP: %s", ifIP)
}
type target struct {
domain string
ips []string
}
var targets []*target
for _, arg := range args {
ips, err := net.LookupIP(arg)
if err != nil {
log.Fatalf("failed to find the domain(%s): %v", arg, err)
}
var t target
t.domain = arg
for _, ip := range ips {
_, bits := ip.DefaultMask().Size()
addr := fmt.Sprintf("%s/%d", ip.To4().String(), bits)
t.ips = append(t.ips, addr)
}
targets = append(targets, &t)
}
unregister := func() {
for _, t := range targets {
log.Printf("Unregistering routes for the domain '%s'", t.domain)
for _, ip := range t.ips {
//log.Println(strings.Join([]string{"route", "delete", ip}, " "))
o, err := exec.Command("route", "delete", ip).Output()
if err != nil {
log.Printf("%s\tFailed: %v", ip, err)
} else {
log.Printf("%s\t%s\n", ip, strings.TrimSpace(string(o)))
}
}
}
}
register := func() {
for _, t := range targets {
log.Printf("Registering routes for the domain '%s'", t.domain)
for _, ip := range t.ips {
//log.Println(strings.Join([]string{"route", "add", ip, ifIP}, " "))
o, err := exec.Command("route", "add", ip, ifIP).Output()
if err != nil {
log.Printf("%s\tFailed: %v", ip, err)
} else {
log.Printf("%s\t%s", ip, strings.TrimSpace(string(o)))
}
}
}
}
defer unregister()
register()
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
<-ctx.Done()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment