1
0
Fork 0
forked from github/pelican

Add type hints for utils module

Types make it easier to understand the code and improve autocompletion
in IDEs.
This commit is contained in:
Björn Ricks 2024-01-18 09:18:00 +01:00
commit e4807316ae
No known key found for this signature in database

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import datetime
import fnmatch
import locale
@ -16,6 +18,21 @@ from html import entities
from html.parser import HTMLParser
from itertools import groupby
from operator import attrgetter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import dateutil.parser
@ -27,11 +44,15 @@ from markupsafe import Markup
import watchfiles
if TYPE_CHECKING:
from pelican.contents import Content
from pelican.readers import Readers
from pelican.settings import Settings
logger = logging.getLogger(__name__)
def sanitised_join(base_directory, *parts):
def sanitised_join(base_directory: str, *parts: str) -> str:
joined = posixize_path(os.path.abspath(os.path.join(base_directory, *parts)))
base = posixize_path(os.path.abspath(base_directory))
if not joined.startswith(base):
@ -40,7 +61,7 @@ def sanitised_join(base_directory, *parts):
return joined
def strftime(date, date_format):
def strftime(date: datetime.datetime, date_format: str) -> str:
"""
Enhanced replacement for built-in strftime with zero stripping
@ -109,10 +130,10 @@ class DateFormatter:
defined in LOCALE setting
"""
def __init__(self):
def __init__(self) -> None:
self.locale = locale.setlocale(locale.LC_TIME)
def __call__(self, date, date_format):
def __call__(self, date: datetime.datetime, date_format: str) -> str:
# on OSX, encoding from LC_CTYPE determines the unicode output in PY3
# make sure it's same as LC_TIME
with temporary_locale(self.locale, locale.LC_TIME), temporary_locale(
@ -131,11 +152,11 @@ class memoized:
"""
def __init__(self, func):
def __init__(self, func: Callable) -> None:
self.func = func
self.cache = {}
self.cache: Dict[Any, Any] = {}
def __call__(self, *args):
def __call__(self, *args) -> Any:
if not isinstance(args, Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
@ -147,17 +168,23 @@ class memoized:
self.cache[args] = value
return value
def __repr__(self):
def __repr__(self) -> Optional[str]:
return self.func.__doc__
def __get__(self, obj, objtype):
def __get__(self, obj: Any, objtype):
"""Support instance methods."""
fn = partial(self.__call__, obj)
fn.cache = self.cache
return fn
def deprecated_attribute(old, new, since=None, remove=None, doc=None):
def deprecated_attribute(
old: str,
new: str,
since: Tuple[int, ...],
remove: Optional[Tuple[int, ...]] = None,
doc: Optional[str] = None,
):
"""Attribute deprecation decorator for gentle upgrades
For example:
@ -198,7 +225,7 @@ def deprecated_attribute(old, new, since=None, remove=None, doc=None):
return decorator
def get_date(string):
def get_date(string: str) -> datetime.datetime:
"""Return a datetime object from a string.
If no format matches the given date, raise a ValueError.
@ -212,7 +239,9 @@ def get_date(string):
@contextmanager
def pelican_open(filename, mode="r", strip_crs=(sys.platform == "win32")):
def pelican_open(
filename: str, mode: str = "r", strip_crs: bool = (sys.platform == "win32")
) -> Generator[str, None, None]:
"""Open a file and return its content"""
# utf-8-sig will clear any BOM if present
@ -221,7 +250,12 @@ def pelican_open(filename, mode="r", strip_crs=(sys.platform == "win32")):
yield content
def slugify(value, regex_subs=(), preserve_case=False, use_unicode=False):
def slugify(
value: str,
regex_subs: Iterable[Tuple[str, str]] = (),
preserve_case: bool = False,
use_unicode: bool = False,
) -> str:
"""
Normalizes string, converts to lowercase, removes non-alpha characters,
and converts spaces to hyphens.
@ -233,9 +267,10 @@ def slugify(value, regex_subs=(), preserve_case=False, use_unicode=False):
"""
import unicodedata
import unidecode
def normalize_unicode(text):
def normalize_unicode(text: str) -> str:
# normalize text by compatibility composition
# see: https://en.wikipedia.org/wiki/Unicode_equivalence
return unicodedata.normalize("NFKC", text)
@ -262,7 +297,9 @@ def slugify(value, regex_subs=(), preserve_case=False, use_unicode=False):
return value.strip()
def copy(source, destination, ignores=None):
def copy(
source: str, destination: str, ignores: Optional[Iterable[str]] = None
) -> None:
"""Recursively copy source into destination.
If source is a file, destination has to be a file as well.
@ -334,7 +371,7 @@ def copy(source, destination, ignores=None):
)
def copy_file(source, destination):
def copy_file(source: str, destination: str) -> None:
"""Copy a file"""
try:
shutil.copyfile(source, destination)
@ -344,7 +381,7 @@ def copy_file(source, destination):
)
def clean_output_dir(path, retention):
def clean_output_dir(path: str, retention: Iterable[str]) -> None:
"""Remove all files from output directory except those in retention list"""
if not os.path.exists(path):
@ -381,24 +418,24 @@ def clean_output_dir(path, retention):
logger.error("Unable to delete %s, file type unknown", file)
def get_relative_path(path):
def get_relative_path(path: str) -> str:
"""Return the relative path from the given path to the root path."""
components = split_all(path)
if len(components) <= 1:
if components is None or len(components) <= 1:
return os.curdir
else:
parents = [os.pardir] * (len(components) - 1)
return os.path.join(*parents)
def path_to_url(path):
def path_to_url(path: str) -> str:
"""Return the URL corresponding to a given path."""
if path is not None:
path = posixize_path(path)
return path
def posixize_path(rel_path):
def posixize_path(rel_path: str) -> str:
"""Use '/' as path separator, so that source references,
like '{static}/foo/bar.jpg' or 'extras/favicon.ico',
will work on Windows as well as on Mac and Linux."""
@ -427,20 +464,20 @@ class _HTMLWordTruncator(HTMLParser):
_singlets = ("br", "col", "link", "base", "img", "param", "area", "hr", "input")
class TruncationCompleted(Exception):
def __init__(self, truncate_at):
def __init__(self, truncate_at: int) -> None:
super().__init__(truncate_at)
self.truncate_at = truncate_at
def __init__(self, max_words):
def __init__(self, max_words: int) -> None:
super().__init__(convert_charrefs=False)
self.max_words = max_words
self.words_found = 0
self.open_tags = []
self.last_word_end = None
self.truncate_at = None
self.truncate_at: Optional[int] = None
def feed(self, *args, **kwargs):
def feed(self, *args, **kwargs) -> None:
try:
super().feed(*args, **kwargs)
except self.TruncationCompleted as exc:
@ -448,29 +485,29 @@ class _HTMLWordTruncator(HTMLParser):
else:
self.truncate_at = None
def getoffset(self):
def getoffset(self) -> int:
line_start = 0
lineno, line_offset = self.getpos()
for i in range(lineno - 1):
line_start = self.rawdata.index("\n", line_start) + 1
return line_start + line_offset
def add_word(self, word_end):
def add_word(self, word_end: int) -> None:
self.words_found += 1
self.last_word_end = None
if self.words_found == self.max_words:
raise self.TruncationCompleted(word_end)
def add_last_word(self):
def add_last_word(self) -> None:
if self.last_word_end is not None:
self.add_word(self.last_word_end)
def handle_starttag(self, tag, attrs):
def handle_starttag(self, tag: str, attrs: Any) -> None:
self.add_last_word()
if tag not in self._singlets:
self.open_tags.insert(0, tag)
def handle_endtag(self, tag):
def handle_endtag(self, tag: str) -> None:
self.add_last_word()
try:
i = self.open_tags.index(tag)
@ -481,7 +518,7 @@ class _HTMLWordTruncator(HTMLParser):
# all unclosed intervening start tags with omitted end tags
del self.open_tags[: i + 1]
def handle_data(self, data):
def handle_data(self, data: str) -> None:
word_end = 0
offset = self.getoffset()
@ -499,7 +536,7 @@ class _HTMLWordTruncator(HTMLParser):
if word_end < len(data):
self.add_last_word()
def _handle_ref(self, name, char):
def _handle_ref(self, name: str, char: str) -> None:
"""
Called by handle_entityref() or handle_charref() when a ref like
`&mdash;`, `&#8212;`, or `&#x2014` is found.
@ -543,7 +580,7 @@ class _HTMLWordTruncator(HTMLParser):
else:
self.add_last_word()
def handle_entityref(self, name):
def handle_entityref(self, name: str) -> None:
"""
Called when an entity ref like '&mdash;' is found
@ -556,7 +593,7 @@ class _HTMLWordTruncator(HTMLParser):
char = ""
self._handle_ref(name, char)
def handle_charref(self, name):
def handle_charref(self, name: str) -> None:
"""
Called when a char ref like '&#8212;' or '&#x2014' is found
@ -574,7 +611,7 @@ class _HTMLWordTruncator(HTMLParser):
self._handle_ref("#" + name, char)
def truncate_html_words(s, num, end_text=""):
def truncate_html_words(s: str, num: int, end_text: str = "") -> str:
"""Truncates HTML to a certain number of words.
(not counting tags and comments). Closes opened tags if they were correctly
@ -600,7 +637,10 @@ def truncate_html_words(s, num, end_text="…"):
return out
def process_translations(content_list, translation_id=None):
def process_translations(
content_list: List[Content],
translation_id: Optional[Union[str, Collection[str]]] = None,
) -> Tuple[List[Content], List[Content]]:
"""Finds translations and returns them.
For each content_list item, populates the 'translations' attribute, and
@ -658,7 +698,7 @@ def process_translations(content_list, translation_id=None):
return index, translations
def get_original_items(items, with_str):
def get_original_items(items: List[Content], with_str: str) -> List[Content]:
def _warn_source_paths(msg, items, *extra):
args = [len(items)]
args.extend(extra)
@ -698,7 +738,10 @@ def get_original_items(items, with_str):
return original_items
def order_content(content_list, order_by="slug"):
def order_content(
content_list: List[Content],
order_by: Union[str, Callable[[Content], Any], None] = "slug",
) -> List[Content]:
"""Sorts content.
order_by can be a string of an attribute or sorting function. If order_by
@ -758,7 +801,11 @@ def order_content(content_list, order_by="slug"):
return content_list
def wait_for_changes(settings_file, reader_class, settings):
def wait_for_changes(
settings_file: str,
reader_class: Type["Readers"],
settings: "Settings",
):
content_path = settings.get("PATH", "")
theme_path = settings.get("THEME", "")
ignore_files = {
@ -788,13 +835,15 @@ def wait_for_changes(settings_file, reader_class, settings):
return next(
watchfiles.watch(
*watching_paths,
watch_filter=watchfiles.DefaultFilter(ignore_entity_patterns=ignore_files),
watch_filter=watchfiles.DefaultFilter(ignore_entity_patterns=ignore_files), # type: ignore
rust_timeout=0,
)
)
def set_date_tzinfo(d, tz_name=None):
def set_date_tzinfo(
d: datetime.datetime, tz_name: Optional[str] = None
) -> datetime.datetime:
"""Set the timezone for dates that don't have tzinfo"""
if tz_name and not d.tzinfo:
timezone = ZoneInfo(tz_name)
@ -805,11 +854,11 @@ def set_date_tzinfo(d, tz_name=None):
return d
def mkdir_p(path):
def mkdir_p(path: str) -> None:
os.makedirs(path, exist_ok=True)
def split_all(path):
def split_all(path: Union[str, pathlib.Path, None]) -> Optional[Sequence[str]]:
"""Split a path into a list of components
While os.path.split() splits a single component off the back of
@ -840,12 +889,12 @@ def split_all(path):
)
def path_to_file_url(path):
def path_to_file_url(path: str) -> str:
"""Convert file-system path to file:// URL"""
return urllib.parse.urljoin("file://", urllib.request.pathname2url(path))
def maybe_pluralize(count, singular, plural):
def maybe_pluralize(count: int, singular: str, plural: str) -> str:
"""
Returns a formatted string containing count and plural if count is not 1
Returns count and singular if count is 1
@ -862,7 +911,9 @@ def maybe_pluralize(count, singular, plural):
@contextmanager
def temporary_locale(temp_locale=None, lc_category=locale.LC_ALL):
def temporary_locale(
temp_locale: Optional[str] = None, lc_category: int = locale.LC_ALL
) -> Generator[None, None, None]:
"""
Enable code to run in a context with a temporary locale
Resets the locale back when exiting context.