4
4
import sys
5
5
from datetime import datetime
6
6
from pathlib import Path
7
+ import numpy as np
7
8
8
9
from datasets .utils .filelock import FileLock
9
10
10
11
from . import __version__
11
12
12
13
14
+ class NpEncoder (json .JSONEncoder ):
15
+ """Numpy aware JSON encoder."""
16
+
17
+ def default (self , o ):
18
+ if isinstance (o , np .floating ):
19
+ return float (o )
20
+ if isinstance (o , np .integer ):
21
+ return int (o )
22
+ if isinstance (o , np .ndarray ):
23
+ return o .tolist ()
24
+ return super ().default (o )
25
+
26
+
13
27
def save (path_or_file , ** data ):
14
28
"""
15
29
Saves results to a JSON file. Also saves system information such as current time, current commit
@@ -40,7 +54,7 @@ def save(path_or_file, **data):
40
54
41
55
with FileLock (str (file_path ) + ".lock" ):
42
56
with open (file_path , "w" ) as f :
43
- json .dump (data , f )
57
+ json .dump (data , f , cls = NpEncoder )
44
58
45
59
# cleanup lock file
46
60
try :
@@ -65,9 +79,13 @@ def _setup_path(path_or_file, current_time):
65
79
66
80
67
81
def _git_commit_hash ():
68
- res = subprocess .run ("git rev-parse --is-inside-work-tree" .split (), cwd = "./" , stdout = subprocess .PIPE )
82
+ res = subprocess .run (
83
+ "git rev-parse --is-inside-work-tree" .split (), cwd = "./" , stdout = subprocess .PIPE
84
+ )
69
85
if res .stdout .decode ().strip () == "true" :
70
- res = subprocess .run ("git rev-parse HEAD" .split (), cwd = os .getcwd (), stdout = subprocess .PIPE )
86
+ res = subprocess .run (
87
+ "git rev-parse HEAD" .split (), cwd = os .getcwd (), stdout = subprocess .PIPE
88
+ )
71
89
return res .stdout .decode ().strip ()
72
90
else :
73
91
return None
0 commit comments