@@ -935,6 +935,92 @@ async def test_flow():
935
935
"Completed" ,
936
936
]
937
937
938
+ async def test_task_passes_failed_state_to_retry_fn (self ):
939
+ mock = MagicMock ()
940
+ exc = SyntaxError ("oops" )
941
+ handler_mock = MagicMock ()
942
+
943
+ async def handler (task , task_run , state ):
944
+ handler_mock ()
945
+ assert state .is_failed ()
946
+ try :
947
+ await state .result ()
948
+ except SyntaxError :
949
+ return True
950
+ return False
951
+
952
+ @task (retries = 3 , retry_condition_fn = handler )
953
+ async def flaky_function ():
954
+ mock ()
955
+ if mock .call_count == 2 :
956
+ return True
957
+ raise exc
958
+
959
+ @flow
960
+ async def test_flow ():
961
+ return await flaky_function (return_state = True )
962
+
963
+ task_run_state = await test_flow ()
964
+ task_run_id = task_run_state .state_details .task_run_id
965
+
966
+ assert task_run_state .is_completed ()
967
+ assert await task_run_state .result () is True
968
+ assert mock .call_count == 2
969
+ assert handler_mock .call_count == 1
970
+
971
+ states = await get_task_run_states (task_run_id )
972
+
973
+ state_names = [state .name for state in states ]
974
+ assert state_names == [
975
+ "Pending" ,
976
+ "Running" ,
977
+ "Retrying" ,
978
+ "Completed" ,
979
+ ]
980
+
981
+ async def test_task_passes_failed_state_to_retry_fn_sync (self ):
982
+ mock = MagicMock ()
983
+ exc = SyntaxError ("oops" )
984
+ handler_mock = MagicMock ()
985
+
986
+ def handler (task , task_run , state ):
987
+ handler_mock ()
988
+ assert state .is_failed ()
989
+ try :
990
+ state .result ()
991
+ except SyntaxError :
992
+ return True
993
+ return False
994
+
995
+ @task (retries = 3 , retry_condition_fn = handler )
996
+ def flaky_function ():
997
+ mock ()
998
+ if mock .call_count == 2 :
999
+ return True
1000
+ raise exc
1001
+
1002
+ @flow
1003
+ def test_flow ():
1004
+ return flaky_function (return_state = True )
1005
+
1006
+ task_run_state = test_flow ()
1007
+ task_run_id = task_run_state .state_details .task_run_id
1008
+
1009
+ assert task_run_state .is_completed ()
1010
+ assert await task_run_state .result () is True
1011
+ assert mock .call_count == 2
1012
+ assert handler_mock .call_count == 1
1013
+
1014
+ states = await get_task_run_states (task_run_id )
1015
+
1016
+ state_names = [state .name for state in states ]
1017
+ assert state_names == [
1018
+ "Pending" ,
1019
+ "Running" ,
1020
+ "Retrying" ,
1021
+ "Completed" ,
1022
+ ]
1023
+
938
1024
async def test_task_retries_receive_latest_task_run_in_context (self ):
939
1025
state_names : List [str ] = []
940
1026
run_counts = []
0 commit comments