Skip to content

Commit

Permalink
Add debug mode to reverse lookup, option to change cache TTL
Browse files Browse the repository at this point in the history
  • Loading branch information
domosekai committed Nov 23, 2024
1 parent ae472de commit 2fcf4ac
Showing 1 changed file with 45 additions and 30 deletions.
75 changes: 45 additions & 30 deletions shdns.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ var minrtt = flag.Int("m", 30, "Minimum possible RTT (ms) for foreign nameserver
var minsafe = flag.Int("s", 100, "Minimum safe RTT (ms) for foreign nameservers. Packets with longer RTT will be immediately accepted. Packets with shorter RTT will be delayed until this threshold.")
var minwait = flag.Int("w", 100, "Time (ms) during which domestic answers are prioritized. Usually used with a local caching resolver.")
var timeout = flag.Int("M", 3000, "DNS query timeout (ms). Use a larger value for high-latency network or DNS-over-HTTPS.")
var reversenet = flag.String("r", "", "Address and port for listening to reverse DNS queries.")
var reversenet = flag.String("r", "", "Address and port for listening to reverse DNS queries from cache")
var cachelife = flag.Int("c", 60, "DNS cache lifetime (minutes) for reverse lookup")
var verbose = flag.Bool("v", false, "Verbose mode. Connection will remain open after replied until timeout.")
var showver = flag.Bool("V", false, "Show version")
var version = "unknown"
Expand All @@ -72,12 +73,13 @@ type answer struct {

type cacheEntry struct {
value string
ns string
modified time.Time
}

type cache struct {
entries map[string]cacheEntry
rw sync.RWMutex
table map[string]cacheEntry
rw sync.RWMutex
}

type byByte []net.IPNet
Expand Down Expand Up @@ -206,36 +208,36 @@ func addTag(bufs []bytes.Buffer, tag string) {
}
}

func (c *cache) add(key, value string) {
func (c *cache) add(key, value, ns string) {
c.rw.Lock()
defer c.rw.Unlock()
c.entries[key] = cacheEntry{value, time.Now()}
c.table[key] = cacheEntry{value, ns, time.Now()}
}

func (c *cache) insert(m map[string]string) {
func (c *cache) insert(m map[string]string, ns string) {
c.rw.Lock()
defer c.rw.Unlock()
t := time.Now()
for key, value := range m {
c.entries[key] = cacheEntry{value, t}
c.table[key] = cacheEntry{value, ns, t}
}
}

func (c *cache) lookup(key string) string {
func (c *cache) lookup(key string) (string, string) {
c.rw.RLock()
defer c.rw.RUnlock()
if e, ok := c.entries[key]; ok {
return e.value
if e, ok := c.table[key]; ok {
return e.value, e.ns
}
return ""
return "", ""
}

func (c *cache) purge(t time.Duration) {
c.rw.Lock()
defer c.rw.Unlock()
for key, value := range c.entries {
for key, value := range c.table {
if value.modified.Add(t).Before(time.Now()) {
delete(c.entries, key)
delete(c.table, key)
}
}
}
Expand All @@ -246,15 +248,28 @@ func handleReverse(conn *net.UDPConn) {
var payload [1500]byte
if n, addr, err := conn.ReadFromUDP(payload[:]); err != nil {
errlog.Println(err)
} else if (n == 5 || n == 17) && payload[0] == 1 {
} else if n == 5 || n == 17 {
ip := net.IP(payload[1:n])
host := reverseTable.lookup(ip.String())
host, ns := reverseTable.lookup(ip.String())
// remove trailing dot (if any)
host = strings.TrimSuffix(host, ".")
op := []byte{2}
_, err := conn.WriteToUDP(append(op, []byte(host)...), addr)
if err != nil {
errlog.Println(err)
switch payload[0] {
case 1:
op := []byte{2}
_, err := conn.WriteToUDP(append(op, host...), addr)
if err != nil {
errlog.Println(err)
}
case 3:
op := []byte{4}
op = append(op, byte(len(host)))
op = append(op, host...)
op = append(op, byte(len(ns)))
op = append(op, ns...)
_, err := conn.WriteToUDP(op, addr)
if err != nil {
errlog.Println(err)
}
}
}
}
Expand Down Expand Up @@ -320,7 +335,7 @@ func handleQuery(addr *net.UDPAddr, payload []byte, inConn *net.UDPConn) { // ne
return
}
defer outConn.Close()
go forwardQueryAndReply(payload, outConn, chAnswer, chSave, chFail, qs[0].Type, hasOPT, dnssec)
go forwardQueryAndReply(payload, outConn, chAnswer, chSave, chFail, qs[0].Type, qs[0].Name.String(), hasOPT, dnssec)
answered := false
waiting := true
var savedAnswer, waitedAnswer, failedAnswer []byte
Expand Down Expand Up @@ -387,14 +402,14 @@ func handleQuery(addr *net.UDPAddr, payload []byte, inConn *net.UDPConn) { // ne
}
}

func forwardQueryAndReply(payload []byte, outConn *net.UDPConn, chAnswer chan<- answer, chSave, chFail chan<- []byte, qType dnsmessage.Type, hasOPT, dnssec bool) {
func forwardQueryAndReply(payload []byte, outConn *net.UDPConn, chAnswer chan<- answer, chSave, chFail chan<- []byte, qType dnsmessage.Type, qName string, hasOPT, dnssec bool) {
defer close(chAnswer)
sentTime := time.Now()
for _, ns := range servers {
outConn.WriteToUDP(payload, ns.udpAddr)
}
outConn.SetReadDeadline(sentTime.Add(time.Duration(*timeout) * time.Millisecond))
parseAnswers(outConn, sentTime, chAnswer, chSave, chFail, qType, hasOPT, dnssec)
parseAnswers(outConn, sentTime, chAnswer, chSave, chFail, qType, qName, hasOPT, dnssec)
}

func lookupServer(addr *net.UDPAddr) (nameserver, bool) {
Expand All @@ -406,7 +421,7 @@ func lookupServer(addr *net.UDPAddr) (nameserver, bool) {
return nameserver{}, false
}

func parseAnswers(conn *net.UDPConn, sentTime time.Time, chAnswer chan<- answer, chSave, chFail chan<- []byte, qType dnsmessage.Type, hasOPT, dnssec bool) {
func parseAnswers(conn *net.UDPConn, sentTime time.Time, chAnswer chan<- answer, chSave, chFail chan<- []byte, qType dnsmessage.Type, qName string, hasOPT, dnssec bool) {
for {
var payload [5000]byte
// receive from nameserver
Expand Down Expand Up @@ -465,7 +480,7 @@ func parseAnswers(conn *net.UDPConn, sentTime time.Time, chAnswer chan<- answer,
}
ip := net.IP(r.A[:]) //r.A is 4-byte
if *reversenet != "" {
reverse[ip.String()] = ah.Name.String()
reverse[ip.String()] = qName
}
if *verbose {
fmt.Fprintf(&buf, " %s %s len %d %dms", ah.Name.String(), ip.String(), len(a), rtt.Nanoseconds()/1000000)
Expand Down Expand Up @@ -508,7 +523,7 @@ func parseAnswers(conn *net.UDPConn, sentTime time.Time, chAnswer chan<- answer,
}
ip := net.IP(r.AAAA[:])
if *reversenet != "" {
reverse[ip.String()] = ah.Name.String()
reverse[ip.String()] = qName
}
if *verbose {
fmt.Fprintf(&buf, " %s %s len %d %dms", ah.Name.String(), ip.String(), len(a), rtt.Nanoseconds()/1000000)
Expand Down Expand Up @@ -613,15 +628,15 @@ func parseAnswers(conn *net.UDPConn, sentTime time.Time, chAnswer chan<- answer,
if *verbose {
fmt.Fprintf(&buf, " %s", net.IP(r.IPv4Hint[i][:]).String())
}
reverse[net.IP(r.IPv4Hint[i][:]).String()] = ah.Name.String()
reverse[net.IP(r.IPv4Hint[i][:]).String()] = qName
}
}
if r.IPv6Hint != nil {
for i := range r.IPv6Hint {
if *verbose {
fmt.Fprintf(&buf, " %s", net.IP(r.IPv6Hint[i][:]).String())
}
reverse[net.IP(r.IPv6Hint[i][:]).String()] = ah.Name.String()
reverse[net.IP(r.IPv6Hint[i][:]).String()] = qName
}
}
if *verbose {
Expand Down Expand Up @@ -736,7 +751,7 @@ func parseAnswers(conn *net.UDPConn, sentTime time.Time, chAnswer chan<- answer,
}
// add to cache for reverse lookup (even for those saved but not used)
if *reversenet != "" {
reverseTable.insert(reverse)
reverseTable.insert(reverse, ns.udpAddr.String())
}
} else if h.RCode == dnsmessage.RCodeServerFailure && ns.sType == foreign {
if *verbose {
Expand Down Expand Up @@ -818,14 +833,14 @@ func main() {
defer inConn.Close()
logger.Printf("Listening on UDP %s", addr)
if *reversenet != "" {
reverseTable.entries = make(map[string]cacheEntry)
reverseTable.table = make(map[string]cacheEntry)
if addr, err := net.ResolveUDPAddr("udp", *reversenet); err == nil {
conn, err := net.ListenUDP("udp", addr)
if err == nil {
go func() {
ticker := time.Tick(15 * time.Minute)
for range ticker {
reverseTable.purge(time.Hour * 3)
reverseTable.purge(time.Minute * time.Duration(*cachelife))
}
}()
go handleReverse(conn)
Expand Down

0 comments on commit 2fcf4ac

Please sign in to comment.