Source code for substrata.logging

"""Logging configuration and helpers for the substrata package."""

import logging
from typing import Optional
from contextlib import contextmanager
import joblib

logger: logging.Logger = logging.getLogger(__name__)


[docs] def setup_logging( level: int = logging.INFO, format_string: Optional[str] = None ) -> None: """Set up logging for the substrata package only. Attaches a StreamHandler directly to the substrata logger and blocks propagation to the root logger, so third-party library output (kaleido, choreographer, etc.) is unaffected. Args: level: Logging level to use. Defaults to logging.INFO. format_string: Custom format string for log messages. """ if format_string is None: format_string = "%(asctime)s - %(levelname)s - %(message)s" if not logger.handlers: handler = logging.StreamHandler() handler.setFormatter(logging.Formatter(format_string, datefmt="%Y-%m-%d %H:%M:%S")) logger.addHandler(handler) logger.setLevel(level) logger.propagate = False
[docs] @contextmanager def tqdm_joblib(tqdm_object): """Context manager to integrate joblib.Parallel with tqdm progress bars. Usage: with tqdm_joblib(tqdm(total=N)): Parallel(n_jobs=-1)(delayed(func)(x) for x in items) Args: tqdm_object: An instance of tqdm configured with the desired total. """ class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): def __call__(self, *args, **kwargs): tqdm_object.update(n=self.batch_size) return super().__call__(*args, **kwargs) old_callback = joblib.parallel.BatchCompletionCallBack joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback try: yield finally: joblib.parallel.BatchCompletionCallBack = old_callback tqdm_object.close()