Skip to content

Commit 403919b

Browse files
committed
minor changes
1 parent 1ae27f1 commit 403919b

File tree

4 files changed

+3
-4
lines changed

4 files changed

+3
-4
lines changed

evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def evaluate_alignment(vectors, args):
177177

178178
def load_perplexity(args):
179179
if not os.path.exists(os.path.join(args.model_name_or_path, "eval_results.txt")):
180-
print("Warning: Perplexity not found.")
180+
logger.warning("Perplexity not found.")
181181
return -1
182182
with open(os.path.join(args.model_name_or_path, "eval_results.txt"), "r") as fp:
183183
text = fp.read().strip()

modifyinput.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from sklearn.metrics.pairwise import cosine_distances
44
import numpy as np
55
VECMAP = None
6-
# TMPCOUNT = -1
76

87

98
def get_is_shifted(inputids: Union[torch.Tensor, List], shift: int) -> Union[torch.Tensor, bool]:
@@ -85,4 +84,3 @@ def replace_with_nn(inputids: torch.Tensor, model: Any, indices_random: torch.Te
8584
nns = torch.LongTensor(np.argsort(dist, axis=1)[:, :replace_with_nn])
8685
choice = torch.randint(low=0, high=nns.shape[1], size=(nns.shape[0], 1))
8786
inputids[indices_random] = torch.gather(nns, 1, choice).squeeze()
88-

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ sentencepiece==0.1.85
1010
six==1.14.0
1111
sklearn==0.0
1212
tensorboard==2.1.1
13-
tensorflow==2.1.2
13+
tensorflow==2.1.0
1414
tensorflow-estimator==2.1.0
1515
tokenizers==0.5.2
1616
torch==1.4.0

shift.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def add_shifted_input(original: List[List[int]], do_not_shift: Set[int], shift:
1919
to_add.append(shift_example(example, do_not_shift, shift))
2020
original.extend(to_add)
2121

22+
2223
def remove_parallel_data(original: List[List[int]]) -> None:
2324
if len(original) % 4 != 0:
2425
raise ValueError("Data not parallel at all?")

0 commit comments

Comments
 (0)