Skip to content

Commit 17fd1aa

Browse files
authored
Optimization of regroup_data_series (#21)
* regroup_data_series: avoid double unique, add percentage * add test with not enough duplicates * remove wrappers when calling a wrapped function * clean errors in merge
1 parent 2443c79 commit 17fd1aa

File tree

3 files changed

+34
-22
lines changed

3 files changed

+34
-22
lines changed

tests/test_1_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,8 @@ def test_function(docs):
511511
docs_results = pd.Series(['avant'] + ["test"] * 5000 + ['milieu'] + ["test"] * 5000 + ['après'], name='test')
512512
data_no_duplicates = pd.Series(['avant'] + ["ceci est un test"] + ['milieu'] + ['après'], name='test')
513513
data_no_duplicates_results = pd.Series(['avant'] + ["test"] + ['milieu'] + ['après'], name='test')
514+
data_not_enough_duplicates = pd.Series(['avant'] + [f"ceci est un test {i}" for i in range(10)] + ['milieu']*2 + ['après'], name='test')
515+
data_not_enough_duplicates_results = pd.Series(['avant'] + ["test" for i in range(10)] + ['milieu'] * 2 + ['après'], name='test')
514516

515517

516518
# Vérification du fonctionnement type
@@ -521,6 +523,11 @@ def test_function(docs):
521523
pd.testing.assert_series_equal(docs_test, docs_test_copy)
522524
# Vérification fonctionnement quand pas de doublons
523525
pd.testing.assert_series_equal(utils.regroup_data_series(test_function, min_nb_data=1)(data_no_duplicates), data_no_duplicates_results)
526+
# Vérification du foctionnement pas assez de doublons
527+
pd.testing.assert_series_equal(utils.regroup_data_series(test_function, min_nb_data=1, max_percent_unique=0.5)(data_not_enough_duplicates), data_not_enough_duplicates_results)
528+
529+
530+
524531

525532

526533
def test_regroup_data_df(self):

words_n_fun/preprocessing/basic.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,19 @@
3838
# - fix_text -> Fixes numerous inconsistencies within a text (via ftfy)
3939

4040

41-
import ftfy
4241
import logging
4342
import unicodedata
43+
from typing import List, Union
44+
45+
import ftfy
46+
import numpy as np
4447
import pandas as pd
45-
from typing import Union, List
4648
from nltk.stem.snowball import FrenchStemmer
4749

48-
from words_n_fun import utils
4950
from words_n_fun import CustomTqdm as tqdm
50-
from words_n_fun.preprocessing import stopwords
51-
from words_n_fun.preprocessing import lemmatizer
52-
from words_n_fun.preprocessing import synonym_malefemale_replacement
51+
from words_n_fun import utils
52+
from words_n_fun.preprocessing import (lemmatizer, stopwords,
53+
synonym_malefemale_replacement)
5354

5455
tqdm.pandas()
5556

@@ -313,7 +314,6 @@ def remove_numeric(docs: pd.Series, replacement_char: str = ' ') -> pd.Series:
313314
logger.debug('Calling basic.remove_numeric')
314315
return impl_remove_numeric(docs, replacement_char)
315316

316-
@utils.regroup_data_series
317317
def impl_remove_stopwords(docs: pd.Series, opt: str = 'all', set_to_add: Union[list, None] = None,
318318
set_to_remove: Union[list, None] = None) -> pd.Series:
319319
'''Removes stopwords
@@ -327,12 +327,13 @@ def impl_remove_stopwords(docs: pd.Series, opt: str = 'all', set_to_add: Union[l
327327
Returns:
328328
pd.Series: Modified documents
329329
'''
330+
# stopwords.remove_stopwords use data_agnostic and regroup_data_series wrappers already
330331
return stopwords.remove_stopwords(docs, opt=opt, set_to_add=set_to_add, set_to_remove=set_to_remove)
331332

332333

333-
@utils.data_agnostic
334-
def remove_stopwords(docs: pd.Series, opt: str = 'all', set_to_add: Union[list, None] = None,
335-
set_to_remove: Union[list, None] = None) -> pd.Series:
334+
# called function already with wrappers
335+
def remove_stopwords(docs: Union[str, list, np.ndarray, pd.Series, pd.DataFrame], opt: str = 'all', set_to_add: Union[list, None] = None,
336+
set_to_remove: Union[list, None] = None) -> Union[str, list, np.ndarray, pd.Series, pd.DataFrame]:
336337
'''Removes stopwords
337338
338339
Args:
@@ -379,7 +380,6 @@ def remove_accents(docs: pd.Series, use_tqdm: bool = False) -> pd.Series:
379380
'''
380381
return impl_remove_accents(docs, use_tqdm)
381382

382-
@utils.regroup_data_series
383383
def impl_remove_gender_synonyms(docs: pd.Series) -> pd.Series:
384384
'''[French] Removes gendered synonyms
385385
# Find occurences such as "male version / female version" (eg: Coiffeur / Coiffeuse)
@@ -391,11 +391,12 @@ def impl_remove_gender_synonyms(docs: pd.Series) -> pd.Series:
391391
Returns:
392392
pd.Series: Modified documents
393393
'''
394+
# synonym_malefemale_replacement.remove_gender_synonyms uses data_agnostic and regroup_data_series wrappers already
394395
return synonym_malefemale_replacement.remove_gender_synonyms(docs)
395396

396397

397-
@utils.data_agnostic
398-
def remove_gender_synonyms(docs: pd.Series) -> pd.Series:
398+
# wrappers in the main function
399+
def remove_gender_synonyms(docs: Union[str, list, np.ndarray, pd.Series, pd.DataFrame]) -> Union[str, list, np.ndarray, pd.Series, pd.DataFrame]:
399400
'''[French] Removes gendered synonyms
400401
# Find occurences such as "male version / female version" (eg: Coiffeur / Coiffeuse)
401402
# By convention, the male version is kept (in accordance with the lemmatizer)
@@ -409,7 +410,7 @@ def remove_gender_synonyms(docs: pd.Series) -> pd.Series:
409410
logger.debug('Calling basic.remove_gender_synonyms')
410411
return impl_remove_gender_synonyms(docs)
411412

412-
@utils.regroup_data_series
413+
# lemmatizer.lemmatize has already wrappers
413414
def impl_lemmatize(docs: pd.Series) -> pd.Series:
414415
'''Lemmatizes the documents
415416
Appel à une API externe
@@ -424,9 +425,8 @@ def impl_lemmatize(docs: pd.Series) -> pd.Series:
424425
# Process
425426
return lemmatizer.lemmatize(docs)
426427

427-
428-
@utils.data_agnostic
429-
def lemmatize(docs: pd.Series) -> pd.Series:
428+
# lemmatizer.lemmatize has already wrappers
429+
def lemmatize(docs: Union[str, list, np.ndarray, pd.Series, pd.DataFrame]) -> Union[str, list, np.ndarray, pd.Series, pd.DataFrame]:
430430
'''Lemmatizes the documents
431431
Appel à une API externe
432432

words_n_fun/utils.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def get_column_to_be_processed(docs: Union[str, list, np.ndarray, pd.Series, pd.
656656
return prefered_column
657657

658658

659-
def regroup_data_series(function: Callable, min_nb_data: int = 1000, prefix_text: Union[str, None] = None) -> Callable:
659+
def regroup_data_series(function: Callable, min_nb_data:int = 1000, prefix_text: Union[str, None] = None, max_percent_unique: float = 0.9) -> Callable:
660660
'''Wrapper to regroup identical data of a pd.Series before being processed
661661
Can be used as a decorator
662662
@@ -665,6 +665,8 @@ def regroup_data_series(function: Callable, min_nb_data: int = 1000, prefix_text
665665
Kwargs:
666666
min_nb_data (int): Minimum number of rows within the document required to apply this wrapper (default : 1000)
667667
prefix_text (str): Prefix to add
668+
max_percent_unique (float): value [0-1] percentage of unique values to perform reduction
669+
for very quick functions min_percent_unique should be low to have a real speed up
668670
Returns:
669671
function: Decorated function
670672
'''
@@ -678,7 +680,7 @@ def regroup_data_series(function: Callable, min_nb_data: int = 1000, prefix_text
678680

679681
# Set wrapper
680682
@wraps(function)
681-
def wrapper(docs: Union[str, list, np.ndarray, pd.Series, pd.DataFrame], *args, **kwargs) -> pd.Series:
683+
def wrapper(docs: pd.Series, *args, **kwargs) -> pd.Series:
682684
'''Wrapper
683685
684686
Args:
@@ -692,16 +694,19 @@ def wrapper(docs: Union[str, list, np.ndarray, pd.Series, pd.DataFrame], *args,
692694
# If there is not enough data, the wrapper is discarded and the function returned as is
693695
if init_len < min_nb_data:
694696
return function(docs, *args, **kwargs)
695-
# If there is no duplicates in the data, the wrapper is discarded as well
696-
elif len(docs.unique()) == init_len:
697+
698+
# If there is not enough duplicates in the data, the wrapper is discarded as well
699+
unique_docs = docs.unique()
700+
if ( len(unique_docs) / init_len ) > max_percent_unique:
697701
return function(docs, *args, **kwargs)
702+
698703
init_name = docs.name
699704
init_index = docs.index
700705
# Put docs into a dataframe
701706
df = pd.DataFrame(docs)
702707
df.columns = ["input_data"]
703708
# Regroup same values together
704-
input_data = df["input_data"].dropna().drop_duplicates()
709+
input_data = pd.Series(unique_docs).dropna()
705710
logger.debug(f"{prefix_text} Reduced data to be processed by {100 * (df.shape[0] - len(input_data)) / df.shape[0]} % (grouped duplicated rows)")
706711
# Get output
707712
output_data = function(input_data, *args, **kwargs)

0 commit comments

Comments
 (0)