diff --git a/hezar/configs.py b/hezar/configs.py index 27e0b605..47805b9e 100644 --- a/hezar/configs.py +++ b/hezar/configs.py @@ -23,7 +23,7 @@ from omegaconf import DictConfig, OmegaConf from .constants import DEFAULT_MODEL_CONFIG_FILE, HEZAR_CACHE_DIR, ConfigType, TaskType -from .utils import get_logger, get_module_config_class +from .utils import get_module_config_class, Logger __all__ = [ @@ -37,7 +37,7 @@ "MetricConfig", ] -logger = get_logger(__name__) +logger = Logger(__name__) CONFIG_CLASS_VARS = ["name", "config_type"] diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index c53b303d..b943c43c 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -2,7 +2,7 @@ import torch from ..preprocessors import Tokenizer -from ..utils import convert_batch_dict_dtype, get_logger +from ..utils import convert_batch_dict_dtype, Logger __all__ = [ @@ -10,7 +10,7 @@ "SequenceLabelingDataCollator", ] -logger = get_logger(__name__) +logger = Logger(__name__) class TextPaddingDataCollator: diff --git a/hezar/data/datasets/sequence_labeling_dataset.py b/hezar/data/datasets/sequence_labeling_dataset.py index 08a18734..76bf6e05 100644 --- a/hezar/data/datasets/sequence_labeling_dataset.py +++ b/hezar/data/datasets/sequence_labeling_dataset.py @@ -7,12 +7,12 @@ from ...constants import TaskType from ...preprocessors import Tokenizer from ...registry import register_dataset -from ...utils import get_logger +from ...utils import Logger from ..data_collators import SequenceLabelingDataCollator from .dataset import Dataset -logger = get_logger(__name__) +logger = Logger(__name__) @dataclass diff --git a/hezar/data/datasets/text_classification_dataset.py b/hezar/data/datasets/text_classification_dataset.py index e7759745..26beba6d 100644 --- a/hezar/data/datasets/text_classification_dataset.py +++ b/hezar/data/datasets/text_classification_dataset.py @@ -8,12 +8,12 @@ from ...constants import TaskType from ...preprocessors import Tokenizer from ...registry import register_dataset -from ...utils import get_logger +from ...utils import Logger from ..data_collators import TextPaddingDataCollator from .dataset import Dataset -logger = get_logger(__name__) +logger = Logger(__name__) @dataclass diff --git a/hezar/data/datasets/text_summarization_dataset.py b/hezar/data/datasets/text_summarization_dataset.py index a9d9deaf..02110810 100644 --- a/hezar/data/datasets/text_summarization_dataset.py +++ b/hezar/data/datasets/text_summarization_dataset.py @@ -6,12 +6,12 @@ from ...constants import TaskType from ...preprocessors import Tokenizer from ...registry import register_dataset -from ...utils import get_logger +from ...utils import Logger from ..data_collators import TextPaddingDataCollator from .dataset import Dataset -logger = get_logger(__name__) +logger = Logger(__name__) @dataclass diff --git a/hezar/embeddings/embedding.py b/hezar/embeddings/embedding.py index d55858e7..b4e6dffd 100644 --- a/hezar/embeddings/embedding.py +++ b/hezar/embeddings/embedding.py @@ -11,10 +11,10 @@ DEFAULT_EMBEDDING_FILE, DEFAULT_EMBEDDING_SUBFOLDER, ) -from ..utils import get_logger +from ..utils import Logger -logger = get_logger(__name__) +logger = Logger(__name__) class Embedding: diff --git a/hezar/metrics/seqeval.py b/hezar/metrics/seqeval.py index 2a5c0521..70c84234 100644 --- a/hezar/metrics/seqeval.py +++ b/hezar/metrics/seqeval.py @@ -6,11 +6,11 @@ from ..configs import MetricConfig from ..constants import MetricType from ..registry import register_metric -from ..utils import get_logger +from ..utils import Logger from .metric import Metric -logger = get_logger(__name__) +logger = Logger(__name__) @dataclass diff --git a/hezar/models/model.py b/hezar/models/model.py index 3b9541b1..3478bad5 100644 --- a/hezar/models/model.py +++ b/hezar/models/model.py @@ -20,7 +20,7 @@ from ..configs import ModelConfig from ..constants import DEFAULT_MODEL_CONFIG_FILE, DEFAULT_MODEL_FILE, HEZAR_CACHE_DIR from ..preprocessors import Preprocessor, PreprocessorsContainer -from ..utils import get_logger +from ..utils import Logger __all__ = [ @@ -28,7 +28,7 @@ "GenerativeModel", ] -logger = get_logger(__name__) +logger = Logger(__name__) class Model(nn.Module): @@ -239,12 +239,14 @@ def push_to_hub( repo_id=repo_id, commit_message=commit_message, ) - logger.info( - f"Uploaded: " - f"`{self.__class__.__name__}(name={self.config.name})`" - f" --> " - f"`{os.path.join(repo_id, filename)}`" - ) + + logger.log_upload_success(self, os.path.join(repo_id, filename)) + # logger.info( + # f"Uploaded: " + # f"`{self.__class__.__name__}(name={self.config.name})`" + # f" --> " + # f"`{os.path.join(repo_id, filename)}`" + # ) def forward(self, inputs, **kwargs) -> Dict: """ diff --git a/hezar/preprocessors/feature_extractors/audio/whisper_feature_extractor.py b/hezar/preprocessors/feature_extractors/audio/whisper_feature_extractor.py index c694a774..0e4cad93 100644 --- a/hezar/preprocessors/feature_extractors/audio/whisper_feature_extractor.py +++ b/hezar/preprocessors/feature_extractors/audio/whisper_feature_extractor.py @@ -4,11 +4,11 @@ import numpy as np from ....registry import register_preprocessor -from ....utils import convert_batch_dict_dtype, get_logger, mel_filter_bank, spectrogram, window_function +from ....utils import convert_batch_dict_dtype, Logger, mel_filter_bank, spectrogram, window_function from .audio_feature_extractor import AudioFeatureExtractor, AudioFeatureExtractorConfig -logger = get_logger(__name__) +logger = Logger(__name__) @dataclass diff --git a/hezar/preprocessors/text_normalizer.py b/hezar/preprocessors/text_normalizer.py index fb093608..6439e7d9 100644 --- a/hezar/preprocessors/text_normalizer.py +++ b/hezar/preprocessors/text_normalizer.py @@ -8,10 +8,10 @@ from ..configs import PreprocessorConfig from ..constants import DEFAULT_NORMALIZER_CONFIG_FILE, DEFAULT_PREPROCESSOR_SUBFOLDER from ..registry import register_preprocessor -from ..utils import get_logger +from ..utils import Logger from .preprocessor import Preprocessor -logger = get_logger(__name__) +logger = Logger(__name__) @dataclass diff --git a/hezar/preprocessors/tokenizers/tokenizer.py b/hezar/preprocessors/tokenizers/tokenizer.py index 8aafc00d..fecf31d6 100644 --- a/hezar/preprocessors/tokenizers/tokenizer.py +++ b/hezar/preprocessors/tokenizers/tokenizer.py @@ -14,11 +14,11 @@ from ...builders import build_preprocessor from ...configs import PreprocessorConfig from ...constants import DEFAULT_TOKENIZER_CONFIG_FILE, DEFAULT_TOKENIZER_FILE -from ...utils import convert_batch_dict_dtype, get_logger +from ...utils import convert_batch_dict_dtype, Logger from ..preprocessor import Preprocessor -logger = get_logger(__name__) +logger = Logger(__name__) @dataclass diff --git a/hezar/preprocessors/tokenizers/whisper_bpe.py b/hezar/preprocessors/tokenizers/whisper_bpe.py index 2e022884..7c1635ba 100644 --- a/hezar/preprocessors/tokenizers/whisper_bpe.py +++ b/hezar/preprocessors/tokenizers/whisper_bpe.py @@ -5,11 +5,11 @@ from tokenizers import processors from ...registry import register_preprocessor -from ...utils import get_logger +from ...utils import Logger from .bpe import BPEConfig, BPETokenizer -logger = get_logger(__name__) +logger = Logger(__name__) LANGUAGES = { "en": "english", diff --git a/hezar/registry.py b/hezar/registry.py index 6d2b61f9..c32efee3 100644 --- a/hezar/registry.py +++ b/hezar/registry.py @@ -28,7 +28,7 @@ from typing import Dict, Optional, Type from .configs import DatasetConfig, EmbeddingConfig, MetricConfig, ModelConfig, PreprocessorConfig, TrainerConfig -from .utils import get_logger +from .utils import Logger __all__ = [ @@ -47,7 +47,7 @@ "trainers_registry", ] -logger = get_logger(__name__) +logger = Logger(__name__) @dataclass diff --git a/hezar/trainers/sequence_labeling/sequence_labeling_trainer.py b/hezar/trainers/sequence_labeling/sequence_labeling_trainer.py index 674b7520..3915ed5d 100644 --- a/hezar/trainers/sequence_labeling/sequence_labeling_trainer.py +++ b/hezar/trainers/sequence_labeling/sequence_labeling_trainer.py @@ -7,11 +7,11 @@ from ...constants import MetricType from ...data.datasets import Dataset from ...models import Model -from ...utils import get_logger +from ...utils import Logger from ..trainer import Trainer -logger = get_logger(__name__) +logger = Logger(__name__) class SequenceLabelingTrainer(Trainer): diff --git a/hezar/trainers/text_classification/text_classification_trainer.py b/hezar/trainers/text_classification/text_classification_trainer.py index ec2b73fc..792bd9dc 100644 --- a/hezar/trainers/text_classification/text_classification_trainer.py +++ b/hezar/trainers/text_classification/text_classification_trainer.py @@ -7,11 +7,11 @@ from ...constants import MetricType from ...data.datasets import Dataset from ...models import Model -from ...utils import get_logger +from ...utils import Logger from ..trainer import Trainer -logger = get_logger(__name__) +logger = Logger(__name__) class TextClassificationTrainer(Trainer): diff --git a/hezar/trainers/trainer.py b/hezar/trainers/trainer.py index 79d65187..0a791784 100644 --- a/hezar/trainers/trainer.py +++ b/hezar/trainers/trainer.py @@ -21,11 +21,11 @@ from ..data.datasets import Dataset from ..models import Model from ..preprocessors import Preprocessor, PreprocessorsContainer -from ..utils import get_logger +from ..utils import Logger from .trainer_utils import MetricsTracker, write_to_tensorboard -logger = get_logger(__name__) +logger = Logger(__name__) optimizers = { "adam": torch.optim.Adam, diff --git a/hezar/utils/audio_utils.py b/hezar/utils/audio_utils.py index 0922a797..8b3df59c 100644 --- a/hezar/utils/audio_utils.py +++ b/hezar/utils/audio_utils.py @@ -5,7 +5,7 @@ import numpy as np -from .logging import get_logger +from .logging import Logger __all__ = [ @@ -18,7 +18,7 @@ "mel_to_hertz", ] -logger = get_logger(__name__) +logger = Logger(__name__) def spectrogram( diff --git a/hezar/utils/core_utils.py b/hezar/utils/core_utils.py index c10b6eda..67b595a8 100644 --- a/hezar/utils/core_utils.py +++ b/hezar/utils/core_utils.py @@ -7,7 +7,7 @@ from ..constants import ConfigType, RegistryType from .common_utils import snake_case -from .logging import get_logger +from .logging import Logger __all__ = [ @@ -20,7 +20,7 @@ "get_registry_point", ] -logger = get_logger(__name__) +logger = Logger(__name__) def flatten_dict(dict_config: Union[Dict, DictConfig]) -> DictConfig: diff --git a/hezar/utils/data_utils.py b/hezar/utils/data_utils.py index aab64831..63f00437 100644 --- a/hezar/utils/data_utils.py +++ b/hezar/utils/data_utils.py @@ -1,6 +1,6 @@ from typing import Any, Dict -from .logging import get_logger +from .logging import Logger __all__ = [ @@ -8,7 +8,7 @@ "get_non_numeric_keys", ] -logger = get_logger(__name__) +logger = Logger(__name__) # TODO: This code might be able to be written in a cleaner way, but be careful, any change might break a lot of things! diff --git a/hezar/utils/file_utils.py b/hezar/utils/file_utils.py index 50582fbc..a52547dd 100644 --- a/hezar/utils/file_utils.py +++ b/hezar/utils/file_utils.py @@ -1,10 +1,10 @@ import gzip import shutil -from .logging import get_logger +from .logging import Logger -logger = get_logger(__name__) +logger = Logger(__name__) __all__ = [ "gunzip" diff --git a/hezar/utils/hub_utils.py b/hezar/utils/hub_utils.py index bcc5707e..e67dae79 100644 --- a/hezar/utils/hub_utils.py +++ b/hezar/utils/hub_utils.py @@ -3,7 +3,7 @@ from huggingface_hub import HfApi, Repository from ..constants import HEZAR_CACHE_DIR, HEZAR_HUB_ID, RepoType -from ..utils.logging import get_logger +from ..utils.logging import Logger __all__ = [ @@ -15,7 +15,7 @@ "list_repo_files", ] -logger = get_logger(__name__) +logger = Logger(__name__) def resolve_pretrained_path(hub_or_local_path): diff --git a/hezar/utils/logging.py b/hezar/utils/logging.py index c95e3945..0dd32f95 100644 --- a/hezar/utils/logging.py +++ b/hezar/utils/logging.py @@ -2,18 +2,20 @@ __all__ = [ - "get_logger", + "Logger" ] -def get_logger(name, level=None, fmt=None): - fmt = fmt or "%(levelname)s: %(message)s" - level = level or "INFO" +class Logger(logging.Logger): + def __init__(self, name: str, level=None, fmt=None): + fmt = fmt or "%(levelname)s: %(message)s" + level = level or "INFO" + super().__init__(name, level) + handler = logging.StreamHandler() + formatter = logging.Formatter(fmt) + handler.setFormatter(formatter) + self.addHandler(handler) - logger = logging.Logger(name, level) - handler = logging.StreamHandler() - formatter = logging.Formatter(fmt) - handler.setFormatter(formatter) - logger.addHandler(handler) - - return logger + def log_upload_success(self, module, path_in_repo: str): + src = f"{module.__class__.__name__}(name={module.config.name})" + self.info(f"Uploaded: `{src}` --> `{path_in_repo}`")