#!/usr/bin/env ruby
# Copyright, 2009, 2012, by Samuel G. D. Williams. <http://www.codeotaku.com>
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

# Pulls down DNS data from old-dns
# rd-dns-check -s old-dns.mydomain.com -d mydomain.com. -f old-dns.yml

# Check data against old-dns
# rd-dns-check -s old-dns.mydomain.com -d mydomain.com. -c old-dns.yml

# Check data against new DNS server
# rd-dns-check -s 10.0.0.36 -d mydomain.com. -c old-dns.yml

require 'yaml'
require 'optparse'
require 'set'

class DNSRecord
	def initialize(arr)
		@record = arr
		normalize
	end
  
	def normalize
		@record[0] = @record[0].downcase
		@record[1] = @record[1].upcase
		@record[2] = @record[2].upcase
		@record[3] = @record[3].downcase
	end
  
	def hostname
		@record[0]
	end
  
	def klass
		@record[1]
	end
  
	def type
		@record[2]
	end
  
	def value
		@record[3]
	end
  
	def is_address?
		["A", "AAAA"].include?(type)
	end
  
	def is_cname?
		return type == "CNAME"
	end
  
	def to_s
		"#{hostname.ljust(50)} #{klass.rjust(4)} #{type.rjust(5)} #{value}"
	end
  
	def key
		"#{hostname}:#{klass}:#{type}".downcase
	end
  
	def to_a
		@record
	end
  
	def == other
		return @record == other.to_a
	end
end

def dig(dns_server, cmd, exclude = ["TXT", "HINFO", "SOA", "NS"])
	records = []
	
	IO.popen("dig @#{dns_server} +nottlid +nocmd +noall +answer " + cmd) do |p|
		p.each do |line|
			r = line.chomp.split(/\s/, 4)
			
			next if exclude.include?(r[2])
      
			records << DNSRecord.new(r)
		end
	end
	
	return records
end

def retrieve_records(dns_server, dns_root)
	return dig(dns_server, "#{dns_root} AXFR")
end

def resolve_hostname(dns_server, hostname)
	return dig(dns_server, "#{hostname} A").first
end

def resolve_address(dns_server, address)
	return dig(dns_server, "-x #{address}").first
end

def print_summary(records, errors, okay, &block)
	puts "[ Summary ]".center(72, "=")
	puts "Checked #{records.size} record(s). #{errors} errors."
	if errors == 0
		puts "Everything seemed okay."
	else
		puts "The following records are okay:"
		okay.each do |r|
			if block_given?
				yield r
			else
				puts "".rjust(12) + r.to_s
			end
		end
	end
end

# Resolve hostnames to IP address "A" or "AAAA" records.
# Works through CNAME records in order to find out the final
# address if possible. Checks for loops in CNAME records.
def resolve_addresses(records)
	addresses = {}
	cnames = {}
  
	# Extract all hostname -> ip address mappings
	records.each do |r|
		if r.is_address?
			addresses[r.hostname] = r
		elsif r.is_cname?
			cnames[r.hostname] = r
		end
	end
  
	cnames.each do |hostname, r|
		q = r
		trail = []
		failed = false
    
		# Keep track of CNAME records to avoid loops
		while q.is_cname?
			trail << q
			q = cnames[q.value] || addresses[q.value]
      
			# Q could be nil at this point, which means there was no address record
			# Q could be already part of the trail, which means there was a loop
			if q == nil || trail.include?(q)
				failed = true
				break
			end
		end
    
		if failed
			q = trail.last
			puts "*** Warning: CNAME record #{hostname} does not point to actual address!"
			trail.each_with_index do |r, idx|
				puts idx.to_s.rjust(10) + ": " + r.to_s
			end
		end
    
		addresses[r.hostname] = q
	end
  
	return addresses, cnames
end

def check_reverse(records, dns_server)
	errors = 0
	okay = []
  
	puts "[ Checking Reverse Lookups ]".center(72, "=")
  
	records.each do |r|
		next unless r.is_address?
    
		sr = resolve_address(dns_server, r.value)
    
		if sr == nil
			puts "*** Could not resolve host"
			puts "".rjust(12) + r.to_s
			errors += 1
		elsif r.hostname != sr.value
			puts "*** Hostname does not match"
			puts "Primary: ".rjust(12) + r.to_s
			puts "Secondary: ".rjust(12) + sr.to_s
			errors += 1
		else
			okay << [r, sr]
		end
	end
  
	print_summary(records, errors, okay) do |r|
		puts "Primary:".rjust(12) + r[0].to_s
		puts "Secondary:".rjust(12) + r[1].to_s
	end
end

def ping_records(records)
	addresses, cnames = resolve_addresses(records)

	errors = 0
	okay = []

	puts "[ Pinging Records ]".center(72, "=")
  
	addresses.each do |hostname, r|
		ping = "ping -c 5 -t 5 -i 1 -o #{r.value} > /dev/null"
    
		system(ping)
    
		if $?.exitstatus == 0
			okay << r
		else
			puts "*** Could not ping host #{hostname.dump}: #{ping.dump}"
			puts "".rjust(12) + r.to_s
			errors += 1
		end
	end
  
	print_summary(records, errors, okay)
end

def query_records(primary, secondary_server)
	addresses, cnames = resolve_addresses(primary)
  
	okay = []
	errors = 0
  
	primary.each do |r|
		sr = resolve_hostname(secondary_server, r.hostname)
    
		if sr == nil
			puts "*** Could not resolve hostname #{r.hostname.dump}"
			puts "Primary: ".rjust(12) + r.to_s
      
			rsr = resolve_address(secondary_server, (addresses[r.value] || r).value)
			puts "Address: ".rjust(12) + rsr.to_s if rsr
      
			errors += 1
		elsif sr.value != r.value
			ra = addresses[r.value] if r.is_cname?
			sra = addresses[sr.value] if sr.is_cname?
      
			if (sra || sr).value != (ra || r).value
				puts "*** IP Address does not match"
				puts "Primary: ".rjust(12) + r.to_s
				puts "Resolved: ".rjust(12) + ra.to_s if ra
				puts "Secondary: ".rjust(12) + sr.to_s
				puts "Resolved: ".rjust(12) + sra.to_s if sra
				errors += 1
			end
		else
			okay << r
		end
	end
  
	print_summary(primary, errors, okay)
end

def check_records(primary, secondary)
	s = {}
	okay = []
	errors = 0
  
	secondary.each do |r|
		s[r.key] = r
	end

	puts "[ Checking Records ]".center(72, "=")

	primary.each do |r|
		sr = s[r.key]
    
		if sr == nil
			puts "*** Could not find record"
			puts "Primary: ".rjust(12) + r.to_s
			errors += 1
		elsif sr != r
			puts "*** Records are different"
			puts "Primary: ".rjust(12) + r.to_s
			puts "Secondary: ".rjust(12) + sr.to_s
			errors += 1
		else
			okay << r
		end
	end
  
	print_summary(primary, errors, okay)
end

OPTIONS = {
	:DNSServer => nil,
	:DNSRoot => ".",
}

ARGV.options do |o|
	script_name = File.basename($0)
  
	o.set_summary_indent('  ')
	o.banner = "Usage: #{script_name} [options]"
	o.define_head "This script is designed to test and check DNS servers."
  
	o.on("-s ns.my.domain.", "--server ns.my.domain.", String, "The DNS server to query.") { |host| OPTIONS[:DNSServer] = host }
	o.on("-d my.domain.", "--domain my.domain.", String, "The DNS zone to transfer/test.") { |host| OPTIONS[:DNSRoot] = host }
  
	o.on("-f output.yml", "--fetch output.yml", String, "Pull down a list of hosts. Filters TXT and HINFO records. DNS transfers must be enabled.") { |f|
		records = retrieve_records(OPTIONS[:DNSServer], OPTIONS[:DNSRoot])
    
		output = (f ? File.open(f, "w") : STDOUT)
    
		output.write(YAML::dump(records))
    
		puts "#{records.size} record(s) retrieved."
	}
  
	o.on("-c input.yml", "--check input.yml", String, "Check that the DNS server returns results as specified by the file.") { |f|
		input = (f ? File.open(f) : STDIN)
    
		master_records = YAML::load(input.read)
		secondary_records = retrieve_records(OPTIONS[:DNSServer], OPTIONS[:DNSRoot])
    
		check_records(master_records, secondary_records)
	}
  
	o.on("-q input.yml", "--query input.yml", String, "Query the remote DNS server with all hostnames in the given file, and checks the IP addresses are consistent.") { |f|
		input = (f ? File.open(f) : STDIN)
    
		master_records = YAML::load(input.read)
    
		query_records(master_records, OPTIONS[:DNSServer])
	}
  
	o.on("-p input.yml", "--ping input.yml", String, "Ping all hosts to check if they are available or not.") { |f|
		input = (f ? File.open(f) : STDIN)
    
		master_records = YAML::load(input.read)
    
		ping_records(master_records)
	}
  
	o.on("-r input.yml", "--reverse input.yml", String, "Check that all address records have appropriate reverse entries.") { |f|
		input = (f ? File.open(f) : STDIN)
    
		master_records = YAML::load(input.read)
    
		check_reverse(master_records, OPTIONS[:DNSServer])
	}
  
	o.separator ""
	o.separator "Help and Copyright information"
  
	o.on_tail("--copy", "Display copyright information") {
		puts "#{script_name}. Copyright (c) 2009, 2011 Samuel Williams. Released under the MIT license."
		puts "See http://www.oriontransfer.co.nz/ for more information."
		exit
	}
  
	o.on_tail("-h", "--help", "Show this help message.") { puts o; exit }
end.parse!