Add type hints for utils module

Types make it easier to understand the code and improve autocompletion
in IDEs.
This commit is contained in:
Björn Ricks 2024-01-18 09:18:00 +01:00
commit e4807316ae
No known key found for this signature in database

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import datetime import datetime
import fnmatch import fnmatch
import locale import locale
@ -16,6 +18,21 @@ from html import entities
from html.parser import HTMLParser from html.parser import HTMLParser
from itertools import groupby from itertools import groupby
from operator import attrgetter from operator import attrgetter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import dateutil.parser import dateutil.parser
@ -27,11 +44,15 @@ from markupsafe import Markup
import watchfiles import watchfiles
if TYPE_CHECKING:
from pelican.contents import Content
from pelican.readers import Readers
from pelican.settings import Settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def sanitised_join(base_directory, *parts): def sanitised_join(base_directory: str, *parts: str) -> str:
joined = posixize_path(os.path.abspath(os.path.join(base_directory, *parts))) joined = posixize_path(os.path.abspath(os.path.join(base_directory, *parts)))
base = posixize_path(os.path.abspath(base_directory)) base = posixize_path(os.path.abspath(base_directory))
if not joined.startswith(base): if not joined.startswith(base):
@ -40,7 +61,7 @@ def sanitised_join(base_directory, *parts):
return joined return joined
def strftime(date, date_format): def strftime(date: datetime.datetime, date_format: str) -> str:
""" """
Enhanced replacement for built-in strftime with zero stripping Enhanced replacement for built-in strftime with zero stripping
@ -109,10 +130,10 @@ class DateFormatter:
defined in LOCALE setting defined in LOCALE setting
""" """
def __init__(self): def __init__(self) -> None:
self.locale = locale.setlocale(locale.LC_TIME) self.locale = locale.setlocale(locale.LC_TIME)
def __call__(self, date, date_format): def __call__(self, date: datetime.datetime, date_format: str) -> str:
# on OSX, encoding from LC_CTYPE determines the unicode output in PY3 # on OSX, encoding from LC_CTYPE determines the unicode output in PY3
# make sure it's same as LC_TIME # make sure it's same as LC_TIME
with temporary_locale(self.locale, locale.LC_TIME), temporary_locale( with temporary_locale(self.locale, locale.LC_TIME), temporary_locale(
@ -131,11 +152,11 @@ class memoized:
""" """
def __init__(self, func): def __init__(self, func: Callable) -> None:
self.func = func self.func = func
self.cache = {} self.cache: Dict[Any, Any] = {}
def __call__(self, *args): def __call__(self, *args) -> Any:
if not isinstance(args, Hashable): if not isinstance(args, Hashable):
# uncacheable. a list, for instance. # uncacheable. a list, for instance.
# better to not cache than blow up. # better to not cache than blow up.
@ -147,17 +168,23 @@ class memoized:
self.cache[args] = value self.cache[args] = value
return value return value
def __repr__(self): def __repr__(self) -> Optional[str]:
return self.func.__doc__ return self.func.__doc__
def __get__(self, obj, objtype): def __get__(self, obj: Any, objtype):
"""Support instance methods.""" """Support instance methods."""
fn = partial(self.__call__, obj) fn = partial(self.__call__, obj)
fn.cache = self.cache fn.cache = self.cache
return fn return fn
def deprecated_attribute(old, new, since=None, remove=None, doc=None): def deprecated_attribute(
old: str,
new: str,
since: Tuple[int, ...],
remove: Optional[Tuple[int, ...]] = None,
doc: Optional[str] = None,
):
"""Attribute deprecation decorator for gentle upgrades """Attribute deprecation decorator for gentle upgrades
For example: For example:
@ -198,7 +225,7 @@ def deprecated_attribute(old, new, since=None, remove=None, doc=None):
return decorator return decorator
def get_date(string): def get_date(string: str) -> datetime.datetime:
"""Return a datetime object from a string. """Return a datetime object from a string.
If no format matches the given date, raise a ValueError. If no format matches the given date, raise a ValueError.
@ -212,7 +239,9 @@ def get_date(string):
@contextmanager @contextmanager
def pelican_open(filename, mode="r", strip_crs=(sys.platform == "win32")): def pelican_open(
filename: str, mode: str = "r", strip_crs: bool = (sys.platform == "win32")
) -> Generator[str, None, None]:
"""Open a file and return its content""" """Open a file and return its content"""
# utf-8-sig will clear any BOM if present # utf-8-sig will clear any BOM if present
@ -221,7 +250,12 @@ def pelican_open(filename, mode="r", strip_crs=(sys.platform == "win32")):
yield content yield content
def slugify(value, regex_subs=(), preserve_case=False, use_unicode=False): def slugify(
value: str,
regex_subs: Iterable[Tuple[str, str]] = (),
preserve_case: bool = False,
use_unicode: bool = False,
) -> str:
""" """
Normalizes string, converts to lowercase, removes non-alpha characters, Normalizes string, converts to lowercase, removes non-alpha characters,
and converts spaces to hyphens. and converts spaces to hyphens.
@ -233,9 +267,10 @@ def slugify(value, regex_subs=(), preserve_case=False, use_unicode=False):
""" """
import unicodedata import unicodedata
import unidecode import unidecode
def normalize_unicode(text): def normalize_unicode(text: str) -> str:
# normalize text by compatibility composition # normalize text by compatibility composition
# see: https://en.wikipedia.org/wiki/Unicode_equivalence # see: https://en.wikipedia.org/wiki/Unicode_equivalence
return unicodedata.normalize("NFKC", text) return unicodedata.normalize("NFKC", text)
@ -262,7 +297,9 @@ def slugify(value, regex_subs=(), preserve_case=False, use_unicode=False):
return value.strip() return value.strip()
def copy(source, destination, ignores=None): def copy(
source: str, destination: str, ignores: Optional[Iterable[str]] = None
) -> None:
"""Recursively copy source into destination. """Recursively copy source into destination.
If source is a file, destination has to be a file as well. If source is a file, destination has to be a file as well.
@ -334,7 +371,7 @@ def copy(source, destination, ignores=None):
) )
def copy_file(source, destination): def copy_file(source: str, destination: str) -> None:
"""Copy a file""" """Copy a file"""
try: try:
shutil.copyfile(source, destination) shutil.copyfile(source, destination)
@ -344,7 +381,7 @@ def copy_file(source, destination):
) )
def clean_output_dir(path, retention): def clean_output_dir(path: str, retention: Iterable[str]) -> None:
"""Remove all files from output directory except those in retention list""" """Remove all files from output directory except those in retention list"""
if not os.path.exists(path): if not os.path.exists(path):
@ -381,24 +418,24 @@ def clean_output_dir(path, retention):
logger.error("Unable to delete %s, file type unknown", file) logger.error("Unable to delete %s, file type unknown", file)
def get_relative_path(path): def get_relative_path(path: str) -> str:
"""Return the relative path from the given path to the root path.""" """Return the relative path from the given path to the root path."""
components = split_all(path) components = split_all(path)
if len(components) <= 1: if components is None or len(components) <= 1:
return os.curdir return os.curdir
else: else:
parents = [os.pardir] * (len(components) - 1) parents = [os.pardir] * (len(components) - 1)
return os.path.join(*parents) return os.path.join(*parents)
def path_to_url(path): def path_to_url(path: str) -> str:
"""Return the URL corresponding to a given path.""" """Return the URL corresponding to a given path."""
if path is not None: if path is not None:
path = posixize_path(path) path = posixize_path(path)
return path return path
def posixize_path(rel_path): def posixize_path(rel_path: str) -> str:
"""Use '/' as path separator, so that source references, """Use '/' as path separator, so that source references,
like '{static}/foo/bar.jpg' or 'extras/favicon.ico', like '{static}/foo/bar.jpg' or 'extras/favicon.ico',
will work on Windows as well as on Mac and Linux.""" will work on Windows as well as on Mac and Linux."""
@ -427,20 +464,20 @@ class _HTMLWordTruncator(HTMLParser):
_singlets = ("br", "col", "link", "base", "img", "param", "area", "hr", "input") _singlets = ("br", "col", "link", "base", "img", "param", "area", "hr", "input")
class TruncationCompleted(Exception): class TruncationCompleted(Exception):
def __init__(self, truncate_at): def __init__(self, truncate_at: int) -> None:
super().__init__(truncate_at) super().__init__(truncate_at)
self.truncate_at = truncate_at self.truncate_at = truncate_at
def __init__(self, max_words): def __init__(self, max_words: int) -> None:
super().__init__(convert_charrefs=False) super().__init__(convert_charrefs=False)
self.max_words = max_words self.max_words = max_words
self.words_found = 0 self.words_found = 0
self.open_tags = [] self.open_tags = []
self.last_word_end = None self.last_word_end = None
self.truncate_at = None self.truncate_at: Optional[int] = None
def feed(self, *args, **kwargs): def feed(self, *args, **kwargs) -> None:
try: try:
super().feed(*args, **kwargs) super().feed(*args, **kwargs)
except self.TruncationCompleted as exc: except self.TruncationCompleted as exc:
@ -448,29 +485,29 @@ class _HTMLWordTruncator(HTMLParser):
else: else:
self.truncate_at = None self.truncate_at = None
def getoffset(self): def getoffset(self) -> int:
line_start = 0 line_start = 0
lineno, line_offset = self.getpos() lineno, line_offset = self.getpos()
for i in range(lineno - 1): for i in range(lineno - 1):
line_start = self.rawdata.index("\n", line_start) + 1 line_start = self.rawdata.index("\n", line_start) + 1
return line_start + line_offset return line_start + line_offset
def add_word(self, word_end): def add_word(self, word_end: int) -> None:
self.words_found += 1 self.words_found += 1
self.last_word_end = None self.last_word_end = None
if self.words_found == self.max_words: if self.words_found == self.max_words:
raise self.TruncationCompleted(word_end) raise self.TruncationCompleted(word_end)
def add_last_word(self): def add_last_word(self) -> None:
if self.last_word_end is not None: if self.last_word_end is not None:
self.add_word(self.last_word_end) self.add_word(self.last_word_end)
def handle_starttag(self, tag, attrs): def handle_starttag(self, tag: str, attrs: Any) -> None:
self.add_last_word() self.add_last_word()
if tag not in self._singlets: if tag not in self._singlets:
self.open_tags.insert(0, tag) self.open_tags.insert(0, tag)
def handle_endtag(self, tag): def handle_endtag(self, tag: str) -> None:
self.add_last_word() self.add_last_word()
try: try:
i = self.open_tags.index(tag) i = self.open_tags.index(tag)
@ -481,7 +518,7 @@ class _HTMLWordTruncator(HTMLParser):
# all unclosed intervening start tags with omitted end tags # all unclosed intervening start tags with omitted end tags
del self.open_tags[: i + 1] del self.open_tags[: i + 1]
def handle_data(self, data): def handle_data(self, data: str) -> None:
word_end = 0 word_end = 0
offset = self.getoffset() offset = self.getoffset()
@ -499,7 +536,7 @@ class _HTMLWordTruncator(HTMLParser):
if word_end < len(data): if word_end < len(data):
self.add_last_word() self.add_last_word()
def _handle_ref(self, name, char): def _handle_ref(self, name: str, char: str) -> None:
""" """
Called by handle_entityref() or handle_charref() when a ref like Called by handle_entityref() or handle_charref() when a ref like
`&mdash;`, `&#8212;`, or `&#x2014` is found. `&mdash;`, `&#8212;`, or `&#x2014` is found.
@ -543,7 +580,7 @@ class _HTMLWordTruncator(HTMLParser):
else: else:
self.add_last_word() self.add_last_word()
def handle_entityref(self, name): def handle_entityref(self, name: str) -> None:
""" """
Called when an entity ref like '&mdash;' is found Called when an entity ref like '&mdash;' is found
@ -556,7 +593,7 @@ class _HTMLWordTruncator(HTMLParser):
char = "" char = ""
self._handle_ref(name, char) self._handle_ref(name, char)
def handle_charref(self, name): def handle_charref(self, name: str) -> None:
""" """
Called when a char ref like '&#8212;' or '&#x2014' is found Called when a char ref like '&#8212;' or '&#x2014' is found
@ -574,7 +611,7 @@ class _HTMLWordTruncator(HTMLParser):
self._handle_ref("#" + name, char) self._handle_ref("#" + name, char)
def truncate_html_words(s, num, end_text=""): def truncate_html_words(s: str, num: int, end_text: str = "") -> str:
"""Truncates HTML to a certain number of words. """Truncates HTML to a certain number of words.
(not counting tags and comments). Closes opened tags if they were correctly (not counting tags and comments). Closes opened tags if they were correctly
@ -600,7 +637,10 @@ def truncate_html_words(s, num, end_text="…"):
return out return out
def process_translations(content_list, translation_id=None): def process_translations(
content_list: List[Content],
translation_id: Optional[Union[str, Collection[str]]] = None,
) -> Tuple[List[Content], List[Content]]:
"""Finds translations and returns them. """Finds translations and returns them.
For each content_list item, populates the 'translations' attribute, and For each content_list item, populates the 'translations' attribute, and
@ -658,7 +698,7 @@ def process_translations(content_list, translation_id=None):
return index, translations return index, translations
def get_original_items(items, with_str): def get_original_items(items: List[Content], with_str: str) -> List[Content]:
def _warn_source_paths(msg, items, *extra): def _warn_source_paths(msg, items, *extra):
args = [len(items)] args = [len(items)]
args.extend(extra) args.extend(extra)
@ -698,7 +738,10 @@ def get_original_items(items, with_str):
return original_items return original_items
def order_content(content_list, order_by="slug"): def order_content(
content_list: List[Content],
order_by: Union[str, Callable[[Content], Any], None] = "slug",
) -> List[Content]:
"""Sorts content. """Sorts content.
order_by can be a string of an attribute or sorting function. If order_by order_by can be a string of an attribute or sorting function. If order_by
@ -758,7 +801,11 @@ def order_content(content_list, order_by="slug"):
return content_list return content_list
def wait_for_changes(settings_file, reader_class, settings): def wait_for_changes(
settings_file: str,
reader_class: Type["Readers"],
settings: "Settings",
):
content_path = settings.get("PATH", "") content_path = settings.get("PATH", "")
theme_path = settings.get("THEME", "") theme_path = settings.get("THEME", "")
ignore_files = { ignore_files = {
@ -788,13 +835,15 @@ def wait_for_changes(settings_file, reader_class, settings):
return next( return next(
watchfiles.watch( watchfiles.watch(
*watching_paths, *watching_paths,
watch_filter=watchfiles.DefaultFilter(ignore_entity_patterns=ignore_files), watch_filter=watchfiles.DefaultFilter(ignore_entity_patterns=ignore_files), # type: ignore
rust_timeout=0, rust_timeout=0,
) )
) )
def set_date_tzinfo(d, tz_name=None): def set_date_tzinfo(
d: datetime.datetime, tz_name: Optional[str] = None
) -> datetime.datetime:
"""Set the timezone for dates that don't have tzinfo""" """Set the timezone for dates that don't have tzinfo"""
if tz_name and not d.tzinfo: if tz_name and not d.tzinfo:
timezone = ZoneInfo(tz_name) timezone = ZoneInfo(tz_name)
@ -805,11 +854,11 @@ def set_date_tzinfo(d, tz_name=None):
return d return d
def mkdir_p(path): def mkdir_p(path: str) -> None:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
def split_all(path): def split_all(path: Union[str, pathlib.Path, None]) -> Optional[Sequence[str]]:
"""Split a path into a list of components """Split a path into a list of components
While os.path.split() splits a single component off the back of While os.path.split() splits a single component off the back of
@ -840,12 +889,12 @@ def split_all(path):
) )
def path_to_file_url(path): def path_to_file_url(path: str) -> str:
"""Convert file-system path to file:// URL""" """Convert file-system path to file:// URL"""
return urllib.parse.urljoin("file://", urllib.request.pathname2url(path)) return urllib.parse.urljoin("file://", urllib.request.pathname2url(path))
def maybe_pluralize(count, singular, plural): def maybe_pluralize(count: int, singular: str, plural: str) -> str:
""" """
Returns a formatted string containing count and plural if count is not 1 Returns a formatted string containing count and plural if count is not 1
Returns count and singular if count is 1 Returns count and singular if count is 1
@ -862,7 +911,9 @@ def maybe_pluralize(count, singular, plural):
@contextmanager @contextmanager
def temporary_locale(temp_locale=None, lc_category=locale.LC_ALL): def temporary_locale(
temp_locale: Optional[str] = None, lc_category: int = locale.LC_ALL
) -> Generator[None, None, None]:
""" """
Enable code to run in a context with a temporary locale Enable code to run in a context with a temporary locale
Resets the locale back when exiting context. Resets the locale back when exiting context.