Skip to content

Commit 871de06

Browse files
committed
add taco test
Signed-off-by: SumanthRH <[email protected]>
1 parent 5dd5978 commit 871de06

File tree

3 files changed

+36
-13
lines changed

3 files changed

+36
-13
lines changed

skythought/evals/scoring/taco/taco.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,18 @@ def score(self, row: Dict[str, Any]):
3737
else:
3838
last_code = code_filter_result[-1]
3939
if self.backend == "mp":
40-
curr_res = _taco_run_tests_mp(input_outputs, generation=last_code)
40+
curr_res, results = _taco_run_tests_mp(
41+
input_outputs, generation=last_code
42+
)
4143
else:
42-
curr_res = _taco_run_tests_ray(input_outputs, generation=last_code)
44+
curr_res, results = _taco_run_tests_ray(
45+
input_outputs, generation=last_code
46+
)
4347

4448
if curr_res:
4549
return {self.SCORE_COLUMN: True}
4650
else:
47-
return {self.SCORE_COLUMN: True}
51+
return {self.SCORE_COLUMN: False}
4852

4953

5054
def _taco_run_tests_mp(input_outputs, generation):
@@ -65,7 +69,7 @@ def _temp_run(input_outputs, generation, debug, result):
6569
p.join()
6670
if p.is_alive():
6771
p.kill()
68-
return bool(result and all(result[0]))
72+
return bool(result and all(result[0])), result
6973

7074

7175
@ray.remote
@@ -82,4 +86,4 @@ def _taco_run_tests_ray(input_outputs, generation):
8286
# run the test in a separate process for safety
8387
obj_ref = _temp_run_ray.remote(input_outputs, generation, False)
8488
result = ray.get(obj_ref)
85-
return bool(result and all(result))
89+
return bool(result and all(result)), result

skythought/evals/tasks/taco/taco_handler.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,6 @@ def generate_prompt(self, problem):
5151
def check_correctness(self, problem, generation):
5252
TIME_OUT = 300
5353

54-
def _temp_run(problem, generation, debug, result):
55-
try:
56-
result.append(
57-
taco_run_test(problem["input_output"], test=generation, debug=debug)
58-
)
59-
except Exception as e:
60-
print(f"Error in _temp_run: {e}")
61-
6254
manager = Manager()
6355
result = manager.list()
6456
p = multiprocessing.Process(
@@ -108,3 +100,12 @@ def load_and_filter_dataset(
108100
)
109101

110102
return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
103+
104+
105+
def _temp_run(problem, generation, debug, result):
106+
try:
107+
result.append(
108+
taco_run_test(problem["input_output"], test=generation, debug=debug)
109+
)
110+
except Exception as e:
111+
print(f"Error in _temp_run: {e}")
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from skythought.evals.scoring.taco import TACOScorer
2+
3+
# an example from the TACO dataset with an incorrect assistant reponse.
4+
# the response is also in a custom "formatted_response" field
5+
TACO_EXAMPLE = {
6+
"formatted_response": "<|begin_of_thought|>\nAlright, so we have a problem related to calculating the maximum possible Vasya distance between a given DNA sequence and its possible variations. The distance is defined based on the similarity between cyclic shifts of the original string. I need to determine how many strings can achieve this maximum distance.\n\nFirst, I note that the given string `s` has length `n`, and it consists of characters 'A', 'C', 'G', and 'T'. The function `h(s, t)` counts the number of positions where the characters in strings `s` and `t` match. The maximum distance, represented by `ρ(s, t)`, is calculated by considering all possible cyclic shifts of `s`.\n\nTo maximize `ρ(s, t)`, I need to analyze the characters in `s`. The approach involves calculating how many characters of `s` can be matched with characters of `t` based on their frequency. If `s` is composed of only one type of character (like 'TTT'), then there is only one string `t` that can match it entirely, leading to a maximum distance of 1.\n\nFor strings with mixed characters, I can use the frequency of each character. The total possible combinations of matching characters can be calculated using the formula involving the powers of the character frequencies. I will iterate through possible shifts and count the combinations that yield the maximum distance.\n\nTo handle large numbers, I will compute the result modulo \\(10^9 + 7\\). This ensures that we stay within the bounds of typical integer operations in programming contests.\n\nNow let's implement this logic in code to compute the answer efficiently.\n<|end_of_thought|>\n\n<|begin_of_solution|>\n```python\nimport sys\ninput = sys.stdin.readline\n\ndef main():\n n = int(input())\n s = input().strip()\n \n ans = 0\n for i in range(1, n):\n ans += pow(26, i, 1000000007) * pow(26, n - i, 1000000007)\n ans %= 1000000007\n \n # The answer is the computed value multiplied by the total number of maximum distance strings\n print(ans)\n\nif __name__ == \"__main__\":\n main()\n```\n<|end_of_solution|>", # noqa: E501
7+
"user_input": '\nQUESTION:\nVasya became interested in bioinformatics. He\'s going to write an article about similar cyclic DNA sequences, so he invented a new method for determining the similarity of cyclic sequences.\n\nLet\'s assume that strings s and t have the same length n, then the function h(s, t) is defined as the number of positions in which the respective symbols of s and t are the same. Function h(s, t) can be used to define the function of Vasya distance ρ(s, t): \n\n<image> where <image> is obtained from string s, by applying left circular shift i times. For example, ρ("AGC", "CGT") = h("AGC", "CGT") + h("AGC", "GTC") + h("AGC", "TCG") + h("GCA", "CGT") + h("GCA", "GTC") + h("GCA", "TCG") + h("CAG", "CGT") + h("CAG", "GTC") + h("CAG", "TCG") = 1 + 1 + 0 + 0 + 1 + 1 + 1 + 0 + 1 = 6\n\nVasya found a string s of length n on the Internet. Now he wants to count how many strings t there are such that the Vasya distance from the string s attains maximum possible value. Formally speaking, t must satisfy the equation: <image>.\n\nVasya could not try all possible strings to find an answer, so he needs your help. As the answer may be very large, count the number of such strings modulo 109 + 7.\n\nInput\n\nThe first line of the input contains a single integer n (1 ≤ n ≤ 105).\n\nThe second line of the input contains a single string of length n, consisting of characters "ACGT".\n\nOutput\n\nPrint a single number — the answer modulo 109 + 7.\n\nExamples\n\nInput\n\n1\nC\n\n\nOutput\n\n1\n\n\nInput\n\n2\nAG\n\n\nOutput\n\n4\n\n\nInput\n\n3\nTTT\n\n\nOutput\n\n1\n\nNote\n\nPlease note that if for two distinct strings t1 and t2 values ρ(s, t1) и ρ(s, t2) are maximum among all possible t, then both strings must be taken into account in the answer even if one of them can be obtained by a circular shift of another one.\n\nIn the first sample, there is ρ("C", "C") = 1, for the remaining strings t of length 1 the value of ρ(s, t) is 0.\n\nIn the second sample, ρ("AG", "AG") = ρ("AG", "GA") = ρ("AG", "AA") = ρ("AG", "GG") = 4.\n\nIn the third sample, ρ("TTT", "TTT") = 27\nUse Standard Input format\nANSWER:\n', # noqa: E501
8+
"input_output_field": '{"inputs": ["20\\nTAAGCGACCAGGTGCTTTAC\\n", "15\\nAGCGAATCCCATTGT\\n", "4\\nGACT\\n", "1\\nT\\n", "3\\nGCA\\n", "318\\nTATCAATCGGTACGTGCGCATCATTGTCAATCGGGCTTCATGGCTTGCGGGCGCTACCCGAGGGGAAGCTGCGGACAGGTAGGTAAGATACACACGAACCAAACGGAGTTATGTTGGATAAATTGGCTGGAAGGGCGTAGGTATATCGAGTCGCGGACCTGGCATAGACTATCAGGGGCAGCGGTACAAGGCAACCGTGAGCGGGGTCTGCCCACCATTAGACCGATGCGCCGGCTCGTATATGTGATTCTGGTGAAAAGTATCATGCCGGGACGCGTAATGACCCGGCTGGCTAATCCACCGTGGCAGCAAAATAAC\\n", "5\\nACGTC\\n", "30\\nCCTTTCGGGGCGCGTTGGCCTTTGTCCTGC\\n", "1\\nC\\n", "2\\nAG\\n", "3\\nTTT\\n", "1\\nA\\n", "2\\nCC\\n", "3\\nGGG\\n", "4\\nTTTT\\n", "5\\nAAAAA\\n", "2\\nAC\\n", "3\\nGTC\\n", "4\\nTTAC\\n", "5\\nACGTA\\n", "1\\nG\\n", "2\\nTG\\n", "3\\nATG\\n", "4\\nCATG\\n", "5\\nACATG\\n", "2\\nCG\\n", "3\\nTCG\\n", "4\\nATCG\\n", "5\\nGATCG\\n", "2\\nAA\\n", "4\\nACGT\\n", "5\\nCGCGC\\n", "6\\nTGTGTG\\n", "7\\nAAAAAAA\\n", "1\\nT\\n", "10\\nACGTACGTAC\\n", "3\\nACG\\n", "6\\nAACGTT\\n", "7\\nTGTGTGT\\n", "10\\nGGGGGGGGGG\\n", "8\\nCTGACTGA\\n", "9\\nCCCCCCCCC\\n", "4\\nGATA\\n", "3\\nAGA\\n", "5\\nGCAGC\\n", "4\\nTGTT\\n", "5\\nACTGC\\n", "2\\nGT\\n", "3\\nCGA\\n", "4\\nTGAC\\n", "6\\nAACCGG\\n", "10\\nTTAGCCTTAA\\n", "5\\nTTTGT\\n", "7\\nAGCTTAG\\n", "9\\nACGTTTGGT\\n", "15\\nAGCGTATTATCTTTA\\n", "13\\nACGTGCTGCCGCC\\n", "6\\nTTAGGT\\n", "7\\nGGGTGGA\\n", "8\\nTGCAATGC\\n", "5\\nTTTCT\\n", "4\\nCTAG\\n", "10\\nGCGCGCTCAC\\n", "20\\nACGTACGTACGTACGTACGT\\n", "5\\nTTTTT\\n", "6\\nCCCCGG\\n", "7\\nATCGTCG\\n", "8\\nGGGGATAA\\n", "9\\nCCGTTAACT\\n", "10\\nCCCCGGGGAT\\n", "20\\nACGTTTACGGGCCCCAAGTC\\n", "30\\nCCACGTAAGTACGTTTACGGGCCCCAAGTC\\n", "4\\nTGCA\\n", "5\\nACCGT\\n", "6\\nACGCAG\\n", "7\\nGGGCCCC\\n", "8\\nCGCCGCCC\\n", "9\\nAAGGCCCTT\\n", "10\\nAGATCTGGAA\\n", "10\\nAGAGTTCGAG\\n", "10\\nTTCTGACTGA\\n", "10\\nGTTACTCGTC\\n", "10\\nCCGGAACTGA\\n", "10\\nTGGCGAACCT\\n", "10\\nAACCGGTGAC\\n", "10\\nTGACTGACTT\\n", "2\\nTT\\n", "3\\nAAA\\n", "4\\nTGCT\\n", "5\\nGTAGC\\n", "10\\nTTTTTTTTTT\\n", "5\\nCCCGT\\n", "2\\nCT\\n", "2\\nCA\\n", "6\\nATGCGA\\n", "4\\nTACA\\n", "4\\nAAAC\\n", "100\\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\\n", "15\\nAGCTAGCTAGCTAGC\\n", "5\\nTGCAT\\n", "5\\nTTTAA\\n", "6\\nTACTAC\\n", "10\\nAAAAAAAAAA\\n", "11\\nAAAAATTTTTT\\n", "3\\nCCC\\n", "4\\nATGC\\n", "3\\nCGC\\n", "5\\nAGTAG\\n", "6\\nATCGTA\\n", "7\\nATCGTAC\\n", "3\\nCCG\\n", "5\\nCCGCC\\n", "4\\nTTTA\\n", "6\\nATCGTC\\n", "7\\nCCCCCCC\\n", "3\\nCGT\\n", "6\\nATGCAG\\n", "8\\nATGCGTAC\\n", "9\\nTGCGTATCG\\n", "3\\nGCG\\n", "4\\nCGCG\\n", "5\\nCAGCT\\n", "5\\nAACGC\\n", "6\\nTGAACG\\n", "5\\nCGTGC\\n", "4\\nTCTC\\n", "5\\nAGCTG\\n", "6\\nTCAGGC\\n", "3\\nTGG\\n", "6\\nAGCTCG\\n", "8\\nATGCTGCC\\n", "6\\nTACGAC\\n", "7\\nGCTGCTA\\n", "5\\nTGAGT\\n", "4\\nAAAA\\n", "6\\nAGCGTA\\n", "3\\nACC\\n", "4\\nGGGA\\n", "5\\nCTGAC\\n", "6\\nGTACCT\\n", "4\\nCCCC\\n", "5\\nGGGGG\\n", "6\\nTTTTTT\\n", "8\\nCGTACGTA\\n", "9\\nTAACCGGTT\\n", "10\\nACTGACTGAC\\n", "5\\nCGTAC\\n", "6\\nTAGCTA\\n", "4\\nGTCA\\n", "2\\nTA\\n", "4\\nAGCT\\n", "6\\nCCCCCC\\n", "7\\nGGGGGGG\\n", "8\\nTTTTTTTT\\n", "10\\nGGCTAATTCC\\n", "15\\nTCATCGTACTAGCAT\\n", "25\\nAAGCTTCGATCGTACATCGTAGCATCG\\n", "50\\nACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTAC\\n", "5\\nTTTTA\\n", "6\\nAAAACC\\n", "7\\nAAAAAAG\\n", "8\\nACGCACGC\\n", "9\\nTTTTTTTTT\\n", "11\\nAAAAAAAAAAA\\n", "15\\nTGTTACCCTAAGCGA\\n", "4\\nAGCT\\n", "5\\nATGCC\\n", "30\\nCCTTTCGGGGCGCGTTGGCCTTTGCCTTGC\\n", "20\\nTAACCGACCAGGTGCTTTAG\\n", "3\\nACG\\n", "2\\nGA\\n", "4\\nAGTC\\n", "5\\nCTGCA\\n", "30\\nCGTCCTGTTTCCGGTTGCGCGGGGCTTTCC\\n", "4\\nTCGA\\n", "30\\nCCTTTCGGGGCGCGTTTGCCTTGGCCTTGC\\n", "20\\nTAAGCGACCAGGAGCTTTTC\\n", "4\\nGATC\\n", "5\\nCCGTA\\n", "4\\nCTGA\\n", "3\\nGAC\\n", "5\\nGCATC\\n", "30\\nCGTTCCGGTTCCGTTTGCGCGGGGCTTTCC\\n", "20\\nCTTTTCGAGGACCAGCGAAT\\n", "4\\nCTAG\\n", "5\\nCCATG\\n", "4\\nTAGC\\n", "20\\nCTTTTCAAGGGCCAGCGAAT\\n", "4\\nCGAT\\n", "4\\nCAGT\\n", "4\\nTGAC\\n", "4\\nTCAG\\n", "5\\nACCTG\\n", "30\\nCGTCCTGTTTCCGGTTGCGCGGGGCTTCTC\\n", "4\\nACTG\\n", "20\\nTAACGGACCAGGTCCTTTAG\\n", "30\\nCTTCCTGTGTCCGGTTGCGCGGGGCTTTCC\\n", "30\\nCCTTTCGGGGCGCGTTTGTCTTGGCCCTGC\\n", "3\\nCAG\\n", "5\\nCTACG\\n", "20\\nCTTTTCGAGGACCAGGCAAT\\n", "20\\nTAAGCGACCGGGAACTTTTC\\n", "5\\nGTCCA\\n", "1\\nA\\n", "15\\nTGTTACCGTAACCGA\\n", "30\\nCGTTCCGTTTCCGGTTGCGCGGGGCTTTCC\\n", "20\\nGATTTCGTGGACCAGCCAAT\\n", "5\\nACTGC\\n", "3\\nAGC\\n", "5\\nGCACT\\n", "30\\nCGTTCCGGTTTCGTTTGCGCGGGGCTCTCC\\n", "4\\nGTAC\\n", "5\\nGTACC\\n", "4\\nTCGC\\n", "4\\nCGTA\\n", "4\\nGTCA\\n", "30\\nCCTCTCGGGGCGCGTTTGCTTTGGCCTTGC\\n", "20\\nTAATCGACCAGGTGCGTTAC\\n", "318\\nCAATAAAACGACGGTGCCACCTAATCGGTCGGCCCAGTAATGCGCAGGGCCGTACTATGAAAAGTGGTCTTAGTGTATATGCTCGGCCGCGTAGCCAGATTACCACCCGTCTGGGGCGAGTGCCAACGGAACATGGCGACGGGGACTATCAGATACGGTCCAGGCGCTGAGCTATATGGATGCGGGAAGGTCGGTTAAATAGGTTGTATTGAGGCAAACCAAGCACACATAGAATGGATGGACAGGCGTCGAAGGGGAGCCCATCGCGGGCGTTCGGTACTTCGGGCTAACTGTTACTACGCGTGCATGGCTAACTAT\\n", "5\\nACGCT\\n", "30\\nCGTTCCGTTTCCGGTTGCTCGGGGCTTGCC\\n", "30\\nCCTTTCGGTGCGCGTTTGCCTTGGCCTGGC\\n", "4\\nGCTC\\n", "4\\nTACG\\n", "20\\nCTTTACAAGGGCCAGCGATT\\n", "4\\nTGCA\\n", "30\\nCGGCCTGTTTCCGGTTGCGCGTGGCTTCTC\\n", "30\\nCTGCCTTTGTCCGGTTGCGCGGGGCTTTCC\\n", "30\\nCTTTTCGGGGCGCGTTTGTCTCGGCCCTGC\\n", "20\\nCTTGTCGAGGACCATGCAAT\\n", "30\\nCCTTTCGGGGCGTGTTGGCCTTCGCCTTGC\\n", "5\\nTCACG\\n", "30\\nTGTTCCGGCTTCGTTTGCGCGGGGCTCTCC\\n", "4\\nCATG\\n", "20\\nCATTGCGTGGACCAGCTAAT\\n", "30\\nCCTTTCGGTGTGCGTCTGCCTTGGCCTGGC\\n", "4\\nATCG\\n", "30\\nCGGCCTGTCTCCGGTTGCGCGTGGCTTTTC\\n", "30\\nCCTCTCGGGGCGCGTTTGCTTCGGCCTTGT\\n", "30\\nCTTTTCGGTGCGCGTTGGCCTCTGTCCGGC\\n", "20\\nTAAGCGACCTGGTGCTATAC\\n", "15\\nAGCGACTACCATTGT\\n", "30\\nGCTTTCGGCGCGCGTTGGCCTTTGTCCTGC\\n", "1\\nG\\n", "15\\nGTTTACCCTAAGCGA\\n", "20\\nCTTTTCGAGGATCAGCGAAC\\n", "5\\nCGATC\\n", "20\\nTAACGGACCAGCTCGTTTAG\\n", "5\\nAGCTC\\n", "20\\nTAACCGACCGGGTGCTTTAA\\n", "4\\nATGC\\n", "5\\nTGACC\\n", "4\\nTGCC\\n", "20\\nTAATCGCCCAGGTGCGTTAA\\n", "30\\nCGGTCCGGTTCCGTTTGCGCGTGGCTTTCC\\n", "4\\nCGTC\\n", "20\\nTTAGCGACCGGGAACATTTC\\n", "4\\nACGT\\n", "2\\nAG\\n", "3\\nTTT\\n", "1\\nC\\n"], "outputs": ["511620083\\n", "14348907\\n", "256\\n", "1\\n", "27\\n", "1\\n", "1\\n", "130653412\\n", "1\\n", "4\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "4\\n", "27\\n", "1\\n", "1\\n", "1\\n", "4\\n", "27\\n", "256\\n", "1\\n", "4\\n", "27\\n", "256\\n", "1\\n", "1\\n", "256\\n", "1\\n", "64\\n", "1\\n", "1\\n", "1024\\n", "27\\n", "64\\n", "1\\n", "1\\n", "65536\\n", "1\\n", "1\\n", "1\\n", "32\\n", "1\\n", "1\\n", "4\\n", "27\\n", "256\\n", "729\\n", "1\\n", "1\\n", "2187\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "65536\\n", "1\\n", "256\\n", "1\\n", "511620083\\n", "1\\n", "1\\n", "2187\\n", "1\\n", "512\\n", "1024\\n", "1\\n", "1\\n", "256\\n", "1\\n", "729\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "59049\\n", "1024\\n", "59049\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "4\\n", "4\\n", "64\\n", "1\\n", "1\\n", "1\\n", "14348907\\n", "1\\n", "1\\n", "729\\n", "1\\n", "1\\n", "1\\n", "256\\n", "1\\n", "32\\n", "64\\n", "2187\\n", "1\\n", "1\\n", "1\\n", "64\\n", "1\\n", "27\\n", "64\\n", "65536\\n", "512\\n", "1\\n", "16\\n", "1\\n", "32\\n", "64\\n", "32\\n", "16\\n", "1\\n", "64\\n", "1\\n", "64\\n", "1\\n", "64\\n", "2187\\n", "32\\n", "1\\n", "64\\n", "1\\n", "1\\n", "1\\n", "64\\n", "1\\n", "1\\n", "1\\n", "65536\\n", "1\\n", "1024\\n", "1\\n", "64\\n", "256\\n", "4\\n", "256\\n", "1\\n", "1\\n", "1\\n", "1024\\n", "1\\n", "33554432\\n", "898961331\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "1\\n", "14348907\\n", "256\\n", "1\\n", "130653412\\n", "511620083\\n", "27\\n", "4\\n", "256\\n", "1\\n", "130653412\\n", "256\\n", "130653412\\n", "511620083\\n", "256\\n", "1\\n", "256\\n", "27\\n", "1\\n", "130653412\\n", "511620083\\n", "256\\n", "1\\n", "256\\n", "511620083\\n", "256\\n", "256\\n", "256\\n", "256\\n", "1\\n", "130653412\\n", "256\\n", "511620083\\n", "130653412\\n", "130653412\\n", "27\\n", "1\\n", "511620083\\n", "511620083\\n", "1\\n", "1\\n", "14348907\\n", "130653412\\n", "511620083\\n", "1\\n", "27\\n", "1\\n", "130653412\\n", "256\\n", "1\\n", "1\\n", "256\\n", "256\\n", "130653412\\n", "511620083\\n", "1\\n", "1\\n", "130653412\\n", "130653412\\n", "1\\n", "256\\n", "511620083\\n", "256\\n", "130653412\\n", "130653412\\n", "130653412\\n", "511620083\\n", "130653412\\n", "1\\n", "130653412\\n", "256\\n", "511620083\\n", "130653412\\n", "256\\n", "130653412\\n", "130653412\\n", "130653412\\n", "511620083\\n", "14348907\\n", "130653412\\n", "1\\n", "14348907\\n", "511620083\\n", "1\\n", "511620083\\n", "1\\n", "511620083\\n", "256\\n", "1\\n", "1\\n", "511620083\\n", "130653412\\n", "1\\n", "511620083\\n", "256\\n", "4\\n", "1\\n", "1\\n"]}', # noqa: E501
9+
}
10+
11+
12+
def test_taco_scorer():
13+
scorer = TACOScorer(
14+
response_column="formatted_response", input_output_column="input_output_field"
15+
)
16+
out = scorer(TACO_EXAMPLE)
17+
assert set(TACO_EXAMPLE.keys()).issubset(set(out.keys()))
18+
assert not out[scorer.SCORE_COLUMN]

0 commit comments

Comments
 (0)