performance, continue analysis where left off (on file basis)
This commit is contained in:
@@ -1,15 +1,25 @@
|
|||||||
|
import pickle
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
|
||||||
def load_whitelist():
|
def generate_whitelist():
|
||||||
filename = 'res/benign_domains.txt'
|
filename = 'res/benign_domains.txt'
|
||||||
whitelist = []
|
whitelist = []
|
||||||
for item in open(filename).read().splitlines():
|
for item in open(filename).read().splitlines():
|
||||||
if item not in whitelist:
|
if item not in whitelist:
|
||||||
whitelist.append(item)
|
whitelist.append(item)
|
||||||
|
whitelist_pkl = open('whitelist.pkl', 'wb')
|
||||||
|
pickle.dump(whitelist, whitelist_pkl)
|
||||||
|
whitelist_pkl.close()
|
||||||
|
|
||||||
|
|
||||||
|
def load_whitelist():
|
||||||
|
whitelist_pkl = open('whitelist.pkl', 'rb')
|
||||||
|
whitelist = pickle.load(whitelist_pkl)
|
||||||
return whitelist
|
return whitelist
|
||||||
|
|
||||||
|
|
||||||
def load_blacklist():
|
def generate_blacklist():
|
||||||
filename = 'res/malicious_domains.txt'
|
filename = 'res/malicious_domains.txt'
|
||||||
blacklist = []
|
blacklist = []
|
||||||
for item in open(filename).read().splitlines():
|
for item in open(filename).read().splitlines():
|
||||||
@@ -17,6 +27,14 @@ def load_blacklist():
|
|||||||
# do not add to black (as EXPOSURE is handling)
|
# do not add to black (as EXPOSURE is handling)
|
||||||
if item not in blacklist and item not in whitelist:
|
if item not in blacklist and item not in whitelist:
|
||||||
blacklist.append(item)
|
blacklist.append(item)
|
||||||
|
blacklist_pkl = open('blacklist.pkl', 'wb')
|
||||||
|
pickle.dump(blacklist, blacklist_pkl)
|
||||||
|
blacklist_pkl.close()
|
||||||
|
|
||||||
|
|
||||||
|
def load_blacklist():
|
||||||
|
blacklist_pkl = open('blacklist.pkl', 'rb')
|
||||||
|
blacklist = pickle.load(blacklist_pkl)
|
||||||
return blacklist
|
return blacklist
|
||||||
|
|
||||||
|
|
||||||
@@ -25,7 +43,7 @@ def is_malicious(domain):
|
|||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
print('blacklist length: ' + str(len(blacklist)))
|
# print('blacklist length: ' + str(len(blacklist)))
|
||||||
|
|
||||||
# dupes = [x for n, x in enumerate(whitelist) if x in whitelist[:n]]
|
# dupes = [x for n, x in enumerate(whitelist) if x in whitelist[:n]]
|
||||||
# print(dupes)
|
# print(dupes)
|
||||||
@@ -37,7 +55,14 @@ def test():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.isfile('whitelist.pkl'):
|
||||||
|
generate_whitelist()
|
||||||
|
|
||||||
whitelist = load_whitelist()
|
whitelist = load_whitelist()
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.isfile('blacklist.pkl'):
|
||||||
|
generate_blacklist()
|
||||||
blacklist = load_blacklist()
|
blacklist = load_blacklist()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -7,9 +7,19 @@ train_end = datetime.date(2017, 9, 7)
|
|||||||
|
|
||||||
analysis_start_date = datetime.date(2017, 9, 1)
|
analysis_start_date = datetime.date(2017, 9, 1)
|
||||||
analysis_days_amount = 7
|
analysis_days_amount = 7
|
||||||
#pdns_logs_path = '/home/felix/pdns/'
|
|
||||||
pdns_logs_path = '/mnt/old/2017'
|
|
||||||
|
|
||||||
# e.g. analysis_days = ['2017-04-07', '2017-04-08', '2017-04-09']
|
# e.g. analysis_days = ['2017-04-07', '2017-04-08', '2017-04-09']
|
||||||
analysis_days = [(analysis_start_date + datetime.timedelta(days=x)).strftime(format_date) for x in
|
analysis_days = [(analysis_start_date + datetime.timedelta(days=x)).strftime(format_date) for x in
|
||||||
range(analysis_days_amount)]
|
range(analysis_days_amount)]
|
||||||
|
|
||||||
|
serialized_path = 'serialized/'
|
||||||
|
#pdns_logs_path = '/home/felix/pdns/'
|
||||||
|
pdns_logs_path = '/mnt/old/2017/'
|
||||||
|
|
||||||
|
# 32 cors on janus
|
||||||
|
num_cores = 32
|
||||||
|
|
||||||
|
gz = False
|
||||||
|
multiprocessed = True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -44,8 +44,6 @@ def serialize_logs_to_db():
|
|||||||
|
|
||||||
# for log_file in ['data/pdns_capture.pc
|
# for log_file in ['data/pdns_capture.pc
|
||||||
|
|
||||||
# TODOap-sgsgpdc0n9x-2017-04-07_00-00-02.csv.gz']:
|
|
||||||
|
|
||||||
for day in range(analysis_days_amount):
|
for day in range(analysis_days_amount):
|
||||||
log_files_hour = get_log_files_for_hours_of_day(analysis_days[day])
|
log_files_hour = get_log_files_for_hours_of_day(analysis_days[day])
|
||||||
# everything[day] = {}
|
# everything[day] = {}
|
||||||
@@ -55,7 +53,7 @@ def serialize_logs_to_db():
|
|||||||
for hour in progress_bar(range(24)):
|
for hour in progress_bar(range(24)):
|
||||||
progress_bar.next()
|
progress_bar.next()
|
||||||
for hour_files in log_files_hour[hour]:
|
for hour_files in log_files_hour[hour]:
|
||||||
with gzip.open(hour_files, 'rt', newline='') as file:
|
with open(hour_files, 'rt') as file:
|
||||||
reader = csv.reader(file)
|
reader = csv.reader(file)
|
||||||
all_rows = list(reader)
|
all_rows = list(reader)
|
||||||
|
|
||||||
@@ -91,7 +89,7 @@ def batch(iterable, n=1):
|
|||||||
# raise Exception('Log files inconsistency')
|
# raise Exception('Log files inconsistency')
|
||||||
|
|
||||||
|
|
||||||
def get_log_files_for_range_of_day(date, minutes_range, gz=True):
|
def get_log_files_for_range_of_day(date, minutes_range, gz=False):
|
||||||
slot_files = {}
|
slot_files = {}
|
||||||
slots_amount = int(1440 / minutes_range)
|
slots_amount = int(1440 / minutes_range)
|
||||||
|
|
||||||
@@ -103,7 +101,7 @@ def get_log_files_for_range_of_day(date, minutes_range, gz=True):
|
|||||||
slot_files[slot] = 'data/*' + date + '_' + time_range + '*.csv' + ('.gz' if gz else '')
|
slot_files[slot] = 'data/*' + date + '_' + time_range + '*.csv' + ('.gz' if gz else '')
|
||||||
|
|
||||||
|
|
||||||
def get_log_files_for_hours_of_day(date, gz=True):
|
def get_log_files_for_hours_of_day(date, gz=False):
|
||||||
slot_files = {}
|
slot_files = {}
|
||||||
slots_amount = 24
|
slots_amount = 24
|
||||||
|
|
||||||
@@ -113,7 +111,7 @@ def get_log_files_for_hours_of_day(date, gz=True):
|
|||||||
return slot_files
|
return slot_files
|
||||||
|
|
||||||
|
|
||||||
def get_log_files_for_day(date, gz=True):
|
def get_log_files_for_day(date, gz=False):
|
||||||
log_files = 'data/*' + date + '*.csv.gz' + ('.gz' if gz else '')
|
log_files = 'data/*' + date + '*.csv.gz' + ('.gz' if gz else '')
|
||||||
|
|
||||||
return glob.glob(log_files)
|
return glob.glob(log_files)
|
||||||
|
|||||||
@@ -11,14 +11,34 @@ logger.setLevel(logging.INFO)
|
|||||||
logger.debug('connecting redis')
|
logger.debug('connecting redis')
|
||||||
|
|
||||||
redis_host = 'localhost'
|
redis_host = 'localhost'
|
||||||
# TODO name ports properly
|
|
||||||
redis_start_port_first_seen = 2337
|
|
||||||
redis_start_port_last_seen = 2340
|
|
||||||
redis_port_reverse = 2343
|
|
||||||
redis_port_4 = 2344
|
|
||||||
redis_port_ttl = 2345
|
|
||||||
bucket_mod = 1048576 # = 2 ** 20
|
bucket_mod = 1048576 # = 2 ** 20
|
||||||
|
|
||||||
|
# redis_start_port_first_seen = 2337
|
||||||
|
# redis_start_port_last_seen = 2340
|
||||||
|
# redis_port_reverse = 2343
|
||||||
|
# redis_port_4 = 2344
|
||||||
|
# redis_port_ttl = 2345
|
||||||
|
# redis_f = Redis(unix_socket_path=base + 'redis_local_f.sock')
|
||||||
|
# redis_f1 = Redis(redis_host, port=2338)
|
||||||
|
# redis_f2 = Redis(redis_host, port=2339)
|
||||||
|
# redis_l = Redis(redis_host, port=2340)
|
||||||
|
# redis_l1 = Redis(redis_host, port=2341)
|
||||||
|
# redis_l2 = Redis(redis_host, port=2342)
|
||||||
|
# redis_r = Redis(redis_host, port=2343)
|
||||||
|
# redis_v = Redis(redis_host, port=2344)
|
||||||
|
# redis_t = Redis(redis_host, port=2345)
|
||||||
|
|
||||||
|
base = '/home/tek/felix/redis/redis/'
|
||||||
|
redis_f = Redis(unix_socket_path=base + 'redis_local_f.sock')
|
||||||
|
redis_f1 = Redis(unix_socket_path=base + 'redis_local_f2.sock')
|
||||||
|
redis_f2 = Redis(unix_socket_path=base + 'redis_local_f3.sock')
|
||||||
|
redis_l = Redis(unix_socket_path=base + 'redis_local_l.sock')
|
||||||
|
redis_l1 = Redis(unix_socket_path=base + 'redis_local_l2.sock')
|
||||||
|
redis_l2 = Redis(unix_socket_path=base + 'redis_local_l3.sock')
|
||||||
|
redis_r = Redis(unix_socket_path=base + 'redis_local_r.sock')
|
||||||
|
redis_v = Redis(unix_socket_path=base + 'redis_local_v.sock')
|
||||||
|
redis_t = Redis(unix_socket_path=base + 'redis_local_t.sock')
|
||||||
|
|
||||||
|
|
||||||
def _get_redis_shard(rrname):
|
def _get_redis_shard(rrname):
|
||||||
bucket = crc32(rrname.encode('utf-8')) % bucket_mod # convert string to byte array
|
bucket = crc32(rrname.encode('utf-8')) % bucket_mod # convert string to byte array
|
||||||
@@ -29,11 +49,15 @@ def _get_redis_shard(rrname):
|
|||||||
|
|
||||||
def get_stats_for_domain(rrname, rrtype='A'):
|
def get_stats_for_domain(rrname, rrtype='A'):
|
||||||
bucket, shard = _get_redis_shard(rrname)
|
bucket, shard = _get_redis_shard(rrname)
|
||||||
|
local_redis_f = redis_f
|
||||||
|
local_redis_l = redis_l
|
||||||
|
|
||||||
redis_f = Redis(redis_host, port=redis_start_port_first_seen + shard)
|
if shard == 1:
|
||||||
redis_l = Redis(redis_host, port=redis_start_port_last_seen + shard)
|
local_redis_f = redis_f1
|
||||||
redis_r = Redis(redis_host, port=redis_port_reverse)
|
local_redis_l = redis_l1
|
||||||
redis_t = Redis(redis_host, port=redis_port_ttl)
|
elif shard == 2:
|
||||||
|
local_redis_f = redis_f2
|
||||||
|
local_redis_l = redis_l2
|
||||||
|
|
||||||
ttls_b = redis_t.lrange('t:{}:{}'.format(rrname, rrtype), 0, -1)
|
ttls_b = redis_t.lrange('t:{}:{}'.format(rrname, rrtype), 0, -1)
|
||||||
|
|
||||||
@@ -55,8 +79,8 @@ def get_stats_for_domain(rrname, rrtype='A'):
|
|||||||
logger.debug('res: ' + str(result))
|
logger.debug('res: ' + str(result))
|
||||||
logger.debug('id: ' + str('f' + str(bucket)))
|
logger.debug('id: ' + str('f' + str(bucket)))
|
||||||
|
|
||||||
t_f = float(unpack('<L', redis_f.hget('f' + str(bucket), rrname + ':' + result))[0])
|
t_f = float(unpack('<L', local_redis_f.hget('f' + str(bucket), rrname + ':' + result))[0])
|
||||||
t_l = float(unpack('<L', redis_l.hget('l' + str(bucket), rrname + ':' + result))[0])
|
t_l = float(unpack('<L', local_redis_l.hget('l' + str(bucket), rrname + ':' + result))[0])
|
||||||
|
|
||||||
t_f = datetime.utcfromtimestamp(t_f).strftime('%Y-%m-%dT%H:%M:%SZ')
|
t_f = datetime.utcfromtimestamp(t_f).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||||
t_l = datetime.utcfromtimestamp(t_l).strftime('%Y-%m-%dT%H:%M:%SZ')
|
t_l = datetime.utcfromtimestamp(t_l).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||||
@@ -65,7 +89,7 @@ def get_stats_for_domain(rrname, rrtype='A'):
|
|||||||
'rrname': rrname,
|
'rrname': rrname,
|
||||||
'rrtype': rrtype.replace('rrtype_', ''),
|
'rrtype': rrtype.replace('rrtype_', ''),
|
||||||
'rdata': result,
|
'rdata': result,
|
||||||
'ttls': list(map(int, ttls)), # TODO do we need to convert iterable of type map to list? (e.g. for length)
|
'ttls': list(map(int, ttls)),
|
||||||
'time_first': t_f,
|
'time_first': t_f,
|
||||||
'time_last': t_l
|
'time_last': t_l
|
||||||
})
|
})
|
||||||
@@ -75,7 +99,6 @@ def get_stats_for_domain(rrname, rrtype='A'):
|
|||||||
|
|
||||||
|
|
||||||
def get_all_ips_for_domain(rrname):
|
def get_all_ips_for_domain(rrname):
|
||||||
redis_r = Redis(redis_host, port=redis_port_reverse)
|
|
||||||
|
|
||||||
# remove trailing slash
|
# remove trailing slash
|
||||||
rrname = rrname.rstrip('/')
|
rrname = rrname.rstrip('/')
|
||||||
@@ -94,7 +117,6 @@ def get_all_ips_for_domain(rrname):
|
|||||||
|
|
||||||
|
|
||||||
def get_stats_for_ip(rdata):
|
def get_stats_for_ip(rdata):
|
||||||
redis_v = Redis(redis_host, port=redis_port_4)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = []
|
results = []
|
||||||
@@ -102,11 +124,18 @@ def get_stats_for_ip(rdata):
|
|||||||
result = result.decode('utf-8') # convert to string (python 3)
|
result = result.decode('utf-8') # convert to string (python 3)
|
||||||
bucket, shard = _get_redis_shard(result)
|
bucket, shard = _get_redis_shard(result)
|
||||||
|
|
||||||
redis_f = Redis(redis_host, port=redis_start_port_first_seen + shard)
|
local_redis_f = redis_f
|
||||||
redis_l = Redis(redis_host, port=redis_start_port_last_seen + shard)
|
local_redis_l = redis_l
|
||||||
|
|
||||||
t_f = float(unpack('<L', redis_f.hget('f' + str(bucket), result + ':' + rdata))[0])
|
if shard == 1:
|
||||||
t_l = float(unpack('<L', redis_l.hget('l' + str(bucket), result + ':' + rdata))[0])
|
local_redis_f = redis_f1
|
||||||
|
local_redis_l = redis_l1
|
||||||
|
elif shard == 2:
|
||||||
|
local_redis_f = redis_f2
|
||||||
|
local_redis_l = redis_l2
|
||||||
|
|
||||||
|
t_f = float(unpack('<L', local_redis_f.hget('f' + str(bucket), result + ':' + rdata))[0])
|
||||||
|
t_l = float(unpack('<L', local_redis_l.hget('l' + str(bucket), result + ':' + rdata))[0])
|
||||||
t_f = datetime.utcfromtimestamp(t_f).strftime('%Y-%m-%dT%H:%M:%SZ')
|
t_f = datetime.utcfromtimestamp(t_f).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||||
t_l = datetime.utcfromtimestamp(t_l).strftime('%Y-%m-%dT%H:%M:%SZ')
|
t_l = datetime.utcfromtimestamp(t_l).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import json
|
|||||||
from geoip2 import database, errors
|
from geoip2 import database, errors
|
||||||
|
|
||||||
logger = logging.getLogger('ip')
|
logger = logging.getLogger('ip')
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
top_100_hosters = json.load(open('res/asns.json'))
|
top_100_hosters = json.load(open('res/asns.json'))
|
||||||
top_100_hosters_asns = []
|
top_100_hosters_asns = []
|
||||||
@@ -12,27 +12,34 @@ for hoster in top_100_hosters.values():
|
|||||||
top_100_hosters_asns.extend(hoster['asns'])
|
top_100_hosters_asns.extend(hoster['asns'])
|
||||||
|
|
||||||
|
|
||||||
|
country_reader = database.Reader('res/GeoLite2-Country_20170905/GeoLite2-Country.mmdb')
|
||||||
|
asn_reader = database.Reader('res/GeoLite2-ASN_20171107/GeoLite2-ASN.mmdb')
|
||||||
|
|
||||||
|
|
||||||
def is_hoster_ip(ip):
|
def is_hoster_ip(ip):
|
||||||
return str(get_isp_by_ip(ip)) in top_100_hosters_asns
|
return str(get_isp_by_ip(ip)) in top_100_hosters_asns
|
||||||
|
|
||||||
|
|
||||||
# if specific country not available in database take continent instead
|
# if specific country not available in database take continent instead
|
||||||
def get_country_by_ip(ip):
|
def get_country_by_ip(ip):
|
||||||
with database.Reader('res/GeoLite2-Country_20170905/GeoLite2-Country.mmdb') as reader:
|
try:
|
||||||
result = reader.country(ip)
|
result = country_reader.country(ip)
|
||||||
if not result.country:
|
except errors.AddressNotFoundError:
|
||||||
return result.continent.geoname_id
|
logger.debug('address not in location database ' + str(ip))
|
||||||
else:
|
return 0
|
||||||
return result.country.geoname_id
|
|
||||||
|
if not result.country:
|
||||||
|
return result.continent.geoname_id
|
||||||
|
else:
|
||||||
|
return result.country.geoname_id
|
||||||
|
|
||||||
|
|
||||||
def get_isp_by_ip(ip):
|
def get_isp_by_ip(ip):
|
||||||
with database.Reader('res/GeoLite2-ASN_20171107/GeoLite2-ASN.mmdb') as reader:
|
try:
|
||||||
try:
|
result = asn_reader.asn(ip)
|
||||||
result = reader.asn(ip)
|
return result.autonomous_system_number
|
||||||
return result.autonomous_system_number
|
except errors.AddressNotFoundError:
|
||||||
except errors.AddressNotFoundError:
|
logger.debug('address not in isp database ' + str(ip))
|
||||||
logger.debug('address not in isp database')
|
|
||||||
|
|
||||||
|
|
||||||
def ratio_ips_hoster(ips):
|
def ratio_ips_hoster(ips):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import datetime
|
import datetime
|
||||||
# logfile = 'analysis_' + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M') + '.log' # https://stackoverflow.com/questions/1943747/python-logging-before-you-run-logging-basicconfig
|
logfile = 'analysis_' + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M') + '.log' # https://stackoverflow.com/questions/1943747/python-logging-before-you-run-logging-basicconfig
|
||||||
# logging.basicConfig(filename=logfile, filemode='w') # important to set basicConfig only once for all modules
|
logging.basicConfig(filename=logfile, filemode='w') # important to set basicConfig only once for all modules
|
||||||
logging.basicConfig()
|
logging.basicConfig()
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -21,11 +21,14 @@ import pickle
|
|||||||
import classify
|
import classify
|
||||||
import config
|
import config
|
||||||
import traceback
|
import traceback
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
# import db_sql
|
# import db_sql
|
||||||
|
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn import tree
|
from sklearn import tree
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger('train')
|
logger = logging.getLogger('train')
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
@@ -35,65 +38,126 @@ train_end = config.train_end
|
|||||||
|
|
||||||
id_upto = 379283817
|
id_upto = 379283817
|
||||||
|
|
||||||
# record types that should be analysed (e.g. only A)
|
|
||||||
record_types = ['A']
|
|
||||||
|
|
||||||
|
|
||||||
# id_upto = db.mariadb_get_nearest_id(train_end.strftime(db_format_time))
|
# id_upto = db.mariadb_get_nearest_id(train_end.strftime(db_format_time))
|
||||||
|
|
||||||
def generate_features_and_classify():
|
def generate_features_and_classify():
|
||||||
start = time.time()
|
|
||||||
logger.info('feature generation start: ' + str(start))
|
|
||||||
|
|
||||||
all_features = []
|
all_features = []
|
||||||
all_classifications = []
|
all_classifications = []
|
||||||
for day in range(config.analysis_days_amount):
|
for day in range(config.analysis_days_amount):
|
||||||
# TODO dev
|
log_files_hour = csv_tools.get_log_files_for_hours_of_day(config.analysis_days[day], gz=config.gz)
|
||||||
# log_files_hour = csv_tools.get_log_files_for_hours_of_day(config.analysis_days[day], gz=False)
|
|
||||||
log_files_hour = csv_tools.get_log_files_for_hours_of_day(config.analysis_days[day])
|
|
||||||
|
|
||||||
|
hour_start = time.time()
|
||||||
progress_bar = progressbar.ProgressBar()
|
progress_bar = progressbar.ProgressBar()
|
||||||
|
|
||||||
for hour in progress_bar(range(24)):
|
for hour in progress_bar(range(24)):
|
||||||
for hour_files in log_files_hour[hour]:
|
file_start = time.time()
|
||||||
# TODO dev
|
# TODO TODO debug only one file per hour
|
||||||
# with open(hour_files, 'rt') as file:
|
log_files_hour[hour] = [log_files_hour[hour][0]]
|
||||||
with gzip.open(hour_files, 'rt', newline='') as file:
|
for hour_chunkfile in log_files_hour[hour]:
|
||||||
reader = csv.reader(file)
|
chunk_basename = os.path.basename(hour_chunkfile)
|
||||||
|
if os.path.isfile(config.serialized_path + chunk_basename + '_feat.pkl') and os.path.isfile(config.serialized_path + chunk_basename + '_class.pkl'):
|
||||||
|
logger.info('chunkfile already processed: ' + str(chunk_basename))
|
||||||
|
else:
|
||||||
|
with open_file(hour_chunkfile) as file:
|
||||||
|
reader = csv.reader(file)
|
||||||
|
|
||||||
for row in reader:
|
if not config.multiprocessed:
|
||||||
if row[2] in record_types:
|
for row in reader:
|
||||||
entity = {'timestamp': row[0], 'domain': row[1], 'type': row[2],
|
if row[2] == 'A':
|
||||||
'record': row[3], 'ttl': row[4]}
|
try:
|
||||||
|
features, classification = process_row(row)
|
||||||
|
all_features.append(features)
|
||||||
|
all_classifications.append(classification)
|
||||||
|
except Exception:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error('Exception occured processing entity: ' + str(row))
|
||||||
|
else:
|
||||||
|
logger.info('start analysing file: ' + str(hour_chunkfile))
|
||||||
|
reader = csv.reader(file)
|
||||||
|
|
||||||
try:
|
# load file into memory
|
||||||
all_features.append(prepare_features_redis(entity))
|
rows = list(reader)
|
||||||
all_classifications.append(classify.is_malicious(entity['domain']))
|
rows = filter(lambda e: e[2] == 'A', rows)
|
||||||
except Exception:
|
|
||||||
logger.error(traceback.format_exc())
|
pool = multiprocessing.Pool(processes=config.num_cores)
|
||||||
logger.error('Exception occured processing entity: ' + str(entity))
|
|
||||||
# break
|
# https://stackoverflow.com/questions/41273960/python-3-does-pool-keep-the-original-order-of-data-passed-to-map
|
||||||
# break
|
# TODO check if len(features) == len(classification) before unzip
|
||||||
# break
|
features, classification = zip(*pool.map(process_row, rows))
|
||||||
# break
|
|
||||||
|
# TODO filter None values?
|
||||||
|
|
||||||
|
serialize(features, chunk_basename + '_feat.pkl')
|
||||||
|
serialize(classification, chunk_basename + '_class.pkl')
|
||||||
|
|
||||||
|
all_features.extend(features)
|
||||||
|
all_classifications.extend(classification)
|
||||||
|
logger.info('file took: ' + str(time.time() - file_start) + 's')
|
||||||
|
file_start = time.time()
|
||||||
|
logger.info('hour took approx: ' + str(time.time() - hour_start) + 's')
|
||||||
|
hour_start = time.time()
|
||||||
# iris = load_iris()
|
# iris = load_iris()
|
||||||
# return iris.data, iris.target
|
# return iris.data, iris.target
|
||||||
|
|
||||||
logger.info('feature generation duration: ' + str(time.time() - start) + 's')
|
|
||||||
return np.array(all_features), np.array(all_classifications)
|
return np.array(all_features), np.array(all_classifications)
|
||||||
|
|
||||||
|
|
||||||
|
def open_file(hour_files):
|
||||||
|
if config.gz:
|
||||||
|
return gzip.open(hour_files, 'rt', newline='')
|
||||||
|
else:
|
||||||
|
return open(hour_files, 'rt')
|
||||||
|
|
||||||
|
|
||||||
|
def process_row(row):
|
||||||
|
asd = time.time()
|
||||||
|
entity = {'timestamp': row[0], 'domain': row[1], 'type': row[2],
|
||||||
|
'record': row[3], 'ttl': row[4]}
|
||||||
|
features = None
|
||||||
|
classification = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
features = prepare_features_redis(entity)
|
||||||
|
classification = classify.is_malicious(row[1])
|
||||||
|
except Exception: # log trace to logger (+ show exact exception in multiprocessing)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error('Exception occured processing entity: ' + str(entity))
|
||||||
|
logger.info('ent took' + str(time.time() - asd))
|
||||||
|
return features, classification
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
checkpoint = start
|
||||||
logger.info('training start: ' + str(start))
|
logger.info('training start: ' + str(start))
|
||||||
|
|
||||||
features, classification = generate_features_and_classify()
|
features, classification = generate_features_and_classify()
|
||||||
|
|
||||||
# TODO save serialized features and classification
|
logger.info('feature preprocessing duration: ' + str(time.time() - checkpoint) + 's')
|
||||||
|
|
||||||
|
if len(features) == 0:
|
||||||
|
logger.info('no features available, check config')
|
||||||
|
exit()
|
||||||
|
else:
|
||||||
|
logger.info('# of features: ' + str(len(features)))
|
||||||
|
logger.info('# of classifications: ' + str(len(classification)))
|
||||||
|
|
||||||
|
checkpoint = time.time()
|
||||||
|
logger.info('serialization start: ' + str(checkpoint))
|
||||||
|
# save serialized features, classification
|
||||||
|
serialize(features, 'complete_feat_' + datetime.datetime.now().strftime(config.format_date) + '.pkl')
|
||||||
|
|
||||||
|
serialize(classification, 'complete_class_' + datetime.datetime.now().strftime(config.format_date) + '.pkl')
|
||||||
|
|
||||||
|
logger.info('serialization duration: ' + str(time.time() - checkpoint))
|
||||||
|
|
||||||
|
checkpoint = time.time()
|
||||||
|
logger.info('building decision tree start: ' + str(checkpoint))
|
||||||
decision_tree_model = tree.DecisionTreeClassifier()
|
decision_tree_model = tree.DecisionTreeClassifier()
|
||||||
decision_tree_model = decision_tree_model.fit(features, classification) # training set, manual classification
|
decision_tree_model = decision_tree_model.fit(features, classification) # training set, manual classification
|
||||||
|
|
||||||
|
logger.info('building decision tree duration: ' + str(time.time() - checkpoint))
|
||||||
|
|
||||||
# predict single or multiple sets with clf.predict([[]])
|
# predict single or multiple sets with clf.predict([[]])
|
||||||
|
|
||||||
# visualize decision tree classifier
|
# visualize decision tree classifier
|
||||||
@@ -102,10 +166,7 @@ def train():
|
|||||||
graph.render('plot' + datetime.datetime.now().strftime(config.format_date))
|
graph.render('plot' + datetime.datetime.now().strftime(config.format_date))
|
||||||
|
|
||||||
# dump trained decision tree classifier to file
|
# dump trained decision tree classifier to file
|
||||||
decision_tree_pkl_filename = 'dtc_' + datetime.datetime.now().strftime(config.format_date) + '.pkl'
|
serialize(decision_tree_model, 'dtc_' + datetime.datetime.now().strftime(config.format_date) + '.pkl')
|
||||||
decision_tree_model_pkl = open(decision_tree_pkl_filename, 'wb')
|
|
||||||
pickle.dump(decision_tree_model, decision_tree_model_pkl)
|
|
||||||
decision_tree_model_pkl.close()
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_features_redis(entity):
|
def prepare_features_redis(entity):
|
||||||
@@ -117,7 +178,8 @@ def prepare_features_redis(entity):
|
|||||||
logger.debug(domain_stats)
|
logger.debug(domain_stats)
|
||||||
|
|
||||||
if not domain_stats:
|
if not domain_stats:
|
||||||
logger.debug('no stats in redis for entity: ' + entity)
|
logger.debug('no stats in redis for entity: ' + str(entity))
|
||||||
|
return
|
||||||
|
|
||||||
domain_stats = domain_stats[0]
|
domain_stats = domain_stats[0]
|
||||||
|
|
||||||
@@ -125,6 +187,8 @@ def prepare_features_redis(entity):
|
|||||||
|
|
||||||
logger.debug('all ips seen for domain ' + str(ips))
|
logger.debug('all ips seen for domain ' + str(ips))
|
||||||
|
|
||||||
|
# TODO readd short lived and those features we are able to calc
|
||||||
|
|
||||||
# feature 5: Number of distinct IP addresses (0)
|
# feature 5: Number of distinct IP addresses (0)
|
||||||
|
|
||||||
distinct_ips = len(ips)
|
distinct_ips = len(ips)
|
||||||
@@ -199,7 +263,7 @@ def prepare_features_redis(entity):
|
|||||||
return all_features
|
return all_features
|
||||||
|
|
||||||
|
|
||||||
# TODO depreated
|
# TODO old
|
||||||
def get_logs_from_db():
|
def get_logs_from_db():
|
||||||
results = db_sql.mariadb_get_logs(id_upto)
|
results = db_sql.mariadb_get_logs(id_upto)
|
||||||
|
|
||||||
@@ -349,6 +413,12 @@ def prepare_features_mysql(entity):
|
|||||||
return all_features
|
return all_features
|
||||||
|
|
||||||
|
|
||||||
|
def serialize(obj, filename):
|
||||||
|
file = open(config.serialized_path + filename, 'wb')
|
||||||
|
pickle.dump(obj, file)
|
||||||
|
file.close()
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
start = time.time()
|
start = time.time()
|
||||||
logger.info('starting training ' + str(start))
|
logger.info('starting training ' + str(start))
|
||||||
|
|||||||
Reference in New Issue
Block a user