1
1
import pytest
2
+ from unittest .mock import patch , MagicMock
2
3
from app .modules .classyfire import classify , result
3
- import asyncio
4
4
5
5
6
6
@pytest .fixture
@@ -14,33 +14,69 @@ def invalid_smiles():
14
14
15
15
16
16
@pytest .mark .asyncio
17
- async def test_valid_classyfire (valid_smiles ):
17
+ @patch ("app.modules.classyfire.requests.post" )
18
+ @patch ("app.modules.classyfire.requests.get" )
19
+ async def test_valid_classyfire (mock_get , mock_post , valid_smiles ):
20
+ # Mock the initial classification request
21
+ mock_post_response = MagicMock ()
22
+ mock_post_response .json .return_value = {
23
+ "id" : "12345" ,
24
+ "query_type" : "STRUCTURE" ,
25
+ "query_input" : valid_smiles ,
26
+ }
27
+ mock_post_response .raise_for_status .return_value = None
28
+ mock_post .return_value = mock_post_response
29
+
30
+ # Mock the result retrieval request
31
+ mock_get_response = MagicMock ()
32
+ mock_get_response .json .return_value = {
33
+ "id" : "12345" ,
34
+ "classification_status" : "Done" ,
35
+ "entities" : [{"class" : {"name" : "Imidazopyrimidines" }}],
36
+ }
37
+ mock_get_response .raise_for_status .return_value = None
38
+ mock_get .return_value = mock_get_response
39
+
18
40
result_ = await classify (valid_smiles )
19
41
assert result_ ["query_type" ] == "STRUCTURE"
20
42
id_ = result_ ["id" ]
21
43
22
- while True :
23
- classified = await result (id_ )
24
- if classified ["classification_status" ] == "Done" :
25
- break
26
- await asyncio .sleep (2 )
27
-
44
+ classified = await result (id_ )
28
45
assert classified ["classification_status" ] == "Done"
29
46
assert classified ["entities" ][0 ]["class" ]["name" ] == "Imidazopyrimidines"
30
47
31
48
32
49
@pytest .mark .asyncio
33
- async def test_invalid_classyfire (invalid_smiles ):
50
+ @patch ("app.modules.classyfire.requests.post" )
51
+ @patch ("app.modules.classyfire.requests.get" )
52
+ async def test_invalid_classyfire (mock_get , mock_post , invalid_smiles ):
53
+ # Mock the initial classification request
54
+ mock_post_response = MagicMock ()
55
+ mock_post_response .json .return_value = {
56
+ "id" : "12346" ,
57
+ "query_type" : "STRUCTURE" ,
58
+ "query_input" : invalid_smiles ,
59
+ }
60
+ mock_post_response .raise_for_status .return_value = None
61
+ mock_post .return_value = mock_post_response
62
+
63
+ # Mock the result retrieval request
64
+ mock_get_response = MagicMock ()
65
+ mock_get_response .json .return_value = {
66
+ "id" : "12346" ,
67
+ "classification_status" : "Done" ,
68
+ "invalid_entities" : [
69
+ {"report" : ["Cannot process the input SMILES string, please check again" ]}
70
+ ],
71
+ }
72
+ mock_get_response .raise_for_status .return_value = None
73
+ mock_get .return_value = mock_get_response
74
+
34
75
result_ = await classify (invalid_smiles )
35
76
assert result_ ["query_input" ] == "invalid_smiles"
36
77
id_ = result_ ["id" ]
37
78
38
- while True :
39
- classified = await result (id_ )
40
- if classified ["classification_status" ] == "Done" :
41
- break
42
- await asyncio .sleep (2 )
43
-
79
+ classified = await result (id_ )
44
80
assert classified ["classification_status" ] == "Done"
45
81
assert (
46
82
classified ["invalid_entities" ][0 ]["report" ][0 ]
0 commit comments