performance, continue analysis where left off (on file basis)

This commit is contained in:
2017-12-12 18:18:38 +01:00
parent 4b9b497fd5
commit 45a9e74c14
6 changed files with 221 additions and 82 deletions

View File

@@ -1,15 +1,25 @@
import pickle
import os.path
def load_whitelist():
def generate_whitelist():
filename = 'res/benign_domains.txt'
whitelist = []
for item in open(filename).read().splitlines():
if item not in whitelist:
whitelist.append(item)
whitelist_pkl = open('whitelist.pkl', 'wb')
pickle.dump(whitelist, whitelist_pkl)
whitelist_pkl.close()
def load_whitelist():
whitelist_pkl = open('whitelist.pkl', 'rb')
whitelist = pickle.load(whitelist_pkl)
return whitelist
def load_blacklist():
def generate_blacklist():
filename = 'res/malicious_domains.txt'
blacklist = []
for item in open(filename).read().splitlines():
@@ -17,6 +27,14 @@ def load_blacklist():
# do not add to black (as EXPOSURE is handling)
if item not in blacklist and item not in whitelist:
blacklist.append(item)
blacklist_pkl = open('blacklist.pkl', 'wb')
pickle.dump(blacklist, blacklist_pkl)
blacklist_pkl.close()
def load_blacklist():
blacklist_pkl = open('blacklist.pkl', 'rb')
blacklist = pickle.load(blacklist_pkl)
return blacklist
@@ -25,7 +43,7 @@ def is_malicious(domain):
def test():
print('blacklist length: ' + str(len(blacklist)))
# print('blacklist length: ' + str(len(blacklist)))
# dupes = [x for n, x in enumerate(whitelist) if x in whitelist[:n]]
# print(dupes)
@@ -37,7 +55,14 @@ def test():
pass
if not os.path.isfile('whitelist.pkl'):
generate_whitelist()
whitelist = load_whitelist()
if not os.path.isfile('blacklist.pkl'):
generate_blacklist()
blacklist = load_blacklist()
if __name__ == "__main__":

View File

@@ -7,9 +7,19 @@ train_end = datetime.date(2017, 9, 7)
analysis_start_date = datetime.date(2017, 9, 1)
analysis_days_amount = 7
#pdns_logs_path = '/home/felix/pdns/'
pdns_logs_path = '/mnt/old/2017'
# e.g. analysis_days = ['2017-04-07', '2017-04-08', '2017-04-09']
analysis_days = [(analysis_start_date + datetime.timedelta(days=x)).strftime(format_date) for x in
range(analysis_days_amount)]
serialized_path = 'serialized/'
#pdns_logs_path = '/home/felix/pdns/'
pdns_logs_path = '/mnt/old/2017/'
# 32 cors on janus
num_cores = 32
gz = False
multiprocessed = True

View File

@@ -44,8 +44,6 @@ def serialize_logs_to_db():
# for log_file in ['data/pdns_capture.pc
# TODOap-sgsgpdc0n9x-2017-04-07_00-00-02.csv.gz']:
for day in range(analysis_days_amount):
log_files_hour = get_log_files_for_hours_of_day(analysis_days[day])
# everything[day] = {}
@@ -55,7 +53,7 @@ def serialize_logs_to_db():
for hour in progress_bar(range(24)):
progress_bar.next()
for hour_files in log_files_hour[hour]:
with gzip.open(hour_files, 'rt', newline='') as file:
with open(hour_files, 'rt') as file:
reader = csv.reader(file)
all_rows = list(reader)
@@ -91,7 +89,7 @@ def batch(iterable, n=1):
# raise Exception('Log files inconsistency')
def get_log_files_for_range_of_day(date, minutes_range, gz=True):
def get_log_files_for_range_of_day(date, minutes_range, gz=False):
slot_files = {}
slots_amount = int(1440 / minutes_range)
@@ -103,7 +101,7 @@ def get_log_files_for_range_of_day(date, minutes_range, gz=True):
slot_files[slot] = 'data/*' + date + '_' + time_range + '*.csv' + ('.gz' if gz else '')
def get_log_files_for_hours_of_day(date, gz=True):
def get_log_files_for_hours_of_day(date, gz=False):
slot_files = {}
slots_amount = 24
@@ -113,7 +111,7 @@ def get_log_files_for_hours_of_day(date, gz=True):
return slot_files
def get_log_files_for_day(date, gz=True):
def get_log_files_for_day(date, gz=False):
log_files = 'data/*' + date + '*.csv.gz' + ('.gz' if gz else '')
return glob.glob(log_files)

View File

@@ -11,14 +11,34 @@ logger.setLevel(logging.INFO)
logger.debug('connecting redis')
redis_host = 'localhost'
# TODO name ports properly
redis_start_port_first_seen = 2337
redis_start_port_last_seen = 2340
redis_port_reverse = 2343
redis_port_4 = 2344
redis_port_ttl = 2345
bucket_mod = 1048576 # = 2 ** 20
# redis_start_port_first_seen = 2337
# redis_start_port_last_seen = 2340
# redis_port_reverse = 2343
# redis_port_4 = 2344
# redis_port_ttl = 2345
# redis_f = Redis(unix_socket_path=base + 'redis_local_f.sock')
# redis_f1 = Redis(redis_host, port=2338)
# redis_f2 = Redis(redis_host, port=2339)
# redis_l = Redis(redis_host, port=2340)
# redis_l1 = Redis(redis_host, port=2341)
# redis_l2 = Redis(redis_host, port=2342)
# redis_r = Redis(redis_host, port=2343)
# redis_v = Redis(redis_host, port=2344)
# redis_t = Redis(redis_host, port=2345)
base = '/home/tek/felix/redis/redis/'
redis_f = Redis(unix_socket_path=base + 'redis_local_f.sock')
redis_f1 = Redis(unix_socket_path=base + 'redis_local_f2.sock')
redis_f2 = Redis(unix_socket_path=base + 'redis_local_f3.sock')
redis_l = Redis(unix_socket_path=base + 'redis_local_l.sock')
redis_l1 = Redis(unix_socket_path=base + 'redis_local_l2.sock')
redis_l2 = Redis(unix_socket_path=base + 'redis_local_l3.sock')
redis_r = Redis(unix_socket_path=base + 'redis_local_r.sock')
redis_v = Redis(unix_socket_path=base + 'redis_local_v.sock')
redis_t = Redis(unix_socket_path=base + 'redis_local_t.sock')
def _get_redis_shard(rrname):
bucket = crc32(rrname.encode('utf-8')) % bucket_mod # convert string to byte array
@@ -29,11 +49,15 @@ def _get_redis_shard(rrname):
def get_stats_for_domain(rrname, rrtype='A'):
bucket, shard = _get_redis_shard(rrname)
local_redis_f = redis_f
local_redis_l = redis_l
redis_f = Redis(redis_host, port=redis_start_port_first_seen + shard)
redis_l = Redis(redis_host, port=redis_start_port_last_seen + shard)
redis_r = Redis(redis_host, port=redis_port_reverse)
redis_t = Redis(redis_host, port=redis_port_ttl)
if shard == 1:
local_redis_f = redis_f1
local_redis_l = redis_l1
elif shard == 2:
local_redis_f = redis_f2
local_redis_l = redis_l2
ttls_b = redis_t.lrange('t:{}:{}'.format(rrname, rrtype), 0, -1)
@@ -55,8 +79,8 @@ def get_stats_for_domain(rrname, rrtype='A'):
logger.debug('res: ' + str(result))
logger.debug('id: ' + str('f' + str(bucket)))
t_f = float(unpack('<L', redis_f.hget('f' + str(bucket), rrname + ':' + result))[0])
t_l = float(unpack('<L', redis_l.hget('l' + str(bucket), rrname + ':' + result))[0])
t_f = float(unpack('<L', local_redis_f.hget('f' + str(bucket), rrname + ':' + result))[0])
t_l = float(unpack('<L', local_redis_l.hget('l' + str(bucket), rrname + ':' + result))[0])
t_f = datetime.utcfromtimestamp(t_f).strftime('%Y-%m-%dT%H:%M:%SZ')
t_l = datetime.utcfromtimestamp(t_l).strftime('%Y-%m-%dT%H:%M:%SZ')
@@ -65,7 +89,7 @@ def get_stats_for_domain(rrname, rrtype='A'):
'rrname': rrname,
'rrtype': rrtype.replace('rrtype_', ''),
'rdata': result,
'ttls': list(map(int, ttls)), # TODO do we need to convert iterable of type map to list? (e.g. for length)
'ttls': list(map(int, ttls)),
'time_first': t_f,
'time_last': t_l
})
@@ -75,7 +99,6 @@ def get_stats_for_domain(rrname, rrtype='A'):
def get_all_ips_for_domain(rrname):
redis_r = Redis(redis_host, port=redis_port_reverse)
# remove trailing slash
rrname = rrname.rstrip('/')
@@ -94,7 +117,6 @@ def get_all_ips_for_domain(rrname):
def get_stats_for_ip(rdata):
redis_v = Redis(redis_host, port=redis_port_4)
try:
results = []
@@ -102,11 +124,18 @@ def get_stats_for_ip(rdata):
result = result.decode('utf-8') # convert to string (python 3)
bucket, shard = _get_redis_shard(result)
redis_f = Redis(redis_host, port=redis_start_port_first_seen + shard)
redis_l = Redis(redis_host, port=redis_start_port_last_seen + shard)
local_redis_f = redis_f
local_redis_l = redis_l
t_f = float(unpack('<L', redis_f.hget('f' + str(bucket), result + ':' + rdata))[0])
t_l = float(unpack('<L', redis_l.hget('l' + str(bucket), result + ':' + rdata))[0])
if shard == 1:
local_redis_f = redis_f1
local_redis_l = redis_l1
elif shard == 2:
local_redis_f = redis_f2
local_redis_l = redis_l2
t_f = float(unpack('<L', local_redis_f.hget('f' + str(bucket), result + ':' + rdata))[0])
t_l = float(unpack('<L', local_redis_l.hget('l' + str(bucket), result + ':' + rdata))[0])
t_f = datetime.utcfromtimestamp(t_f).strftime('%Y-%m-%dT%H:%M:%SZ')
t_l = datetime.utcfromtimestamp(t_l).strftime('%Y-%m-%dT%H:%M:%SZ')

View File

@@ -4,7 +4,7 @@ import json
from geoip2 import database, errors
logger = logging.getLogger('ip')
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.INFO)
top_100_hosters = json.load(open('res/asns.json'))
top_100_hosters_asns = []
@@ -12,27 +12,34 @@ for hoster in top_100_hosters.values():
top_100_hosters_asns.extend(hoster['asns'])
country_reader = database.Reader('res/GeoLite2-Country_20170905/GeoLite2-Country.mmdb')
asn_reader = database.Reader('res/GeoLite2-ASN_20171107/GeoLite2-ASN.mmdb')
def is_hoster_ip(ip):
return str(get_isp_by_ip(ip)) in top_100_hosters_asns
# if specific country not available in database take continent instead
def get_country_by_ip(ip):
with database.Reader('res/GeoLite2-Country_20170905/GeoLite2-Country.mmdb') as reader:
result = reader.country(ip)
if not result.country:
return result.continent.geoname_id
else:
return result.country.geoname_id
try:
result = country_reader.country(ip)
except errors.AddressNotFoundError:
logger.debug('address not in location database ' + str(ip))
return 0
if not result.country:
return result.continent.geoname_id
else:
return result.country.geoname_id
def get_isp_by_ip(ip):
with database.Reader('res/GeoLite2-ASN_20171107/GeoLite2-ASN.mmdb') as reader:
try:
result = reader.asn(ip)
return result.autonomous_system_number
except errors.AddressNotFoundError:
logger.debug('address not in isp database')
try:
result = asn_reader.asn(ip)
return result.autonomous_system_number
except errors.AddressNotFoundError:
logger.debug('address not in isp database ' + str(ip))
def ratio_ips_hoster(ips):

View File

@@ -1,7 +1,7 @@
import logging
import datetime
# logfile = 'analysis_' + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M') + '.log' # https://stackoverflow.com/questions/1943747/python-logging-before-you-run-logging-basicconfig
# logging.basicConfig(filename=logfile, filemode='w') # important to set basicConfig only once for all modules
logfile = 'analysis_' + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M') + '.log' # https://stackoverflow.com/questions/1943747/python-logging-before-you-run-logging-basicconfig
logging.basicConfig(filename=logfile, filemode='w') # important to set basicConfig only once for all modules
logging.basicConfig()
import logging
@@ -21,11 +21,14 @@ import pickle
import classify
import config
import traceback
import multiprocessing
import os
# import db_sql
from sklearn.datasets import load_iris
from sklearn import tree
logger = logging.getLogger('train')
logger.setLevel(logging.INFO)
@@ -35,65 +38,126 @@ train_end = config.train_end
id_upto = 379283817
# record types that should be analysed (e.g. only A)
record_types = ['A']
# id_upto = db.mariadb_get_nearest_id(train_end.strftime(db_format_time))
def generate_features_and_classify():
start = time.time()
logger.info('feature generation start: ' + str(start))
all_features = []
all_classifications = []
for day in range(config.analysis_days_amount):
# TODO dev
# log_files_hour = csv_tools.get_log_files_for_hours_of_day(config.analysis_days[day], gz=False)
log_files_hour = csv_tools.get_log_files_for_hours_of_day(config.analysis_days[day])
log_files_hour = csv_tools.get_log_files_for_hours_of_day(config.analysis_days[day], gz=config.gz)
hour_start = time.time()
progress_bar = progressbar.ProgressBar()
for hour in progress_bar(range(24)):
for hour_files in log_files_hour[hour]:
# TODO dev
# with open(hour_files, 'rt') as file:
with gzip.open(hour_files, 'rt', newline='') as file:
reader = csv.reader(file)
file_start = time.time()
# TODO TODO debug only one file per hour
log_files_hour[hour] = [log_files_hour[hour][0]]
for hour_chunkfile in log_files_hour[hour]:
chunk_basename = os.path.basename(hour_chunkfile)
if os.path.isfile(config.serialized_path + chunk_basename + '_feat.pkl') and os.path.isfile(config.serialized_path + chunk_basename + '_class.pkl'):
logger.info('chunkfile already processed: ' + str(chunk_basename))
else:
with open_file(hour_chunkfile) as file:
reader = csv.reader(file)
for row in reader:
if row[2] in record_types:
entity = {'timestamp': row[0], 'domain': row[1], 'type': row[2],
'record': row[3], 'ttl': row[4]}
if not config.multiprocessed:
for row in reader:
if row[2] == 'A':
try:
features, classification = process_row(row)
all_features.append(features)
all_classifications.append(classification)
except Exception:
logger.error(traceback.format_exc())
logger.error('Exception occured processing entity: ' + str(row))
else:
logger.info('start analysing file: ' + str(hour_chunkfile))
reader = csv.reader(file)
try:
all_features.append(prepare_features_redis(entity))
all_classifications.append(classify.is_malicious(entity['domain']))
except Exception:
logger.error(traceback.format_exc())
logger.error('Exception occured processing entity: ' + str(entity))
# break
# break
# break
# break
# load file into memory
rows = list(reader)
rows = filter(lambda e: e[2] == 'A', rows)
pool = multiprocessing.Pool(processes=config.num_cores)
# https://stackoverflow.com/questions/41273960/python-3-does-pool-keep-the-original-order-of-data-passed-to-map
# TODO check if len(features) == len(classification) before unzip
features, classification = zip(*pool.map(process_row, rows))
# TODO filter None values?
serialize(features, chunk_basename + '_feat.pkl')
serialize(classification, chunk_basename + '_class.pkl')
all_features.extend(features)
all_classifications.extend(classification)
logger.info('file took: ' + str(time.time() - file_start) + 's')
file_start = time.time()
logger.info('hour took approx: ' + str(time.time() - hour_start) + 's')
hour_start = time.time()
# iris = load_iris()
# return iris.data, iris.target
logger.info('feature generation duration: ' + str(time.time() - start) + 's')
return np.array(all_features), np.array(all_classifications)
def open_file(hour_files):
if config.gz:
return gzip.open(hour_files, 'rt', newline='')
else:
return open(hour_files, 'rt')
def process_row(row):
asd = time.time()
entity = {'timestamp': row[0], 'domain': row[1], 'type': row[2],
'record': row[3], 'ttl': row[4]}
features = None
classification = None
try:
features = prepare_features_redis(entity)
classification = classify.is_malicious(row[1])
except Exception: # log trace to logger (+ show exact exception in multiprocessing)
logger.error(traceback.format_exc())
logger.error('Exception occured processing entity: ' + str(entity))
logger.info('ent took' + str(time.time() - asd))
return features, classification
def train():
start = time.time()
checkpoint = start
logger.info('training start: ' + str(start))
features, classification = generate_features_and_classify()
# TODO save serialized features and classification
logger.info('feature preprocessing duration: ' + str(time.time() - checkpoint) + 's')
if len(features) == 0:
logger.info('no features available, check config')
exit()
else:
logger.info('# of features: ' + str(len(features)))
logger.info('# of classifications: ' + str(len(classification)))
checkpoint = time.time()
logger.info('serialization start: ' + str(checkpoint))
# save serialized features, classification
serialize(features, 'complete_feat_' + datetime.datetime.now().strftime(config.format_date) + '.pkl')
serialize(classification, 'complete_class_' + datetime.datetime.now().strftime(config.format_date) + '.pkl')
logger.info('serialization duration: ' + str(time.time() - checkpoint))
checkpoint = time.time()
logger.info('building decision tree start: ' + str(checkpoint))
decision_tree_model = tree.DecisionTreeClassifier()
decision_tree_model = decision_tree_model.fit(features, classification) # training set, manual classification
logger.info('building decision tree duration: ' + str(time.time() - checkpoint))
# predict single or multiple sets with clf.predict([[]])
# visualize decision tree classifier
@@ -102,10 +166,7 @@ def train():
graph.render('plot' + datetime.datetime.now().strftime(config.format_date))
# dump trained decision tree classifier to file
decision_tree_pkl_filename = 'dtc_' + datetime.datetime.now().strftime(config.format_date) + '.pkl'
decision_tree_model_pkl = open(decision_tree_pkl_filename, 'wb')
pickle.dump(decision_tree_model, decision_tree_model_pkl)
decision_tree_model_pkl.close()
serialize(decision_tree_model, 'dtc_' + datetime.datetime.now().strftime(config.format_date) + '.pkl')
def prepare_features_redis(entity):
@@ -117,7 +178,8 @@ def prepare_features_redis(entity):
logger.debug(domain_stats)
if not domain_stats:
logger.debug('no stats in redis for entity: ' + entity)
logger.debug('no stats in redis for entity: ' + str(entity))
return
domain_stats = domain_stats[0]
@@ -125,6 +187,8 @@ def prepare_features_redis(entity):
logger.debug('all ips seen for domain ' + str(ips))
# TODO readd short lived and those features we are able to calc
# feature 5: Number of distinct IP addresses (0)
distinct_ips = len(ips)
@@ -199,7 +263,7 @@ def prepare_features_redis(entity):
return all_features
# TODO depreated
# TODO old
def get_logs_from_db():
results = db_sql.mariadb_get_logs(id_upto)
@@ -349,6 +413,12 @@ def prepare_features_mysql(entity):
return all_features
def serialize(obj, filename):
file = open(config.serialized_path + filename, 'wb')
pickle.dump(obj, file)
file.close()
def test():
start = time.time()
logger.info('starting training ' + str(start))