Skip to content

Commit

Permalink
reduce time on textclass test
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Jun 19, 2024
1 parent 8477b04 commit dd518a9
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/test_textclassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def setUp(self):

# Wrangling data into a dataframe and selecting training examples
data = pd.DataFrame({"text": corpus, "label": group_labels})
train_df = data.groupby("label").sample(500)
test_df = data.drop(index=train_df.index)
train_df = data.groupby("label").sample(50)
test_df = data.drop(index=train_df.index).groupby("label").sample(100)

x_train = train_df["text"].values
y_train = train_df["label"].values
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_textregression(self):

# test training results
self.assertAlmostEqual(max(hist.history["lr"]), lr)
self.assertLess(min(hist.history["val_mae"]), 0.1)
self.assertLess(min(hist.history["val_mae"]), 0.5)

# test top losses
obs = learner.top_losses(n=1, val_data=None)
Expand All @@ -150,10 +150,10 @@ def test_textregression(self):

# test predictor
p = ktrain.get_predictor(learner.model, preproc)
self.assertGreater(p.predict([TEST_DOC])[0], 0.9)
self.assertGreater(p.predict([TEST_DOC])[0], 0.5)
p.save("/tmp/test_predictor")
p = ktrain.load_predictor("/tmp/test_predictor")
self.assertGreater(p.predict([TEST_DOC])[0], 0.9)
self.assertGreater(p.predict([TEST_DOC])[0], 0.5)
self.assertIsNone(p.explain(TEST_DOC))


Expand Down

0 comments on commit dd518a9

Please sign in to comment.