@@ -24,11 +24,16 @@ def abcd(self):
2424 else :
2525 port = 9200
2626 host = "localhost"
27+ security_enabled = os .getenv ("security_enabled" ) == "true"
2728
2829 logging .basicConfig (level = logging .INFO )
2930
3031 url = f"opensearch://admin:admin@{ host } :{ port } "
31- opensearch_abcd = ABCD .from_url (url , index_name = "test_index" , use_ssl = False )
32+ opensearch_abcd = ABCD .from_url (
33+ url ,
34+ index_name = "test_index" ,
35+ use_ssl = security_enabled ,
36+ )
3237 assert isinstance (opensearch_abcd , OpenSearchDatabase )
3338 return opensearch_abcd
3439
@@ -78,10 +83,9 @@ def test_push(self, abcd):
7883 assert isinstance (atoms_2 , Atoms )
7984 atoms_2 .set_cell ([1 , 1 , 1 ])
8085
86+ abcd .refresh ()
8187 result = AtomsModel (
82- None ,
83- None ,
84- abcd .client .search (index = "test_index" )["hits" ]["hits" ][0 ]["_source" ],
88+ dict = abcd .client .search (index = "test_index" )["hits" ]["hits" ][0 ]["_source" ],
8589 ).to_ase ()
8690 assert atoms_1 == result
8791 assert atoms_2 != result
@@ -117,17 +121,14 @@ def test_bulk(self, abcd):
117121 atoms_list .append (atoms_1 )
118122 atoms_list .append (atoms_2 )
119123 abcd .push (atoms_list )
124+ abcd .refresh ()
120125 assert abcd .count () == 2
121126
122127 result_1 = AtomsModel (
123- None ,
124- None ,
125- abcd .client .search (index = "test_index" )["hits" ]["hits" ][0 ]["_source" ],
128+ dict = abcd .client .search (index = "test_index" )["hits" ]["hits" ][0 ]["_source" ],
126129 ).to_ase ()
127130 result_2 = AtomsModel (
128- None ,
129- None ,
130- abcd .client .search (index = "test_index" )["hits" ]["hits" ][1 ]["_source" ],
131+ dict = abcd .client .search (index = "test_index" )["hits" ]["hits" ][1 ]["_source" ],
131132 ).to_ase ()
132133 assert atoms_1 == result_1
133134 assert atoms_2 == result_2
@@ -151,4 +152,5 @@ def test_count(self, abcd):
151152 atoms .set_cell ([1 , 1 , 1 ])
152153 abcd .push (atoms )
153154 abcd .push (atoms )
155+ abcd .refresh ()
154156 assert abcd .count () == 2
0 commit comments