Skip to content

Commit d4f221e

Browse files
authored
Merge pull request #966 from Sekhar-Kumar-Dash/patch-59
Created unittest for profile_handler.py
2 parents 1904e62 + 681554f commit d4f221e

File tree

6 files changed

+2579
-28
lines changed

6 files changed

+2579
-28
lines changed

.github/workflows/unit-tests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ jobs:
7171
- test_timeline.py
7272
- test_database.py
7373
- test_symbols_handler.py
74+
- test_profile_handler.py
7475

7576
steps:
7677
- uses: actions/checkout@v4

slips_files/core/database/database_manager.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -748,8 +748,8 @@ def mark_profile_as_dhcp(self, *args, **kwargs):
748748
def add_profile(self, *args, **kwargs):
749749
return self.rdb.add_profile(*args, **kwargs)
750750

751-
def set_profile_module_label(self, *args, **kwargs):
752-
return self.rdb.set_profile_module_label(*args, **kwargs)
751+
def set_module_label_for_profile(self, *args, **kwargs):
752+
return self.rdb.set_module_label_for_profile(*args, **kwargs)
753753

754754
def check_tw_to_close(self, *args, **kwargs):
755755
return self.rdb.check_tw_to_close(*args, **kwargs)
@@ -773,7 +773,11 @@ def search_tws_for_flow(self, twid, uid, go_back=False):
773773
"""
774774

775775
# TODO test this
776-
tws_to_search = self.rdb.get_tws_to_search(go_back)
776+
# how many tws so search back in?
777+
tws_to_search = float("inf")
778+
if go_back:
779+
hrs_to_search = float(go_back)
780+
tws_to_search = self.rdb.get_equivalent_tws(hrs_to_search)
777781

778782
twid_number: int = int(twid.split("timewindow")[-1])
779783
while twid_number > -1 and tws_to_search > 0:
@@ -790,8 +794,8 @@ def search_tws_for_flow(self, twid, uid, go_back=False):
790794
# uid isn't in this twid or any of the previous ones
791795
return {uid: None}
792796

793-
def get_profile_modules_labels(self, *args, **kwargs):
794-
return self.rdb.get_profile_modules_labels(*args, **kwargs)
797+
def get_modules_labels_of_a_profile(self, *args, **kwargs):
798+
return self.rdb.get_modules_labels_of_a_profile(*args, **kwargs)
795799

796800
def add_timeline_line(self, *args, **kwargs):
797801
return self.rdb.add_timeline_line(*args, **kwargs)

slips_files/core/database/redis_db/profile_handler.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_timewindow(self, flowtime, profileid):
7676
belong to that tw
7777
if it is == the end of a tw, it will belong to the next one
7878
for example,
79-
a flow with ts = 2 belongs to tw1
79+
a flow with ts = 2 belongs to tw2
8080
a flow with ts = 4 belongs to tw3
8181
8282
tw1 tw2 tw3 tw4
@@ -1178,6 +1178,7 @@ def get_modified_profiles_since(
11781178
time_of_last_modified_tw: float = modified_tws[-1][-1]
11791179

11801180
# this list will store modified profiles without tws
1181+
# this is a list of ips. not profileids
11811182
profiles = []
11821183
profiles.extend(
11831184
modified_tw[0].split("_")[1] for modified_tw in modified_tws
@@ -1472,32 +1473,35 @@ def add_profile(self, profileid, starttime):
14721473
self.print(type(inst), 0, 1)
14731474
self.print(inst, 0, 1)
14741475

1475-
def set_profile_module_label(self, profileid, module, label):
1476+
def set_module_label_for_profile(self, profileid, module, label):
14761477
"""
14771478
Set a module label for a profile.
14781479
A module label is a label set by a module, and not
14791480
a groundtruth label
14801481
"""
1481-
data = self.get_profile_modules_labels(profileid)
1482+
data = self.get_modules_labels_of_a_profile(profileid)
14821483
data[module] = label
14831484
data = json.dumps(data)
14841485
self.r.hset(profileid, "modules_labels", data)
14851486

14861487
def check_tw_to_close(self, close_all=False):
14871488
"""
1488-
Check if we should close some TW
1489-
Search in the modifed tw list and compare when they
1489+
Check if we should close a TW
1490+
Search in the modified tw list and compare when they
14901491
were modified with the slips internal time
14911492
"""
14921493

14931494
sit = self.get_slips_internal_time()
14941495

1495-
# for each modified profile
1496+
# sit is the ts of the last tw modification detected by slips
1497+
# so this line means if 1h(width) passed since the last
1498+
# modification detected, then it's time to close the tw
14961499
modification_time = float(sit) - self.width
14971500
if close_all:
14981501
# close all tws no matter when they were last modified
14991502
modification_time = float("inf")
15001503

1504+
# these are the tws that havent been modified in the last 1h
15011505
profiles_tws_to_close = self.r.zrangebyscore(
15021506
self.constants.MODIFIED_TIMEWINDOWS,
15031507
0,
@@ -1675,15 +1679,7 @@ def add_tuple(
16751679
)
16761680
self.print(traceback.format_exc(), 0, 1)
16771681

1678-
def get_tws_to_search(self, go_back):
1679-
tws_to_search = float("inf")
1680-
1681-
if go_back:
1682-
hrs_to_search = float(go_back)
1683-
tws_to_search = self.get_equivalent_tws(hrs_to_search)
1684-
return tws_to_search
1685-
1686-
def get_profile_modules_labels(self, profileid):
1682+
def get_modules_labels_of_a_profile(self, profileid):
16871683
"""
16881684
Get labels set by modules in the profile.
16891685
"""
@@ -1710,7 +1706,7 @@ def get_timeline_last_lines(
17101706
key = str(
17111707
profileid + self.separator + twid + self.separator + "timeline"
17121708
)
1713-
# The the amount of lines in this list
1709+
# The amount of lines in this list
17141710
last_index = self.r.zcard(key)
17151711
# Get the data in the list from the index asked (first_index) until the last
17161712
data = self.r.zrange(key, first_index, last_index - 1)

tests/module_factory.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from modules.flowalerts.dns import DNS
2323
from modules.flowalerts.downloaded_file import DownloadedFile
2424
from slips_files.core.helpers.symbols_handler import SymbolHandler
25+
from slips_files.core.database.redis_db.profile_handler import ProfileHandler
2526
from modules.flowalerts.notice import Notice
2627
from modules.flowalerts.smtp import SMTP
2728
from modules.flowalerts.software import Software
@@ -619,11 +620,6 @@ def create_riskiq_obj(self, mock_db):
619620
riskiq.db = mock_db
620621
return riskiq
621622

622-
def create_alert_handler_obj(self):
623-
alert_handler = AlertHandler()
624-
alert_handler.constants = Constants()
625-
return alert_handler
626-
627623
@patch(MODULE_DB_MANAGER, name="mock_db")
628624
def create_timeline_object(self, mock_db):
629625
logger = Mock()
@@ -633,3 +629,18 @@ def create_timeline_object(self, mock_db):
633629
tl = Timeline(logger, output_dir, redis_port, termination_event)
634630
tl.db = mock_db
635631
return tl
632+
633+
def create_alert_handler_obj(self):
634+
alert_handler = AlertHandler()
635+
alert_handler.constants = Constants()
636+
return alert_handler
637+
638+
def create_profile_handler_obj(self):
639+
handler = ProfileHandler()
640+
handler.constants = Constants()
641+
handler.r = Mock()
642+
handler.rcache = Mock()
643+
handler.separator = "_"
644+
handler.width = 3600
645+
handler.print = Mock()
646+
return handler

tests/test_database.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def test_profile_moddule_labels():
153153
db = ModuleFactory().create_db_manager_obj(6387, flush_db=True)
154154
module_label = "malicious"
155155
module_name = "test"
156-
db.set_profile_module_label(profileid, module_name, module_label)
157-
labels = db.get_profile_modules_labels(profileid)
156+
db.set_module_label_for_profile(profileid, module_name, module_label)
157+
labels = db.get_modules_labels_of_a_profile(profileid)
158158
assert "test" in labels
159159
assert labels["test"] == "malicious"
160160

0 commit comments

Comments
 (0)