1
0
Fork 0
forked from github/pelican

Merge pull request #3278 from bjoernricks/contents-types

This commit is contained in:
Justin Mayer 2024-01-24 22:54:01 +01:00 committed by GitHub
commit dbf90a4821
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 150 additions and 88 deletions

View file

@ -6,7 +6,8 @@ import os
import re
from datetime import timezone
from html import unescape
from urllib.parse import unquote, urljoin, urlparse, urlunparse
from typing import Any, Dict, Optional, Set, Tuple
from urllib.parse import ParseResult, unquote, urljoin, urlparse, urlunparse
try:
from zoneinfo import ZoneInfo
@ -15,7 +16,7 @@ except ModuleNotFoundError:
from pelican.plugins import signals
from pelican.settings import DEFAULT_CONFIG
from pelican.settings import DEFAULT_CONFIG, Settings
from pelican.utils import (
deprecated_attribute,
memoized,
@ -44,12 +45,20 @@ class Content:
"""
default_template: Optional[str] = None
mandatory_properties: Tuple[str, ...] = ()
@deprecated_attribute(old="filename", new="source_path", since=(3, 2, 0))
def filename():
return None
def __init__(
self, content, metadata=None, settings=None, source_path=None, context=None
self,
content: str,
metadata: Optional[Dict[str, Any]] = None,
settings: Optional[Settings] = None,
source_path: Optional[str] = None,
context: Optional[Dict[Any, Any]] = None,
):
if metadata is None:
metadata = {}
@ -156,10 +165,10 @@ class Content:
signals.content_object_init.send(self)
def __str__(self):
def __str__(self) -> str:
return self.source_path or repr(self)
def _has_valid_mandatory_properties(self):
def _has_valid_mandatory_properties(self) -> bool:
"""Test mandatory properties are set."""
for prop in self.mandatory_properties:
if not hasattr(self, prop):
@ -169,7 +178,7 @@ class Content:
return False
return True
def _has_valid_save_as(self):
def _has_valid_save_as(self) -> bool:
"""Return true if save_as doesn't write outside output path, false
otherwise."""
try:
@ -190,7 +199,7 @@ class Content:
return True
def _has_valid_status(self):
def _has_valid_status(self) -> bool:
if hasattr(self, "allowed_statuses"):
if self.status not in self.allowed_statuses:
logger.error(
@ -204,7 +213,7 @@ class Content:
# if undefined we allow all
return True
def is_valid(self):
def is_valid(self) -> bool:
"""Validate Content"""
# Use all() to not short circuit and get results of all validations
return all(
@ -216,7 +225,7 @@ class Content:
)
@property
def url_format(self):
def url_format(self) -> Dict[str, Any]:
"""Returns the URL, formatted with the proper values"""
metadata = copy.copy(self.metadata)
path = self.metadata.get("path", self.get_relative_source_path())
@ -232,19 +241,19 @@ class Content:
)
return metadata
def _expand_settings(self, key, klass=None):
def _expand_settings(self, key: str, klass: Optional[str] = None) -> str:
if not klass:
klass = self.__class__.__name__
fq_key = (f"{klass}_{key}").upper()
return str(self.settings[fq_key]).format(**self.url_format)
def get_url_setting(self, key):
def get_url_setting(self, key: str) -> str:
if hasattr(self, "override_" + key):
return getattr(self, "override_" + key)
key = key if self.in_default_lang else "lang_%s" % key
return self._expand_settings(key)
def _link_replacer(self, siteurl, m):
def _link_replacer(self, siteurl: str, m: re.Match) -> str:
what = m.group("what")
value = urlparse(m.group("value"))
path = value.path
@ -272,15 +281,15 @@ class Content:
# XXX Put this in a different location.
if what in {"filename", "static", "attach"}:
def _get_linked_content(key, url):
def _get_linked_content(key: str, url: ParseResult) -> Optional[Content]:
nonlocal value
def _find_path(path):
def _find_path(path: str) -> Optional[Content]:
if path.startswith("/"):
path = path[1:]
else:
# relative to the source path of this content
path = self.get_relative_source_path(
path = self.get_relative_source_path( # type: ignore
os.path.join(self.relative_dir, path)
)
return self._context[key].get(path, None)
@ -324,7 +333,7 @@ class Content:
linked_content = _get_linked_content(key, value)
if linked_content:
if what == "attach":
linked_content.attach_to(self)
linked_content.attach_to(self) # type: ignore
origin = joiner(siteurl, linked_content.url)
origin = origin.replace("\\", "/") # for Windows paths.
else:
@ -359,7 +368,7 @@ class Content:
return "".join((m.group("markup"), m.group("quote"), origin, m.group("quote")))
def _get_intrasite_link_regex(self):
def _get_intrasite_link_regex(self) -> re.Pattern:
intrasite_link_regex = self.settings["INTRASITE_LINK_REGEX"]
regex = r"""
(?P<markup><[^\>]+ # match tag with all url-value attributes
@ -370,7 +379,7 @@ class Content:
(?P=quote)""".format(intrasite_link_regex)
return re.compile(regex, re.X)
def _update_content(self, content, siteurl):
def _update_content(self, content: str, siteurl: str) -> str:
"""Update the content attribute.
Change all the relative paths of the content to relative paths
@ -386,7 +395,7 @@ class Content:
hrefs = self._get_intrasite_link_regex()
return hrefs.sub(lambda m: self._link_replacer(siteurl, m), content)
def get_static_links(self):
def get_static_links(self) -> Set[str]:
static_links = set()
hrefs = self._get_intrasite_link_regex()
for m in hrefs.finditer(self._content):
@ -402,15 +411,15 @@ class Content:
path = self.get_relative_source_path(
os.path.join(self.relative_dir, path)
)
path = path.replace("%20", " ")
path = path.replace("%20", " ") # type: ignore
static_links.add(path)
return static_links
def get_siteurl(self):
def get_siteurl(self) -> str:
return self._context.get("localsiteurl", "")
@memoized
def get_content(self, siteurl):
def get_content(self, siteurl: str) -> str:
if hasattr(self, "_get_content"):
content = self._get_content()
else:
@ -418,11 +427,11 @@ class Content:
return self._update_content(content, siteurl)
@property
def content(self):
def content(self) -> str:
return self.get_content(self.get_siteurl())
@memoized
def get_summary(self, siteurl):
def get_summary(self, siteurl: str) -> str:
"""Returns the summary of an article.
This is based on the summary metadata if set, otherwise truncate the
@ -441,10 +450,10 @@ class Content:
)
@property
def summary(self):
def summary(self) -> str:
return self.get_summary(self.get_siteurl())
def _get_summary(self):
def _get_summary(self) -> str:
"""deprecated function to access summary"""
logger.warning(
@ -454,34 +463,36 @@ class Content:
return self.summary
@summary.setter
def summary(self, value):
def summary(self, value: str):
"""Dummy function"""
pass
@property
def status(self):
def status(self) -> str:
return self._status
@status.setter
def status(self, value):
def status(self, value: str) -> None:
# TODO maybe typecheck
self._status = value.lower()
@property
def url(self):
def url(self) -> str:
return self.get_url_setting("url")
@property
def save_as(self):
def save_as(self) -> str:
return self.get_url_setting("save_as")
def _get_template(self):
def _get_template(self) -> str:
if hasattr(self, "template") and self.template is not None:
return self.template
else:
return self.default_template
def get_relative_source_path(self, source_path=None):
def get_relative_source_path(
self, source_path: Optional[str] = None
) -> Optional[str]:
"""Return the relative path (from the content path) to the given
source_path.
@ -501,7 +512,7 @@ class Content:
)
@property
def relative_dir(self):
def relative_dir(self) -> str:
return posixize_path(
os.path.dirname(
os.path.relpath(
@ -511,7 +522,7 @@ class Content:
)
)
def refresh_metadata_intersite_links(self):
def refresh_metadata_intersite_links(self) -> None:
for key in self.settings["FORMATTED_FIELDS"]:
if key in self.metadata and key != "summary":
value = self._update_content(self.metadata[key], self.get_siteurl())
@ -534,7 +545,7 @@ class Page(Content):
default_status = "published"
default_template = "page"
def _expand_settings(self, key):
def _expand_settings(self, key: str) -> str:
klass = "draft_page" if self.status == "draft" else None
return super()._expand_settings(key, klass)
@ -561,7 +572,7 @@ class Article(Content):
if not hasattr(self, "date") and self.status == "draft":
self.date = datetime.datetime.max.replace(tzinfo=self.timezone)
def _expand_settings(self, key):
def _expand_settings(self, key: str) -> str:
klass = "draft" if self.status == "draft" else "article"
return super()._expand_settings(key, klass)
@ -571,7 +582,7 @@ class Static(Content):
default_status = "published"
default_template = None
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._output_location_referenced = False
@ -588,18 +599,18 @@ class Static(Content):
return None
@property
def url(self):
def url(self) -> str:
# Note when url has been referenced, so we can avoid overriding it.
self._output_location_referenced = True
return super().url
@property
def save_as(self):
def save_as(self) -> str:
# Note when save_as has been referenced, so we can avoid overriding it.
self._output_location_referenced = True
return super().save_as
def attach_to(self, content):
def attach_to(self, content: Content) -> None:
"""Override our output directory with that of the given content object."""
# Determine our file's new output path relative to the linking
@ -624,7 +635,7 @@ class Static(Content):
new_url = path_to_url(new_save_as)
def _log_reason(reason):
def _log_reason(reason: str) -> None:
logger.warning(
"The {attach} link in %s cannot relocate "
"%s because %s. Falling back to "

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import datetime
import fnmatch
import locale
@ -16,6 +18,21 @@ from html import entities
from html.parser import HTMLParser
from itertools import groupby
from operator import attrgetter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import dateutil.parser
@ -27,11 +44,15 @@ from markupsafe import Markup
import watchfiles
if TYPE_CHECKING:
from pelican.contents import Content
from pelican.readers import Readers
from pelican.settings import Settings
logger = logging.getLogger(__name__)
def sanitised_join(base_directory, *parts):
def sanitised_join(base_directory: str, *parts: str) -> str:
joined = posixize_path(os.path.abspath(os.path.join(base_directory, *parts)))
base = posixize_path(os.path.abspath(base_directory))
if not joined.startswith(base):
@ -40,7 +61,7 @@ def sanitised_join(base_directory, *parts):
return joined
def strftime(date, date_format):
def strftime(date: datetime.datetime, date_format: str) -> str:
"""
Enhanced replacement for built-in strftime with zero stripping
@ -109,10 +130,10 @@ class DateFormatter:
defined in LOCALE setting
"""
def __init__(self):
def __init__(self) -> None:
self.locale = locale.setlocale(locale.LC_TIME)
def __call__(self, date, date_format):
def __call__(self, date: datetime.datetime, date_format: str) -> str:
# on OSX, encoding from LC_CTYPE determines the unicode output in PY3
# make sure it's same as LC_TIME
with temporary_locale(self.locale, locale.LC_TIME), temporary_locale(
@ -131,11 +152,11 @@ class memoized:
"""
def __init__(self, func):
def __init__(self, func: Callable) -> None:
self.func = func
self.cache = {}
self.cache: Dict[Any, Any] = {}
def __call__(self, *args):
def __call__(self, *args) -> Any:
if not isinstance(args, Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
@ -147,17 +168,23 @@ class memoized:
self.cache[args] = value
return value
def __repr__(self):
def __repr__(self) -> Optional[str]:
return self.func.__doc__
def __get__(self, obj, objtype):
def __get__(self, obj: Any, objtype):
"""Support instance methods."""
fn = partial(self.__call__, obj)
fn.cache = self.cache
return fn
def deprecated_attribute(old, new, since=None, remove=None, doc=None):
def deprecated_attribute(
old: str,
new: str,
since: Tuple[int, ...],
remove: Optional[Tuple[int, ...]] = None,
doc: Optional[str] = None,
):
"""Attribute deprecation decorator for gentle upgrades
For example:
@ -198,7 +225,7 @@ def deprecated_attribute(old, new, since=None, remove=None, doc=None):
return decorator
def get_date(string):
def get_date(string: str) -> datetime.datetime:
"""Return a datetime object from a string.
If no format matches the given date, raise a ValueError.
@ -212,7 +239,9 @@ def get_date(string):
@contextmanager
def pelican_open(filename, mode="r", strip_crs=(sys.platform == "win32")):
def pelican_open(
filename: str, mode: str = "r", strip_crs: bool = (sys.platform == "win32")
) -> Generator[str, None, None]:
"""Open a file and return its content"""
# utf-8-sig will clear any BOM if present
@ -221,7 +250,12 @@ def pelican_open(filename, mode="r", strip_crs=(sys.platform == "win32")):
yield content
def slugify(value, regex_subs=(), preserve_case=False, use_unicode=False):
def slugify(
value: str,
regex_subs: Iterable[Tuple[str, str]] = (),
preserve_case: bool = False,
use_unicode: bool = False,
) -> str:
"""
Normalizes string, converts to lowercase, removes non-alpha characters,
and converts spaces to hyphens.
@ -233,9 +267,10 @@ def slugify(value, regex_subs=(), preserve_case=False, use_unicode=False):
"""
import unicodedata
import unidecode
def normalize_unicode(text):
def normalize_unicode(text: str) -> str:
# normalize text by compatibility composition
# see: https://en.wikipedia.org/wiki/Unicode_equivalence
return unicodedata.normalize("NFKC", text)
@ -262,7 +297,9 @@ def slugify(value, regex_subs=(), preserve_case=False, use_unicode=False):
return value.strip()
def copy(source, destination, ignores=None):
def copy(
source: str, destination: str, ignores: Optional[Iterable[str]] = None
) -> None:
"""Recursively copy source into destination.
If source is a file, destination has to be a file as well.
@ -334,7 +371,7 @@ def copy(source, destination, ignores=None):
)
def copy_file(source, destination):
def copy_file(source: str, destination: str) -> None:
"""Copy a file"""
try:
shutil.copyfile(source, destination)
@ -344,7 +381,7 @@ def copy_file(source, destination):
)
def clean_output_dir(path, retention):
def clean_output_dir(path: str, retention: Iterable[str]) -> None:
"""Remove all files from output directory except those in retention list"""
if not os.path.exists(path):
@ -381,24 +418,24 @@ def clean_output_dir(path, retention):
logger.error("Unable to delete %s, file type unknown", file)
def get_relative_path(path):
def get_relative_path(path: str) -> str:
"""Return the relative path from the given path to the root path."""
components = split_all(path)
if len(components) <= 1:
if components is None or len(components) <= 1:
return os.curdir
else:
parents = [os.pardir] * (len(components) - 1)
return os.path.join(*parents)
def path_to_url(path):
def path_to_url(path: str) -> str:
"""Return the URL corresponding to a given path."""
if path is not None:
path = posixize_path(path)
return path
def posixize_path(rel_path):
def posixize_path(rel_path: str) -> str:
"""Use '/' as path separator, so that source references,
like '{static}/foo/bar.jpg' or 'extras/favicon.ico',
will work on Windows as well as on Mac and Linux."""
@ -427,20 +464,20 @@ class _HTMLWordTruncator(HTMLParser):
_singlets = ("br", "col", "link", "base", "img", "param", "area", "hr", "input")
class TruncationCompleted(Exception):
def __init__(self, truncate_at):
def __init__(self, truncate_at: int) -> None:
super().__init__(truncate_at)
self.truncate_at = truncate_at
def __init__(self, max_words):
def __init__(self, max_words: int) -> None:
super().__init__(convert_charrefs=False)
self.max_words = max_words
self.words_found = 0
self.open_tags = []
self.last_word_end = None
self.truncate_at = None
self.truncate_at: Optional[int] = None
def feed(self, *args, **kwargs):
def feed(self, *args, **kwargs) -> None:
try:
super().feed(*args, **kwargs)
except self.TruncationCompleted as exc:
@ -448,29 +485,29 @@ class _HTMLWordTruncator(HTMLParser):
else:
self.truncate_at = None
def getoffset(self):
def getoffset(self) -> int:
line_start = 0
lineno, line_offset = self.getpos()
for i in range(lineno - 1):
line_start = self.rawdata.index("\n", line_start) + 1
return line_start + line_offset
def add_word(self, word_end):
def add_word(self, word_end: int) -> None:
self.words_found += 1
self.last_word_end = None
if self.words_found == self.max_words:
raise self.TruncationCompleted(word_end)
def add_last_word(self):
def add_last_word(self) -> None:
if self.last_word_end is not None:
self.add_word(self.last_word_end)
def handle_starttag(self, tag, attrs):
def handle_starttag(self, tag: str, attrs: Any) -> None:
self.add_last_word()
if tag not in self._singlets:
self.open_tags.insert(0, tag)
def handle_endtag(self, tag):
def handle_endtag(self, tag: str) -> None:
self.add_last_word()
try:
i = self.open_tags.index(tag)
@ -481,7 +518,7 @@ class _HTMLWordTruncator(HTMLParser):
# all unclosed intervening start tags with omitted end tags
del self.open_tags[: i + 1]
def handle_data(self, data):
def handle_data(self, data: str) -> None:
word_end = 0
offset = self.getoffset()
@ -499,7 +536,7 @@ class _HTMLWordTruncator(HTMLParser):
if word_end < len(data):
self.add_last_word()
def _handle_ref(self, name, char):
def _handle_ref(self, name: str, char: str) -> None:
"""
Called by handle_entityref() or handle_charref() when a ref like
`&mdash;`, `&#8212;`, or `&#x2014` is found.
@ -543,7 +580,7 @@ class _HTMLWordTruncator(HTMLParser):
else:
self.add_last_word()
def handle_entityref(self, name):
def handle_entityref(self, name: str) -> None:
"""
Called when an entity ref like '&mdash;' is found
@ -556,7 +593,7 @@ class _HTMLWordTruncator(HTMLParser):
char = ""
self._handle_ref(name, char)
def handle_charref(self, name):
def handle_charref(self, name: str) -> None:
"""
Called when a char ref like '&#8212;' or '&#x2014' is found
@ -574,7 +611,7 @@ class _HTMLWordTruncator(HTMLParser):
self._handle_ref("#" + name, char)
def truncate_html_words(s, num, end_text=""):
def truncate_html_words(s: str, num: int, end_text: str = "") -> str:
"""Truncates HTML to a certain number of words.
(not counting tags and comments). Closes opened tags if they were correctly
@ -600,7 +637,10 @@ def truncate_html_words(s, num, end_text="…"):
return out
def process_translations(content_list, translation_id=None):
def process_translations(
content_list: List[Content],
translation_id: Optional[Union[str, Collection[str]]] = None,
) -> Tuple[List[Content], List[Content]]:
"""Finds translations and returns them.
For each content_list item, populates the 'translations' attribute, and
@ -658,7 +698,7 @@ def process_translations(content_list, translation_id=None):
return index, translations
def get_original_items(items, with_str):
def get_original_items(items: List[Content], with_str: str) -> List[Content]:
def _warn_source_paths(msg, items, *extra):
args = [len(items)]
args.extend(extra)
@ -698,7 +738,10 @@ def get_original_items(items, with_str):
return original_items
def order_content(content_list, order_by="slug"):
def order_content(
content_list: List[Content],
order_by: Union[str, Callable[[Content], Any], None] = "slug",
) -> List[Content]:
"""Sorts content.
order_by can be a string of an attribute or sorting function. If order_by
@ -758,7 +801,11 @@ def order_content(content_list, order_by="slug"):
return content_list
def wait_for_changes(settings_file, reader_class, settings):
def wait_for_changes(
settings_file: str,
reader_class: Type["Readers"],
settings: "Settings",
):
content_path = settings.get("PATH", "")
theme_path = settings.get("THEME", "")
ignore_files = {
@ -788,13 +835,15 @@ def wait_for_changes(settings_file, reader_class, settings):
return next(
watchfiles.watch(
*watching_paths,
watch_filter=watchfiles.DefaultFilter(ignore_entity_patterns=ignore_files),
watch_filter=watchfiles.DefaultFilter(ignore_entity_patterns=ignore_files), # type: ignore
rust_timeout=0,
)
)
def set_date_tzinfo(d, tz_name=None):
def set_date_tzinfo(
d: datetime.datetime, tz_name: Optional[str] = None
) -> datetime.datetime:
"""Set the timezone for dates that don't have tzinfo"""
if tz_name and not d.tzinfo:
timezone = ZoneInfo(tz_name)
@ -805,11 +854,11 @@ def set_date_tzinfo(d, tz_name=None):
return d
def mkdir_p(path):
def mkdir_p(path: str) -> None:
os.makedirs(path, exist_ok=True)
def split_all(path):
def split_all(path: Union[str, pathlib.Path, None]) -> Optional[Sequence[str]]:
"""Split a path into a list of components
While os.path.split() splits a single component off the back of
@ -840,12 +889,12 @@ def split_all(path):
)
def path_to_file_url(path):
def path_to_file_url(path: str) -> str:
"""Convert file-system path to file:// URL"""
return urllib.parse.urljoin("file://", urllib.request.pathname2url(path))
def maybe_pluralize(count, singular, plural):
def maybe_pluralize(count: int, singular: str, plural: str) -> str:
"""
Returns a formatted string containing count and plural if count is not 1
Returns count and singular if count is 1
@ -862,7 +911,9 @@ def maybe_pluralize(count, singular, plural):
@contextmanager
def temporary_locale(temp_locale=None, lc_category=locale.LC_ALL):
def temporary_locale(
temp_locale: Optional[str] = None, lc_category: int = locale.LC_ALL
) -> Generator[None, None, None]:
"""
Enable code to run in a context with a temporary locale
Resets the locale back when exiting context.