From f31f6453238fb53e628c14f0b57966f3edf40994 Mon Sep 17 00:00:00 2001 From: Felix Steghofer Date: Mon, 6 Nov 2017 21:29:55 +0100 Subject: [PATCH] first features ready for training --- src/DoresA/.gitignore | 1 + src/DoresA/db.py | 61 +- src/DoresA/ip.py | 74 + src/DoresA/logs/one_week_serialize_to_db.txt | 2 + ...are_features_for_one_domain_with_index.txt | 8 + ..._features_for_one_domain_without_index.txt | 8 + src/DoresA/res/all-tld.txt | 1544 +++++++++++++++++ .../scripts}/mongodb/collection_stats.js | 0 src/DoresA/scripts/sql/find_nearest_date.sql | 1 + src/DoresA/serialize_logs_to_db.py | 7 +- src/DoresA/time.py | 16 - src/DoresA/train.py | 160 ++ 12 files changed, 1861 insertions(+), 21 deletions(-) create mode 100644 src/DoresA/ip.py create mode 100644 src/DoresA/logs/one_week_serialize_to_db.txt create mode 100644 src/DoresA/logs/prepare_features_for_one_domain_with_index.txt create mode 100644 src/DoresA/logs/prepare_features_for_one_domain_without_index.txt create mode 100644 src/DoresA/res/all-tld.txt rename {scripts => src/DoresA/scripts}/mongodb/collection_stats.js (100%) create mode 100644 src/DoresA/scripts/sql/find_nearest_date.sql create mode 100644 src/DoresA/train.py diff --git a/src/DoresA/.gitignore b/src/DoresA/.gitignore index 0603c33..6dce8e9 100644 --- a/src/DoresA/.gitignore +++ b/src/DoresA/.gitignore @@ -6,3 +6,4 @@ /include/ /lib/ /__pycache__/ +*.pyc diff --git a/src/DoresA/db.py b/src/DoresA/db.py index 791ac82..aa0878a 100644 --- a/src/DoresA/db.py +++ b/src/DoresA/db.py @@ -79,11 +79,70 @@ def mariadb_insert_logs(csv_entries): def mariadb_get_logs(from_time, to_time): - get_logs_from_to = 'SELECT * FROM ' + sql_table_name + ' WHERE timestamp BETWEEN \'{}\' and \'{}\';'.format(from_time, to_time) + # get_logs_from_to = 'SELECT * FROM ' + sql_table_name + ' WHERE timestamp BETWEEN \'{}\' and \'{}\';'.format(from_time, to_time) + get_logs_from_to = 'SELECT * FROM ' + sql_table_name + ' WHERE id < 379283817;' sql_connection.query(get_logs_from_to) return sql_connection.use_result() +# TODO not used +# def mariadb_get_distinct_ttl(domain, from_time, to_time): +# get_distinct_ttl = 'SELECT DISTINCT ttl FROM ' + sql_table_name + \ +# ' WHERE timestamp BETWEEN \'{}\' and \'{}\' '.format(from_time, to_time) + \ +# 'AND domain=\'' + domain + '\';' +# sql_connection.query(get_distinct_ttl) +# return sql_connection.use_result() + + +def mariadb_get_logs_for_domain(domain, from_time, to_time): + # we need a second connection for this query as this usually (always) run in parallel to the first query + sql_connection_tmp = mariadb.connect(host=sql_host, user=sql_user_name, passwd=sql_pw, db=sql_db_name, port=sql_port) + + # timestamp comparison super slow, check if better with index + # get_distinct_ttl = 'SELECT * FROM ' + sql_table_name + \ + # ' WHERE timestamp BETWEEN \'{}\' and \'{}\' '.format(from_time, to_time) + \ + # 'AND domain=\'' + domain + '\';' + get_distinct_ttl = 'SELECT * FROM ' + sql_table_name + \ + ' WHERE id < 379283817 ' + \ + 'AND domain=\'' + domain + '\';' + sql_connection_tmp.query(get_distinct_ttl) + result = sql_connection_tmp.use_result() + logs_for_domain = result.fetch_row(maxrows=0, how=1) # TODO this can consume a lot of memory, think of alternatives + + sql_connection_tmp.close() + + return logs_for_domain + + +def mariadb_get_logs_for_ip(ip, from_time, to_time): + # we need a second connection for this query as this usually (always) run in parallel to the first query + sql_connection_tmp = mariadb.connect(host=sql_host, user=sql_user_name, passwd=sql_pw, db=sql_db_name, port=sql_port) + sql_cursor_tmp = sql_connection_tmp.cursor() + # get_distinct_ttl = 'SELECT * FROM ' + sql_table_name + \ + # ' WHERE timestamp BETWEEN \'{}\' and \'{}\' '.format(from_time, to_time) + \ + # 'AND domain=\'' + str(ip) + '\';' + get_distinct_ttl = 'SELECT * FROM ' + sql_table_name + \ + ' WHERE id < 379283817 ' + \ + 'AND domain=\'' + str(ip) + '\';' + sql_connection_tmp.query(get_distinct_ttl) + + result = sql_connection_tmp.use_result() + logs_for_ip = result.fetch_row(maxrows=0, how=1) # TODO this can consume a lot of memory, think of alternatives + + # sql_cursor_tmp.close() + sql_connection_tmp.close() + + return logs_for_ip + + +def mariadb_get_nearest_id(timestamp): + get_nearest_id = 'SELECT id FROM ' + sql_table_name + ' WHERE timestamp > \'{}\' LIMIT 1;'.format(timestamp) + sql_connection.query(get_nearest_id) + result = sql_connection.use_result() + entities = result.fetch_row(maxrows=0, how=1) + return entities[0].id + + def mariadb_create_table(): create_table = 'CREATE TABLE IF NOT EXISTS ' + sql_table_name + """ ( id INTEGER AUTO_INCREMENT PRIMARY KEY, diff --git a/src/DoresA/ip.py b/src/DoresA/ip.py new file mode 100644 index 0000000..8eaff44 --- /dev/null +++ b/src/DoresA/ip.py @@ -0,0 +1,74 @@ +import re + + +# proudly taken from https://stackoverflow.com/questions/319279/how-to-validate-ip-address-in-python +def is_valid_ipv4(ip): + """Validates IPv4 addresses. + """ + pattern = re.compile(r""" + ^ + (?: + # Dotted variants: + (?: + # Decimal 1-255 (no leading 0's) + [3-9]\d?|2(?:5[0-5]|[0-4]?\d)?|1\d{0,2} + | + 0x0*[0-9a-f]{1,2} # Hexadecimal 0x0 - 0xFF (possible leading 0's) + | + 0+[1-3]?[0-7]{0,2} # Octal 0 - 0377 (possible leading 0's) + ) + (?: # Repeat 0-3 times, separated by a dot + \. + (?: + [3-9]\d?|2(?:5[0-5]|[0-4]?\d)?|1\d{0,2} + | + 0x0*[0-9a-f]{1,2} + | + 0+[1-3]?[0-7]{0,2} + ) + ){0,3} + | + 0x0*[0-9a-f]{1,8} # Hexadecimal notation, 0x0 - 0xffffffff + | + 0+[0-3]?[0-7]{0,10} # Octal notation, 0 - 037777777777 + | + # Decimal notation, 1-4294967295: + 429496729[0-5]|42949672[0-8]\d|4294967[01]\d\d|429496[0-6]\d{3}| + 42949[0-5]\d{4}|4294[0-8]\d{5}|429[0-3]\d{6}|42[0-8]\d{7}| + 4[01]\d{8}|[1-3]\d{0,9}|[4-9]\d{0,8} + ) + $ + """, re.VERBOSE | re.IGNORECASE) + return pattern.match(ip) is not None + + +def is_valid_ipv6(ip): + """Validates IPv6 addresses. + """ + pattern = re.compile(r""" + ^ + \s* # Leading whitespace + (?!.*::.*::) # Only a single whildcard allowed + (?:(?!:)|:(?=:)) # Colon iff it would be part of a wildcard + (?: # Repeat 6 times: + [0-9a-f]{0,4} # A group of at most four hexadecimal digits + (?:(?<=::)|(? '2017-05-08 00:00:00' LIMIT 1; \ No newline at end of file diff --git a/src/DoresA/serialize_logs_to_db.py b/src/DoresA/serialize_logs_to_db.py index 68a160f..6fdf449 100644 --- a/src/DoresA/serialize_logs_to_db.py +++ b/src/DoresA/serialize_logs_to_db.py @@ -8,11 +8,10 @@ from progress.bar import Bar import db -# TODO environment this analysis_start_date = datetime.date(2017, 5, 1) -analysis_days_amount = 31 +analysis_days_amount = 7 # pdns_logs_path = 'data/' -pdns_logs_path = '/data/' +pdns_logs_path = '/run/media/felix/ext/2017.05/' # e.g. analysis_days = ['2017-04-07', '2017-04-08', '2017-04-09'] analysis_days = [(analysis_start_date + datetime.timedelta(days=x)).strftime('%Y-%m-%d') for x in @@ -29,7 +28,7 @@ def main(): # everything = {} # for log_file in ['data/pdns_capture.pcap-sgsgpdc0n9x-2017-04-07_00-00-02.csv.gz']: - + for day in range(analysis_days_amount): log_files_hour = get_log_files_for_hours_of_day(analysis_days[day]) # everything[day] = {} diff --git a/src/DoresA/time.py b/src/DoresA/time.py index 03ef5c5..4cae2da 100644 --- a/src/DoresA/time.py +++ b/src/DoresA/time.py @@ -18,22 +18,6 @@ def variance(a): return np.var(a) -def test_decision_tree(): - from sklearn.datasets import load_iris - from sklearn import tree - iris = load_iris() - clf = tree.DecisionTreeClassifier() - clf = clf.fit(iris.data, iris.target) # training set, manual classification - - # predict single or multiple sets with clf.predict([[]]) - - # visualize decision tree classifier - import graphviz - dot_data = tree.export_graphviz(clf, out_file=None) - graph = graphviz.Source(dot_data) - graph.render('iris', view=True) - - def test(): # a = np.array((1, 2, 3)) # b = np.array((0, 1, 2)) diff --git a/src/DoresA/train.py b/src/DoresA/train.py new file mode 100644 index 0000000..23cdc32 --- /dev/null +++ b/src/DoresA/train.py @@ -0,0 +1,160 @@ +from sklearn.datasets import load_iris +from sklearn import tree + +import numpy as np +import graphviz +import datetime +import time +import db +import domain +import ip +import location + +db_format_time = '%Y-%m-%d %H:%M:%S' + +train_start = datetime.date(2017, 5, 1) +train_end = datetime.date(2017, 5, 2) + + +def get_logs_from_db(): + results = db.mariadb_get_logs(train_start.strftime(db_format_time), train_end.strftime(db_format_time)) + + row = results.fetch_row(how=1) + + print("# entity: " + row[0]['domain']) + + features = prepare_features(row[0]) + + print(str(features)) + # while row: + # print("# entity: " + row[0]['domain']) + # + # features = prepare_features(row[0]) + # + # print(str(features)) + # + # row = results.fetch_row(how=1) + + +def prepare_features(entity): + # get all logs for the same domain + logs_for_domain = db.mariadb_get_logs_for_domain(entity['domain'], train_start.strftime(db_format_time), + train_end.strftime(db_format_time)) + ttls = [log['ttl'] for log in logs_for_domain] + ips = [log['record'] for log in logs_for_domain] # TODO check if valid ip address + + domains_with_same_ip = [] + # get all logs for the same ip if valid ip + if ip.is_valid_ipv4(entity['record']) or ip.is_valid_ipv6(entity['record']): + logs_for_ip = db.mariadb_get_logs_for_ip(entity['record'], train_start.strftime(db_format_time), + train_end.strftime(db_format_time)) + domains_with_same_ip = [log['domain'] for log in logs_for_ip] + + # feature 1: Short Life + + short_life = 0 + + # feature 2: Daily Similarity + + daily_similarity = 0 + + # feature 3: Repeating Patterns + + repeating_patterns = 0 + + # feature 4: Access ratio + + access_ratio = 0 + + # feature 5: Number of distinct IP addresses + + distinct_ips = len(list(set(ips))) + + # feature 6: Number of distinct countries + + distinct_countries = len(list(set([location.get_country_by_ip(ip) for ip in list(set(ips))]))) + + # feature 7: Number of (distinct) domains share the IP with + + distinct_domains_with_same_ip = len(list(set(domains_with_same_ip))) + + # feature 8: Reverse DNS query results + + reverse_dns_result = 0 + + # feature 9: Average TTL + + average_ttl = sum(ttls) / len(ttls) + + # feature 10: Standard Deviation of TTL + + standard_deviation = 0 + + # feature 11: Number of distinct TTL values + + distinct_ttl = len(list(set(ttls))) + + # feature 12: Number of TTL change + + ttl_changes = 0 + + # feature 13: Percentage usage of specific TTL ranges + # specific ranges: [0, 1], [1, 100], [100, 300], [300, 900], [900, inf] + # TODO decide if 5 individual features make a difference + + ttl = entity['ttl'] + specific_ttl_ranges = 4 # default is [900, inf] + + if 0 < ttl <= 1: + specific_ttl_ranges = 0 + elif 1 < ttl <= 100: + specific_ttl_ranges = 1 + elif 100 < ttl <= 300: + specific_ttl_ranges = 2 + elif 300 < ttl <= 900: + specific_ttl_ranges = 3 + + # feature 14: % of numerical characters + + numerical_characters_percent = domain.ratio_numerical_to_alpha(entity['domain']) + + # feature 15: % of the length of the LMS + + lms_percent = domain.ratio_lms_to_fqdn(entity['domain']) + + all_features = np.array([ + short_life, daily_similarity, repeating_patterns, access_ratio, distinct_ips, distinct_countries, + distinct_domains_with_same_ip, reverse_dns_result, average_ttl, standard_deviation, distinct_ttl, ttl_changes, + specific_ttl_ranges, numerical_characters_percent, lms_percent + ]) + + return all_features + + +def test(): + start = time.time() + print('starting training ' + str(start)) + + get_logs_from_db() + + print('total duration: ' + str(time.time() - start) + 's') + db.close() + + # db.mariadb_get_distinct_ttl('d2s45lswxaswrw.cloudfront.net', train_start.strftime(db_format_time), train_end.strftime(db_format_time)) + + +def flow(): + iris = load_iris() + clf = tree.DecisionTreeClassifier() + clf = clf.fit(iris.data, iris.target) # training set, manual classification + + # predict single or multiple sets with clf.predict([[]]) + + # visualize decision tree classifier + dot_data = tree.export_graphviz(clf, out_file=None) + graph = graphviz.Source(dot_data) + graph.render('test', view=True) + + +if __name__ == "__main__": + test()