Add support for compound file extensions.

Created a new class ReaderTree that is an infinitely
nested defaultdict containing components of the extension.
See comments on PR #2816.
This commit is contained in:
Holden Nelson 2021-12-18 23:33:52 -08:00
commit 208780e477
4 changed files with 261 additions and 8 deletions

View file

@ -123,10 +123,13 @@ class Generator:
if any(fnmatch.fnmatch(basename, ignore) for ignore in ignores):
return False
ext = os.path.splitext(basename)[1][1:]
if extensions is False or ext in extensions:
if extensions is False:
return True
for ext in extensions:
if basename.endswith(f'.{ext}'):
return True
return False
def get_files(self, paths, exclude=[], extensions=None):

View file

@ -1,8 +1,10 @@
import datetime
import logging
import operator
import os
import re
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from functools import reduce
from html import escape
from html.parser import HTMLParser
from io import StringIO
@ -496,8 +498,8 @@ class Readers(FileStampDataCacher):
def __init__(self, settings=None, cache_name=''):
self.settings = settings or {}
self.readers = {}
self.reader_classes = {}
self.readers = ReaderTree()
self.reader_classes = ReaderTree()
for cls in [BaseReader] + BaseReader.__subclasses__():
if not cls.enabled:
@ -542,8 +544,7 @@ class Readers(FileStampDataCacher):
source_path, content_class.__name__)
if not fmt:
_, ext = os.path.splitext(os.path.basename(path))
fmt = ext[1:]
fmt = self.readers.get_format(path)
if fmt not in self.readers:
raise TypeError(
@ -746,3 +747,159 @@ def parse_path_metadata(source_path, settings=None, process=None):
v = process(k, v)
metadata[k] = v
return metadata
class ReaderTree():
def __init__(self):
self.tree_dd = ReaderTree._rec_dd()
def __str__(self):
return str(ReaderTree._rec_dd_to_dict(self.tree_dd))
def __iter__(self):
for key in ReaderTree._rec_get_next_key(self.tree_dd):
yield key
def __setitem__(self, key, value):
components = reversed(key.split('.'))
reduce(operator.getitem, components, self.tree_dd)[''] = value
def __getitem__(self, key):
components = reversed(key.split('.'))
value = reduce(operator.getitem, components, self.tree_dd)
if value:
return value['']
else:
raise KeyError
def __delitem__(self, key):
value = ReaderTree._rec_del_item(self.tree_dd, key)
if not value:
raise KeyError
def __contains__(self, item):
try:
self[item]
return True
except KeyError:
return False
def __len__(self):
return len(list(self.keys()))
def keys(self):
return self.__iter__()
def values(self):
for value in ReaderTree._rec_get_next_value(self.tree_dd):
yield value
def items(self):
return zip(self.keys(), self.values())
def get(self, key):
return self[key]
def setdefault(self, key, value):
if key in self:
return self[key]
else:
self[key] = value
return value
def clear(self):
self.tree_dd.clear()
def pop(self, key, default=None):
if key in self:
value = self[key]
del self[key]
return value
elif default:
return default
else:
raise KeyError
def copy(self):
return self.tree_dd.copy()
def update(self, d):
for key, value in d.items():
self[key] = value
def get_format(self, filename):
try:
ext = ReaderTree._rec_get_fmt_from_filename(self.tree_dd, filename)
return ext[1:]
except TypeError:
return ''
def has_reader(self, filename):
fmt = self.get_format(filename)
return fmt in self
def as_dict(self):
return ReaderTree._rec_dd_to_dict(self.tree_dd)
@staticmethod
def _rec_dd():
return defaultdict(ReaderTree._rec_dd)
@staticmethod
def _rec_dd_to_dict(dd):
d = dict(dd)
for key, value in d.items():
if type(value) == defaultdict:
d[key] = ReaderTree._rec_dd_to_dict(value)
return d
@staticmethod
def _rec_get_next_key(d):
for key in d:
if key != '':
if '' in d[key]:
yield key
if type(d[key]) == defaultdict:
for component in ReaderTree._rec_get_next_key(d[key]):
yield '.'.join([component, key])
@staticmethod
def _rec_get_next_value(d):
for key, value in d.items():
if key == '':
yield value
else:
if type(d[key]) == defaultdict:
yield from ReaderTree._rec_get_next_value(d[key])
@staticmethod
def _rec_del_item(d, intended_key):
if intended_key in d:
value = d[intended_key]['']
del d[intended_key]['']
return value
else:
for key in d:
if type(d[key]) == defaultdict:
ReaderTree._rec_del_item(d[key], intended_key)
return None
@staticmethod
def _rec_get_fmt_from_filename(d, filename):
if '.' in filename:
file, ext = os.path.splitext(filename)
fmt = ext[1:] if ext else ext
if fmt in d:
next_component = ReaderTree._rec_get_fmt_from_filename(d[fmt], file)
return '.'.join([next_component, fmt])
elif '' in d:
return ''
else:
raise TypeError('No reader found for file.')
else:
return ''

View file

@ -41,6 +41,9 @@ class TestGenerator(unittest.TestCase):
ignored_file = os.path.join(CUR_DIR, 'content', 'ignored1.rst')
self.assertFalse(include_path(ignored_file))
compound_file = os.path.join(CUR_DIR, 'content', 'compound.md.html')
self.assertTrue(include_path(compound_file, extensions=('md.html',)))
def test_get_files_exclude(self):
"""Test that Generator.get_files() properly excludes directories.
"""

View file

@ -1,5 +1,5 @@
import os
from unittest.mock import patch
from unittest.mock import Mock, patch
from pelican import readers
from pelican.tests.support import get_settings, unittest
@ -76,6 +76,18 @@ class DefaultReaderTest(ReaderTest):
with self.assertRaises(TypeError):
self.read_file(path='article_with_metadata.unknownextension')
with self.assertRaises(TypeError):
self.read_file(path='article_with.compound.extension')
def test_readfile_compound_extension(self):
CompoundReader = Mock()
# throws type error b/c of mock
with self.assertRaises(TypeError):
self.read_file(path='article_with.compound.extension',
READERS={'compound.extension': CompoundReader})
CompoundReader.read.assert_called_with('article_with.compound.extension')
def test_readfile_path_metadata_implicit_dates(self):
test_file = 'article_with_metadata_implicit_dates.html'
page = self.read_file(path=test_file, DEFAULT_DATE='fs')
@ -918,3 +930,81 @@ class HTMLReaderTest(ReaderTest):
'title': 'Article with an inline SVG',
}
self.assertDictHasSubset(page.metadata, expected)
class ReaderTreeTest(unittest.TestCase):
def setUp(self):
readers_and_exts = {
'BaseReader': ['static'],
'RstReader': ['rst'],
'HtmlReader': ['htm', 'html'],
'MDReader': ['md', 'mk', 'mkdown', 'mkd'],
'MDeepReader': ['md.html'],
'FooReader': ['foo.bar.baz.yaz']
}
self.reader_classes = readers.ReaderTree()
for reader, exts in readers_and_exts.items():
for ext in exts:
self.reader_classes[ext] = reader
def test_correct_mapping_generated(self):
expected_mapping = {
'static': {'': 'BaseReader'},
'rst': {'': 'RstReader'},
'htm': {'': 'HtmlReader'},
'html': {
'': 'HtmlReader',
'md': {'': 'MDeepReader'}
},
'md': {'': 'MDReader'},
'mk': {'': 'MDReader'},
'mkdown': {'': 'MDReader'},
'mkd': {'': 'MDReader'},
'yaz': {
'baz': {
'bar': {
'foo': {'': 'FooReader'}}}}}
self.assertEqual(expected_mapping, self.reader_classes.as_dict())
def test_containment(self):
self.assertTrue('md.html' in self.reader_classes)
self.assertTrue('html' in self.reader_classes)
self.assertFalse('txt' in self.reader_classes)
def test_deletion(self):
self.assertTrue('rst' in self.reader_classes)
del self.reader_classes['rst']
self.assertFalse('rst' in self.reader_classes)
def test_update(self):
self.reader_classes.update({
'new.ext': 'NewExtReader',
'txt': 'TxtReader'
})
self.assertEqual(self.reader_classes['new.ext'], 'NewExtReader')
self.assertEqual(self.reader_classes['txt'], 'TxtReader')
def test_get_format(self):
html_ext = self.reader_classes.get_format('text.html')
md_ext = self.reader_classes.get_format('another.md')
compound_ext = self.reader_classes.get_format('dots.compound.md.html')
no_ext = self.reader_classes.get_format('no_extension')
bar_ext = self.reader_classes.get_format('file.bar')
self.assertEqual(html_ext, 'html')
self.assertEqual(md_ext, 'md')
self.assertEqual(compound_ext, 'md.html')
self.assertEqual(no_ext, '')
self.assertEqual(bar_ext, '')
def test_has_reader(self):
has_reader = self.reader_classes.has_reader
self.assertTrue(has_reader('text.html'))
self.assertFalse(has_reader('no_ext'))
print(has_reader('bad_ext.bar'))
self.assertFalse(has_reader('bad_ext.bar'))