Skip to content

Commit

Permalink
Change get_logger to a whole Logger class
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Aug 29, 2023
1 parent 42c0fee commit 3847f91
Show file tree
Hide file tree
Showing 22 changed files with 63 additions and 59 deletions.
4 changes: 2 additions & 2 deletions hezar/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from omegaconf import DictConfig, OmegaConf

from .constants import DEFAULT_MODEL_CONFIG_FILE, HEZAR_CACHE_DIR, ConfigType, TaskType
from .utils import get_logger, get_module_config_class
from .utils import get_module_config_class, Logger


__all__ = [
Expand All @@ -37,7 +37,7 @@
"MetricConfig",
]

logger = get_logger(__name__)
logger = Logger(__name__)

CONFIG_CLASS_VARS = ["name", "config_type"]

Expand Down
4 changes: 2 additions & 2 deletions hezar/data/data_collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import torch

from ..preprocessors import Tokenizer
from ..utils import convert_batch_dict_dtype, get_logger
from ..utils import convert_batch_dict_dtype, Logger


__all__ = [
"TextPaddingDataCollator",
"SequenceLabelingDataCollator",
]

logger = get_logger(__name__)
logger = Logger(__name__)


class TextPaddingDataCollator:
Expand Down
4 changes: 2 additions & 2 deletions hezar/data/datasets/sequence_labeling_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from ...constants import TaskType
from ...preprocessors import Tokenizer
from ...registry import register_dataset
from ...utils import get_logger
from ...utils import Logger
from ..data_collators import SequenceLabelingDataCollator
from .dataset import Dataset


logger = get_logger(__name__)
logger = Logger(__name__)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions hezar/data/datasets/text_classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from ...constants import TaskType
from ...preprocessors import Tokenizer
from ...registry import register_dataset
from ...utils import get_logger
from ...utils import Logger
from ..data_collators import TextPaddingDataCollator
from .dataset import Dataset


logger = get_logger(__name__)
logger = Logger(__name__)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions hezar/data/datasets/text_summarization_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from ...constants import TaskType
from ...preprocessors import Tokenizer
from ...registry import register_dataset
from ...utils import get_logger
from ...utils import Logger
from ..data_collators import TextPaddingDataCollator
from .dataset import Dataset


logger = get_logger(__name__)
logger = Logger(__name__)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions hezar/embeddings/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
DEFAULT_EMBEDDING_FILE,
DEFAULT_EMBEDDING_SUBFOLDER,
)
from ..utils import get_logger
from ..utils import Logger


logger = get_logger(__name__)
logger = Logger(__name__)


class Embedding:
Expand Down
4 changes: 2 additions & 2 deletions hezar/metrics/seqeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from ..configs import MetricConfig
from ..constants import MetricType
from ..registry import register_metric
from ..utils import get_logger
from ..utils import Logger
from .metric import Metric


logger = get_logger(__name__)
logger = Logger(__name__)


@dataclass
Expand Down
18 changes: 10 additions & 8 deletions hezar/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from ..configs import ModelConfig
from ..constants import DEFAULT_MODEL_CONFIG_FILE, DEFAULT_MODEL_FILE, HEZAR_CACHE_DIR
from ..preprocessors import Preprocessor, PreprocessorsContainer
from ..utils import get_logger
from ..utils import Logger


__all__ = [
"Model",
"GenerativeModel",
]

logger = get_logger(__name__)
logger = Logger(__name__)


class Model(nn.Module):
Expand Down Expand Up @@ -239,12 +239,14 @@ def push_to_hub(
repo_id=repo_id,
commit_message=commit_message,
)
logger.info(
f"Uploaded: "
f"`{self.__class__.__name__}(name={self.config.name})`"
f" --> "
f"`{os.path.join(repo_id, filename)}`"
)

logger.log_upload_success(self, os.path.join(repo_id, filename))
# logger.info(
# f"Uploaded: "
# f"`{self.__class__.__name__}(name={self.config.name})`"
# f" --> "
# f"`{os.path.join(repo_id, filename)}`"
# )

def forward(self, inputs, **kwargs) -> Dict:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import numpy as np

from ....registry import register_preprocessor
from ....utils import convert_batch_dict_dtype, get_logger, mel_filter_bank, spectrogram, window_function
from ....utils import convert_batch_dict_dtype, Logger, mel_filter_bank, spectrogram, window_function
from .audio_feature_extractor import AudioFeatureExtractor, AudioFeatureExtractorConfig


logger = get_logger(__name__)
logger = Logger(__name__)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions hezar/preprocessors/text_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from ..configs import PreprocessorConfig
from ..constants import DEFAULT_NORMALIZER_CONFIG_FILE, DEFAULT_PREPROCESSOR_SUBFOLDER
from ..registry import register_preprocessor
from ..utils import get_logger
from ..utils import Logger
from .preprocessor import Preprocessor

logger = get_logger(__name__)
logger = Logger(__name__)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions hezar/preprocessors/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from ...builders import build_preprocessor
from ...configs import PreprocessorConfig
from ...constants import DEFAULT_TOKENIZER_CONFIG_FILE, DEFAULT_TOKENIZER_FILE
from ...utils import convert_batch_dict_dtype, get_logger
from ...utils import convert_batch_dict_dtype, Logger
from ..preprocessor import Preprocessor


logger = get_logger(__name__)
logger = Logger(__name__)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions hezar/preprocessors/tokenizers/whisper_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from tokenizers import processors

from ...registry import register_preprocessor
from ...utils import get_logger
from ...utils import Logger
from .bpe import BPEConfig, BPETokenizer


logger = get_logger(__name__)
logger = Logger(__name__)

LANGUAGES = {
"en": "english",
Expand Down
4 changes: 2 additions & 2 deletions hezar/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import Dict, Optional, Type

from .configs import DatasetConfig, EmbeddingConfig, MetricConfig, ModelConfig, PreprocessorConfig, TrainerConfig
from .utils import get_logger
from .utils import Logger


__all__ = [
Expand All @@ -47,7 +47,7 @@
"trainers_registry",
]

logger = get_logger(__name__)
logger = Logger(__name__)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions hezar/trainers/sequence_labeling/sequence_labeling_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from ...constants import MetricType
from ...data.datasets import Dataset
from ...models import Model
from ...utils import get_logger
from ...utils import Logger
from ..trainer import Trainer


logger = get_logger(__name__)
logger = Logger(__name__)


class SequenceLabelingTrainer(Trainer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from ...constants import MetricType
from ...data.datasets import Dataset
from ...models import Model
from ...utils import get_logger
from ...utils import Logger
from ..trainer import Trainer


logger = get_logger(__name__)
logger = Logger(__name__)


class TextClassificationTrainer(Trainer):
Expand Down
4 changes: 2 additions & 2 deletions hezar/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from ..data.datasets import Dataset
from ..models import Model
from ..preprocessors import Preprocessor, PreprocessorsContainer
from ..utils import get_logger
from ..utils import Logger
from .trainer_utils import MetricsTracker, write_to_tensorboard


logger = get_logger(__name__)
logger = Logger(__name__)

optimizers = {
"adam": torch.optim.Adam,
Expand Down
4 changes: 2 additions & 2 deletions hezar/utils/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from .logging import get_logger
from .logging import Logger


__all__ = [
Expand All @@ -18,7 +18,7 @@
"mel_to_hertz",
]

logger = get_logger(__name__)
logger = Logger(__name__)


def spectrogram(
Expand Down
4 changes: 2 additions & 2 deletions hezar/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..constants import ConfigType, RegistryType
from .common_utils import snake_case
from .logging import get_logger
from .logging import Logger


__all__ = [
Expand All @@ -20,7 +20,7 @@
"get_registry_point",
]

logger = get_logger(__name__)
logger = Logger(__name__)


def flatten_dict(dict_config: Union[Dict, DictConfig]) -> DictConfig:
Expand Down
4 changes: 2 additions & 2 deletions hezar/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Any, Dict

from .logging import get_logger
from .logging import Logger


__all__ = [
"convert_batch_dict_dtype",
"get_non_numeric_keys",
]

logger = get_logger(__name__)
logger = Logger(__name__)


# TODO: This code might be able to be written in a cleaner way, but be careful, any change might break a lot of things!
Expand Down
4 changes: 2 additions & 2 deletions hezar/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import gzip
import shutil

from .logging import get_logger
from .logging import Logger


logger = get_logger(__name__)
logger = Logger(__name__)

__all__ = [
"gunzip"
Expand Down
4 changes: 2 additions & 2 deletions hezar/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from huggingface_hub import HfApi, Repository

from ..constants import HEZAR_CACHE_DIR, HEZAR_HUB_ID, RepoType
from ..utils.logging import get_logger
from ..utils.logging import Logger


__all__ = [
Expand All @@ -15,7 +15,7 @@
"list_repo_files",
]

logger = get_logger(__name__)
logger = Logger(__name__)


def resolve_pretrained_path(hub_or_local_path):
Expand Down
24 changes: 13 additions & 11 deletions hezar/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@


__all__ = [
"get_logger",
"Logger"
]


def get_logger(name, level=None, fmt=None):
fmt = fmt or "%(levelname)s: %(message)s"
level = level or "INFO"
class Logger(logging.Logger):
def __init__(self, name: str, level=None, fmt=None):
fmt = fmt or "%(levelname)s: %(message)s"
level = level or "INFO"
super().__init__(name, level)
handler = logging.StreamHandler()
formatter = logging.Formatter(fmt)
handler.setFormatter(formatter)
self.addHandler(handler)

logger = logging.Logger(name, level)
handler = logging.StreamHandler()
formatter = logging.Formatter(fmt)
handler.setFormatter(formatter)
logger.addHandler(handler)

return logger
def log_upload_success(self, module, path_in_repo: str):
src = f"{module.__class__.__name__}(name={module.config.name})"
self.info(f"Uploaded: `{src}` --> `{path_in_repo}`")

0 comments on commit 3847f91

Please sign in to comment.