Skip to content

Commit

Permalink
Add support for --tcp flag.
Browse files Browse the repository at this point in the history
  • Loading branch information
konstgorecki authored and arrowd committed Jul 25, 2018
1 parent 0a3b114 commit 4359358
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 39 deletions.
44 changes: 23 additions & 21 deletions dnsrecon.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def brute_domain(res, dict, dom, filter=None, verbose=False, ignore_wildcard=Fal
return found_hosts


def in_cache(dict_file, ns):
def in_cache(res, dict_file, ns):
"""
Function for Cache Snooping, it will check a given NS server for specific
type of records for a given domain are in it's cache.
Expand All @@ -499,7 +499,7 @@ def in_cache(dict_file, ns):
dom_to_query = str.strip(zone)
query = dns.message.make_query(dom_to_query, dns.rdatatype.A, dns.rdataclass.IN)
query.flags ^= dns.flags.RD
answer = dns.query.udp(query, ns)
answer = res.query(query, ns)
if len(answer.answer) > 0:
for an in answer.answer:
for rcd in an:
Expand Down Expand Up @@ -824,7 +824,7 @@ def write_db(db, data):
def get_nsec_type(domain, res):
target = "0." + domain

answer = get_a_answer(target, res._res.nameservers[0], res._res.timeout)
answer = get_a_answer(res, target, res._res.nameservers[0], res._res.timeout)
for a in answer.authority:
if a.rdtype == 50:
return "NSEC3"
Expand Down Expand Up @@ -854,7 +854,7 @@ def dns_sec_check(domain, res):
except dns.resolver.NXDOMAIN:
print_error("Could not resolve domain: {0}".format(domain))
sys.exit(1)

except dns.resolver.NoNameservers:
print_error("All nameservers failed to answer the DNSSEC query for {0}".format(domain))

Expand All @@ -868,14 +868,14 @@ def dns_sec_check(domain, res):
print_error("DNSSEC is not configured for {0}".format(domain))


def check_bindversion(ns_server, timeout):
def check_bindversion(res, ns_server, timeout):
"""
Check if the version of Bind can be queried for.
"""
version = ""
request = dns.message.make_query('version.bind', 'txt', 'ch')
try:
response = dns.query.udp(request, ns_server, timeout=timeout, one_rr_per_rrset=True)
response = res.query(request, ns_server, timeout=timeout, one_rr_per_rrset=True)
if (len(response.answer) > 0):
print_status("\t Bind Version for {0} {1}".format(ns_server, response.answer[0].items[0].strings[0]))
version = response.answer[0].items[0].strings[0]
Expand All @@ -884,14 +884,14 @@ def check_bindversion(ns_server, timeout):
return version


def check_recursive(ns_server, timeout):
def check_recursive(res, ns_server, timeout):
"""
Check if a NS Server is recursive.
"""
is_recursive = False
query = dns.message.make_query('www.google.com.', dns.rdatatype.NS)
try:
response = dns.query.udp(query, ns_server, timeout)
response = res.query(query, ns_server, timeout)
recursion_flag_pattern = "\.*RA\.*"
flags = dns.flags.to_text(response.flags)
result = re.findall(recursion_flag_pattern, flags)
Expand Down Expand Up @@ -961,8 +961,8 @@ def general_enum(res, domain, do_axfr, do_google, do_bing, do_spf, do_whois, do_
print_status("\t {0} {1} {2}".format(ns_rcrd[0], ns_rcrd[1], ns_rcrd[2]))

# Save dictionary of returned record
recursive = check_recursive(ns_rcrd[2], res._res.timeout)
bind_ver = check_bindversion(ns_rcrd[2], res._res.timeout)
recursive = check_recursive(res, ns_rcrd[2], res._res.timeout)
bind_ver = check_bindversion(res, ns_rcrd[2], res._res.timeout)
returned_records.extend([
{"type": ns_rcrd[0], "target": ns_rcrd[1], "address": ns_rcrd[2], "recursive": str(recursive),
"Version": bind_ver}])
Expand Down Expand Up @@ -1086,7 +1086,7 @@ def general_enum(res, domain, do_axfr, do_google, do_bing, do_spf, do_whois, do_
#sys.exit(0)


def query_ds(target, ns, timeout=5.0):
def query_ds(res, target, ns, timeout=5.0):
"""
Function for performing DS Record queries. Returns answer object. Since a
timeout will break the DS NSEC chain of a zone walk it will exit if a timeout
Expand All @@ -1097,7 +1097,7 @@ def query_ds(target, ns, timeout=5.0):
query.flags += dns.flags.CD
query.use_edns(edns=True, payload=4096)
query.want_dnssec(True)
answer = dns.query.udp(query, ns, timeout)
answer = res.query(query, ns, timeout)
except dns.exception.Timeout:
print_error("A timeout error occurred please make sure you can reach the target DNS Servers")
print_error("directly and requests are not being filtered. Increase the timeout from {0} second".format(timeout))
Expand Down Expand Up @@ -1196,18 +1196,18 @@ def lookup_next(target, res):
return returned_records


def get_a_answer(target, ns, timeout):
def get_a_answer(res, target, ns, timeout):
query = dns.message.make_query(target, dns.rdatatype.A, dns.rdataclass.IN)
query.flags += dns.flags.CD
query.use_edns(edns=True, payload=4096)
query.want_dnssec(True)
answer = dns.query.udp(query, ns, timeout)
answer = res.query(query, ns, timeout)
return answer


def get_next(target, ns, timeout):
def get_next(res, target, ns, timeout):
next_host = None
response = get_a_answer(target, ns, timeout)
response = get_a_answer(res, target, ns, timeout)
for a in response.authority:
if a.rdtype == 47:
for r in a:
Expand Down Expand Up @@ -1275,9 +1275,9 @@ def ds_zone_walk(res, domain):

# Perform a DNS query for the target and process the response
if not nameserver:
response = get_a_answer(target, res._res.nameservers[0], timeout)
response = get_a_answer(res, target, res._res.nameservers[0], timeout)
else:
response = get_a_answer(target, nameserver, timeout)
response = get_a_answer(res, target, nameserver, timeout)
for a in response.authority:
if a.rdtype != 47:
continue
Expand Down Expand Up @@ -1344,6 +1344,7 @@ def usage():
print(" -z Performs a DNSSEC zone walk with standard enumeration.")
print(" --threads <number> Number of threads to use in reverse lookups, forward lookups, brute force and SRV")
print(" record enumeration.")
print(" --tcp Force using TCP protocol when making DNS queries.")
print(" --lifetime <number> Time to wait for a server to response to a query.")
print(" --db <file> SQLite 3 file to save found records.")
print(" --xml <file> XML file to save found records.")
Expand Down Expand Up @@ -1411,6 +1412,7 @@ def main():
parser.add_argument("-z", help="Performs a DNSSEC zone walk with standard enumeration.", action="store_true")
parser.add_argument("--threads", type=int, dest="threads", help="Number of threads to use in reverse lookups, forward lookups, brute force and SRV record enumeration.")
parser.add_argument("--lifetime", type=int, dest="lifetime", help="Time to wait for a server to response to a query.")
parser.add_argument("--tcp", dest="tcp", help="Use TCP protocol to make queries.", action="store_true")
parser.add_argument("--db", type=str, dest="db", help="SQLite 3 file to save found records.")
parser.add_argument("-x", "--xml", type=str, dest="xml", help="XML file to save found records.")
parser.add_argument("-c", "--csv", type=str, dest="csv", help="Comma separated value file.")
Expand Down Expand Up @@ -1484,12 +1486,12 @@ def main():
zonewalk = arguments.z
spf_enum = arguments.s
wildcard_filter = arguments.f

proto = "tcp" if arguments.tcp else "udp"
# Setting the number of threads to 10
pool = ThreadPool(thread_num)

# Set the resolver
res = DnsHelper(domain, ns_server, request_timeout)
res = DnsHelper(domain, ns_server, request_timeout, proto)

domain_req = ["axfr", "std", "srv", "tld", "goo", "bing", "crt", "zonewalk"]
scan_info = [" ".join(sys.argv), str(datetime.datetime.now())]
Expand Down Expand Up @@ -1591,7 +1593,7 @@ def main():
elif r == "snoop":
if (dict is not None) and (ns_server is not None):
print_status("Performing Cache Snooping against NS Server: {0}".format(ns_server))
cache_enum_records = in_cache(dict, ns_server)
cache_enum_records = in_cache(res, dict, ns_server)
if (output_file is not None) or (results_db is not None) or (csv_file is not None) or (json_file is not None):
returned_records.extend(cache_enum_records)

Expand Down
53 changes: 35 additions & 18 deletions lib/dnshelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@


class DnsHelper:
def __init__(self, domain, ns_server=None, request_timeout=3.0, ):
def __init__(self, domain, ns_server=None, request_timeout=3.0, proto="tcp"):
self._domain = domain
self._proto = proto
if ns_server:
self._res = dns.resolver.Resolver(configure=False)
self._res.nameservers = [ns_server]
Expand Down Expand Up @@ -65,22 +66,30 @@ def resolve(self, target, type, ns=None):
Function for performing general resolution types returning the RDATA
"""
if ns:
res = dns.resolver.Resolver(configure=False)
res = dns.resolver.Resolver(configure=False, )
res.nameservers = [ns]
else:
res = dns.resolver.Resolver(configure=True)

answers = res.query(target, type)
tcp = True if self._proto == "tcp" else False
answers = res.query(target, type, tcp=tcp)
return answers

def query(self, q, where, timeout=None, port=53, af=None, source=None, source_port=0, one_rr_per_rrset=False):
if self._proto == "tcp":
return dns.query.tcp(q, where, timeout, port, af, source, source_port, one_rr_per_rrset)
else:
return dns.query.udp(q, where, timeout, port, af, source, source_port, False, one_rr_per_rrset)

def get_a(self, host_trg):
"""
Function for resolving the A Record for a given host. Returns an Array of
the IP Address it resolves to. It will also return CNAME data.
"""
address = []
tcp = True if self._proto == "tcp" else False
try:
ipv4_answers = self._res.query(host_trg, 'A')
ipv4_answers = self._res.query(host_trg, 'A', tcp=tcp)
for ardata in ipv4_answers.response.answer:
for rdata in ardata:
if rdata.rdtype == 5:
Expand All @@ -102,8 +111,9 @@ def get_aaaa(self, host_trg):
the IP Address it resolves to. It will also return CNAME data.
"""
address = []
tcp = True if self._proto == "tcp" else False
try:
ipv6_answers = self._res.query(host_trg, 'AAAA')
ipv6_answers = self._res.query(host_trg, 'AAAA', tcp=tcp)
for ardata in ipv6_answers.response.answer:
for rdata in ardata:
if rdata.rdtype == 5:
Expand Down Expand Up @@ -136,11 +146,12 @@ def get_mx(self):
address of the host both in IPv4 and IPv6. Returns an Array
"""
mx_records = []
answers = self._res.query(self._domain, 'MX')
tcp = True if self._proto == "tcp" else False
answers = self._res.query(self._domain, 'MX', tcp=tcp)
for rdata in answers:
try:
name = rdata.exchange.to_text()
ipv4_answers = self._res.query(name, 'A')
ipv4_answers = self._res.query(name, 'A', tcp=tcp)
for ardata in ipv4_answers:
if name.endswith('.'):
mx_records.append(['MX', name[:-1], ardata.address,
Expand All @@ -153,7 +164,7 @@ def get_mx(self):
try:
for rdata in answers:
name = rdata.exchange.to_text()
ipv6_answers = self._res.query(name, 'AAAA')
ipv6_answers = self._res.query(name, 'AAAA', tcp=tcp)
for ardata in ipv6_answers:
if name.endswith('.'):
mx_records.append(['MX', name[:-1], ardata.address,
Expand All @@ -171,7 +182,8 @@ def get_ns(self):
address of the host both in IPv4 and IPv6. Returns an Array.
"""
name_servers = []
answer = self._res.query(self._domain, 'NS')
tcp = True if self._proto == "tcp" else False
answer = self._res.query(self._domain, 'NS', tcp=tcp)
if answer is not None:
for aa in answer:
name = aa.target.to_text()[:-1]
Expand All @@ -187,17 +199,18 @@ def get_soa(self):
address of the host both in IPv4 and IPv6. Returns an Array.
"""
soa_records = []
tcp = True if self._proto == "tcp" else False
query = dns.message.make_query(self._domain, dns.rdatatype.SOA)
try:
response = dns.query.udp(query, self._res.nameservers[0], self._res.timeout)
response = query(query, self._res.nameservers[0], self._res.timeout)
if len(response.answer) > 0:
answers = response.answer
elif len(response.authority) > 0:
answers = response.authority
for rdata in answers:
# A zone only has one SOA record so we select the first.
name = rdata[0].mname.to_text()
ipv4_answers = self._res.query(name, 'A')
ipv4_answers = self._res.query(name, 'A', tcp=tcp)
for ardata in ipv4_answers:
if name.endswith('.'):
soa_records.append(['SOA', name[:-1], ardata.address])
Expand All @@ -210,7 +223,7 @@ def get_soa(self):
try:
for rdata in answers:
name = rdata.mname.to_text()
ipv4_answers = self._res.query(name, 'AAAA')
ipv4_answers = self._res.query(name, 'AAAA', tcp=tcp)
for ardata in ipv4_answers:
if name.endswith('.'):
soa_records.append(['SOA', name[:-1], ardata.address])
Expand All @@ -227,9 +240,9 @@ def get_spf(self):
Prints the string for the SPF Record and Returns the string
"""
spf_record = []

tcp = True if self._proto == "tcp" else False
try:
answers = self._res.query(self._domain, 'SPF')
answers = self._res.query(self._domain, 'SPF', tcp=tcp)
for rdata in answers:
name = ''.join(rdata.strings)
spf_record.append(['SPF', name])
Expand All @@ -243,10 +256,11 @@ def get_txt(self, target=None):
Function for TXT Record resolving returns the string.
"""
txt_record = []
tcp = True if self._proto == "tcp" else False
if target is None:
target = self._domain
try:
answers = self._res.query(target, 'TXT')
answers = self._res.query(target, 'TXT', tcp=tcp)
for rdata in answers:
string = "".join(rdata.strings)
txt_record.append(['TXT', target, string])
Expand All @@ -260,9 +274,10 @@ def get_ptr(self, ipaddress):
Function for resolving PTR Record given it's IPv4 or IPv6 Address.
"""
found_ptr = []
tcp = True if self._proto == "tcp" else False
n = dns.reversename.from_address(ipaddress)
try:
answers = self._res.query(n, 'PTR')
answers = self._res.query(n, 'PTR', tcp=tcp)
for a in answers:
if a.target.to_text().endswith('.'):
found_ptr.append(['PTR', a.target.to_text()[:-1], ipaddress])
Expand All @@ -277,8 +292,9 @@ def get_srv(self, host):
Function for resolving SRV Records.
"""
record = []
tcp = True if self._proto == "tcp" else False
try:
answers = self._res.query(host, 'SRV')
answers = self._res.query(host, 'SRV', tcp=tcp)
for a in answers:
if a.target.to_text().endswith('.'):
target = a.target.to_text()[:-1]
Expand All @@ -305,7 +321,8 @@ def get_nsec(self, host):
Function for querying for a NSEC record and retrieving the rdata object.
This function is used mostly for performing a Zone Walk against a zone.
"""
answer = self._res.query(host, 'NSEC')
tcp = True if self._proto == "tcp" else False
answer = self._res.query(host, 'NSEC', tcp=tcp)
return answer

def from_wire(self, xfr, zone_factory=Zone, relativize=True):
Expand Down

0 comments on commit 4359358

Please sign in to comment.