.resolve_db_name() and .execute() work without inspect

Refs #420
This commit is contained in:
Simon Willison 2019-03-31 16:51:52 -07:00
commit 7d0f668556
4 changed files with 68 additions and 16 deletions

View file

@ -116,6 +116,43 @@ async def favicon(request):
return response.text("")
class ConnectedDatabase:
def __init__(self, path=None, is_mutable=False, is_memory=False):
self.path = path
self.is_mutable = is_mutable
self.is_memory = is_memory
self.hash = None
self.size = None
if not self.is_mutable:
p = Path(path)
self.hash = inspect_hash(p)
self.size = p.stat().st_size
@property
def name(self):
if self.is_memory:
return ":memory:"
else:
return Path(self.path).stem
def __repr__(self):
tags = []
if self.is_mutable:
tags.append("mutable")
if self.is_memory:
tags.append("memory")
if self.hash:
tags.append("hash={}".format(self.hash))
if self.size is not None:
tags.append("size={}".format(self.size))
tags_str = ""
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<ConnectedDatabase: {}{}>".format(
self.name, tags_str
)
class Datasette:
def __init__(
@ -141,6 +178,18 @@ class Datasette:
self.files = [MEMORY]
elif memory:
self.files = (MEMORY,) + self.files
self.databases = {}
for file in self.files:
path = file
is_memory = False
if file is MEMORY:
path = None
is_memory = True
db = ConnectedDatabase(path, is_mutable=path not in self.immutables, is_memory=is_memory)
if db.name in self.databases:
raise Exception("Multiple files with same stem: {}".format(db.name))
self.databases[db.name] = db
print(self.databases)
self.cache_headers = cache_headers
self.cors = cors
self._inspect = inspect_data
@ -419,17 +468,17 @@ class Datasette:
def in_thread():
conn = getattr(connections, db_name, None)
if not conn:
info = self.inspect()[db_name]
if info["file"] == ":memory:":
db = self.databases[db_name]
if db.is_memory:
conn = sqlite3.connect(":memory:")
else:
# mode=ro or immutable=1?
if info["file"] in self.immutables:
qs = "immutable=1"
else:
if db.is_mutable:
qs = "mode=ro"
else:
qs = "immutable=1"
conn = sqlite3.connect(
"file:{}?{}".format(info["file"], qs),
"file:{}?{}".format(db.path, qs),
uri=True,
check_same_thread=False,
)

View file

@ -156,14 +156,13 @@ class BaseView(RenderMixin):
return r
async def resolve_db_name(self, request, db_name, **kwargs):
databases = self.ds.inspect()
hash = None
name = None
if "-" in db_name:
# Might be name-and-hash, or might just be
# a name with a hyphen in it
name, hash = db_name.rsplit("-", 1)
if name not in databases:
if name not in self.ds.databases:
# Try the whole name
name = db_name
hash = None
@ -171,11 +170,13 @@ class BaseView(RenderMixin):
name = db_name
# Verify the hash
try:
info = databases[name]
db = self.ds.databases[name]
except KeyError:
raise NotFound("Database not found: {}".format(name))
expected = info["hash"][:HASH_LENGTH]
expected = "000"
if db.hash is not None:
expected = db.hash[:HASH_LENGTH]
correct_hash_provided = (expected == hash)
if not correct_hash_provided: