diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 6ca0fac8..00000000 --- a/.coveragerc +++ /dev/null @@ -1,2 +0,0 @@ -[run] -omit = datasette/_version.py, datasette/utils/shutil_backport.py diff --git a/.dockerignore b/.dockerignore index 5078bf47..2c34db66 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,13 +1,13 @@ .DS_Store .cache .eggs +.git .gitignore .ipynb_checkpoints +.travis.yml build *.spec *.egg-info dist scratchpad venv -*.db -*.sqlite diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs deleted file mode 100644 index 84e574fd..00000000 --- a/.git-blame-ignore-revs +++ /dev/null @@ -1,4 +0,0 @@ -# Applying Black -35d6ee2790e41e96f243c1ff58be0c9c0519a8ce -368638555160fb9ac78f462d0f79b1394163fa30 -2b344f6a34d2adaa305996a1a580ece06397f6e4 diff --git a/.gitattributes b/.gitattributes index 744258eb..fb82f167 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1 @@ -datasette/static/codemirror-* linguist-vendored +datasette/_version.py export-subst diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml deleted file mode 100644 index f0bcdbe0..00000000 --- a/.github/FUNDING.yml +++ /dev/null @@ -1 +0,0 @@ -github: [simonw] diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index 88bb03b1..00000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,11 +0,0 @@ -version: 2 -updates: -- package-ecosystem: pip - directory: "/" - schedule: - interval: daily - time: "13:00" - groups: - python-packages: - patterns: - - "*" diff --git a/.github/workflows/deploy-latest.yml b/.github/workflows/deploy-latest.yml deleted file mode 100644 index b0640ae8..00000000 --- a/.github/workflows/deploy-latest.yml +++ /dev/null @@ -1,132 +0,0 @@ -name: Deploy latest.datasette.io - -on: - workflow_dispatch: - push: - branches: - - main - # - 1.0-dev - -permissions: - contents: read - -jobs: - deploy: - runs-on: ubuntu-latest - steps: - - name: Check out datasette - uses: actions/checkout@v6 - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: "3.13" - cache: pip - - name: Install Python dependencies - run: | - python -m pip install --upgrade pip - python -m pip install . --group dev - python -m pip install sphinx-to-sqlite==0.1a1 - - name: Run tests - if: ${{ github.ref == 'refs/heads/main' }} - run: | - pytest -n auto -m "not serial" - pytest -m "serial" - - name: Build fixtures.db and other files needed to deploy the demo - run: |- - python tests/fixtures.py \ - fixtures.db \ - fixtures-config.json \ - fixtures-metadata.json \ - plugins \ - --extra-db-filename extra_database.db - - name: Build docs.db - if: ${{ github.ref == 'refs/heads/main' }} - run: |- - cd docs - DISABLE_SPHINX_INLINE_TABS=1 sphinx-build -b xml . _build - sphinx-to-sqlite ../docs.db _build - cd .. - - name: Set up the alternate-route demo - run: | - echo ' - from datasette import hookimpl - - @hookimpl - def startup(datasette): - db = datasette.get_database("fixtures2") - db.route = "alternative-route" - ' > plugins/alternative_route.py - cp fixtures.db fixtures2.db - - name: And the counters writable stored query demo - run: | - cat > plugins/counters.py < metadata.json - # cat metadata.json - - id: auth - name: Authenticate to Google Cloud - uses: google-github-actions/auth@v3 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - - name: Set up Cloud SDK - uses: google-github-actions/setup-gcloud@v3 - - name: Deploy to Cloud Run - env: - LATEST_DATASETTE_SECRET: ${{ secrets.LATEST_DATASETTE_SECRET }} - run: |- - gcloud config set run/region us-central1 - gcloud config set project datasette-222320 - export SUFFIX="-${GITHUB_REF#refs/heads/}" - export SUFFIX=${SUFFIX#-main} - # Replace 1.0 with one-dot-zero in SUFFIX - export SUFFIX=${SUFFIX//1.0/one-dot-zero} - datasette publish cloudrun fixtures.db fixtures2.db extra_database.db \ - -m fixtures-metadata.json \ - --plugins-dir=plugins \ - --branch=$GITHUB_SHA \ - --version-note=$GITHUB_SHA \ - --extra-options="--setting template_debug 1 --setting trace_debug 1 --crossdb --root" \ - --install 'datasette-ephemeral-tables>=0.2.2' \ - --service "datasette-latest$SUFFIX" \ - --secret $LATEST_DATASETTE_SECRET - - name: Deploy to docs as well (only for main) - if: ${{ github.ref == 'refs/heads/main' }} - run: |- - # Deploy docs.db to a different service - datasette publish cloudrun docs.db \ - --branch=$GITHUB_SHA \ - --version-note=$GITHUB_SHA \ - --extra-options="--setting template_debug 1" \ - --service=datasette-docs-latest diff --git a/.github/workflows/documentation-links.yml b/.github/workflows/documentation-links.yml deleted file mode 100644 index b8fb8aaa..00000000 --- a/.github/workflows/documentation-links.yml +++ /dev/null @@ -1,16 +0,0 @@ -name: Read the Docs Pull Request Preview -on: - pull_request: - types: - - opened - -permissions: - pull-requests: write - -jobs: - documentation-links: - runs-on: ubuntu-latest - steps: - - uses: readthedocs/actions/preview@v1 - with: - project-slug: "datasette" diff --git a/.github/workflows/prettier.yml b/.github/workflows/prettier.yml deleted file mode 100644 index 735e14e9..00000000 --- a/.github/workflows/prettier.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: Check JavaScript for conformance with Prettier - -on: [push] - -permissions: - contents: read - -jobs: - prettier: - runs-on: ubuntu-latest - steps: - - name: Check out repo - uses: actions/checkout@v6 - - uses: actions/cache@v5 - name: Configure npm caching - with: - path: ~/.npm - key: ${{ runner.OS }}-npm-${{ hashFiles('**/package-lock.json') }} - restore-keys: | - ${{ runner.OS }}-npm- - - name: Install dependencies - run: npm ci - - name: Run prettier - run: |- - npm run prettier -- --check diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index 87300593..00000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,109 +0,0 @@ -name: Publish Python Package - -on: - release: - types: [created] - -permissions: - contents: read - -jobs: - test: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] - steps: - - uses: actions/checkout@v6 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v6 - with: - python-version: ${{ matrix.python-version }} - cache: pip - cache-dependency-path: pyproject.toml - - name: Install dependencies - run: | - pip install . --group dev - - name: Run tests - run: | - pytest - - deploy: - runs-on: ubuntu-latest - needs: [test] - environment: release - permissions: - id-token: write - steps: - - uses: actions/checkout@v6 - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: '3.13' - cache: pip - cache-dependency-path: pyproject.toml - - name: Install dependencies - run: | - pip install setuptools wheel build - - name: Build - run: | - python -m build - - name: Publish - uses: pypa/gh-action-pypi-publish@release/v1 - - deploy_static_docs: - runs-on: ubuntu-latest - needs: [deploy] - if: "!github.event.release.prerelease" - steps: - - uses: actions/checkout@v6 - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: '3.10' - cache: pip - cache-dependency-path: pyproject.toml - - name: Install dependencies - run: | - python -m pip install . --group dev - python -m pip install sphinx-to-sqlite==0.1a1 - - name: Build docs.db - run: |- - cd docs - DISABLE_SPHINX_INLINE_TABS=1 sphinx-build -b xml . _build - sphinx-to-sqlite ../docs.db _build - cd .. - - id: auth - name: Authenticate to Google Cloud - uses: google-github-actions/auth@v2 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - - name: Set up Cloud SDK - uses: google-github-actions/setup-gcloud@v3 - - name: Deploy stable-docs.datasette.io to Cloud Run - run: |- - gcloud config set run/region us-central1 - gcloud config set project datasette-222320 - datasette publish cloudrun docs.db \ - --service=datasette-docs-stable - - deploy_docker: - runs-on: ubuntu-latest - needs: [deploy] - if: "!github.event.release.prerelease" - steps: - - uses: actions/checkout@v6 - - name: Build and push to Docker Hub - env: - DOCKER_USER: ${{ secrets.DOCKER_USER }} - DOCKER_PASS: ${{ secrets.DOCKER_PASS }} - run: |- - sleep 60 # Give PyPI time to make the new release available - docker login -u $DOCKER_USER -p $DOCKER_PASS - export REPO=datasetteproject/datasette - docker build -f Dockerfile \ - -t $REPO:${GITHUB_REF#refs/tags/} \ - --build-arg VERSION=${GITHUB_REF#refs/tags/} . - docker tag $REPO:${GITHUB_REF#refs/tags/} $REPO:latest - docker push $REPO:${GITHUB_REF#refs/tags/} - docker push $REPO:latest diff --git a/.github/workflows/push_docker_tag.yml b/.github/workflows/push_docker_tag.yml deleted file mode 100644 index e622ef4c..00000000 --- a/.github/workflows/push_docker_tag.yml +++ /dev/null @@ -1,28 +0,0 @@ -name: Push specific Docker tag - -on: - workflow_dispatch: - inputs: - version_tag: - description: Tag to build and push - -permissions: - contents: read - -jobs: - deploy_docker: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - - name: Build and push to Docker Hub - env: - DOCKER_USER: ${{ secrets.DOCKER_USER }} - DOCKER_PASS: ${{ secrets.DOCKER_PASS }} - VERSION_TAG: ${{ github.event.inputs.version_tag }} - run: |- - docker login -u $DOCKER_USER -p $DOCKER_PASS - export REPO=datasetteproject/datasette - docker build -f Dockerfile \ - -t $REPO:${VERSION_TAG} \ - --build-arg VERSION=${VERSION_TAG} . - docker push $REPO:${VERSION_TAG} diff --git a/.github/workflows/spellcheck.yml b/.github/workflows/spellcheck.yml deleted file mode 100644 index 9a808194..00000000 --- a/.github/workflows/spellcheck.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: Check spelling in documentation - -on: [push, pull_request] - -permissions: - contents: read - -jobs: - spellcheck: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: '3.11' - cache: 'pip' - cache-dependency-path: '**/pyproject.toml' - - name: Install dependencies - run: | - pip install . --group dev - - name: Check spelling - run: | - codespell README.md --ignore-words docs/codespell-ignore-words.txt - codespell docs/*.rst --ignore-words docs/codespell-ignore-words.txt - codespell datasette -S datasette/static --ignore-words docs/codespell-ignore-words.txt - codespell tests --ignore-words docs/codespell-ignore-words.txt diff --git a/.github/workflows/stable-docs.yml b/.github/workflows/stable-docs.yml deleted file mode 100644 index 59b5fbc0..00000000 --- a/.github/workflows/stable-docs.yml +++ /dev/null @@ -1,76 +0,0 @@ -name: Update Stable Docs - -on: - release: - types: [published] - push: - branches: - - main - -permissions: - contents: write - -jobs: - update_stable_docs: - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v6 - with: - fetch-depth: 0 # We need all commits to find docs/ changes - - name: Set up Git user - run: | - git config user.name "Automated" - git config user.email "actions@users.noreply.github.com" - - name: Create stable branch if it does not yet exist - run: | - if ! git ls-remote --heads origin stable | grep -qE '\bstable\b'; then - # Make sure we have all tags locally - git fetch --tags --quiet - - # Latest tag that is just numbers and dots (optionally prefixed with 'v') - # e.g., 0.65.2 or v0.65.2 — excludes 1.0a20, 1.0-rc1, etc. - LATEST_RELEASE=$( - git tag -l --sort=-v:refname \ - | grep -E '^v?[0-9]+(\.[0-9]+){1,3}$' \ - | head -n1 - ) - - git checkout -b stable - - # If there are any stable releases, copy docs/ from the most recent - if [ -n "$LATEST_RELEASE" ]; then - rm -rf docs/ - git checkout "$LATEST_RELEASE" -- docs/ || true - fi - - git commit -m "Populate docs/ from $LATEST_RELEASE" || echo "No changes" - git push -u origin stable - fi - - name: Handle Release - if: github.event_name == 'release' && !github.event.release.prerelease - run: | - git fetch --all - git checkout stable - git reset --hard ${GITHUB_REF#refs/tags/} - git push origin stable --force - - name: Handle Commit to Main - if: contains(github.event.head_commit.message, '!stable-docs') - run: | - git fetch origin - git checkout -b stable origin/stable - # Get the list of modified files in docs/ from the current commit - FILES=$(git diff-tree --no-commit-id --name-only -r ${{ github.sha }} -- docs/) - # Check if the list of files is non-empty - if [[ -n "$FILES" ]]; then - # Checkout those files to the stable branch to over-write with their contents - for FILE in $FILES; do - git checkout ${{ github.sha }} -- $FILE - done - git add docs/ - git commit -m "Doc changes from ${{ github.sha }}" - git push origin stable - else - echo "No changes to docs/ in this commit." - exit 0 - fi diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml deleted file mode 100644 index c514048e..00000000 --- a/.github/workflows/test-coverage.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Calculate test coverage - -on: - push: - branches: - - main - pull_request: - branches: - - main -permissions: - contents: read - -jobs: - test: - runs-on: ubuntu-latest - steps: - - name: Check out datasette - uses: actions/checkout@v6 - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: '3.12' - cache: 'pip' - cache-dependency-path: '**/pyproject.toml' - - name: Install Python dependencies - run: | - python -m pip install --upgrade pip - python -m pip install . --group dev - python -m pip install pytest-cov - - name: Run tests - run: |- - ls -lah - cat .coveragerc - pytest -m "not serial" --cov=datasette --cov-config=.coveragerc --cov-report xml:coverage.xml --cov-report term -x - ls -lah - - name: Upload coverage report - uses: codecov/codecov-action@v1 - with: - token: ${{ secrets.CODECOV_TOKEN }} - file: coverage.xml diff --git a/.github/workflows/test-pyodide.yml b/.github/workflows/test-pyodide.yml deleted file mode 100644 index 5162c47a..00000000 --- a/.github/workflows/test-pyodide.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: Test in Pyodide with shot-scraper - -on: - push: - pull_request: - workflow_dispatch: - -permissions: - contents: read - -jobs: - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - - name: Set up Python 3.10 - uses: actions/setup-python@v6 - with: - python-version: "3.10" - cache: 'pip' - cache-dependency-path: '**/pyproject.toml' - - name: Cache Playwright browsers - uses: actions/cache@v5 - with: - path: ~/.cache/ms-playwright/ - key: ${{ runner.os }}-browsers - - name: Install Playwright dependencies - run: | - pip install shot-scraper build - shot-scraper install - - name: Run test - run: | - ./test-in-pyodide-with-shot-scraper.sh diff --git a/.github/workflows/test-sqlite-support.yml b/.github/workflows/test-sqlite-support.yml deleted file mode 100644 index 23fce459..00000000 --- a/.github/workflows/test-sqlite-support.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: Test SQLite versions - -on: [push, pull_request] - -permissions: - contents: read - -jobs: - test: - runs-on: ${{ matrix.platform }} - continue-on-error: true - strategy: - matrix: - platform: [ubuntu-latest] - python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] - sqlite-version: [ - #"3", # latest version - "3.46", - #"3.45", - #"3.27", - #"3.26", - "3.25", - #"3.25.3", # 2018-09-25, window functions breaks test_upsert for some reason on 3.10, skip for now - #"3.24", # 2018-06-04, added UPSERT support - #"3.23.1" # 2018-04-10, before UPSERT - ] - steps: - - uses: actions/checkout@v6 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v6 - with: - python-version: ${{ matrix.python-version }} - allow-prereleases: true - cache: pip - cache-dependency-path: pyproject.toml - - name: Set up SQLite ${{ matrix.sqlite-version }} - uses: asg017/sqlite-versions@71ea0de37ae739c33e447af91ba71dda8fcf22e6 - with: - version: ${{ matrix.sqlite-version }} - cflags: "-DSQLITE_ENABLE_DESERIALIZE -DSQLITE_ENABLE_FTS5 -DSQLITE_ENABLE_FTS4 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_RTREE -DSQLITE_ENABLE_JSON1" - - run: python3 -c "import sqlite3; print(sqlite3.sqlite_version)" - - run: echo $LD_LIBRARY_PATH - - name: Build extension for --load-extension test - run: |- - (cd tests && gcc ext.c -fPIC -shared -o ext.so) - - name: Install dependencies - run: | - pip install . --group dev - pip freeze - - name: Run tests - run: | - pytest -n auto -m "not serial" - pytest -m "serial" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml deleted file mode 100644 index a1b2e9d2..00000000 --- a/.github/workflows/test.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: Test - -on: [push, pull_request] - -permissions: - contents: read - -jobs: - test: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] - steps: - - uses: actions/checkout@v6 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v6 - with: - python-version: ${{ matrix.python-version }} - allow-prereleases: true - cache: pip - cache-dependency-path: pyproject.toml - - name: Build extension for --load-extension test - run: |- - (cd tests && gcc ext.c -fPIC -shared -o ext.so) - - name: Install dependencies - run: | - pip install . --group dev - pip freeze - - name: Run tests - run: | - pytest -n auto -m "not serial" - pytest -m "serial" - # And the test that exceeds a localhost HTTPS server - tests/test_datasette_https_server.sh - - name: Black - run: | - black --version - black --check . - - name: Ruff - run: ruff check datasette tests - - name: Check if cog needs to be run - run: | - cog --check docs/*.rst - - name: Check if blacken-docs needs to be run - run: | - # This fails on syntax errors, or a diff was applied - blacken-docs -l 60 docs/*.rst - - name: Test DATASETTE_LOAD_PLUGINS - run: | - pip install datasette-init datasette-json-html - tests/test-datasette-load-plugins.sh diff --git a/.github/workflows/tmate-mac.yml b/.github/workflows/tmate-mac.yml deleted file mode 100644 index a033cd92..00000000 --- a/.github/workflows/tmate-mac.yml +++ /dev/null @@ -1,15 +0,0 @@ -name: tmate session mac - -on: - workflow_dispatch: - -permissions: - contents: read - -jobs: - build: - runs-on: macos-latest - steps: - - uses: actions/checkout@v6 - - name: Setup tmate session - uses: mxschmitt/action-tmate@v3 diff --git a/.github/workflows/tmate.yml b/.github/workflows/tmate.yml deleted file mode 100644 index 72af1eec..00000000 --- a/.github/workflows/tmate.yml +++ /dev/null @@ -1,18 +0,0 @@ -name: tmate session - -on: - workflow_dispatch: - -permissions: - contents: read - models: read - -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - - name: Setup tmate session - uses: mxschmitt/action-tmate@v3 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 12acd87e..cf8578ee 100644 --- a/.gitignore +++ b/.gitignore @@ -3,20 +3,9 @@ datasets.json scratchpad -.vscode - -uv.lock -data.db - -# test databases +# SQLite databases *.db - -# We don't use Pipfile, so ignore them -Pipfile -Pipfile.lock - -fixtures.db -*test.db +*.sqlite # Byte-compiled / optimized / DLL files __pycache__/ @@ -122,13 +111,3 @@ ENV/ # macOS files .DS_Store -node_modules -.*.swp - -# In case someone compiled tests/ext.c for test_load_extensions, don't -# include it in source control. -tests/*.dylib -tests/*.so -tests/*.dll - -.idea \ No newline at end of file diff --git a/.prettierrc b/.prettierrc deleted file mode 100644 index 222861c3..00000000 --- a/.prettierrc +++ /dev/null @@ -1,4 +0,0 @@ -{ - "tabWidth": 2, - "useTabs": false -} diff --git a/.readthedocs.yaml b/.readthedocs.yaml deleted file mode 100644 index 8b3e54aa..00000000 --- a/.readthedocs.yaml +++ /dev/null @@ -1,17 +0,0 @@ -version: 2 - -sphinx: - configuration: docs/conf.py - -build: - os: ubuntu-24.04 - tools: - python: "3.13" - jobs: - install: - - pip install --upgrade pip - - pip install . --group dev - -formats: -- pdf -- epub diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..fc01da26 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,8 @@ +language: python + +python: + - 3.5 + - 3.6 + +script: + - python setup.py test diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md deleted file mode 100644 index 14d4c567..00000000 --- a/CODE_OF_CONDUCT.md +++ /dev/null @@ -1,128 +0,0 @@ -# Contributor Covenant Code of Conduct - -## Our Pledge - -We as members, contributors, and leaders pledge to make participation in our -community a harassment-free experience for everyone, regardless of age, body -size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, -nationality, personal appearance, race, religion, or sexual identity -and orientation. - -We pledge to act and interact in ways that contribute to an open, welcoming, -diverse, inclusive, and healthy community. - -## Our Standards - -Examples of behavior that contributes to a positive environment for our -community include: - -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our mistakes, - and learning from the experience -* Focusing on what is best not just for us as individuals, but for the - overall community - -Examples of unacceptable behavior include: - -* The use of sexualized language or imagery, and sexual attention or - advances of any kind -* Trolling, insulting or derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email - address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting - -## Enforcement Responsibilities - -Community leaders are responsible for clarifying and enforcing our standards of -acceptable behavior and will take appropriate and fair corrective action in -response to any behavior that they deem inappropriate, threatening, offensive, -or harmful. - -Community leaders have the right and responsibility to remove, edit, or reject -comments, commits, code, wiki edits, issues, and other contributions that are -not aligned to this Code of Conduct, and will communicate reasons for moderation -decisions when appropriate. - -## Scope - -This Code of Conduct applies within all community spaces, and also applies when -an individual is officially representing the community in public spaces. -Examples of representing our community include using an official e-mail address, -posting via an official social media account, or acting as an appointed -representative at an online or offline event. - -## Enforcement - -Instances of abusive, harassing, or otherwise unacceptable behavior may be -reported to the community leaders responsible for enforcement at -`swillison+datasette-code-of-conduct@gmail.com`. -All complaints will be reviewed and investigated promptly and fairly. - -All community leaders are obligated to respect the privacy and security of the -reporter of any incident. - -## Enforcement Guidelines - -Community leaders will follow these Community Impact Guidelines in determining -the consequences for any action they deem in violation of this Code of Conduct: - -### 1. Correction - -**Community Impact**: Use of inappropriate language or other behavior deemed -unprofessional or unwelcome in the community. - -**Consequence**: A private, written warning from community leaders, providing -clarity around the nature of the violation and an explanation of why the -behavior was inappropriate. A public apology may be requested. - -### 2. Warning - -**Community Impact**: A violation through a single incident or series -of actions. - -**Consequence**: A warning with consequences for continued behavior. No -interaction with the people involved, including unsolicited interaction with -those enforcing the Code of Conduct, for a specified period of time. This -includes avoiding interactions in community spaces as well as external channels -like social media. Violating these terms may lead to a temporary or -permanent ban. - -### 3. Temporary Ban - -**Community Impact**: A serious violation of community standards, including -sustained inappropriate behavior. - -**Consequence**: A temporary ban from any sort of interaction or public -communication with the community for a specified period of time. No public or -private interaction with the people involved, including unsolicited interaction -with those enforcing the Code of Conduct, is allowed during this period. -Violating these terms may lead to a permanent ban. - -### 4. Permanent Ban - -**Community Impact**: Demonstrating a pattern of violation of community -standards, including sustained inappropriate behavior, harassment of an -individual, or aggression toward or disparagement of classes of individuals. - -**Consequence**: A permanent ban from any sort of public interaction within -the community. - -## Attribution - -This Code of Conduct is adapted from the [Contributor Covenant][homepage], -version 2.0, available at -https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. - -Community Impact Guidelines were inspired by [Mozilla's code of conduct -enforcement ladder](https://github.com/mozilla/diversity). - -[homepage]: https://www.contributor-covenant.org - -For answers to common questions about this code of conduct, see the FAQ at -https://www.contributor-covenant.org/faq. Translations are available at -https://www.contributor-covenant.org/translations. diff --git a/Dockerfile b/Dockerfile index 9a8f06cf..cb3d6621 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,18 +1,42 @@ -FROM python:3.11.0-slim-bullseye as build +FROM python:3.6-slim-stretch as build -# Version of Datasette to install, e.g. 0.55 -# docker build . -t datasette --build-arg VERSION=0.55 -ARG VERSION +# Setup build dependencies +RUN apt update \ +&& apt install -y python3-dev build-essential wget libxml2-dev libproj-dev libgeos-dev libsqlite3-dev zlib1g-dev pkg-config \ + && apt clean -RUN apt-get update && \ - apt-get install -y --no-install-recommends libsqlite3-mod-spatialite && \ - apt clean && \ - rm -rf /var/lib/apt && \ - rm -rf /var/lib/dpkg/info/* -RUN pip install https://github.com/simonw/datasette/archive/refs/tags/${VERSION}.zip && \ - find /usr/local/lib -name '__pycache__' | xargs rm -r && \ - rm -rf /root/.cache/pip +RUN wget "https://www.sqlite.org/2018/sqlite-autoconf-3230100.tar.gz" && tar xzf sqlite-autoconf-3230100.tar.gz \ + && cd sqlite-autoconf-3230100 && ./configure --disable-static --enable-fts5 --enable-json1 CFLAGS="-g -O2 -DSQLITE_ENABLE_FTS3=1 -DSQLITE_ENABLE_FTS4=1 -DSQLITE_ENABLE_RTREE=1 -DSQLITE_ENABLE_JSON1" \ + && make && make install + +RUN wget "https://www.gaia-gis.it/gaia-sins/freexl-1.0.5.tar.gz" && tar zxf freexl-1.0.5.tar.gz \ + && cd freexl-1.0.5 && ./configure && make && make install + +RUN wget "https://www.gaia-gis.it/gaia-sins/libspatialite-4.4.0-RC0.tar.gz" && tar zxf libspatialite-4.4.0-RC0.tar.gz \ + && cd libspatialite-4.4.0-RC0 && ./configure && make && make install + +RUN wget "https://www.gaia-gis.it/gaia-sins/readosm-1.1.0.tar.gz" && tar zxf readosm-1.1.0.tar.gz && cd readosm-1.1.0 && ./configure && make && make install + +RUN wget "https://www.gaia-gis.it/gaia-sins/spatialite-tools-4.4.0-RC0.tar.gz" && tar zxf spatialite-tools-4.4.0-RC0.tar.gz \ + && cd spatialite-tools-4.4.0-RC0 && ./configure && make && make install + + +# Add local code to the image instead of fetching from pypi. +COPY . /datasette + +RUN pip install /datasette + +FROM python:3.6-slim-stretch + +# Copy python dependencies and spatialite libraries +COPY --from=build /usr/local/lib/ /usr/local/lib/ +# Copy executables +COPY --from=build /usr/local/bin /usr/local/bin +# Copy spatial extensions +COPY --from=build /usr/lib/x86_64-linux-gnu /usr/lib/x86_64-linux-gnu + +ENV LD_LIBRARY_PATH=/usr/local/lib EXPOSE 8001 CMD ["datasette"] diff --git a/Justfile b/Justfile deleted file mode 100644 index 657881be..00000000 --- a/Justfile +++ /dev/null @@ -1,60 +0,0 @@ -export DATASETTE_SECRET := "not_a_secret" - -# Run tests and linters -@default: test lint - -# Setup project -@init: - uv sync - -# Run pytest with supplied options -@test *options: init - uv run pytest -n auto {{options}} - -@codespell: - uv run codespell README.md --ignore-words docs/codespell-ignore-words.txt - uv run codespell docs/*.rst --ignore-words docs/codespell-ignore-words.txt - uv run codespell datasette -S datasette/static --ignore-words docs/codespell-ignore-words.txt - uv run codespell tests --ignore-words docs/codespell-ignore-words.txt - -# Run linters: black, ruff, cog -@lint: codespell - uv run black datasette tests --check - uv run ruff check datasette tests - uv run cog --check README.md docs/*.rst - -# Apply ruff fixes -@fix: - uv run ruff check --fix datasette tests - -# Rebuild docs with cog -@cog: - uv run cog -r README.md docs/*.rst - -# Serve live docs on localhost:8000 -@docs: cog blacken-docs - uv run make -C docs livehtml - -# Build docs as static HTML -@docs-build: cog blacken-docs - rm -rf docs/_build && cd docs && uv run make html - -# Apply Black -@black: - uv run black datasette tests - -# Apply blacken-docs -@blacken-docs: - uv run blacken-docs -l 60 docs/*.rst - -# Apply prettier -@prettier: - npm run fix - -# Format code with both black and prettier -@format: black prettier blacken-docs - -@serve *options: - uv run sqlite-utils create-database data.db - uv run sqlite-utils create-table data.db docs id integer title text --pk id --ignore - uv run python -m datasette data.db --root --reload {{options}} diff --git a/MANIFEST.in b/MANIFEST.in index 8c5e3ee6..cca501c9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,3 @@ recursive-include datasette/static * -recursive-include datasette/templates * include versioneer.py include datasette/_version.py -include LICENSE diff --git a/README.md b/README.md index 393e8e5c..e7388457 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,43 @@ -Datasette +# Datasette [![PyPI](https://img.shields.io/pypi/v/datasette.svg)](https://pypi.org/project/datasette/) -[![Changelog](https://img.shields.io/github/v/release/simonw/datasette?label=changelog)](https://docs.datasette.io/en/latest/changelog.html) -[![Python 3.x](https://img.shields.io/pypi/pyversions/datasette.svg?logo=python&logoColor=white)](https://pypi.org/project/datasette/) -[![Tests](https://github.com/simonw/datasette/workflows/Test/badge.svg)](https://github.com/simonw/datasette/actions?query=workflow%3ATest) -[![Documentation Status](https://readthedocs.org/projects/datasette/badge/?version=latest)](https://docs.datasette.io/en/latest/?badge=latest) -[![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](https://github.com/simonw/datasette/blob/main/LICENSE) -[![docker: datasette](https://img.shields.io/badge/docker-datasette-blue)](https://hub.docker.com/r/datasetteproject/datasette) -[![discord](https://img.shields.io/discord/823971286308356157?label=discord)](https://datasette.io/discord) +[![Travis CI](https://travis-ci.org/simonw/datasette.svg?branch=master)](https://travis-ci.org/simonw/datasette) +[![Documentation Status](https://readthedocs.org/projects/datasette/badge/?version=latest)](http://datasette.readthedocs.io/en/latest/?badge=latest) +[![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](https://github.com/simonw/datasette/blob/master/LICENSE) -*An open source multi-tool for exploring and publishing data* +*An instant JSON API for your SQLite databases* -Datasette is a tool for exploring and publishing data. It helps people take data of any shape or size and publish that as an interactive, explorable website and accompanying API. +Datasette provides an instant, read-only JSON API for any SQLite database. It also provides tools for packaging the database up as a Docker container and deploying that container to hosting providers such as [Zeit Now](https://zeit.co/now). -Datasette is aimed at data journalists, museum curators, archivists, local governments, scientists, researchers and anyone else who has data that they wish to share with the world. +Got CSV data? Use [csvs-to-sqlite](https://github.com/simonw/csvs-to-sqlite) to convert them to SQLite, then publish them with Datasette. Or try [Datasette Publish](https://publish.datasettes.com), a web app that lets you upload CSV data and deploy it using Datasette without needing to install any software. -[Explore a demo](https://datasette.io/global-power-plants/global-power-plants), watch [a video about the project](https://simonwillison.net/2021/Feb/7/video/) or try it out [on GitHub Codespaces](https://github.com/datasette/datasette-studio). +Documentation: http://datasette.readthedocs.io/ Examples: https://github.com/simonw/datasette/wiki/Datasettes -* [datasette.io](https://datasette.io/) is the official project website -* Latest [Datasette News](https://datasette.io/news) -* Comprehensive documentation: https://docs.datasette.io/ -* Examples: https://datasette.io/examples -* Live demo of current `main` branch: https://latest.datasette.io/ -* Questions, feedback or want to talk about the project? Join our [Discord](https://datasette.io/discord) +## News -Want to stay up-to-date with the project? Subscribe to the [Datasette newsletter](https://datasette.substack.com/) for tips, tricks and news on what's new in the Datasette ecosystem. +* 23rd May 2018: [Datasette 0.22.1 bugfix](https://github.com/simonw/datasette/releases/tag/0.22.1) plus we now use [versioneer](https://github.com/warner/python-versioneer) +* 20th May 2018: [Datasette 0.22: Datasette Facets](https://simonwillison.net/2018/May/20/datasette-facets) +* 5th May 2018: [Datasette 0.21: New _shape=, new _size=, search within columns](https://github.com/simonw/datasette/releases/tag/0.21) +* 25th April 2018: [Exploring the UK Register of Members Interests with SQL and Datasette](https://simonwillison.net/2018/Apr/25/register-members-interests/) - a tutorial describing how [register-of-members-interests.datasettes.com](https://register-of-members-interests.datasettes.com/) was built ([source code here](https://github.com/simonw/register-of-members-interests)) +* 20th April 2018: [Datasette plugins, and building a clustered map visualization](https://simonwillison.net/2018/Apr/20/datasette-plugins/) - introducing Datasette's new plugin system and [datasette-cluster-map](https://pypi.org/project/datasette-cluster-map/), a plugin for visualizing data on a map +* 20th April 2018: [Datasette 0.20: static assets and templates for plugins](https://github.com/simonw/datasette/releases/tag/0.20) +* 16th April 2018: [Datasette 0.19: plugins preview](https://github.com/simonw/datasette/releases/tag/0.19) +* 14th April 2018: [Datasette 0.18: units](https://github.com/simonw/datasette/releases/tag/0.18) +* 9th April 2018: [Datasette 0.15: sort by column](https://github.com/simonw/datasette/releases/tag/0.15) +* 28th March 2018: [Baltimore Sun Public Salary Records](https://simonwillison.net/2018/Mar/28/datasette-in-the-wild/) - a data journalism project from the Baltimore Sun powered by Datasette - source code [is available here](https://github.com/baltimore-sun-data/salaries-datasette) +* 27th March 2018: [Cloud-first: Rapid webapp deployment using containers](https://wwwf.imperial.ac.uk/blog/research-software-engineering/2018/03/27/cloud-first-rapid-webapp-deployment-using-containers/) - a tutorial covering deploying Datasette using Microsoft Azure by the Research Software Engineering team at Imperial College London +* 28th January 2018: [Analyzing my Twitter followers with Datasette](https://simonwillison.net/2018/Jan/28/analyzing-my-twitter-followers/) - a tutorial on using Datasette to analyze follower data pulled from the Twitter API +* 17th January 2018: [Datasette Publish: a web app for publishing CSV files as an online database](https://simonwillison.net/2018/Jan/17/datasette-publish/) +* 12th December 2017: [Building a location to time zone API with SpatiaLite, OpenStreetMap and Datasette](https://simonwillison.net/2017/Dec/12/building-a-location-time-zone-api/) +* 9th December 2017: [Datasette 0.14: customization edition](https://github.com/simonw/datasette/releases/tag/0.14) +* 25th November 2017: [New in Datasette: filters, foreign keys and search](https://simonwillison.net/2017/Nov/25/new-in-datasette/) +* 13th November 2017: [Datasette: instantly create and publish an API for your SQLite databases](https://simonwillison.net/2017/Nov/13/datasette/) ## Installation -If you are on a Mac, [Homebrew](https://brew.sh/) is the easiest way to install Datasette: + pip3 install datasette - brew install datasette - -You can also install it using `pip` or `pipx`: - - pip install datasette - -Datasette requires Python 3.8 or higher. We also have [detailed installation instructions](https://docs.datasette.io/en/stable/installation.html) covering other options such as Docker. +Datasette requires Python 3.5 or higher. ## Basic usage @@ -48,12 +49,86 @@ This will start a web server on port 8001 - visit http://localhost:8001/ to acce Use Chrome on OS X? You can run datasette against your browser history like so: - datasette ~/Library/Application\ Support/Google/Chrome/Default/History --nolock + datasette ~/Library/Application\ Support/Google/Chrome/Default/History Now visiting http://localhost:8001/History/downloads will show you a web interface to browse your downloads data: ![Downloads table rendered by datasette](https://static.simonwillison.net/static/2017/datasette-downloads.png) +http://localhost:8001/History/downloads.json will return that data as JSON: + + { + "database": "History", + "columns": [ + "id", + "current_path", + "target_path", + "start_time", + "received_bytes", + "total_bytes", + ... + ], + "table_rows_count": 576, + "rows": [ + [ + 1, + "/Users/simonw/Downloads/DropboxInstaller.dmg", + "/Users/simonw/Downloads/DropboxInstaller.dmg", + 13097290269022132, + 626688, + 0, + ... + ] + ] + } + + +http://localhost:8001/History/downloads.json?_shape=objects will return that data as JSON in a more convenient but less efficient format: + + { + ... + "rows": [ + { + "start_time": 13097290269022132, + "interrupt_reason": 0, + "hash": "", + "id": 1, + "site_url": "", + "referrer": "https://www.dropbox.com/downloading?src=index", + ... + } + ] + } + +## datasette serve options + + $ datasette serve --help + Usage: datasette serve [OPTIONS] [FILES]... + + Serve up specified SQLite database files with a web UI + + Options: + -h, --host TEXT host for server, defaults to 127.0.0.1 + -p, --port INTEGER port for server, defaults to 8001 + --debug Enable debug mode - useful for development + --reload Automatically reload if code change detected - + useful for development + --cors Enable CORS by serving Access-Control-Allow- + Origin: * + --load-extension PATH Path to a SQLite extension to load + --inspect-file TEXT Path to JSON file created using "datasette + inspect" + -m, --metadata FILENAME Path to JSON file containing license/source + metadata + --template-dir DIRECTORY Path to directory containing custom templates + --plugins-dir DIRECTORY Path to directory containing custom plugins + --static STATIC MOUNT mountpoint:path-to-directory for serving static + files + --config CONFIG Set config option using configname:value + datasette.readthedocs.io/en/latest/config.html + --help-config Show available config options + --help Show this message and exit. + ## metadata.json If you want to include licensing and source information in the generated datasette website you can do so using a JSON file that looks something like this: @@ -66,26 +141,115 @@ If you want to include licensing and source information in the generated dataset "source_url": "https://github.com/fivethirtyeight/data" } -Save this in `metadata.json` and run Datasette like so: - - datasette serve fivethirtyeight.db -m metadata.json - The license and source information will be displayed on the index page and in the footer. They will also be included in the JSON produced by the API. ## datasette publish -If you have [Heroku](https://heroku.com/) or [Google Cloud Run](https://cloud.google.com/run/) configured, Datasette can deploy one or more SQLite databases to the internet with a single command: +If you have [Zeit Now](https://zeit.co/now) or [Heroku](https://heroku.com/) configured, datasette can deploy one or more SQLite databases to the internet with a single command: - datasette publish heroku database.db + datasette publish now database.db Or: - datasette publish cloudrun database.db + datasette publish heroku database.db -This will create a docker image containing both the datasette application and the specified SQLite database files. It will then deploy that image to Heroku or Cloud Run and give you a URL to access the resulting website and API. +This will create a docker image containing both the datasette application and the specified SQLite database files. It will then deploy that image to Zeit Now or Heroku and give you a URL to access the API. -See [Publishing data](https://docs.datasette.io/en/stable/publish.html) in the documentation for more details. + $ datasette publish --help + Usage: datasette publish [OPTIONS] PUBLISHER [FILES]... -## Datasette Lite + Publish specified SQLite database files to the internet along with a + datasette API. -[Datasette Lite](https://lite.datasette.io/) is Datasette packaged using WebAssembly so that it runs entirely in your browser, no Python web application server required. Read more about that in the [Datasette Lite documentation](https://github.com/simonw/datasette-lite/blob/main/README.md). + Options for PUBLISHER: * 'now' - You must have Zeit Now installed: + https://zeit.co/now * 'heroku' - You must have Heroku installed: + https://cli.heroku.com/ + + Example usage: datasette publish now my-database.db + + Options: + -n, --name TEXT Application name to use when deploying to Now + (ignored for Heroku) + -m, --metadata FILENAME Path to JSON file containing metadata to publish + --extra-options TEXT Extra options to pass to datasette serve + --force Pass --force option to now + --branch TEXT Install datasette from a GitHub branch e.g. master + --template-dir DIRECTORY Path to directory containing custom templates + --plugins-dir DIRECTORY Path to directory containing custom plugins + --static STATIC MOUNT mountpoint:path-to-directory for serving static + files + --install TEXT Additional packages (e.g. plugins) to install + --spatialite Enable SpatialLite extension + --title TEXT Title for metadata + --license TEXT License label for metadata + --license_url TEXT License URL for metadata + --source TEXT Source label for metadata + --source_url TEXT Source URL for metadata + --help Show this message and exit. + +## datasette package + +If you have docker installed you can use `datasette package` to create a new Docker image in your local repository containing the datasette app and selected SQLite databases: + + $ datasette package --help + Usage: datasette package [OPTIONS] FILES... + + Package specified SQLite files into a new datasette Docker container + + Options: + -t, --tag TEXT Name for the resulting Docker container, can + optionally use name:tag format + -m, --metadata FILENAME Path to JSON file containing metadata to publish + --extra-options TEXT Extra options to pass to datasette serve + --branch TEXT Install datasette from a GitHub branch e.g. master + --template-dir DIRECTORY Path to directory containing custom templates + --plugins-dir DIRECTORY Path to directory containing custom plugins + --static STATIC MOUNT mountpoint:path-to-directory for serving static + files + --install TEXT Additional packages (e.g. plugins) to install + --spatialite Enable SpatialLite extension + --title TEXT Title for metadata + --license TEXT License label for metadata + --license_url TEXT License URL for metadata + --source TEXT Source label for metadata + --source_url TEXT Source URL for metadata + --help Show this message and exit. + +Both publish and package accept an `extra_options` argument option, which will affect how the resulting application is executed. For example, say you want to increase the SQL time limit for a particular container: + + datasette package parlgov.db \ + --extra-options="--config sql_time_limit_ms:2500 --config default_page_size:10" + +The resulting container will run the application with those options. + +Here's example output for the package command: + + $ datasette package parlgov.db --extra-options="--config sql_time_limit_ms:2500" + Sending build context to Docker daemon 4.459MB + Step 1/7 : FROM python:3 + ---> 79e1dc9af1c1 + Step 2/7 : COPY . /app + ---> Using cache + ---> cd4ec67de656 + Step 3/7 : WORKDIR /app + ---> Using cache + ---> 139699e91621 + Step 4/7 : RUN pip install datasette + ---> Using cache + ---> 340efa82bfd7 + Step 5/7 : RUN datasette inspect parlgov.db --inspect-file inspect-data.json + ---> Using cache + ---> 5fddbe990314 + Step 6/7 : EXPOSE 8001 + ---> Using cache + ---> 8e83844b0fed + Step 7/7 : CMD datasette serve parlgov.db --port 8001 --inspect-file inspect-data.json --config sql_time_limit_ms:2500 + ---> Using cache + ---> 1bd380ea8af3 + Successfully built 1bd380ea8af3 + +You can now run the resulting container like so: + + docker run -p 8081:8001 1bd380ea8af3 + +This exposes port 8001 inside the container as port 8081 on your host machine, so you can access the application at http://localhost:8081/ diff --git a/_config.yml b/_config.yml new file mode 100644 index 00000000..3397c9a4 --- /dev/null +++ b/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-architect \ No newline at end of file diff --git a/codecov.yml b/codecov.yml deleted file mode 100644 index bfdc9877..00000000 --- a/codecov.yml +++ /dev/null @@ -1,8 +0,0 @@ -coverage: - status: - project: - default: - informational: true - patch: - default: - informational: true diff --git a/datasette/__init__.py b/datasette/__init__.py index eb18e59e..1ec88d90 100644 --- a/datasette/__init__.py +++ b/datasette/__init__.py @@ -1,9 +1,3 @@ -from datasette.permissions import Permission # noqa from datasette.version import __version_info__, __version__ # noqa -from datasette.events import Event # noqa -from datasette.tokens import TokenHandler, TokenRestrictions # noqa -from datasette.utils.asgi import Forbidden, NotFound, Request, Response # noqa -from datasette.utils import actor_matches_allow # noqa -from datasette.views import Context # noqa -from .hookspecs import hookimpl # noqa -from .hookspecs import hookspec # noqa +from .hookspecs import hookimpl # noqa +from .hookspecs import hookspec # noqa diff --git a/datasette/__main__.py b/datasette/__main__.py deleted file mode 100644 index 4adef844..00000000 --- a/datasette/__main__.py +++ /dev/null @@ -1,4 +0,0 @@ -from datasette.cli import cli - -if __name__ == "__main__": - cli() diff --git a/datasette/_pytest_plugin.py b/datasette/_pytest_plugin.py deleted file mode 100644 index 5fb6b473..00000000 --- a/datasette/_pytest_plugin.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -Pytest plugin that automatically closes any Datasette instances constructed -during a pytest test — both in the test body and in function-scoped -fixtures. Instances constructed by session-, module-, class- or package- -scoped fixtures are left alone, because other tests in the session will -still want to use them. - -Registered as a pytest11 entry point in pyproject.toml so that downstream -projects using Datasette get the same FD-safety net for their own tests. - -Opt out by setting ``datasette_autoclose = false`` in pytest.ini (or the -equivalent ini file). -""" - -from __future__ import annotations - -import contextvars -import weakref - -import pytest - -from datasette.app import Datasette - -_active_instances: contextvars.ContextVar[list | None] = contextvars.ContextVar( - "datasette_active_instances", default=None -) - -_original_init = Datasette.__init__ - - -def _tracking_init(self, *args, **kwargs): - _original_init(self, *args, **kwargs) - instances = _active_instances.get() - if instances is not None: - instances.append(weakref.ref(self)) - - -Datasette.__init__ = _tracking_init - - -def pytest_addoption(parser): - parser.addini( - "datasette_autoclose", - help=( - "Automatically close Datasette instances created inside test " - "bodies and function-scoped fixtures (default: true)." - ), - default="true", - ) - - -def _enabled(config) -> bool: - value = config.getini("datasette_autoclose") - if isinstance(value, bool): - return value - return str(value).strip().lower() not in ("false", "0", "no", "off") - - -@pytest.hookimpl(hookwrapper=True) -def pytest_runtest_protocol(item, nextitem): - """Track Datasette instances across setup, call and teardown; close at end.""" - if not _enabled(item.config): - yield - return - refs: list[weakref.ref] = [] - token = _active_instances.set(refs) - try: - yield - finally: - _active_instances.reset(token) - for ref in reversed(refs): - ds = ref() - if ds is None: - continue - try: - ds.close() - except Exception as e: - item.warn( - pytest.PytestUnraisableExceptionWarning( - f"Error closing Datasette instance: {e!r}" - ) - ) - - -@pytest.hookimpl(hookwrapper=True) -def pytest_fixture_setup(fixturedef, request): - """Exempt instances created by non-function-scoped fixtures. - - Session-, module-, class- and package-scoped fixtures produce Datasette - instances that must survive beyond the current test — other tests in - the session will still use them. When such a fixture creates one or - more Datasette instances during its setup, we snapshot the tracking - list before the fixture runs and subtract off any instances that were - added during its setup, so they don't get closed at test teardown. - """ - refs = _active_instances.get() - if refs is None: - yield - return - before_ids = {id(ref) for ref in refs} - yield - if fixturedef.scope != "function": - new_refs = [ref for ref in refs if id(ref) not in before_ids] - for new_ref in new_refs: - try: - refs.remove(new_ref) - except ValueError: - pass diff --git a/datasette/_version.py b/datasette/_version.py new file mode 100644 index 00000000..73d6658c --- /dev/null +++ b/datasette/_version.py @@ -0,0 +1,520 @@ + +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.18 (https://github.com/warner/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "$Format:%d$" + git_full = "$Format:%H$" + git_date = "$Format:%ci$" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "" + cfg.parentdir_prefix = "datasette-" + cfg.versionfile_source = "datasette/_version.py" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = p.communicate()[0].strip() + if sys.version_info[0] >= 3: + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, p.returncode + return stdout, p.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for i in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + else: + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + f = open(versionfile_abs, "r") + for line in f.readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + f.close() + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") + if date is not None: + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = set([r for r in refs if re.search(r'\d', r)]) + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], + cwd=root)[0].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for i in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} diff --git a/datasette/actor_auth_cookie.py b/datasette/actor_auth_cookie.py deleted file mode 100644 index 368213af..00000000 --- a/datasette/actor_auth_cookie.py +++ /dev/null @@ -1,23 +0,0 @@ -from datasette import hookimpl -from itsdangerous import BadSignature -from datasette.utils import baseconv -import time - - -@hookimpl -def actor_from_request(datasette, request): - if "ds_actor" not in request.cookies: - return None - try: - decoded = datasette.unsign(request.cookies["ds_actor"], "actor") - # If it has "e" and "a" keys process the "e" expiry - if not isinstance(decoded, dict) or "a" not in decoded: - return None - expires_at = decoded.get("e") - if expires_at: - timestamp = int(baseconv.base62.decode(expires_at)) - if time.time() > timestamp: - return None - return decoded["a"] - except BadSignature: - return None diff --git a/datasette/app.py b/datasette/app.py index 56b89789..e449dd89 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,1988 +1,284 @@ -from __future__ import annotations - import asyncio -import contextvars -from typing import TYPE_CHECKING, Any, Dict, Iterable, List - -if TYPE_CHECKING: - from datasette.permissions import Resource - from datasette.tokens import TokenRestrictions import collections -import dataclasses -import datetime -import functools -import glob import hashlib -import httpx -import importlib.metadata -import inspect -from itsdangerous import BadSignature +import itertools import json import os -import re -import secrets +import sqlite3 import sys import threading -import time -import types +import traceback import urllib.parse from concurrent import futures from pathlib import Path -from markupsafe import Markup, escape -from itsdangerous import URLSafeSerializer -from jinja2 import ( - ChoiceLoader, - Environment, - FileSystemLoader, - PrefixLoader, -) -from jinja2.environment import Template -from jinja2.exceptions import TemplateNotFound +from markupsafe import Markup +import pluggy +from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader +from sanic import Sanic, response +from sanic.request import Request as SanicRequest +from sanic.exceptions import InvalidUsage, NotFound -from .events import Event -from .column_types import SQLiteType -from . import stored_queries -from .views import Context -from .views.database import ( - database_download, - DatabaseView, - TableCreateView, - QueryView, -) -from .views.execute_write import ExecuteWriteAnalyzeView, ExecuteWriteView -from .views.stored_queries import ( - QueryCreateAnalyzeView, - QueryDeleteView, - QueryDefinitionView, - GlobalQueryListView, - QueryListView, - QueryParametersView, - QueryStoreView, - QueryUpdateView, +from .views.base import ( + DatasetteError, + RenderMixin, + ureg ) +from .views.database import DatabaseDownload, DatabaseView from .views.index import IndexView -from .views.special import ( - JsonDataView, - PatternPortfolioView, - AuthTokenView, - ApiExplorerView, - CreateTokenView, - LogoutView, - AllowDebugView, - PermissionsDebugView, - MessagesDebugView, - AllowedResourcesView, - PermissionRulesView, - PermissionCheckView, - JumpView, - InstanceSchemaView, - DatabaseSchemaView, - TableSchemaView, -) -from .views.table import ( - TableInsertView, - TableUpsertView, - TableSetColumnTypeView, - TableDropView, - table_view, -) -from .views.row import RowView, RowDeleteView, RowUpdateView -from .renderer import json_renderer -from .url_builder import Urls -from .database import Database, QueryInterrupted +from .views.special import JsonDataView +from .views.table import RowView, TableView +from . import hookspecs from .utils import ( - PaginatedResources, - PrefixedUrlString, - SPATIALITE_FUNCTIONS, - StartupError, - async_call_with_supported_arguments, - await_me_maybe, - baseconv, - call_with_supported_arguments, - detect_json1, - display_actor, + InterruptedError, + Results, escape_css_string, escape_sqlite, - find_spatialite, - format_bytes, + get_plugins, module_from_path, - move_plugins_and_allow, - move_table_config, - parse_metadata, - resolve_env_secrets, - resolve_routes, - tilde_decode, - tilde_encode, - to_css_class, - urlsafe_components, - redact_keys, - row_sql_params_pks, + sqlite_timelimit, + to_css_class ) -from .utils.asgi import ( - AsgiLifespan, - Forbidden, - NotFound, - DatabaseNotFound, - TableNotFound, - RowNotFound, - Request, - Response, - AsgiRunOnFirstRequest, - asgi_static, - asgi_send, - asgi_send_file, - asgi_send_redirect, -) -from .csrf import CrossOriginProtectionMiddleware -from .utils.internal_db import init_internal_db, populate_schema_tables -from .utils.sqlite import ( - sqlite3, - using_pysqlite3, -) -from .tracer import AsgiTracer -from .plugins import pm, DEFAULT_PLUGINS, get_plugins +from .inspect import inspect_hash, inspect_views, inspect_tables from .version import __version__ -from .resources import DatabaseResource, TableResource - app_root = Path(__file__).parent.parent +connections = threading.local() -# Context variable to track when code is executing within a datasette.client request -_in_datasette_client = contextvars.ContextVar("in_datasette_client", default=False) +pm = pluggy.PluginManager("datasette") +pm.add_hookspecs(hookspecs) +pm.load_setuptools_entrypoints("datasette") -class _DatasetteClientContext: - """Context manager to mark code as executing within a datasette.client request.""" - - def __enter__(self): - self.token = _in_datasette_client.set(True) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - _in_datasette_client.reset(self.token) - return False - - -@dataclasses.dataclass -class PermissionCheck: - """Represents a logged permission check for debugging purposes.""" - - when: str - actor: Dict[str, Any] | None - action: str - parent: str | None - child: str | None - result: bool - - -# https://github.com/simonw/datasette/issues/283#issuecomment-781591015 -SQLITE_LIMIT_ATTACHED = 10 - -INTERNAL_DB_NAME = "__INTERNAL__" - -Setting = collections.namedtuple("Setting", ("name", "default", "help")) -SETTINGS = ( - Setting("default_page_size", 100, "Default page size for the table view"), - Setting( - "max_returned_rows", - 1000, - "Maximum rows that can be returned from a table or custom query", - ), - Setting( - "max_insert_rows", - 100, - "Maximum rows that can be inserted at a time using the bulk insert API", - ), - Setting( - "num_sql_threads", - 3, - "Number of threads in the thread pool for executing SQLite queries", - ), - Setting("sql_time_limit_ms", 1000, "Time limit for a SQL query in milliseconds"), - Setting( - "default_facet_size", 30, "Number of values to return for requested facets" - ), - Setting("facet_time_limit_ms", 200, "Time limit for calculating a requested facet"), - Setting( - "facet_suggest_time_limit_ms", - 50, - "Time limit for calculating a suggested facet", - ), - Setting( - "allow_facet", - True, - "Allow users to specify columns to facet using ?_facet= parameter", - ), - Setting( - "allow_download", - True, - "Allow users to download the original SQLite database files", - ), - Setting( - "allow_signed_tokens", - True, - "Allow users to create and use signed API tokens", - ), - Setting( - "default_allow_sql", - True, - "Allow anyone to run arbitrary SQL queries", - ), - Setting( - "max_signed_tokens_ttl", - 0, - "Maximum allowed expiry time for signed API tokens", - ), - Setting("suggest_facets", True, "Calculate and display suggested facets"), - Setting( - "default_cache_ttl", - 5, - "Default HTTP cache TTL (used in Cache-Control: max-age= header)", - ), - Setting("cache_size_kb", 0, "SQLite cache size in KB (0 == use SQLite default)"), - Setting( - "allow_csv_stream", - True, - "Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)", - ), - Setting( - "max_csv_mb", - 100, - "Maximum size allowed for CSV export in MB - set 0 to disable this limit", - ), - Setting( - "truncate_cells_html", - 2048, - "Truncate cells longer than this in HTML table view - set 0 to disable", - ), - Setting( - "force_https_urls", - False, - "Force URLs in API output to always use https:// protocol", - ), - Setting( - "template_debug", - False, - "Allow display of template debug information with ?_context=1", - ), - Setting( - "trace_debug", - False, - "Allow display of SQL trace debug information with ?_trace=1", - ), - Setting("base_url", "/", "Datasette URLs should use this base path"), +ConfigOption = collections.namedtuple( + "ConfigOption", ("name", "default", "help") ) -_HASH_URLS_REMOVED = "The hash_urls setting has been removed, try the datasette-hashed-urls plugin instead" -OBSOLETE_SETTINGS = { - "hash_urls": _HASH_URLS_REMOVED, - "default_cache_ttl_hashed": _HASH_URLS_REMOVED, +CONFIG_OPTIONS = ( + ConfigOption("default_page_size", 100, """ + Default page size for the table view + """.strip()), + ConfigOption("max_returned_rows", 1000, """ + Maximum rows that can be returned from a table or custom query + """.strip()), + ConfigOption("num_sql_threads", 3, """ + Number of threads in the thread pool for executing SQLite queries + """.strip()), + ConfigOption("sql_time_limit_ms", 1000, """ + Time limit for a SQL query in milliseconds + """.strip()), + ConfigOption("default_facet_size", 30, """ + Number of values to return for requested facets + """.strip()), + ConfigOption("facet_time_limit_ms", 200, """ + Time limit for calculating a requested facet + """.strip()), + ConfigOption("facet_suggest_time_limit_ms", 50, """ + Time limit for calculating a suggested facet + """.strip()), + ConfigOption("allow_facet", True, """ + Allow users to specify columns to facet using ?_facet= parameter + """.strip()), + ConfigOption("allow_download", True, """ + Allow users to download the original SQLite database files + """.strip()), + ConfigOption("suggest_facets", True, """ + Calculate and display suggested facets + """.strip()), + ConfigOption("allow_sql", True, """ + Allow arbitrary SQL queries via ?sql= parameter + """.strip()), + ConfigOption("default_cache_ttl", 365 * 24 * 60 * 60, """ + Default HTTP cache TTL (used in Cache-Control: max-age= header) + """.strip()), + ConfigOption("cache_size_kb", 0, """ + SQLite cache size in KB (0 == use SQLite default) + """.strip()), +) +DEFAULT_CONFIG = { + option.name: option.default + for option in CONFIG_OPTIONS } -DEFAULT_SETTINGS = {option.name: option.default for option in SETTINGS} - -FAVICON_PATH = app_root / "datasette" / "static" / "favicon.png" - -DEFAULT_NOT_SET = object() -ResourcesSQL = collections.namedtuple("ResourcesSQL", ("sql", "params")) - - -async def favicon(request, send): - await asgi_send_file( - send, - str(FAVICON_PATH), - content_type="image/png", - headers={"Cache-Control": "max-age=3600, immutable, public"}, - ) - - -ResolvedTable = collections.namedtuple("ResolvedTable", ("db", "table", "is_view")) -ResolvedRow = collections.namedtuple( - "ResolvedRow", ("db", "table", "sql", "params", "pks", "pk_values", "row") -) - - -def _to_string(value): - if isinstance(value, str): - return value - else: - return json.dumps(value, default=str) +async def favicon(request): + return response.text("") class Datasette: - # Message constants: - INFO = 1 - WARNING = 2 - ERROR = 3 def __init__( self, - files=None, - immutables=None, + files, cache_headers=True, cors=False, inspect_data=None, - config=None, metadata=None, sqlite_extensions=None, template_dir=None, plugins_dir=None, static_mounts=None, - memory=False, - settings=None, - secret=None, - version_note=None, - config_dir=None, - pdb=False, - crossdb=False, - nolock=False, - internal=None, - default_deny=False, + config=None, ): - self._startup_invoked = False - self._closed = False - assert config_dir is None or isinstance( - config_dir, Path - ), "config_dir= should be a pathlib.Path" - self.config_dir = config_dir - self.pdb = pdb - self._secret = secret or secrets.token_hex(32) - if files is not None and isinstance(files, str): - raise ValueError("files= must be a list of paths, not a string") - self.files = tuple(files or []) + tuple(immutables or []) - if config_dir: - db_files = [] - for ext in ("db", "sqlite", "sqlite3"): - db_files.extend(config_dir.glob("*.{}".format(ext))) - self.files += tuple(str(f) for f in db_files) - if ( - config_dir - and (config_dir / "inspect-data.json").exists() - and not inspect_data - ): - inspect_data = json.loads((config_dir / "inspect-data.json").read_text()) - if not immutables: - immutable_filenames = [i["file"] for i in inspect_data.values()] - immutables = [ - f for f in self.files if Path(f).name in immutable_filenames - ] - self.inspect_data = inspect_data - self.immutables = set(immutables or []) - self.databases = collections.OrderedDict() - self.actions = {} # .invoke_startup() will populate this - self._column_types = {} # .invoke_startup() will populate this - try: - self._refresh_schemas_lock = asyncio.Lock() - except RuntimeError as rex: - # Workaround for intermittent test failure, see: - # https://github.com/simonw/datasette/issues/1802 - if "There is no current event loop in thread" in str(rex): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self._refresh_schemas_lock = asyncio.Lock() - else: - raise - self.crossdb = crossdb - self.nolock = nolock - if memory or crossdb or not self.files: - self.add_database( - Database(self, is_mutable=False, is_memory=True), name="_memory" - ) - for file in self.files: - self.add_database( - Database(self, file, is_mutable=file not in self.immutables) - ) - - self.internal_db_created = False - if internal is None: - self._internal_database = Database(self, is_temp_disk=True) - else: - self._internal_database = Database(self, path=internal, mode="rwc") - self._internal_database.name = INTERNAL_DB_NAME - + self.files = files self.cache_headers = cache_headers self.cors = cors - config_files = [] - metadata_files = [] - if config_dir: - metadata_files = [ - config_dir / filename - for filename in ("metadata.json", "metadata.yaml", "metadata.yml") - if (config_dir / filename).exists() - ] - config_files = [ - config_dir / filename - for filename in ("datasette.json", "datasette.yaml", "datasette.yml") - if (config_dir / filename).exists() - ] - if config_dir and metadata_files and not metadata: - with metadata_files[0].open() as fp: - metadata = parse_metadata(fp.read()) - - if config_dir and config_files and not config: - with config_files[0].open() as fp: - config = parse_metadata(fp.read()) - - # Move any "plugins" and "allow" settings from metadata to config - updates them in place - metadata = metadata or {} - config = config or {} - metadata, config = move_plugins_and_allow(metadata, config) - # Now migrate any known table configuration settings over as well - metadata, config = move_table_config(metadata, config) - - self._metadata_local = metadata or {} - self.sqlite_extensions = [] - for extension in sqlite_extensions or []: - # Resolve spatialite, if requested - if extension == "spatialite": - # Could raise SpatialiteNotFound - self.sqlite_extensions.append(find_spatialite()) - else: - self.sqlite_extensions.append(extension) - if config_dir and (config_dir / "templates").is_dir() and not template_dir: - template_dir = str((config_dir / "templates").resolve()) + self._inspect = inspect_data + self.metadata = metadata or {} + self.sqlite_functions = [] + self.sqlite_extensions = sqlite_extensions or [] self.template_dir = template_dir - if config_dir and (config_dir / "plugins").is_dir() and not plugins_dir: - plugins_dir = str((config_dir / "plugins").resolve()) self.plugins_dir = plugins_dir - if config_dir and (config_dir / "static").is_dir() and not static_mounts: - static_mounts = [("static", str((config_dir / "static").resolve()))] self.static_mounts = static_mounts or [] - if config_dir and (config_dir / "datasette.json").exists() and not config: - config = json.loads((config_dir / "datasette.json").read_text()) - - config = config or {} - config_settings = config.get("settings") or {} - - # Validate settings from config file - for key, value in config_settings.items(): - if key not in DEFAULT_SETTINGS: - raise StartupError(f"Invalid setting '{key}' in config file") - # Validate type matches expected type from DEFAULT_SETTINGS - if value is not None: # Allow None/null values - expected_type = type(DEFAULT_SETTINGS[key]) - actual_type = type(value) - if actual_type != expected_type: - raise StartupError( - f"Setting '{key}' in config file has incorrect type. " - f"Expected {expected_type.__name__}, got {actual_type.__name__}. " - f"Value: {value!r}. " - f"Hint: In YAML/JSON config files, remove quotes from boolean and integer values." - ) - - # Validate settings from constructor parameter - if settings: - for key, value in settings.items(): - if key not in DEFAULT_SETTINGS: - raise StartupError(f"Invalid setting '{key}' in settings parameter") - if value is not None: - expected_type = type(DEFAULT_SETTINGS[key]) - actual_type = type(value) - if actual_type != expected_type: - raise StartupError( - f"Setting '{key}' in settings parameter has incorrect type. " - f"Expected {expected_type.__name__}, got {actual_type.__name__}. " - f"Value: {value!r}" - ) - - self.config = config - # CLI settings should overwrite datasette.json settings - self._settings = dict(DEFAULT_SETTINGS, **(config_settings), **(settings or {})) - self.renderers = {} # File extension -> (renderer, can_render) functions - self.version_note = version_note - if self.setting("num_sql_threads") == 0: - self.executor = None - else: - self.executor = futures.ThreadPoolExecutor( - max_workers=self.setting("num_sql_threads") - ) - self.max_returned_rows = self.setting("max_returned_rows") - self.sql_time_limit_ms = self.setting("sql_time_limit_ms") - self.page_size = self.setting("default_page_size") + self.config = dict(DEFAULT_CONFIG, **(config or {})) + self.executor = futures.ThreadPoolExecutor( + max_workers=self.config["num_sql_threads"] + ) + self.max_returned_rows = self.config["max_returned_rows"] + self.sql_time_limit_ms = self.config["sql_time_limit_ms"] + self.page_size = self.config["default_page_size"] # Execute plugins in constructor, to ensure they are available # when the rest of `datasette inspect` executes if self.plugins_dir: - for filepath in glob.glob(os.path.join(self.plugins_dir, "*.py")): - if not os.path.isfile(filepath): - continue - mod = module_from_path(filepath, name=os.path.basename(filepath)) + for filename in os.listdir(self.plugins_dir): + filepath = os.path.join(self.plugins_dir, filename) + mod = module_from_path(filepath, name=filename) try: pm.register(mod) except ValueError: # Plugin already registered pass - # Configure Jinja - default_templates = str(app_root / "datasette" / "templates") - template_paths = [] - if self.template_dir: - template_paths.append(self.template_dir) - plugin_template_paths = [ - plugin["templates_path"] - for plugin in get_plugins() - if plugin["templates_path"] - ] - template_paths.extend(plugin_template_paths) - template_paths.append(default_templates) - template_loader = ChoiceLoader( - [ - FileSystemLoader(template_paths), - # Support {% extends "default:table.html" %}: - PrefixLoader( - {"default": FileSystemLoader(default_templates)}, delimiter=":" - ), - ] - ) - environment = Environment( - loader=template_loader, - autoescape=True, - enable_async=True, - # undefined=StrictUndefined, - ) - environment.filters["escape_css_string"] = escape_css_string - environment.filters["quote_plus"] = urllib.parse.quote_plus - self._jinja_env = environment - environment.filters["escape_sqlite"] = escape_sqlite - environment.filters["to_css_class"] = to_css_class - self._register_renderers() - self._permission_checks = collections.deque(maxlen=200) - self._root_token = secrets.token_hex(32) - self.root_enabled = False - self.default_deny = default_deny - self.client = DatasetteClient(self) - - async def apply_metadata_json(self): - # Apply any metadata entries from metadata.json to the internal tables - # step 1: top-level metadata - for key in self._metadata_local or {}: - if key == "databases": - continue - value = self._metadata_local[key] - await self.set_instance_metadata(key, _to_string(value)) - - # step 2: database-level metadata - for dbname, db in self._metadata_local.get("databases", {}).items(): - for key, value in db.items(): - if key in ("tables", "queries"): - continue - await self.set_database_metadata(dbname, key, _to_string(value)) - - # step 3: table-level metadata - for tablename, table in db.get("tables", {}).items(): - for key, value in table.items(): - if key == "columns": - continue - await self.set_resource_metadata( - dbname, tablename, key, _to_string(value) - ) - - # step 4: column-level metadata (only descriptions in metadata.json) - for columnname, column_description in table.get("columns", {}).items(): - await self.set_column_metadata( - dbname, tablename, columnname, "description", column_description - ) - - # TODO(alex) is metadata.json was loaded in, and --internal is not memory, then log - # a warning to user that they should delete their metadata.json file - - async def _save_queries_from_config(self): - await stored_queries.save_queries_from_config(self) - - def get_jinja_environment(self, request: Request = None) -> Environment: - environment = self._jinja_env - if request: - for environment in pm.hook.jinja2_environment_from_request( - datasette=self, request=request, env=environment - ): - pass - return environment - - def get_action(self, name_or_abbr: str): - """ - Returns an Action object for the given name or abbreviation. Returns None if not found. - """ - if name_or_abbr in self.actions: - return self.actions[name_or_abbr] - # Try abbreviation - for action in self.actions.values(): - if action.abbr == name_or_abbr: - return action - return None - - async def refresh_schemas(self): - # Throttle schema refreshes to at most once per second - if time.monotonic() - getattr(self, "_last_schema_refresh", 0) < 1.0: - return - self._last_schema_refresh = time.monotonic() - if self._refresh_schemas_lock.locked(): - return - async with self._refresh_schemas_lock: - await self._refresh_schemas() - - async def _refresh_schemas(self): - internal_db = self.get_internal_database() - if not self.internal_db_created: - await init_internal_db(internal_db) - await self.apply_metadata_json() - self.internal_db_created = True - current_schema_versions = { - row["database_name"]: row["schema_version"] - for row in await internal_db.execute( - "select database_name, schema_version from catalog_databases" - ) - } - catalog_table_names = ( - "catalog_columns", - "catalog_foreign_keys", - "catalog_indexes", - "catalog_views", - "catalog_tables", - "catalog_databases", - ) - # Delete stale entries for databases that are no longer attached - catalog_database_names = set(current_schema_versions.keys()) - for table in catalog_table_names[:-1]: - catalog_database_names.update( - row["database_name"] - for row in await internal_db.execute( - "select distinct database_name from {}".format(table) - ) - if row["database_name"] is not None - ) - stale_databases = catalog_database_names - set(self.databases.keys()) - if stale_databases: - - def delete_stale_database_catalog(conn): - for stale_db_name in stale_databases: - for table in catalog_table_names: - conn.execute( - "DELETE FROM {} WHERE database_name = ?".format(table), - [stale_db_name], - ) - - await internal_db.execute_write_fn(delete_stale_database_catalog) - for database_name, db in self.databases.items(): - schema_version = (await db.execute("PRAGMA schema_version")).first()[0] - # Compare schema versions to see if we should skip it - if schema_version == current_schema_versions.get(database_name): - continue - placeholders = "(?, ?, ?, ?)" - values = [database_name, str(db.path), db.is_memory, schema_version] - if db.path is None: - placeholders = "(?, null, ?, ?)" - values = [database_name, db.is_memory, schema_version] - await internal_db.execute_write( - """ - INSERT OR REPLACE INTO catalog_databases (database_name, path, is_memory, schema_version) - VALUES {} - """.format(placeholders), - values, - ) - await populate_schema_tables(internal_db, db) - - @property - def urls(self): - return Urls(self) - - @property - def pm(self): - """ - Return the global plugin manager instance. - - This provides access to the pluggy PluginManager that manages all - Datasette plugins and hooks. Use datasette.pm.hook.hook_name() to - call plugin hooks. - """ - return pm - - async def invoke_startup(self): - # This must be called for Datasette to be in a usable state - if self._startup_invoked: - return - # Register event classes - event_classes = [] - for hook in pm.hook.register_events(datasette=self): - extra_classes = await await_me_maybe(hook) - if extra_classes: - event_classes.extend(extra_classes) - self.event_classes = tuple(event_classes) - - # Register actions, but watch out for duplicate name/abbr - action_names = {} - action_abbrs = {} - for hook in pm.hook.register_actions(datasette=self): - if hook: - for action in hook: - if ( - action.name in action_names - and action != action_names[action.name] - ): - raise StartupError( - "Duplicate action name: {}".format(action.name) - ) - if ( - action.abbr - and action.abbr in action_abbrs - and action != action_abbrs[action.abbr] - ): - raise StartupError( - "Duplicate action abbr: {}".format(action.abbr) - ) - action_names[action.name] = action - if action.abbr: - action_abbrs[action.abbr] = action - self.actions[action.name] = action - - # Register column types (classes, not instances) - self._column_types = {} - for hook in pm.hook.register_column_types(datasette=self): - if hook: - for ct_cls in hook: - if ct_cls.name in self._column_types: - raise StartupError(f"Duplicate column type name: {ct_cls.name}") - self._column_types[ct_cls.name] = ct_cls - - for hook in pm.hook.prepare_jinja2_environment( - env=self._jinja_env, datasette=self - ): - await await_me_maybe(hook) - # Ensure internal tables and metadata are populated before startup hooks - await self._refresh_schemas() - await self._save_queries_from_config() - # Load column_types from config into internal DB - await self._apply_column_types_config() - for hook in pm.hook.startup(datasette=self): - await await_me_maybe(hook) - self._startup_invoked = True - - def sign(self, value, namespace="default"): - return URLSafeSerializer(self._secret, namespace).dumps(value) - - def unsign(self, signed, namespace="default"): - return URLSafeSerializer(self._secret, namespace).loads(signed) - - def in_client(self) -> bool: - """Check if the current code is executing within a datasette.client request. - - Returns: - bool: True if currently executing within a datasette.client request, False otherwise. - """ - return _in_datasette_client.get() - - def _token_handlers(self): - """Collect all registered token handlers from plugins.""" - from datasette.tokens import TokenHandler - - handlers = [] - for result in pm.hook.register_token_handler(datasette=self): - if isinstance(result, TokenHandler): - handlers.append(result) - elif isinstance(result, list): - handlers.extend(h for h in result if isinstance(h, TokenHandler)) - return handlers - - async def create_token( - self, - actor_id: str, - *, - expires_after: int | None = None, - restrictions: "TokenRestrictions | None" = None, - handler: str | None = None, - ) -> str: - """ - Create an API token for the given actor. - - Uses the first registered token handler by default, or a specific - handler if ``handler`` is provided (matched by handler name). - - Pass a :class:`TokenRestrictions` to limit which actions the token - can perform. - """ - handlers = self._token_handlers() - if not handlers: - raise RuntimeError("No token handlers are registered") - - if handler is not None: - matched = [h for h in handlers if h.name == handler] - if not matched: - available = [h.name for h in handlers] - raise ValueError( - f"Token handler {handler!r} not found. " - f"Available handlers: {available}" - ) - chosen = matched[0] - else: - chosen = handlers[0] - - return await chosen.create_token( - self, - actor_id, - expires_after=expires_after, - restrictions=restrictions, - ) - - async def verify_token(self, token: str) -> dict | None: - """ - Verify an API token by trying all registered token handlers. - - Returns an actor dict from the first handler that recognizes the - token, or None if no handler accepts it. - """ - for token_handler in self._token_handlers(): - result = await token_handler.verify_token(self, token) - if result is not None: - return result - return None - - def get_database(self, name=None, route=None): - if route is not None: - matches = [db for db in self.databases.values() if db.route == route] - if not matches: - raise KeyError - return matches[0] - if name is None: - name = [key for key in self.databases.keys()][0] - return self.databases[name] - - def add_database(self, db, name=None, route=None): - new_databases = self.databases.copy() - if name is None: - # Pick a unique name for this database - suggestion = db.suggest_name() - name = suggestion - else: - suggestion = name - i = 2 - while name in self.databases: - name = "{}_{}".format(suggestion, i) - i += 1 - db.name = name - db.route = route or name - new_databases[name] = db - # don't mutate! that causes race conditions with live import - self.databases = new_databases - return db - - def add_memory_database(self, memory_name, name=None, route=None): - return self.add_database( - Database(self, memory_name=memory_name), name=name, route=route - ) - - def remove_database(self, name): - self.get_database(name).close() - new_databases = self.databases.copy() - new_databases.pop(name) - self.databases = new_databases - - def close(self): - """Release all resources held by this Datasette instance. - - Closes every attached Database (including the internal database), - shuts down the executor, and unlinks the temporary file used for - the internal database if one was created. Idempotent and one-way. - """ - if self._closed: - return - self._closed = True - first_exception = None - dbs = list(self.databases.values()) + [self._internal_database] - for db in dbs: - try: - db.close() - except Exception as e: - if first_exception is None: - first_exception = e - if self.executor is not None: - try: - self.executor.shutdown(wait=True, cancel_futures=True) - except Exception as e: - if first_exception is None: - first_exception = e - if first_exception is not None: - raise first_exception - - def setting(self, key): - return self._settings.get(key, None) - - def settings_dict(self): - # Returns a fully resolved settings dictionary, useful for templates - return {option.name: self.setting(option.name) for option in SETTINGS} - - def _metadata_recursive_update(self, orig, updated): - if not isinstance(orig, dict) or not isinstance(updated, dict): - return orig - - for key, upd_value in updated.items(): - if isinstance(upd_value, dict) and isinstance(orig.get(key), dict): - orig[key] = self._metadata_recursive_update(orig[key], upd_value) - else: - orig[key] = upd_value - return orig - - async def get_instance_metadata(self): - rows = await self.get_internal_database().execute(""" - SELECT - key, - value - FROM metadata_instance - """) - return dict(rows) - - async def get_database_metadata(self, database_name: str): - rows = await self.get_internal_database().execute( - """ - SELECT - key, - value - FROM metadata_databases - WHERE database_name = ? - """, - [database_name], - ) - return dict(rows) - - async def get_resource_metadata(self, database_name: str, resource_name: str): - rows = await self.get_internal_database().execute( - """ - SELECT - key, - value - FROM metadata_resources - WHERE database_name = ? - AND resource_name = ? - """, - [database_name, resource_name], - ) - return dict(rows) - - async def get_column_metadata( - self, database_name: str, resource_name: str, column_name: str - ): - rows = await self.get_internal_database().execute( - """ - SELECT - key, - value - FROM metadata_columns - WHERE database_name = ? - AND resource_name = ? - AND column_name = ? - """, - [database_name, resource_name, column_name], - ) - return dict(rows) - - async def set_instance_metadata(self, key: str, value: str): - # TODO upsert only supported on SQLite 3.24.0 (2018-06-04) - await self.get_internal_database().execute_write( - """ - INSERT INTO metadata_instance(key, value) - VALUES(?, ?) - ON CONFLICT(key) DO UPDATE SET value = excluded.value; - """, - [key, value], - ) - - async def set_database_metadata(self, database_name: str, key: str, value: str): - # TODO upsert only supported on SQLite 3.24.0 (2018-06-04) - await self.get_internal_database().execute_write( - """ - INSERT INTO metadata_databases(database_name, key, value) - VALUES(?, ?, ?) - ON CONFLICT(database_name, key) DO UPDATE SET value = excluded.value; - """, - [database_name, key, value], - ) - - async def set_resource_metadata( - self, database_name: str, resource_name: str, key: str, value: str - ): - # TODO upsert only supported on SQLite 3.24.0 (2018-06-04) - await self.get_internal_database().execute_write( - """ - INSERT INTO metadata_resources(database_name, resource_name, key, value) - VALUES(?, ?, ?, ?) - ON CONFLICT(database_name, resource_name, key) DO UPDATE SET value = excluded.value; - """, - [database_name, resource_name, key, value], - ) - - async def set_column_metadata( - self, - database_name: str, - resource_name: str, - column_name: str, - key: str, - value: str, - ): - # TODO upsert only supported on SQLite 3.24.0 (2018-06-04) - await self.get_internal_database().execute_write( - """ - INSERT INTO metadata_columns(database_name, resource_name, column_name, key, value) - VALUES(?, ?, ?, ?, ?) - ON CONFLICT(database_name, resource_name, column_name, key) DO UPDATE SET value = excluded.value; - """, - [database_name, resource_name, column_name, key, value], - ) - - @staticmethod - def _query_row_to_stored_query(row) -> stored_queries.StoredQuery | None: - return stored_queries.query_row_to_stored_query(row) - - @staticmethod - def _query_options_json(options): - return stored_queries.query_options_json(options) - - async def add_query( - self, - database: str, - name: str, - sql: str, - *, - title: str | None = None, - description: str | None = None, - description_html: str | None = None, - hide_sql: bool = False, - fragment: str | None = None, - parameters: Iterable[str] | None = None, - is_write: bool = False, - is_private: bool = False, - is_trusted: bool = False, - source: str = "plugin", - owner_id: str | None = None, - on_success_message: str | None = None, - on_success_message_sql: str | None = None, - on_success_redirect: str | None = None, - on_error_message: str | None = None, - on_error_redirect: str | None = None, - replace: bool = True, - ) -> None: - return await stored_queries.add_query( - self, - database, - name, - sql, - title=title, - description=description, - description_html=description_html, - hide_sql=hide_sql, - fragment=fragment, - parameters=parameters, - is_write=is_write, - is_private=is_private, - is_trusted=is_trusted, - source=source, - owner_id=owner_id, - on_success_message=on_success_message, - on_success_message_sql=on_success_message_sql, - on_success_redirect=on_success_redirect, - on_error_message=on_error_message, - on_error_redirect=on_error_redirect, - replace=replace, - ) - - async def update_query( - self, - database: str, - name: str, - *, - sql=stored_queries.UNCHANGED, - title=stored_queries.UNCHANGED, - description=stored_queries.UNCHANGED, - description_html=stored_queries.UNCHANGED, - hide_sql=stored_queries.UNCHANGED, - fragment=stored_queries.UNCHANGED, - parameters=stored_queries.UNCHANGED, - is_write=stored_queries.UNCHANGED, - is_private=stored_queries.UNCHANGED, - is_trusted=stored_queries.UNCHANGED, - source=stored_queries.UNCHANGED, - owner_id=stored_queries.UNCHANGED, - on_success_message=stored_queries.UNCHANGED, - on_success_message_sql=stored_queries.UNCHANGED, - on_success_redirect=stored_queries.UNCHANGED, - on_error_message=stored_queries.UNCHANGED, - on_error_redirect=stored_queries.UNCHANGED, - ) -> None: - return await stored_queries.update_query( - self, - database, - name, - sql=sql, - title=title, - description=description, - description_html=description_html, - hide_sql=hide_sql, - fragment=fragment, - parameters=parameters, - is_write=is_write, - is_private=is_private, - is_trusted=is_trusted, - source=source, - owner_id=owner_id, - on_success_message=on_success_message, - on_success_message_sql=on_success_message_sql, - on_success_redirect=on_success_redirect, - on_error_message=on_error_message, - on_error_redirect=on_error_redirect, - ) - - async def remove_query( - self, database: str, name: str, source: str | None = None - ) -> None: - return await stored_queries.remove_query(self, database, name, source=source) - - async def get_query( - self, database: str, name: str - ) -> stored_queries.StoredQuery | None: - return await stored_queries.get_query(self, database, name) - - async def count_queries( - self, - database: str | None = None, - *, - actor: dict[str, Any] | None = None, - q: str | None = None, - is_write: bool | None = None, - is_private: bool | None = None, - is_trusted: bool | None = None, - source: str | None = None, - owner_id: str | None = None, - ) -> int: - return await stored_queries.count_queries( - self, - database, - actor=actor, - q=q, - is_write=is_write, - is_private=is_private, - is_trusted=is_trusted, - source=source, - owner_id=owner_id, - ) - - async def list_queries( - self, - database: str | None = None, - *, - actor: dict[str, Any] | None = None, - limit: int = 50, - cursor: str | None = None, - q: str | None = None, - is_write: bool | None = None, - is_private: bool | None = None, - is_trusted: bool | None = None, - source: str | None = None, - owner_id: str | None = None, - include_private: bool = False, - ) -> stored_queries.StoredQueryPage: - return await stored_queries.list_queries( - self, - database, - actor=actor, - limit=limit, - cursor=cursor, - q=q, - is_write=is_write, - is_private=is_private, - is_trusted=is_trusted, - source=source, - owner_id=owner_id, - include_private=include_private, - ) - - async def ensure_query_write_permissions( - self, database, sql, *, actor=None, params=None, analysis=None - ): - return await stored_queries.ensure_query_write_permissions( - self, database, sql, actor=actor, params=params, analysis=analysis - ) - - # Column types API - - async def _get_resource_column_details(self, database: str, resource: str): - db = self.databases.get(database) - if db is None: - return {} - try: - return { - column.name: column - for column in await db.table_column_details(resource) - } - except sqlite3.OperationalError: - return {} - - @staticmethod - def _column_type_is_applicable(ct_cls, column_detail) -> bool: - sqlite_types = getattr(ct_cls, "sqlite_types", None) - if sqlite_types is None: - return True - if column_detail is None: - return False - actual_sqlite_type = SQLiteType.from_declared_type(column_detail.type) - return actual_sqlite_type in sqlite_types - - async def _validate_column_type_assignment( - self, database: str, resource: str, column: str, ct_cls - ) -> None: - sqlite_types = getattr(ct_cls, "sqlite_types", None) - if sqlite_types is None: - return - - column_detail = ( - await self._get_resource_column_details(database, resource) - ).get(column) - if column_detail is None: - return - - actual_sqlite_type = SQLiteType.from_declared_type(column_detail.type) - if actual_sqlite_type in sqlite_types: - return - - allowed = ", ".join(sqlite_type.value for sqlite_type in sqlite_types) - actual = ( - actual_sqlite_type.value - if actual_sqlite_type is not None - else "unrecognized {!r}".format(column_detail.type) - ) - raise ValueError( - "Column type {!r} is only applicable to SQLite types {} but {}.{}.{} " - "has SQLite type {}".format( - ct_cls.name, - allowed, - database, - resource, - column, - actual, - ) - ) - - async def _apply_column_types_config(self): - """Load column_types from datasette.json config into the internal DB.""" - import logging - - for db_name, db_conf in (self.config or {}).get("databases", {}).items(): - for table_name, table_conf in db_conf.get("tables", {}).items(): - for col_name, ct in table_conf.get("column_types", {}).items(): - if isinstance(ct, str): - col_type, config = ct, None - else: - col_type = ct["type"] - config = ct.get("config") - if col_type not in self._column_types: - logging.warning( - "column_types config references unknown type %r " - "for %s.%s.%s", - col_type, - db_name, - table_name, - col_name, - ) - try: - await self.set_column_type( - db_name, table_name, col_name, col_type, config - ) - except ValueError as ex: - logging.warning(str(ex)) - - async def get_column_type(self, database: str, resource: str, column: str): - """ - Return a ColumnType instance (with config baked in) for a specific - column, or None if no column type is assigned. - """ - row = await self.get_internal_database().execute( - "SELECT column_type, config FROM column_types " - "WHERE database_name = ? AND resource_name = ? AND column_name = ?", - [database, resource, column], - ) - rows = row.rows - if not rows: - return None - ct_name, config = rows[0] - ct_cls = self._column_types.get(ct_name) - if ct_cls is None: - return None - column_detail = ( - await self._get_resource_column_details(database, resource) - ).get(column) - if not self._column_type_is_applicable(ct_cls, column_detail): - return None - return ct_cls(config=json.loads(config) if config else None) - - async def get_column_types(self, database: str, resource: str) -> dict: - """ - Return {column_name: ColumnType instance (with config)} - for all columns with assigned types on the given resource. - """ - rows = await self.get_internal_database().execute( - "SELECT column_name, column_type, config FROM column_types " - "WHERE database_name = ? AND resource_name = ?", - [database, resource], - ) - column_details = await self._get_resource_column_details(database, resource) - result = {} - for row in rows.rows: - col_name, ct_name, config = row - ct_cls = self._column_types.get(ct_name) - if ct_cls is not None and self._column_type_is_applicable( - ct_cls, column_details.get(col_name) - ): - result[col_name] = ct_cls(config=json.loads(config) if config else None) - return result - - async def set_column_type( - self, - database: str, - resource: str, - column: str, - column_type: str, - config: dict = None, - ) -> None: - """Assign a column type. Overwrites any existing assignment.""" - ct_cls = self._column_types.get(column_type) - if ct_cls is not None: - await self._validate_column_type_assignment( - database, resource, column, ct_cls - ) - await self.get_internal_database().execute_write( - """INSERT OR REPLACE INTO column_types - (database_name, resource_name, column_name, column_type, config) - VALUES (?, ?, ?, ?, ?)""", - [ - database, - resource, - column, - column_type, - json.dumps(config) if config else None, - ], - ) - - async def remove_column_type( - self, database: str, resource: str, column: str - ) -> None: - """Remove a column type assignment.""" - await self.get_internal_database().execute_write( - "DELETE FROM column_types " - "WHERE database_name = ? AND resource_name = ? AND column_name = ?", - [database, resource, column], - ) - - def get_internal_database(self): - return self._internal_database - - def plugin_config(self, plugin_name, database=None, table=None, fallback=True): - """Return config for plugin, falling back from specified database/table""" - if database is None and table is None: - config = self._plugin_config_top(plugin_name) - else: - config = self._plugin_config_nested(plugin_name, database, table, fallback) - - return resolve_env_secrets(config, os.environ) - - def _plugin_config_top(self, plugin_name): - """Returns any top-level plugin configuration for the specified plugin.""" - return ((self.config or {}).get("plugins") or {}).get(plugin_name) - - def _plugin_config_nested(self, plugin_name, database, table=None, fallback=True): - """Returns any database or table-level plugin configuration for the specified plugin.""" - db_config = ((self.config or {}).get("databases") or {}).get(database) - - # if there's no db-level configuration, then return early, falling back to top-level if needed - if not db_config: - return self._plugin_config_top(plugin_name) if fallback else None - - db_plugin_config = (db_config.get("plugins") or {}).get(plugin_name) - - if table: - table_plugin_config = ( - ((db_config.get("tables") or {}).get(table) or {}).get("plugins") or {} - ).get(plugin_name) - - # fallback to db_config or top-level config, in that order, if needed - if table_plugin_config is None and fallback: - return db_plugin_config or self._plugin_config_top(plugin_name) - - return table_plugin_config - - # fallback to top-level if needed - if db_plugin_config is None and fallback: - self._plugin_config_top(plugin_name) - - return db_plugin_config - - def static_hash(self, filename): - if not hasattr(self, "_static_hashes"): - self._static_hashes = {} - path = os.path.join(str(app_root), "datasette/static", filename) - signature = (os.path.getmtime(path), os.path.getsize(path)) - cached = self._static_hashes.get(filename) - if cached and cached["signature"] == signature: - return cached["hash"] - with open(path) as fp: - static_hash = hashlib.sha1(fp.read().encode("utf8")).hexdigest()[:6] - self._static_hashes[filename] = { - "signature": signature, - "hash": static_hash, - } - return static_hash - def app_css_hash(self): - return self.static_hash("app.css") + if not hasattr(self, "_app_css_hash"): + self._app_css_hash = hashlib.sha1( + open( + os.path.join(str(app_root), "datasette/static/app.css") + ).read().encode( + "utf8" + ) + ).hexdigest()[ + :6 + ] + return self._app_css_hash - def _prepare_connection(self, conn, database): + def get_canned_query(self, database_name, query_name): + query = self.metadata.get("databases", {}).get(database_name, {}).get( + "queries", {} + ).get( + query_name + ) + if query: + return {"name": query_name, "sql": query} + + async def get_table_definition(self, database_name, table, type_="table"): + table_definition_rows = list( + await self.execute( + database_name, + 'select sql from sqlite_master where name = :n and type=:t', + {"n": table, "t": type_}, + ) + ) + if not table_definition_rows: + return None + return table_definition_rows[0][0] + + def get_view_definition(self, database_name, view): + return self.get_table_definition(database_name, view, 'view') + + def asset_urls(self, key): + # Flatten list-of-lists from plugins: + seen_urls = set() + for url_or_dict in itertools.chain( + itertools.chain.from_iterable(getattr(pm.hook, key)()), + (self.metadata.get(key) or []) + ): + if isinstance(url_or_dict, dict): + url = url_or_dict["url"] + sri = url_or_dict.get("sri") + else: + url = url_or_dict + sri = None + if url in seen_urls: + continue + seen_urls.add(url) + if sri: + yield {"url": url, "sri": sri} + else: + yield {"url": url} + + def extra_css_urls(self): + return self.asset_urls("extra_css_urls") + + def extra_js_urls(self): + return self.asset_urls("extra_js_urls") + + def update_with_inherited_metadata(self, metadata): + # Fills in source/license with defaults, if available + metadata.update( + { + "source": metadata.get("source") or self.metadata.get("source"), + "source_url": metadata.get("source_url") + or self.metadata.get("source_url"), + "license": metadata.get("license") or self.metadata.get("license"), + "license_url": metadata.get("license_url") + or self.metadata.get("license_url"), + } + ) + + def prepare_connection(self, conn): conn.row_factory = sqlite3.Row conn.text_factory = lambda x: str(x, "utf-8", "replace") - if self.sqlite_extensions and database != INTERNAL_DB_NAME: + for name, num_args, func in self.sqlite_functions: + conn.create_function(name, num_args, func) + if self.sqlite_extensions: conn.enable_load_extension(True) for extension in self.sqlite_extensions: - # "extension" is either a string path to the extension - # or a 2-item tuple that specifies which entrypoint to load. - if isinstance(extension, tuple): - path, entrypoint = extension - conn.execute("SELECT load_extension(?, ?)", [path, entrypoint]) - else: - conn.execute("SELECT load_extension(?)", [extension]) - if self.setting("cache_size_kb"): - conn.execute(f"PRAGMA cache_size=-{self.setting('cache_size_kb')}") - # pylint: disable=no-member - if database != INTERNAL_DB_NAME: - pm.hook.prepare_connection(conn=conn, database=database, datasette=self) - # If self.crossdb and this is _memory, connect the first SQLITE_LIMIT_ATTACHED databases - if self.crossdb and database == "_memory": - count = 0 - for db_name, db in self.databases.items(): - if count >= SQLITE_LIMIT_ATTACHED or db.is_memory: - continue - sql = 'ATTACH DATABASE "file:{path}?{qs}" AS [{name}];'.format( - path=db.path, - qs="mode=ro" if db.is_mutable else "immutable=1", - name=db_name, - ) - conn.execute(sql) - count += 1 - - def add_message(self, request, message, type=INFO): - if not hasattr(request, "_messages"): - request._messages = [] - request._messages_should_clear = False - request._messages.append((message, type)) - - def _write_messages_to_response(self, request, response): - if getattr(request, "_messages", None): - # Set those messages - response.set_cookie("ds_messages", self.sign(request._messages, "messages")) - elif getattr(request, "_messages_should_clear", False): - response.set_cookie("ds_messages", "", expires=0, max_age=0) - - def _show_messages(self, request): - if getattr(request, "_messages", None): - request._messages_should_clear = True - messages = request._messages - request._messages = [] - return messages - else: - return [] - - async def _crumb_items(self, request, table=None, database=None): - crumbs = [] - actor = None - if request: - actor = request.actor - # Top-level link - if await self.allowed(action="view-instance", actor=actor): - crumbs.append({"href": self.urls.instance(), "label": "home"}) - # Database link - if database: - if await self.allowed( - action="view-database", - resource=DatabaseResource(database=database), - actor=actor, - ): - crumbs.append( - { - "href": self.urls.database(database), - "label": database, - } - ) - # Table link - if table: - assert database, "table= requires database=" - if await self.allowed( - action="view-table", - resource=TableResource(database=database, table=table), - actor=actor, - ): - crumbs.append( - { - "href": self.urls.table(database, table), - "label": table, - } - ) - return crumbs - - async def actors_from_ids( - self, actor_ids: Iterable[str | int] - ) -> Dict[int | str, Dict]: - result = pm.hook.actors_from_ids(datasette=self, actor_ids=actor_ids) - if result is None: - # Do the default thing - return {actor_id: {"id": actor_id} for actor_id in actor_ids} - result = await await_me_maybe(result) - return result - - async def track_event(self, event: Event): - assert isinstance(event, self.event_classes), "Invalid event type: {}".format( - type(event) - ) - for hook in pm.hook.track_event(datasette=self, event=event): - await await_me_maybe(hook) - - def resource_for_action(self, action: str, parent: str | None, child: str | None): - """ - Create a Resource instance for the given action with parent/child values. - - Looks up the action's resource_class and instantiates it with the - provided parent and child identifiers. - - Args: - action: The action name (e.g., "view-table", "view-query") - parent: The parent resource identifier (e.g., database name) - child: The child resource identifier (e.g., table/query name) - - Returns: - A Resource instance of the appropriate subclass - - Raises: - ValueError: If the action is unknown - """ - from datasette.permissions import Resource - - action_obj = self.actions.get(action) - if not action_obj: - raise ValueError(f"Unknown action: {action}") - - resource_class = action_obj.resource_class - instance = object.__new__(resource_class) - Resource.__init__(instance, parent=parent, child=child) - return instance - - async def check_visibility( - self, - actor: dict, - action: str, - resource: "Resource" | None = None, - ): - """ - Check if actor can see a resource and if it's private. - - Returns (visible, private) tuple: - - visible: bool - can the actor see it? - - private: bool - if visible, can anonymous users NOT see it? - """ - from datasette.permissions import Resource - - # Validate that resource is a Resource object or None - if resource is not None and not isinstance(resource, Resource): - raise TypeError("resource must be a Resource subclass instance or None.") - - # Check if actor can see it - if not await self.allowed(action=action, resource=resource, actor=actor): - return False, False - - # Check if anonymous user can see it (for "private" flag) - if not await self.allowed(action=action, resource=resource, actor=None): - # Actor can see it but anonymous cannot - it's private - return True, True - - # Both actor and anonymous can see it - it's public - return True, False - - async def allowed_resources_sql( - self, - *, - action: str, - actor: dict | None = None, - parent: str | None = None, - include_is_private: bool = False, - ) -> ResourcesSQL: - """ - Build SQL query to get all resources the actor can access for the given action. - - Args: - action: The action name (e.g., "view-table") - actor: The actor dict (or None for unauthenticated) - parent: Optional parent filter (e.g., database name) to limit results - include_is_private: If True, include is_private column showing if anonymous cannot access - - Returns a namedtuple of (query: str, params: dict) that can be executed against the internal database. - The query returns rows with (parent, child, reason) columns, plus is_private if requested. - - Example: - query, params = await datasette.allowed_resources_sql( - action="view-table", - actor=actor, - parent="mydb", - include_is_private=True - ) - result = await datasette.get_internal_database().execute(query, params) - """ - from datasette.utils.actions_sql import build_allowed_resources_sql - - action_obj = self.actions.get(action) - if not action_obj: - raise ValueError(f"Unknown action: {action}") - - sql, params = await build_allowed_resources_sql( - self, actor, action, parent=parent, include_is_private=include_is_private - ) - return ResourcesSQL(sql, params) - - async def allowed_resources( - self, - action: str, - actor: dict | None = None, - *, - parent: str | None = None, - include_is_private: bool = False, - include_reasons: bool = False, - limit: int = 100, - next: str | None = None, - ) -> PaginatedResources: - """ - Return paginated resources the actor can access for the given action. - - Uses SQL with keyset pagination to efficiently filter resources. - Returns PaginatedResources with list of Resource instances and pagination metadata. - - Args: - action: The action name (e.g., "view-table") - actor: The actor dict (or None for unauthenticated) - parent: Optional parent filter (e.g., database name) to limit results - include_is_private: If True, adds a .private attribute to each Resource - include_reasons: If True, adds a .reasons attribute with List[str] of permission reasons - limit: Maximum number of results to return (1-1000, default 100) - next: Keyset token from previous page for pagination - - Returns: - PaginatedResources with: - - resources: List of Resource objects for this page - - next: Token for next page (None if no more results) - - Example: - # Get first page of tables - page = await datasette.allowed_resources("view-table", actor, limit=50) - for table in page.resources: - print(f"{table.parent}/{table.child}") - - # Get next page - if page.next: - next_page = await datasette.allowed_resources( - "view-table", actor, limit=50, next=page.next - ) - - # With reasons for debugging - page = await datasette.allowed_resources( - "view-table", actor, include_reasons=True - ) - for table in page.resources: - print(f"{table.child}: {table.reasons}") - - # Iterate through all results with async generator - page = await datasette.allowed_resources("view-table", actor) - async for table in page.all(): - print(table.child) - """ - - action_obj = self.actions.get(action) - if not action_obj: - raise ValueError(f"Unknown action: {action}") - - # Validate and cap limit - limit = min(max(1, limit), 1000) - - # Get base SQL query - query, params = await self.allowed_resources_sql( - action=action, - actor=actor, - parent=parent, - include_is_private=include_is_private, - ) - - # Add keyset pagination WHERE clause if next token provided - if next: - try: - components = urlsafe_components(next) - if len(components) >= 2: - last_parent, last_child = components[0], components[1] - # Keyset condition: (parent > last) OR (parent = last AND child > last) - keyset_where = """ - (parent > :keyset_parent OR - (parent = :keyset_parent AND child > :keyset_child)) - """ - # Wrap original query and add keyset filter - query = f"SELECT * FROM ({query}) WHERE {keyset_where}" - params["keyset_parent"] = last_parent - params["keyset_child"] = last_child - except (ValueError, KeyError): - # Invalid token - ignore and start from beginning - pass - - # Add LIMIT (fetch limit+1 to detect if there are more results) - # Note: query from allowed_resources_sql() already includes ORDER BY parent, child - query = f"{query} LIMIT :limit" - params["limit"] = limit + 1 - - # Execute query - result = await self.get_internal_database().execute(query, params) - rows = list(result.rows) - - # Check if truncated (got more than limit rows) - truncated = len(rows) > limit - if truncated: - rows = rows[:limit] # Remove the extra row - - # Build Resource objects with optional attributes - resources = [] - for row in rows: - # row[0]=parent, row[1]=child, row[2]=reason, row[3]=is_private (if requested) - resource = self.resource_for_action(action, parent=row[0], child=row[1]) - - # Add reasons if requested - if include_reasons: - reason_json = row[2] - try: - reasons_array = ( - json.loads(reason_json) if isinstance(reason_json, str) else [] - ) - resource.reasons = [r for r in reasons_array if r is not None] - except (json.JSONDecodeError, TypeError): - resource.reasons = [reason_json] if reason_json else [] - - # Add private flag if requested - if include_is_private: - resource.private = bool(row[3]) - - resources.append(resource) - - # Generate next token if there are more results - next_token = None - if truncated and resources: - last_resource = resources[-1] - # Use tilde-encoding like table pagination - next_token = "{},{}".format( - tilde_encode(str(last_resource.parent)), - tilde_encode(str(last_resource.child)), - ) - - return PaginatedResources( - resources=resources, - next=next_token, - _datasette=self, - _action=action, - _actor=actor, - _parent=parent, - _include_is_private=include_is_private, - _include_reasons=include_reasons, - _limit=limit, - ) - - async def allowed( - self, - *, - action: str, - resource: "Resource" = None, - actor: dict | None = None, - ) -> bool: - """ - Check if actor can perform action on specific resource. - - Uses SQL to check permission for a single resource without fetching all resources. - This is efficient - it does NOT call allowed_resources() and check membership. - - For global actions, resource should be None (or omitted). - - Example: - from datasette.resources import TableResource - can_view = await datasette.allowed( - action="view-table", - resource=TableResource(database="analytics", table="users"), - actor=actor - ) - - # For global actions, resource can be omitted: - can_debug = await datasette.allowed(action="permissions-debug", actor=actor) - """ - from datasette.utils.actions_sql import check_permission_for_resource - - # For global actions, resource remains None - - # Check if this action has also_requires - if so, check that action first - action_obj = self.actions.get(action) - if action_obj and action_obj.also_requires: - # Must have the required action first - if not await self.allowed( - action=action_obj.also_requires, - resource=resource, - actor=actor, - ): - return False - - # For global actions, resource is None - parent = resource.parent if resource else None - child = resource.child if resource else None - - result = await check_permission_for_resource( - datasette=self, - actor=actor, - action=action, - parent=parent, - child=child, - ) - - # Log the permission check for debugging - self._permission_checks.append( - PermissionCheck( - when=datetime.datetime.now(datetime.timezone.utc).isoformat(), - actor=actor, - action=action, - parent=parent, - child=child, - result=result, - ) - ) - - return result - - async def ensure_permission( - self, - *, - action: str, - resource: "Resource" = None, - actor: dict | None = None, - ): - """ - Check if actor can perform action on resource, raising Forbidden if not. - - This is a convenience wrapper around allowed() that raises Forbidden - instead of returning False. Use this when you want to enforce a permission - check and halt execution if it fails. - - Example: - from datasette.resources import TableResource - - # Will raise Forbidden if actor cannot view the table - await datasette.ensure_permission( - action="view-table", - resource=TableResource(database="analytics", table="users"), - actor=request.actor - ) - - # For instance-level actions, resource can be omitted: - await datasette.ensure_permission( - action="permissions-debug", - actor=request.actor - ) - """ - if not await self.allowed(action=action, resource=resource, actor=actor): - raise Forbidden(action) - - async def execute( - self, - db_name, - sql, - params=None, - truncate=False, - custom_time_limit=None, - page_size=None, - log_sql_errors=True, - ): - return await self.databases[db_name].execute( - sql, - params=params, - truncate=truncate, - custom_time_limit=custom_time_limit, - page_size=page_size, - log_sql_errors=log_sql_errors, - ) - - async def expand_foreign_keys(self, actor, database, table, column, values): - """Returns dict mapping (column, value) -> label""" - labeled_fks = {} - db = self.databases[database] - foreign_keys = await db.foreign_keys_for_table(table) - # Find the foreign_key for this column - try: - fk = [ - foreign_key - for foreign_key in foreign_keys - if foreign_key["column"] == column - ][0] - except IndexError: - return {} - # Ensure user has permission to view the referenced table - from datasette.resources import TableResource - - other_table = fk["other_table"] - other_column = fk["other_column"] - visible, _ = await self.check_visibility( - actor, - action="view-table", - resource=TableResource(database=database, table=other_table), - ) - if not visible: - return {} - label_column = await db.label_column_for_table(other_table) - if not label_column: - return {(fk["column"], value): str(value) for value in values} - labeled_fks = {} - sql = """ - select {other_column}, {label_column} - from {other_table} - where {other_column} in ({placeholders}) - """.format( - other_column=escape_sqlite(other_column), - label_column=escape_sqlite(label_column), - other_table=escape_sqlite(other_table), - placeholders=", ".join(["?"] * len(set(values))), - ) - try: - results = await self.execute(database, sql, list(set(values))) - except QueryInterrupted: - pass - else: - for id, value in results: - labeled_fks[(fk["column"], id)] = value - return labeled_fks - - def absolute_url(self, request, path): - url = urllib.parse.urljoin(request.url, path) - if url.startswith("http://") and self.setting("force_https_urls"): - url = "https://" + url[len("http://") :] - return url - - def _connected_databases(self): - return [ - { - "name": d.name, - "route": d.route, - "path": d.path, - "size": d.size, - "is_mutable": d.is_mutable, - "is_memory": d.is_memory, - "hash": d.hash, - } - for name, d in self.databases.items() - ] - - def _versions(self): + conn.execute("SELECT load_extension('{}')".format(extension)) + if self.config["cache_size_kb"]: + conn.execute('PRAGMA cache_size=-{}'.format(self.config["cache_size_kb"])) + pm.hook.prepare_connection(conn=conn) + + def table_exists(self, database, table): + return table in self.inspect().get(database, {}).get("tables") + + def inspect(self): + " Inspect the database and return a dictionary of table metadata " + if self._inspect: + return self._inspect + + self._inspect = {} + for filename in self.files: + path = Path(filename) + name = path.stem + if name in self._inspect: + raise Exception("Multiple files with same stem %s" % name) + + with sqlite3.connect( + "file:{}?immutable=1".format(path), uri=True + ) as conn: + self.prepare_connection(conn) + self._inspect[name] = { + "hash": inspect_hash(path), + "file": str(path), + "views": inspect_views(conn), + "tables": inspect_tables(conn, self.metadata.get("databases", {}).get(name, {})) + } + return self._inspect + + def register_custom_units(self): + "Register any custom units defined in the metadata.json with Pint" + for unit in self.metadata.get("custom_units", []): + ureg.define(unit) + + def versions(self): conn = sqlite3.connect(":memory:") - self._prepare_connection(conn, "_memory") + self.prepare_connection(conn) sqlite_version = conn.execute("select sqlite_version()").fetchone()[0] - sqlite_extensions = {"json1": detect_json1(conn)} + sqlite_extensions = {} for extension, testsql, hasversion in ( + ("json1", "SELECT json('{}')", False), ("spatialite", "SELECT spatialite_version()", True), ): try: @@ -1991,19 +287,8 @@ class Datasette: sqlite_extensions[extension] = result.fetchone()[0] else: sqlite_extensions[extension] = None - except Exception: + except Exception as e: pass - # More details on SpatiaLite - if "spatialite" in sqlite_extensions: - spatialite_details = {} - for fn in SPATIALITE_FUNCTIONS: - try: - result = conn.execute("select {}()".format(fn)) - spatialite_details[fn] = result.fetchone()[0] - except Exception as e: - spatialite_details[fn] = {"error": str(e)} - sqlite_extensions["spatialite"] = spatialite_details - # Figure out supported FTS versions fts_versions = [] for fts in ("FTS5", "FTS4", "FTS3"): @@ -2014,1017 +299,237 @@ class Datasette: fts_versions.append(fts) except sqlite3.OperationalError: continue - datasette_version = {"version": __version__} - if self.version_note: - datasette_version["note"] = self.version_note - try: - # Optional import to avoid breaking Pyodide - # https://github.com/simonw/datasette/issues/1733#issuecomment-1115268245 - import uvicorn - - uvicorn_version = uvicorn.__version__ - except ImportError: - uvicorn_version = None - info = { + return { "python": { - "version": ".".join(map(str, sys.version_info[:3])), - "full": sys.version, + "version": ".".join(map(str, sys.version_info[:3])), "full": sys.version }, - "datasette": datasette_version, - "asgi": "3.0", - "uvicorn": uvicorn_version, + "datasette": {"version": __version__}, "sqlite": { "version": sqlite_version, "fts_versions": fts_versions, "extensions": sqlite_extensions, - "compile_options": [ - r[0] for r in conn.execute("pragma compile_options;").fetchall() - ], }, } - if using_pysqlite3: - for package in ("pysqlite3", "pysqlite3-binary"): - try: - info["pysqlite3"] = importlib.metadata.version(package) - break - except importlib.metadata.PackageNotFoundError: - pass - conn.close() - return info - def _plugins(self, request=None, all=False): - ps = list(get_plugins()) - should_show_all = False - if request is not None: - should_show_all = request.args.get("all") - else: - should_show_all = all - if not should_show_all: - ps = [p for p in ps if p["name"] not in DEFAULT_PLUGINS] - ps.sort(key=lambda p: p["name"]) + def plugins(self): return [ { "name": p["name"], "static": p["static_path"] is not None, "templates": p["templates_path"] is not None, "version": p.get("version"), - "hooks": list(sorted(set(p["hooks"]))), } - for p in ps + for p in get_plugins(pm) ] - def _threads(self): - if self.setting("num_sql_threads") == 0: - return {"num_threads": 0, "threads": []} - threads = list(threading.enumerate()) - d = { - "num_threads": len(threads), - "threads": [ - {"name": t.name, "ident": t.ident, "daemon": t.daemon} for t in threads - ], - } - tasks = asyncio.all_tasks() - d.update( - { - "num_tasks": len(tasks), - "tasks": [_cleaner_task_str(t) for t in tasks], - } - ) - return d - - def _actor(self, request): - return {"actor": request.actor} - - def _actions(self): - return [ - { - "name": action.name, - "abbr": action.abbr, - "description": action.description, - "takes_parent": action.takes_parent, - "takes_child": action.takes_child, - "resource_class": ( - action.resource_class.__name__ if action.resource_class else None - ), - "also_requires": action.also_requires, - } - for action in sorted(self.actions.values(), key=lambda a: a.name) - ] - - async def table_config(self, database: str, table: str) -> dict: - """Return dictionary of configuration for specified table""" - return ( - (self.config or {}) - .get("databases", {}) - .get(database, {}) - .get("tables", {}) - .get(table, {}) - ) - - def _register_renderers(self): - """Register output renderers which output data in custom formats.""" - # Built-in renderers - self.renderers["json"] = (json_renderer, lambda: True) - - # Hooks - hook_renderers = [] - # pylint: disable=no-member - for hook in pm.hook.register_output_renderer(datasette=self): - if type(hook) is list: - hook_renderers += hook - else: - hook_renderers.append(hook) - - for renderer in hook_renderers: - self.renderers[renderer["extension"]] = ( - # It used to be called "callback" - remove this in Datasette 1.0 - renderer.get("render") or renderer["callback"], - renderer.get("can_render") or (lambda: True), - ) - - async def render_template( + async def execute( self, - templates: List[str] | str | Template, - context: Dict[str, Any] | Context | None = None, - request: Request | None = None, - view_name: str | None = None, + db_name, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, ): - if not self._startup_invoked: - raise Exception("render_template() called before await ds.invoke_startup()") - context = context or {} - if isinstance(templates, Template): - template = templates - else: - if isinstance(templates, str): - templates = [templates] - template = self.get_jinja_environment(request).select_template(templates) - if dataclasses.is_dataclass(context): - context = dataclasses.asdict(context) - body_scripts = [] - # pylint: disable=no-member - for extra_script in pm.hook.extra_body_script( - template=template.name, - database=context.get("database"), - table=context.get("table"), - columns=context.get("columns"), - view_name=view_name, - request=request, - datasette=self, - ): - extra_script = await await_me_maybe(extra_script) - if isinstance(extra_script, dict): - script = extra_script["script"] - module = bool(extra_script.get("module")) - else: - script = extra_script - module = False - body_scripts.append({"script": Markup(script), "module": module}) + """Executes sql against db_name in a thread""" + page_size = page_size or self.page_size - extra_template_vars = {} - # pylint: disable=no-member - for extra_vars in pm.hook.extra_template_vars( - template=template.name, - database=context.get("database"), - table=context.get("table"), - columns=context.get("columns"), - view_name=view_name, - request=request, - datasette=self, - ): - extra_vars = await await_me_maybe(extra_vars) - assert isinstance(extra_vars, dict), "extra_vars is of type {}".format( - type(extra_vars) - ) - extra_template_vars.update(extra_vars) - - async def menu_links(): - links = [] - for hook in pm.hook.menu_links( - datasette=self, - actor=request.actor if request else None, - request=request or None, - ): - extra_links = await await_me_maybe(hook) - if extra_links: - links.extend(extra_links) - return links - - template_context = { - **context, - **{ - "request": request, - "crumb_items": self._crumb_items, - "urls": self.urls, - "actor": request.actor if request else None, - "menu_links": menu_links, - "display_actor": display_actor, - "show_logout": request is not None - and "ds_actor" in request.cookies - and request.actor, - "app_css_hash": self.app_css_hash(), - "zip": zip, - "body_scripts": body_scripts, - "format_bytes": format_bytes, - "show_messages": lambda: self._show_messages(request), - "extra_css_urls": await self._asset_urls( - "extra_css_urls", template, context, request, view_name - ), - "extra_js_urls": await self._asset_urls( - "extra_js_urls", template, context, request, view_name - ), - "base_url": self.setting("base_url"), - "csrftoken": ( - request.scope["csrftoken"] - if request and "csrftoken" in request.scope - else lambda: "" - ), - "datasette_version": __version__, - }, - **extra_template_vars, - } - if request and request.args.get("_context") and self.setting("template_debug"): - return "
{}
".format( - escape(json.dumps(template_context, default=repr, indent=4)) - ) - - return await template.render_async(template_context) - - def set_actor_cookie( - self, response: Response, actor: dict, expire_after: int | None = None - ): - data = {"a": actor} - if expire_after: - expires_at = int(time.time()) + (24 * 60 * 60) - data["e"] = baseconv.base62.encode(expires_at) - response.set_cookie("ds_actor", self.sign(data, "actor")) - - def delete_actor_cookie(self, response: Response): - response.set_cookie("ds_actor", "", expires=0, max_age=0) - - async def _asset_urls(self, key, template, context, request, view_name): - # Flatten list-of-lists from plugins: - seen_urls = set() - collected = [] - for hook in getattr(pm.hook, key)( - template=template.name, - database=context.get("database"), - table=context.get("table"), - columns=context.get("columns"), - view_name=view_name, - request=request, - datasette=self, - ): - hook = await await_me_maybe(hook) - collected.extend(hook) - collected.extend((self.config or {}).get(key) or []) - output = [] - for url_or_dict in collected: - if isinstance(url_or_dict, dict): - url = url_or_dict["url"] - sri = url_or_dict.get("sri") - module = bool(url_or_dict.get("module")) - else: - url = url_or_dict - sri = None - module = False - if url in seen_urls: - continue - seen_urls.add(url) - if url.startswith("/"): - # Take base_url into account: - url = self.urls.path(url) - script = {"url": url} - if sri: - script["sri"] = sri - if module: - script["module"] = True - output.append(script) - return output - - def _config(self): - return redact_keys( - self.config, ("secret", "key", "password", "token", "hash", "dsn") - ) - - def _routes(self): - routes = [] - - for routes_to_add in pm.hook.register_routes(datasette=self): - for regex, view_fn in routes_to_add: - routes.append((regex, wrap_view(view_fn, self))) - - def add_route(view, regex): - routes.append((regex, view)) - - add_route(IndexView.as_view(self), r"/(\.(?Pjsono?))?$") - add_route(IndexView.as_view(self), r"/-/(\.(?Pjsono?))?$") - add_route(permanent_redirect("/-/"), r"/-$") - # TODO: /favicon.ico and /-/static/ deserve far-future cache expires - add_route(favicon, "/favicon.ico") - - add_route( - asgi_static(app_root / "datasette" / "static"), r"/-/static/(?P.*)$" - ) - for path, dirname in self.static_mounts: - add_route(asgi_static(dirname), r"/" + path + "/(?P.*)$") - - # Mount any plugin static/ directories - for plugin in get_plugins(): - if plugin["static_path"]: - add_route( - asgi_static(plugin["static_path"]), - f"/-/static-plugins/{plugin['name']}/(?P.*)$", + def sql_operation_in_thread(): + conn = getattr(connections, db_name, None) + if not conn: + info = self.inspect()[db_name] + conn = sqlite3.connect( + "file:{}?immutable=1".format(info["file"]), + uri=True, + check_same_thread=False, ) - # Support underscores in name in addition to hyphens, see https://github.com/simonw/datasette/issues/611 - add_route( - asgi_static(plugin["static_path"]), - "/-/static-plugins/{}/(?P.*)$".format( - plugin["name"].replace("-", "_") - ), - ) - add_route( - permanent_redirect( - "/_memory", forward_query_string=True, forward_rest=True - ), - r"/:memory:(?P.*)$", - ) - add_route( - JsonDataView.as_view(self, "versions.json", self._versions), - r"/-/versions(\.(?Pjson))?$", - ) - add_route( - JsonDataView.as_view( - self, "plugins.json", self._plugins, needs_request=True - ), - r"/-/plugins(\.(?Pjson))?$", - ) - add_route( - JsonDataView.as_view(self, "settings.json", lambda: self._settings), - r"/-/settings(\.(?Pjson))?$", - ) - add_route( - JsonDataView.as_view(self, "config.json", lambda: self._config()), - r"/-/config(\.(?Pjson))?$", - ) - add_route( - JsonDataView.as_view(self, "threads.json", self._threads), - r"/-/threads(\.(?Pjson))?$", - ) - add_route( - JsonDataView.as_view(self, "databases.json", self._connected_databases), - r"/-/databases(\.(?Pjson))?$", - ) - add_route( - JsonDataView.as_view( - self, "actor.json", self._actor, needs_request=True, permission=None - ), - r"/-/actor(\.(?Pjson))?$", - ) - add_route( - JsonDataView.as_view( - self, - "actions.json", - self._actions, - template="debug_actions.html", - permission="permissions-debug", - ), - r"/-/actions(\.(?Pjson))?$", - ) - add_route( - AuthTokenView.as_view(self), - r"/-/auth-token$", - ) - add_route( - CreateTokenView.as_view(self), - r"/-/create-token$", - ) - add_route( - ApiExplorerView.as_view(self), - r"/-/api$", - ) - add_route( - JumpView.as_view(self), - r"/-/jump(\.(?Pjson))?$", - ) - add_route( - GlobalQueryListView.as_view(self), - r"/-/queries(\.(?Pjson))?$", - ) - add_route( - InstanceSchemaView.as_view(self), - r"/-/schema(\.(?Pjson|md))?$", - ) - add_route( - LogoutView.as_view(self), - r"/-/logout$", - ) - add_route( - PermissionsDebugView.as_view(self), - r"/-/permissions$", - ) - add_route( - AllowedResourcesView.as_view(self), - r"/-/allowed(\.(?Pjson))?$", - ) - add_route( - PermissionRulesView.as_view(self), - r"/-/rules(\.(?Pjson))?$", - ) - add_route( - PermissionCheckView.as_view(self), - r"/-/check(\.(?Pjson))?$", - ) - add_route( - MessagesDebugView.as_view(self), - r"/-/messages$", - ) - add_route( - AllowDebugView.as_view(self), - r"/-/allow-debug$", - ) - add_route( - wrap_view(PatternPortfolioView, self), - r"/-/patterns$", - ) - add_route( - wrap_view(database_download, self), - r"/(?P[^\/\.]+)\.db$", - ) - add_route( - wrap_view(DatabaseView, self), - r"/(?P[^\/\.]+)(\.(?P\w+))?$", - ) - add_route(TableCreateView.as_view(self), r"/(?P[^\/\.]+)/-/create$") - add_route( - QueryListView.as_view(self), - r"/(?P[^\/\.]+)/-/queries(\.(?Pjson))?$", - ) - add_route( - QueryCreateAnalyzeView.as_view(self), - r"/(?P[^\/\.]+)/-/queries/analyze$", - ) - add_route( - QueryStoreView.as_view(self), - r"/(?P[^\/\.]+)/-/queries/store$", - ) - add_route( - ExecuteWriteAnalyzeView.as_view(self), - r"/(?P[^\/\.]+)/-/execute-write/analyze$", - ) - add_route( - ExecuteWriteView.as_view(self), - r"/(?P[^\/\.]+)/-/execute-write$", - ) - add_route( - DatabaseSchemaView.as_view(self), - r"/(?P[^\/\.]+)/-/schema(\.(?Pjson|md))?$", - ) - add_route( - QueryParametersView.as_view(self), - r"/(?P[^\/\.]+)/-/query/parameters$", - ) - add_route( - wrap_view(QueryView, self), - r"/(?P[^\/\.]+)/-/query(\.(?P\w+))?$", - ) - add_route( - QueryDefinitionView.as_view(self), - r"/(?P[^\/\.]+)/(?P[^\/\.]+)/-/definition$", - ) - add_route( - QueryUpdateView.as_view(self), - r"/(?P[^\/\.]+)/(?P[^\/\.]+)/-/update$", - ) - add_route( - QueryDeleteView.as_view(self), - r"/(?P[^\/\.]+)/(?P[^\/\.]+)/-/delete$", - ) - add_route( - wrap_view(table_view, self), - r"/(?P[^\/\.]+)/(?P[^\/\.]+)(\.(?P\w+))?$", - ) - add_route( - RowView.as_view(self), - r"/(?P[^\/\.]+)/(?P
[^/]+?)/(?P[^/]+?)(\.(?P\w+))?$", - ) - add_route( - TableInsertView.as_view(self), - r"/(?P[^\/\.]+)/(?P
[^\/\.]+)/-/insert$", - ) - add_route( - TableUpsertView.as_view(self), - r"/(?P[^\/\.]+)/(?P
[^\/\.]+)/-/upsert$", - ) - add_route( - TableSetColumnTypeView.as_view(self), - r"/(?P[^\/\.]+)/(?P
[^\/\.]+)/-/set-column-type$", - ) - add_route( - TableDropView.as_view(self), - r"/(?P[^\/\.]+)/(?P
[^\/\.]+)/-/drop$", - ) - add_route( - TableSchemaView.as_view(self), - r"/(?P[^\/\.]+)/(?P
[^\/\.]+)/-/schema(\.(?Pjson|md))?$", - ) - add_route( - RowDeleteView.as_view(self), - r"/(?P[^\/\.]+)/(?P
[^/]+?)/(?P[^/]+?)/-/delete$", - ) - add_route( - RowUpdateView.as_view(self), - r"/(?P[^\/\.]+)/(?P
[^/]+?)/(?P[^/]+?)/-/update$", - ) - return [ - # Compile any strings to regular expressions - ((re.compile(pattern) if isinstance(pattern, str) else pattern), view) - for pattern, view in routes - ] + self.prepare_connection(conn) + setattr(connections, db_name, conn) - async def resolve_database(self, request): - database_route = tilde_decode(request.url_vars["database"]) - try: - return self.get_database(route=database_route) - except KeyError: - raise DatabaseNotFound(database_route) + time_limit_ms = self.sql_time_limit_ms + if custom_time_limit and custom_time_limit < time_limit_ms: + time_limit_ms = custom_time_limit - async def resolve_table(self, request): - db = await self.resolve_database(request) - table_name = tilde_decode(request.url_vars["table"]) - # Table must exist - is_view = False - table_exists = await db.table_exists(table_name) - if not table_exists: - is_view = await db.view_exists(table_name) - if not (table_exists or is_view): - raise TableNotFound(db.name, table_name) - return ResolvedTable(db, table_name, is_view) - - async def resolve_row(self, request): - db, table_name, _ = await self.resolve_table(request) - pk_values = urlsafe_components(request.url_vars["pks"]) - sql, params, pks = await row_sql_params_pks(db, table_name, pk_values) - results = await db.execute(sql, params, truncate=True) - row = results.first() - if row is None: - raise RowNotFound(db.name, table_name, pk_values) - return ResolvedRow(db, table_name, sql, params, pks, pk_values, results.first()) - - def app(self): - """Returns an ASGI app function that serves the whole of Datasette""" - routes = self._routes() - - async def setup_db(): - # First time server starts up, calculate table counts for immutable databases - for database in self.databases.values(): - if not database.is_mutable: - await database.table_counts(limit=60 * 60 * 1000) - - async def _close_on_shutdown(): - self.close() - - asgi = CrossOriginProtectionMiddleware(DatasetteRouter(self, routes), self) - if self.setting("trace_debug"): - asgi = AsgiTracer(asgi) - asgi = AsgiLifespan(asgi, on_shutdown=[_close_on_shutdown]) - asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup]) - for wrapper in pm.hook.asgi_wrapper(datasette=self): - asgi = wrapper(asgi) - return asgi - - -class DatasetteRouter: - def __init__(self, datasette, routes): - self.ds = datasette - self.routes = routes or [] - - async def __call__(self, scope, receive, send): - # Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves - path = scope["path"] - raw_path = scope.get("raw_path") - if raw_path: - path = raw_path.decode("ascii") - path = path.partition("?")[0] - return await self.route_path(scope, receive, send, path) - - async def route_path(self, scope, receive, send, path): - # Strip off base_url if present before routing - base_url = self.ds.setting("base_url") - if base_url != "/" and path.startswith(base_url): - path = "/" + path[len(base_url) :] - scope = dict(scope, route_path=path) - request = Request(scope, receive) - # Populate request_messages if ds_messages cookie is present - try: - request._messages = self.ds.unsign( - request.cookies.get("ds_messages", ""), "messages" - ) - except BadSignature: - pass - - scope_modifications = {} - # Apply force_https_urls, if set - if ( - self.ds.setting("force_https_urls") - and scope["type"] == "http" - and scope.get("scheme") != "https" - ): - scope_modifications["scheme"] = "https" - # Handle authentication - default_actor = scope.get("actor") or None - actor = None - results = pm.hook.actor_from_request(datasette=self.ds, request=request) - for result in results: - result = await await_me_maybe(result) - if result and actor is None: - actor = result - # Don't break — we must await all coroutines to avoid - # "coroutine was never awaited" warnings - scope_modifications["actor"] = actor or default_actor - scope = dict(scope, **scope_modifications) - - match, view = resolve_routes(self.routes, path) - - if match is None: - return await self.handle_404(request, send) - - new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) - request.scope = new_scope - try: - response = await view(request, send) - if response: - self.ds._write_messages_to_response(request, response) - await response.asgi_send(send) - return - except NotFound as exception: - return await self.handle_404(request, send, exception) - except Forbidden as exception: - # Try the forbidden() plugin hook - for custom_response in pm.hook.forbidden( - datasette=self.ds, request=request, message=exception.args[0] - ): - custom_response = await await_me_maybe(custom_response) - assert ( - custom_response - ), "Default forbidden() hook should have been called" - return await custom_response.asgi_send(send) - except Exception as exception: - return await self.handle_exception(request, send, exception) - - async def handle_404(self, request, send, exception=None): - # If path contains % encoding, redirect to tilde encoding - if "%" in request.path: - # Try the same path but with "%" replaced by "~" - # and "~" replaced with "~7E" - # and "." replaced with "~2E" - new_path = ( - request.path.replace("~", "~7E").replace("%", "~").replace(".", "~2E") - ) - if request.query_string: - new_path += "?{}".format(request.query_string) - await asgi_send_redirect(send, new_path) - return - # If URL has a trailing slash, redirect to URL without it - path = request.scope.get( - "raw_path", request.scope["path"].encode("utf8") - ).partition(b"?")[0] - context = {} - if path.endswith(b"/"): - path = path.rstrip(b"/") - if request.scope["query_string"]: - path += b"?" + request.scope["query_string"] - await asgi_send_redirect(send, path.decode("latin1")) - else: - # Is there a pages/* template matching this path? - route_path = request.scope.get("route_path", request.scope["path"]) - # Jinja requires template names to use "/" even on Windows - template_name = "pages" + route_path + ".html" - # Build a list of pages/blah/{name}.html matching expressions - environment = self.ds.get_jinja_environment(request) - pattern_templates = [ - filepath - for filepath in environment.list_templates() - if "{" in filepath and filepath.startswith("pages/") - ] - page_routes = [ - (route_pattern_from_filepath(filepath[len("pages/") :]), filepath) - for filepath in pattern_templates - ] - try: - template = environment.select_template([template_name]) - except TemplateNotFound: - template = None - if template is None: - # Try for a pages/blah/{name}.html template match - for regex, wildcard_template in page_routes: - match = regex.match(route_path) - if match is not None: - context.update(match.groupdict()) - template = wildcard_template - break - - if template: - headers = {} - status = [200] - - def custom_header(name, value): - headers[name] = value - return "" - - def custom_status(code): - status[0] = code - return "" - - def custom_redirect(location, code=302): - status[0] = code - headers["Location"] = location - return "" - - def raise_404(message=""): - raise NotFoundExplicit(message) - - context.update( - { - "custom_header": custom_header, - "custom_status": custom_status, - "custom_redirect": custom_redirect, - "raise_404": raise_404, - } - ) + with sqlite_timelimit(conn, time_limit_ms): try: - body = await self.ds.render_template( - template, - context, - request=request, - view_name="page", - ) - except NotFoundExplicit as e: - await self.handle_exception(request, send, e) - return - # Pull content-type out into separate parameter - content_type = "text/html; charset=utf-8" - matches = [k for k in headers if k.lower() == "content-type"] - if matches: - content_type = headers[matches[0]] - await asgi_send( - send, - body, - status=status[0], - headers=headers, - content_type=content_type, - ) - else: - await self.handle_exception(request, send, exception or NotFound("404")) - - async def handle_exception(self, request, send, exception): - responses = [] - for hook in pm.hook.handle_exception( - datasette=self.ds, - request=request, - exception=exception, - ): - response = await await_me_maybe(hook) - if response is not None: - responses.append(response) - - assert responses, "Default exception handler should have returned something" - # Even if there are multiple responses use just the first one - response = responses[0] - await response.asgi_send(send) - - -_cleaner_task_str_re = re.compile(r"\S*site-packages/") - - -def _cleaner_task_str(task): - s = str(task) - # This has something like the following in it: - # running at /Users/simonw/Dropbox/Development/datasette/venv-3.7.5/lib/python3.7/site-packages/uvicorn/main.py:361> - # Clean up everything up to and including site-packages - return _cleaner_task_str_re.sub("", s) - - -def wrap_view(view_fn_or_class, datasette): - is_function = isinstance(view_fn_or_class, types.FunctionType) - if is_function: - return wrap_view_function(view_fn_or_class, datasette) - else: - if not isinstance(view_fn_or_class, type): - raise ValueError("view_fn_or_class must be a function or a class") - return wrap_view_class(view_fn_or_class, datasette) - - -def wrap_view_class(view_class, datasette): - async def async_view_for_class(request, send): - instance = view_class() - if inspect.iscoroutinefunction(instance.__call__): - return await async_call_with_supported_arguments( - instance.__call__, - scope=request.scope, - receive=request.receive, - send=send, - request=request, - datasette=datasette, - ) - else: - return call_with_supported_arguments( - instance.__call__, - scope=request.scope, - receive=request.receive, - send=send, - request=request, - datasette=datasette, - ) - - async_view_for_class.view_class = view_class - return async_view_for_class - - -def wrap_view_function(view_fn, datasette): - @functools.wraps(view_fn) - async def async_view_fn(request, send): - if inspect.iscoroutinefunction(view_fn): - response = await async_call_with_supported_arguments( - view_fn, - scope=request.scope, - receive=request.receive, - send=send, - request=request, - datasette=datasette, - ) - else: - response = call_with_supported_arguments( - view_fn, - scope=request.scope, - receive=request.receive, - send=send, - request=request, - datasette=datasette, - ) - if response is not None: - return response - - return async_view_fn - - -def permanent_redirect(path, forward_query_string=False, forward_rest=False): - return wrap_view( - lambda request, send: Response.redirect( - path - + (request.url_vars["rest"] if forward_rest else "") - + ( - ("?" + request.query_string) - if forward_query_string and request.query_string - else "" - ), - status=301, - ), - datasette=None, - ) - - -_curly_re = re.compile(r"({.*?})") - - -def route_pattern_from_filepath(filepath): - # Drop the ".html" suffix - if filepath.endswith(".html"): - filepath = filepath[: -len(".html")] - re_bits = ["/"] - for bit in _curly_re.split(filepath): - if _curly_re.match(bit): - re_bits.append(f"(?P<{bit[1:-1]}>[^/]*)") - else: - re_bits.append(re.escape(bit)) - return re.compile("^" + "".join(re_bits) + "$") - - -class NotFoundExplicit(NotFound): - pass - - -class DatasetteClient: - """Internal HTTP client for making requests to a Datasette instance. - - Used for testing and for internal operations that need to make HTTP requests - to the Datasette app without going through an actual HTTP server. - """ - - def __init__(self, ds): - self.ds = ds - - @property - def app(self): - return self.ds.app() - - def actor_cookie(self, actor): - # Utility method, mainly for tests - return self.ds.sign({"a": actor}, "actor") - - def _fix(self, path, avoid_path_rewrites=False): - if not isinstance(path, PrefixedUrlString) and not avoid_path_rewrites: - path = self.ds.urls.path(path) - if path.startswith("/"): - path = f"http://localhost{path}" - return path - - def _apply_actor(self, kwargs): - """If ``actor=`` was supplied, convert it into a signed ds_actor cookie.""" - actor = kwargs.pop("actor", None) - if actor is None: - return - cookies = dict(kwargs.get("cookies") or {}) - if "ds_actor" in cookies: - raise TypeError("Cannot pass both actor= and a ds_actor cookie") - cookies["ds_actor"] = self.actor_cookie(actor) - kwargs["cookies"] = cookies - - async def _request(self, method, path, skip_permission_checks=False, **kwargs): - from datasette.permissions import SkipPermissions - - self._apply_actor(kwargs) - with _DatasetteClientContext(): - if skip_permission_checks: - with SkipPermissions(): - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=self.app), - cookies=kwargs.pop("cookies", None), - ) as client: - return await getattr(client, method)(self._fix(path), **kwargs) - else: - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=self.app), - cookies=kwargs.pop("cookies", None), - ) as client: - return await getattr(client, method)(self._fix(path), **kwargs) - - async def get(self, path, skip_permission_checks=False, **kwargs): - return await self._request( - "get", path, skip_permission_checks=skip_permission_checks, **kwargs - ) - - async def options(self, path, skip_permission_checks=False, **kwargs): - return await self._request( - "options", path, skip_permission_checks=skip_permission_checks, **kwargs - ) - - async def head(self, path, skip_permission_checks=False, **kwargs): - return await self._request( - "head", path, skip_permission_checks=skip_permission_checks, **kwargs - ) - - async def post(self, path, skip_permission_checks=False, **kwargs): - return await self._request( - "post", path, skip_permission_checks=skip_permission_checks, **kwargs - ) - - async def put(self, path, skip_permission_checks=False, **kwargs): - return await self._request( - "put", path, skip_permission_checks=skip_permission_checks, **kwargs - ) - - async def patch(self, path, skip_permission_checks=False, **kwargs): - return await self._request( - "patch", path, skip_permission_checks=skip_permission_checks, **kwargs - ) - - async def delete(self, path, skip_permission_checks=False, **kwargs): - return await self._request( - "delete", path, skip_permission_checks=skip_permission_checks, **kwargs - ) - - async def request(self, method, path, skip_permission_checks=False, **kwargs): - """Make an HTTP request with the specified method. - - Args: - method: HTTP method (e.g., "GET", "POST", "PUT") - path: The path to request - skip_permission_checks: If True, bypass all permission checks for this request - **kwargs: Additional arguments to pass to httpx - - Returns: - httpx.Response: The response from the request - """ - from datasette.permissions import SkipPermissions - - avoid_path_rewrites = kwargs.pop("avoid_path_rewrites", None) - self._apply_actor(kwargs) - with _DatasetteClientContext(): - if skip_permission_checks: - with SkipPermissions(): - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=self.app), - cookies=kwargs.pop("cookies", None), - ) as client: - return await client.request( - method, self._fix(path, avoid_path_rewrites), **kwargs + cursor = conn.cursor() + cursor.execute(sql, params or {}) + max_returned_rows = self.max_returned_rows + if max_returned_rows == page_size: + max_returned_rows += 1 + if max_returned_rows and truncate: + rows = cursor.fetchmany(max_returned_rows + 1) + truncated = len(rows) > max_returned_rows + rows = rows[:max_returned_rows] + else: + rows = cursor.fetchall() + truncated = False + except sqlite3.OperationalError as e: + if e.args == ('interrupted',): + raise InterruptedError(e) + print( + "ERROR: conn={}, sql = {}, params = {}: {}".format( + conn, repr(sql), params, e ) - else: - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=self.app), - cookies=kwargs.pop("cookies", None), - ) as client: - return await client.request( - method, self._fix(path, avoid_path_rewrites), **kwargs ) + raise + + if truncate: + return Results(rows, truncated, cursor.description) + + else: + return Results(rows, False, cursor.description) + + return await asyncio.get_event_loop().run_in_executor( + self.executor, sql_operation_in_thread + ) + + def app(self): + app = Sanic(__name__) + default_templates = str(app_root / "datasette" / "templates") + template_paths = [] + if self.template_dir: + template_paths.append(self.template_dir) + template_paths.extend( + [ + plugin["templates_path"] + for plugin in get_plugins(pm) + if plugin["templates_path"] + ] + ) + template_paths.append(default_templates) + template_loader = ChoiceLoader( + [ + FileSystemLoader(template_paths), + # Support {% extends "default:table.html" %}: + PrefixLoader( + {"default": FileSystemLoader(default_templates)}, delimiter=":" + ), + ] + ) + self.jinja_env = Environment(loader=template_loader, autoescape=True) + self.jinja_env.filters["escape_css_string"] = escape_css_string + self.jinja_env.filters["quote_plus"] = lambda u: urllib.parse.quote_plus(u) + self.jinja_env.filters["escape_sqlite"] = escape_sqlite + self.jinja_env.filters["to_css_class"] = to_css_class + pm.hook.prepare_jinja2_environment(env=self.jinja_env) + app.add_route(IndexView.as_view(self), "/") + # TODO: /favicon.ico and /-/static/ deserve far-future cache expires + app.add_route(favicon, "/favicon.ico") + app.static("/-/static/", str(app_root / "datasette" / "static")) + for path, dirname in self.static_mounts: + app.static(path, dirname) + # Mount any plugin static/ directories + for plugin in get_plugins(pm): + if plugin["static_path"]: + modpath = "/-/static-plugins/{}/".format(plugin["name"]) + app.static(modpath, plugin["static_path"]) + app.add_route( + JsonDataView.as_view(self, "inspect.json", self.inspect), + "/-/inspect", + ) + app.add_route( + JsonDataView.as_view(self, "metadata.json", lambda: self.metadata), + "/-/metadata", + ) + app.add_route( + JsonDataView.as_view(self, "versions.json", self.versions), + "/-/versions", + ) + app.add_route( + JsonDataView.as_view(self, "plugins.json", self.plugins), + "/-/plugins", + ) + app.add_route( + JsonDataView.as_view(self, "config.json", lambda: self.config), + "/-/config", + ) + app.add_route( + DatabaseView.as_view(self), "/" + ) + app.add_route( + DatabaseDownload.as_view(self), "/" + ) + app.add_route( + TableView.as_view(self), + "//", + ) + app.add_route( + RowView.as_view(self), + "///", + ) + + self.register_custom_units() + + @app.exception(Exception) + def on_exception(request, exception): + title = None + help = None + if isinstance(exception, NotFound): + status = 404 + info = {} + message = exception.args[0] + elif isinstance(exception, InvalidUsage): + status = 405 + info = {} + message = exception.args[0] + elif isinstance(exception, DatasetteError): + status = exception.status + info = exception.error_dict + message = exception.message + if exception.messagge_is_html: + message = Markup(message) + title = exception.title + else: + status = 500 + info = {} + message = str(exception) + traceback.print_exc() + templates = ["500.html"] + if status != 500: + templates = ["{}.html".format(status)] + templates + info.update( + {"ok": False, "error": message, "status": status, "title": title} + ) + if request.path.split("?")[0].endswith(".json"): + return response.json(info, status=status) + + else: + template = self.jinja_env.select_template(templates) + return response.html(template.render(info), status=status) + + class AsgiApp(): + def __init__(self, scope): + self.scope = scope + + async def __call__(self, receive, send): + # Create Sanic request from scope + path = self.scope["path"].encode("utf8") + if self.scope["query_string"]: + path = b"{}?{}".format(path, self.scope["query_string"]) + request = SanicRequest( + path, + {}, '1.1', 'GET', None + ) + async def write_callback(response): + await send({ + 'type': 'http.response.start', + 'status': 200, + 'headers': [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in response.headers.items() + ], + }) + await send({ + 'type': 'http.response.body', + 'body': response.body, + }) + + # TODO: Fix this + stream_callback = write_callback + await app.handle_request(request, write_callback, stream_callback) + + app.AsgiApp = AsgiApp + + return app diff --git a/datasette/blob_renderer.py b/datasette/blob_renderer.py deleted file mode 100644 index 4d8c6bea..00000000 --- a/datasette/blob_renderer.py +++ /dev/null @@ -1,61 +0,0 @@ -from datasette import hookimpl -from datasette.utils.asgi import Response, BadRequest -from datasette.utils import to_css_class -import hashlib - -_BLOB_COLUMN = "_blob_column" -_BLOB_HASH = "_blob_hash" - - -async def render_blob(datasette, database, rows, columns, request, table, view_name): - if _BLOB_COLUMN not in request.args: - raise BadRequest(f"?{_BLOB_COLUMN}= is required") - blob_column = request.args[_BLOB_COLUMN] - if blob_column not in columns: - raise BadRequest(f"{blob_column} is not a valid column") - - # If ?_blob_hash= provided, use that to select the row - otherwise use first row - blob_hash = None - if _BLOB_HASH in request.args: - blob_hash = request.args[_BLOB_HASH] - for row in rows: - value = row[blob_column] - if hashlib.sha256(value).hexdigest() == blob_hash: - break - else: - # Loop did not break - raise BadRequest( - "Link has expired - the requested binary content has changed or could not be found." - ) - else: - row = rows[0] - - value = row[blob_column] - filename_bits = [] - if table: - filename_bits.append(to_css_class(table)) - if "pks" in request.url_vars: - filename_bits.append(request.url_vars["pks"]) - filename_bits.append(to_css_class(blob_column)) - if blob_hash: - filename_bits.append(blob_hash[:6]) - filename = "-".join(filename_bits) + ".blob" - headers = { - "X-Content-Type-Options": "nosniff", - "Content-Disposition": f'attachment; filename="{filename}"', - } - return Response( - body=value or b"", - status=200, - headers=headers, - content_type="application/binary", - ) - - -@hookimpl -def register_output_renderer(): - return { - "extension": "blob", - "render": render_blob, - "can_render": lambda: False, - } diff --git a/datasette/cli.py b/datasette/cli.py index 93aa22ef..2f61559a 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -1,201 +1,330 @@ -import asyncio -import uvicorn import click from click import formatting -from click.types import CompositeParamType from click_default_group import DefaultGroup -import functools import json import os -import pathlib -from runpy import run_module import shutil -from subprocess import call +from subprocess import call, check_output import sys -import textwrap -import webbrowser -from .app import ( - Datasette, - DEFAULT_SETTINGS, - SETTINGS, - SQLITE_LIMIT_ATTACHED, - pm, -) +from .app import Datasette, DEFAULT_CONFIG, CONFIG_OPTIONS from .utils import ( - LoadExtension, - StartupError, - check_connection, - deep_dict_update, - find_spatialite, - parse_metadata, - ConnectionProblem, - SpatialiteConnectionProblem, - initial_path_for_datasette, - pairs_to_nested_config, temporary_docker_directory, + temporary_heroku_directory, value_as_boolean, - SpatialiteNotFound, - StaticMount, ValueAsBooleanError, ) -from .utils.sqlite import sqlite3 -from .utils.testing import TestClient -from .version import __version__ -def run_sync(coro_func): - """Run an async callable to completion on a fresh event loop.""" - loop = asyncio.new_event_loop() - try: - asyncio.set_event_loop(loop) - return loop.run_until_complete(coro_func()) - finally: - asyncio.set_event_loop(None) - loop.close() +class StaticMount(click.ParamType): + name = "static mount" + + def convert(self, value, param, ctx): + if ":" not in value: + self.fail( + '"{}" should be of format mountpoint:directory'.format(value), + param, ctx + ) + path, dirpath = value.split(":") + if not os.path.exists(dirpath) or not os.path.isdir(dirpath): + self.fail("%s is not a valid directory path" % value, param, ctx) + return path, dirpath -# Use Rich for tracebacks if it is installed -try: - from rich.traceback import install - - install(show_locals=True) -except ImportError: - pass - - -class Setting(CompositeParamType): - name = "setting" - arity = 2 +class Config(click.ParamType): + name = "config" def convert(self, config, param, ctx): - name, value = config - if name in DEFAULT_SETTINGS: - # For backwards compatibility with how this worked prior to - # Datasette 1.0, we turn bare setting names into setting.name - # Type checking for those older settings - default = DEFAULT_SETTINGS[name] - name = "settings.{}".format(name) - if isinstance(default, bool): - try: - return name, "true" if value_as_boolean(value) else "false" - except ValueAsBooleanError: - self.fail(f'"{name}" should be on/off/true/false/1/0', param, ctx) - elif isinstance(default, int): - if not value.isdigit(): - self.fail(f'"{name}" should be an integer', param, ctx) - return name, value - elif isinstance(default, str): - return name, value - else: - # Should never happen: - self.fail("Invalid option") - return name, value - - -def sqlite_extensions(fn): - fn = click.option( - "sqlite_extensions", - "--load-extension", - type=LoadExtension(), - envvar="DATASETTE_LOAD_EXTENSION", - multiple=True, - help="Path to a SQLite extension to load, and optional entrypoint", - )(fn) - - # Wrap it in a custom error handler - @functools.wraps(fn) - def wrapped(*args, **kwargs): - try: - return fn(*args, **kwargs) - except AttributeError as e: - if "enable_load_extension" in str(e): - raise click.ClickException(textwrap.dedent(""" - Your Python installation does not have the ability to load SQLite extensions. - - More information: https://datasette.io/help/extensions - """).strip()) - raise - - return wrapped + if ":" not in config: + self.fail( + '"{}" should be name:value'.format(config), param, ctx + ) + return + name, value = config.split(":") + if name not in DEFAULT_CONFIG: + self.fail( + "{} is not a valid option (--help-config to see all)".format( + name + ), param, ctx + ) + return + # Type checking + default = DEFAULT_CONFIG[name] + if isinstance(default, bool): + try: + return name, value_as_boolean(value) + except ValueAsBooleanError: + self.fail( + '"{}" should be on/off/true/false/1/0'.format(name), param, ctx + ) + return + elif isinstance(default, int): + if not value.isdigit(): + self.fail( + '"{}" should be an integer'.format(name), param, ctx + ) + return + return name, int(value) + else: + # Should never happen: + self.fail('Invalid option') @click.group(cls=DefaultGroup, default="serve", default_if_no_args=True) -@click.version_option(version=__version__) +@click.version_option() def cli(): """ - Datasette is an open source multi-tool for exploring and publishing data - - \b - About Datasette: https://datasette.io/ - Full documentation: https://docs.datasette.io/ + Datasette! """ @cli.command() @click.argument("files", type=click.Path(exists=True), nargs=-1) -@click.option("--inspect-file", default="-") -@sqlite_extensions +@click.option("--inspect-file", default="inspect-data.json") +@click.option( + "sqlite_extensions", + "--load-extension", + envvar="SQLITE_EXTENSIONS", + multiple=True, + type=click.Path(exists=True, resolve_path=True), + help="Path to a SQLite extension to load", +) def inspect(files, inspect_file, sqlite_extensions): - """ - Generate JSON summary of provided database files - - This can then be passed to "datasette --inspect-file" to speed up count - operations against immutable database files. - """ - inspect_data = run_sync(lambda: inspect_(files, sqlite_extensions)) - if inspect_file == "-": - sys.stdout.write(json.dumps(inspect_data, indent=2)) - else: - with open(inspect_file, "w") as fp: - fp.write(json.dumps(inspect_data, indent=2)) - - -async def inspect_(files, sqlite_extensions): - app = Datasette([], immutables=files, sqlite_extensions=sqlite_extensions) - data = {} - for name, database in app.databases.items(): - counts = await database.table_counts(limit=3600 * 1000) - data[name] = { - "hash": database.hash, - "size": database.size, - "file": database.path, - "tables": { - table_name: {"count": table_count} - for table_name, table_count in counts.items() - }, - } - return data - - -@cli.group() -def publish(): - """Publish specified SQLite database files to the internet along with a Datasette-powered interface and API""" - pass - - -# Register publish plugins -pm.hook.publish_subcommand(publish=publish) + app = Datasette(files, sqlite_extensions=sqlite_extensions) + open(inspect_file, "w").write(json.dumps(app.inspect(), indent=2)) @cli.command() -@click.option("--all", help="Include built-in default plugins", is_flag=True) +@click.argument("publisher", type=click.Choice(["now", "heroku"])) +@click.argument("files", type=click.Path(exists=True), nargs=-1) @click.option( - "--requirements", help="Output requirements.txt of installed plugins", is_flag=True + "-n", + "--name", + default="datasette", + help="Application name to use when deploying to Now (ignored for Heroku)", +) +@click.option( + "-m", + "--metadata", + type=click.File(mode="r"), + help="Path to JSON file containing metadata to publish", +) +@click.option("--extra-options", help="Extra options to pass to datasette serve") +@click.option("--force", is_flag=True, help="Pass --force option to now") +@click.option("--branch", help="Install datasette from a GitHub branch e.g. master") +@click.option( + "--template-dir", + type=click.Path(exists=True, file_okay=False, dir_okay=True), + help="Path to directory containing custom templates", ) @click.option( "--plugins-dir", type=click.Path(exists=True, file_okay=False, dir_okay=True), help="Path to directory containing custom plugins", ) -def plugins(all, requirements, plugins_dir): - """List currently installed plugins""" - app = Datasette([], plugins_dir=plugins_dir) - if requirements: - for plugin in app._plugins(): - if plugin["version"]: - click.echo("{}=={}".format(plugin["name"], plugin["version"])) - else: - click.echo(json.dumps(app._plugins(all=all), indent=4)) +@click.option( + "--static", + type=StaticMount(), + help="mountpoint:path-to-directory for serving static files", + multiple=True, +) +@click.option( + "--install", + help="Additional packages (e.g. plugins) to install", + multiple=True, +) +@click.option( + "--spatialite", is_flag=True, help="Enable SpatialLite extension" +) +@click.option("--title", help="Title for metadata") +@click.option("--license", help="License label for metadata") +@click.option("--license_url", help="License URL for metadata") +@click.option("--source", help="Source label for metadata") +@click.option("--source_url", help="Source URL for metadata") +def publish( + publisher, + files, + name, + metadata, + extra_options, + force, + branch, + template_dir, + plugins_dir, + static, + install, + spatialite, + **extra_metadata +): + """ + Publish specified SQLite database files to the internet along with a datasette API. + + Options for PUBLISHER: + * 'now' - You must have Zeit Now installed: https://zeit.co/now + * 'heroku' - You must have Heroku installed: https://cli.heroku.com/ + + Example usage: datasette publish now my-database.db + """ + + def _fail_if_publish_binary_not_installed(binary, publish_target, install_link): + """Exit (with error message) if ``binary` isn't installed""" + if not shutil.which(binary): + click.secho( + "Publishing to {publish_target} requires {binary} to be installed and configured".format( + publish_target=publish_target, binary=binary + ), + bg="red", + fg="white", + bold=True, + err=True, + ) + click.echo( + "Follow the instructions at {install_link}".format( + install_link=install_link + ), + err=True, + ) + sys.exit(1) + + if publisher == "now": + _fail_if_publish_binary_not_installed("now", "Zeit Now", "https://zeit.co/now") + with temporary_docker_directory( + files, + name, + metadata, + extra_options, + branch, + template_dir, + plugins_dir, + static, + install, + spatialite, + extra_metadata, + ): + if force: + call(["now", "--force"]) + else: + call("now") + + elif publisher == "heroku": + _fail_if_publish_binary_not_installed( + "heroku", "Heroku", "https://cli.heroku.com" + ) + if spatialite: + click.secho( + "The --spatialite option is not yet supported for Heroku", + bg="red", + fg="white", + bold=True, + err=True, + ) + click.echo( + "See https://github.com/simonw/datasette/issues/301", + err=True, + ) + sys.exit(1) + + # Check for heroku-builds plugin + plugins = [ + line.split()[0] for line in check_output(["heroku", "plugins"]).splitlines() + ] + if b"heroku-builds" not in plugins: + click.echo( + "Publishing to Heroku requires the heroku-builds plugin to be installed." + ) + click.confirm( + "Install it? (this will run `heroku plugins:install heroku-builds`)", + abort=True, + ) + call(["heroku", "plugins:install", "heroku-builds"]) + + with temporary_heroku_directory( + files, + name, + metadata, + extra_options, + branch, + template_dir, + plugins_dir, + static, + install, + extra_metadata, + ): + create_output = check_output(["heroku", "apps:create", "--json"]).decode( + "utf8" + ) + app_name = json.loads(create_output)["name"] + call(["heroku", "builds:create", "-a", app_name]) + + +@cli.command() +@click.argument("files", type=click.Path(exists=True), nargs=-1, required=True) +@click.option( + "-m", + "--metadata", + default="metadata.json", + help="Name of metadata file to generate", +) +@click.option( + "sqlite_extensions", + "--load-extension", + envvar="SQLITE_EXTENSIONS", + multiple=True, + type=click.Path(exists=True, resolve_path=True), + help="Path to a SQLite extension to load", +) +def skeleton(files, metadata, sqlite_extensions): + "Generate a skeleton metadata.json file for specified SQLite databases" + if os.path.exists(metadata): + click.secho( + "File {} already exists, will not over-write".format(metadata), + bg="red", + fg="white", + bold=True, + err=True, + ) + sys.exit(1) + app = Datasette(files, sqlite_extensions=sqlite_extensions) + databases = {} + for database_name, info in app.inspect().items(): + databases[database_name] = { + "title": None, + "description": None, + "description_html": None, + "license": None, + "license_url": None, + "source": None, + "source_url": None, + "queries": {}, + "tables": { + table_name: { + "title": None, + "description": None, + "description_html": None, + "license": None, + "license_url": None, + "source": None, + "source_url": None, + "units": {}, + } + for table_name in (info.get("tables") or {}) + }, + } + open(metadata, "w").write( + json.dumps( + { + "title": None, + "description": None, + "description_html": None, + "license": None, + "license_url": None, + "source": None, + "source_url": None, + "databases": databases, + }, + indent=4, + ) + ) + click.echo("Wrote skeleton to {}".format(metadata)) @cli.command() @@ -209,10 +338,10 @@ def plugins(all, requirements, plugins_dir): "-m", "--metadata", type=click.File(mode="r"), - help="Path to JSON/YAML file containing metadata to publish", + help="Path to JSON file containing metadata to publish", ) @click.option("--extra-options", help="Extra options to pass to datasette serve") -@click.option("--branch", help="Install datasette from a GitHub branch e.g. main") +@click.option("--branch", help="Install datasette from a GitHub branch e.g. master") @click.option( "--template-dir", type=click.Path(exists=True, file_okay=False, dir_okay=True), @@ -226,34 +355,22 @@ def plugins(all, requirements, plugins_dir): @click.option( "--static", type=StaticMount(), - help="Serve static files from this directory at /MOUNT/...", + help="mountpoint:path-to-directory for serving static files", multiple=True, ) @click.option( - "--install", help="Additional packages (e.g. plugins) to install", multiple=True -) -@click.option("--spatialite", is_flag=True, help="Enable SpatialLite extension") -@click.option("--version-note", help="Additional note to show on /-/versions") -@click.option( - "--secret", - help="Secret used for signing secure values, such as signed cookies", - envvar="DATASETTE_PUBLISH_SECRET", - default=lambda: os.urandom(32).hex(), + "--install", + help="Additional packages (e.g. plugins) to install", + multiple=True, ) @click.option( - "-p", - "--port", - default=8001, - type=click.IntRange(1, 65535), - help="Port to run the server on, defaults to 8001", + "--spatialite", is_flag=True, help="Enable SpatialLite extension" ) @click.option("--title", help="Title for metadata") @click.option("--license", help="License label for metadata") @click.option("--license_url", help="License URL for metadata") @click.option("--source", help="Source label for metadata") @click.option("--source_url", help="Source URL for metadata") -@click.option("--about", help="About label for metadata") -@click.option("--about_url", help="About URL for metadata") def package( files, tag, @@ -265,12 +382,9 @@ def package( static, install, spatialite, - version_note, - secret, - port, - **extra_metadata, + **extra_metadata ): - """Package SQLite files into a Datasette Docker container""" + "Package specified SQLite files into a new datasette Docker container" if not shutil.which("docker"): click.secho( ' The package command requires "docker" to be installed and configured ', @@ -283,18 +397,15 @@ def package( with temporary_docker_directory( files, "datasette", - metadata=metadata, - extra_options=extra_options, - branch=branch, - template_dir=template_dir, - plugins_dir=plugins_dir, - static=static, - install=install, - spatialite=spatialite, - version_note=version_note, - secret=secret, - extra_metadata=extra_metadata, - port=port, + metadata, + extra_options, + branch, + template_dir, + plugins_dir, + static, + install, + spatialite, + extra_metadata, ): args = ["docker", "build"] if tag: @@ -305,85 +416,30 @@ def package( @cli.command() -@click.argument("packages", nargs=-1) +@click.argument("files", type=click.Path(exists=True), nargs=-1) @click.option( - "-U", "--upgrade", is_flag=True, help="Upgrade packages to latest version" + "-h", "--host", default="127.0.0.1", help="host for server, defaults to 127.0.0.1" ) +@click.option("-p", "--port", default=8001, help="port for server, defaults to 8001") @click.option( - "-r", - "--requirement", - type=click.Path(exists=True), - help="Install from requirements file", -) -@click.option( - "-e", - "--editable", - help="Install a project in editable mode from this path", -) -def install(packages, upgrade, requirement, editable): - """Install plugins and packages from PyPI into the same environment as Datasette""" - if not packages and not requirement and not editable: - raise click.UsageError("Please specify at least one package to install") - args = ["pip", "install"] - if upgrade: - args += ["--upgrade"] - if editable: - args += ["--editable", editable] - if requirement: - args += ["-r", requirement] - args += list(packages) - sys.argv = args - run_module("pip", run_name="__main__") - - -@cli.command() -@click.argument("packages", nargs=-1, required=True) -@click.option("-y", "--yes", is_flag=True, help="Don't ask for confirmation") -def uninstall(packages, yes): - """Uninstall plugins and Python packages from the Datasette environment""" - sys.argv = ["pip", "uninstall"] + list(packages) + (["-y"] if yes else []) - run_module("pip", run_name="__main__") - - -@cli.command() -@click.argument("files", type=click.Path(), nargs=-1) -@click.option( - "-i", - "--immutable", - type=click.Path(exists=True), - help="Database files to open in immutable mode", - multiple=True, -) -@click.option( - "-h", - "--host", - default="127.0.0.1", - help=( - "Host for server. Defaults to 127.0.0.1 which means only connections " - "from the local machine will be allowed. Use 0.0.0.0 to listen to " - "all IPs and allow access from other machines." - ), -) -@click.option( - "-p", - "--port", - default=8001, - type=click.IntRange(0, 65535), - help="Port for server, defaults to 8001. Use -p 0 to automatically assign an available port.", -) -@click.option( - "--uds", - help="Bind to a Unix domain socket", + "--debug", is_flag=True, help="Enable debug mode - useful for development" ) @click.option( "--reload", is_flag=True, - help="Automatically reload if code or metadata change detected - useful for development", + help="Automatically reload if code change detected - useful for development", ) @click.option( "--cors", is_flag=True, help="Enable CORS by serving Access-Control-Allow-Origin: *" ) -@sqlite_extensions +@click.option( + "sqlite_extensions", + "--load-extension", + envvar="SQLITE_EXTENSIONS", + multiple=True, + type=click.Path(exists=True, resolve_path=True), + help="Path to a SQLite extension to load", +) @click.option( "--inspect-file", help='Path to JSON file created using "datasette inspect"' ) @@ -391,7 +447,7 @@ def uninstall(packages, yes): "-m", "--metadata", type=click.File(mode="r"), - help="Path to JSON/YAML file containing license/source metadata", + help="Path to JSON file containing license/source metadata", ) @click.option( "--template-dir", @@ -406,102 +462,25 @@ def uninstall(packages, yes): @click.option( "--static", type=StaticMount(), - help="Serve static files from this directory at /MOUNT/...", + help="mountpoint:path-to-directory for serving static files", multiple=True, ) -@click.option("--memory", is_flag=True, help="Make /_memory database available") @click.option( - "-c", "--config", - type=click.File(mode="r"), - help="Path to JSON/YAML Datasette configuration file", -) -@click.option( - "-s", - "--setting", - "settings", - type=Setting(), - help="nested.key, value setting to use in Datasette configuration", + type=Config(), + help="Set config option using configname:value datasette.readthedocs.io/en/latest/config.html", multiple=True, ) @click.option( - "--secret", - help="Secret used for signing secure values, such as signed cookies", - envvar="DATASETTE_SECRET", -) -@click.option( - "--root", - help="Output URL that sets a cookie authenticating the root user", + "--help-config", is_flag=True, -) -@click.option( - "--default-deny", - help="Deny all permissions by default", - is_flag=True, -) -@click.option( - "--get", - help="Run an HTTP GET request against this path, print results and exit", -) -@click.option( - "--headers", - is_flag=True, - help="Include HTTP headers in --get output", -) -@click.option( - "--token", - help="API token to send with --get requests", -) -@click.option( - "--actor", - help="Actor to use for --get requests (JSON string)", -) -@click.option("--version-note", help="Additional note to show on /-/versions") -@click.option("--help-settings", is_flag=True, help="Show available settings") -@click.option("--pdb", is_flag=True, help="Launch debugger on any errors") -@click.option( - "-o", - "--open", - "open_browser", - is_flag=True, - help="Open Datasette in your web browser", -) -@click.option( - "--create", - is_flag=True, - help="Create database files if they do not exist", -) -@click.option( - "--crossdb", - is_flag=True, - help="Enable cross-database joins using the /_memory database", -) -@click.option( - "--nolock", - is_flag=True, - help="Ignore locking, open locked files in read-only mode", -) -@click.option( - "--ssl-keyfile", - help="SSL key file", - envvar="DATASETTE_SSL_KEYFILE", -) -@click.option( - "--ssl-certfile", - help="SSL certificate file", - envvar="DATASETTE_SSL_CERTFILE", -) -@click.option( - "--internal", - type=click.Path(), - help="Path to a persistent Datasette internal SQLite database", + help="Show available config options", ) def serve( files, - immutable, host, port, - uds, + debug, reload, cors, sqlite_extensions, @@ -510,391 +489,49 @@ def serve( template_dir, plugins_dir, static, - memory, config, - settings, - secret, - root, - default_deny, - get, - headers, - token, - actor, - version_note, - help_settings, - pdb, - open_browser, - create, - crossdb, - nolock, - ssl_keyfile, - ssl_certfile, - internal, - return_instance=False, + help_config, ): """Serve up specified SQLite database files with a web UI""" - if help_settings: + if help_config: formatter = formatting.HelpFormatter() - with formatter.section("Settings"): - formatter.write_dl( - [ - (option.name, f"{option.help} (default={option.default})") - for option in SETTINGS - ] - ) + with formatter.section("Config options"): + formatter.write_dl([ + (option.name, '{} (default={})'.format( + option.help, option.default + )) + for option in CONFIG_OPTIONS + ]) click.echo(formatter.getvalue()) sys.exit(0) if reload: import hupper - reloader = hupper.start_reloader("datasette.cli.cli") - if immutable: - reloader.watch_files(immutable) - if config: - reloader.watch_files([config.name]) + reloader = hupper.start_reloader("datasette.cli.serve") if metadata: reloader.watch_files([metadata.name]) inspect_data = None if inspect_file: - with open(inspect_file) as fp: - inspect_data = json.load(fp) + inspect_data = json.load(open(inspect_file)) metadata_data = None if metadata: - metadata_data = parse_metadata(metadata.read()) + metadata_data = json.loads(metadata.read()) - config_data = None - if config: - config_data = parse_metadata(config.read()) - - config_data = config_data or {} - - # Merge in settings from -s/--setting - if settings: - settings_updates = pairs_to_nested_config(settings) - # Merge recursively, to avoid over-writing nested values - # https://github.com/simonw/datasette/issues/2389 - deep_dict_update(config_data, settings_updates) - - kwargs = dict( - immutables=immutable, - cache_headers=not reload, + click.echo("Serve! files={} on port {}".format(files, port)) + ds = Datasette( + files, + cache_headers=not debug and not reload, cors=cors, inspect_data=inspect_data, - config=config_data, metadata=metadata_data, sqlite_extensions=sqlite_extensions, template_dir=template_dir, plugins_dir=plugins_dir, static_mounts=static, - settings=None, # These are passed in config= now - memory=memory, - secret=secret, - version_note=version_note, - pdb=pdb, - crossdb=crossdb, - nolock=nolock, - internal=internal, - default_deny=default_deny, + config=dict(config), ) - - # Separate directories from files - directories = [f for f in files if os.path.isdir(f)] - file_paths = [f for f in files if not os.path.isdir(f)] - - # Handle config_dir - only one directory allowed - if len(directories) > 1: - raise click.ClickException( - "Cannot pass multiple directories. Pass a single directory as config_dir." - ) - elif len(directories) == 1: - kwargs["config_dir"] = pathlib.Path(directories[0]) - - # Verify list of files, create if needed (and --create) - for file in file_paths: - if not pathlib.Path(file).exists(): - if create: - conn = sqlite3.connect(file) - conn.execute("vacuum") - conn.close() - else: - raise click.ClickException( - "Invalid value for '[FILES]...': Path '{}' does not exist.".format( - file - ) - ) - - # Check for duplicate files by resolving all paths to their absolute forms - # Collect all database files that will be loaded (explicit files + config_dir files) - all_db_files = [] - - # Add explicit files - for file in file_paths: - all_db_files.append((file, pathlib.Path(file).resolve())) - - # Add config_dir databases if config_dir is set - if "config_dir" in kwargs: - config_dir = kwargs["config_dir"] - for ext in ("db", "sqlite", "sqlite3"): - for db_file in config_dir.glob(f"*.{ext}"): - all_db_files.append((str(db_file), db_file.resolve())) - - # Check for duplicates - seen = {} - for original_path, resolved_path in all_db_files: - if resolved_path in seen: - raise click.ClickException( - f"Duplicate database file: '{original_path}' and '{seen[resolved_path]}' " - f"both refer to {resolved_path}" - ) - seen[resolved_path] = original_path - - files = file_paths - - try: - ds = Datasette(files, **kwargs) - except SpatialiteNotFound: - raise click.ClickException("Could not find SpatiaLite extension") - except StartupError as e: - raise click.ClickException(e.args[0]) - - if return_instance: - # Private utility mechanism for writing unit tests - return ds - - # Run async soundness checks before startup hooks, since invoke_startup - # now populates internal tables which requires querying each database - run_sync(lambda: check_databases(ds)) - - # Run the "startup" plugin hooks - try: - run_sync(ds.invoke_startup) - except StartupError as e: - raise click.ClickException(e.args[0]) - - if headers and not get: - raise click.ClickException("--headers can only be used with --get") - - if token and not get: - raise click.ClickException("--token can only be used with --get") - - if get: - client = TestClient(ds) - request_headers = {} - if token: - request_headers["Authorization"] = "Bearer {}".format(token) - cookies = {} - if actor: - cookies["ds_actor"] = client.actor_cookie(json.loads(actor)) - response = client.get(get, headers=request_headers, cookies=cookies) - - if headers: - # Output HTTP status code, headers, two newlines, then the response body - click.echo(f"HTTP/1.1 {response.status}") - for key, value in response.headers.items(): - click.echo(f"{key}: {value}") - if response.text: - click.echo() - click.echo(response.text) - else: - click.echo(response.text) - - exit_code = 0 if response.status == 200 else 1 - sys.exit(exit_code) - return - - # Start the server - url = None - if root: - ds.root_enabled = True - url = "http://{}:{}{}?token={}".format( - host, port, ds.urls.path("-/auth-token"), ds._root_token - ) - click.echo(url) - if open_browser: - if url is None: - # Figure out most convenient URL - to table, database or homepage - path = run_sync(lambda: initial_path_for_datasette(ds)) - url = f"http://{host}:{port}{path}" - webbrowser.open(url) - uvicorn_kwargs = dict( - host=host, port=port, log_level="info", lifespan="on", workers=1 - ) - if uds: - uvicorn_kwargs["uds"] = uds - if ssl_keyfile: - uvicorn_kwargs["ssl_keyfile"] = ssl_keyfile - if ssl_certfile: - uvicorn_kwargs["ssl_certfile"] = ssl_certfile - uvicorn.run(ds.app(), **uvicorn_kwargs) - - -@cli.command() -@click.argument("id") -@click.option( - "--secret", - help="Secret used for signing the API tokens", - envvar="DATASETTE_SECRET", - required=True, -) -@click.option( - "-e", - "--expires-after", - help="Token should expire after this many seconds", - type=int, -) -@click.option( - "alls", - "-a", - "--all", - type=str, - metavar="ACTION", - multiple=True, - help="Restrict token to this action", -) -@click.option( - "databases", - "-d", - "--database", - type=(str, str), - metavar="DB ACTION", - multiple=True, - help="Restrict token to this action on this database", -) -@click.option( - "resources", - "-r", - "--resource", - type=(str, str, str), - metavar="DB RESOURCE ACTION", - multiple=True, - help="Restrict token to this action on this database resource (a table, SQL view or named query)", -) -@click.option( - "--debug", - help="Show decoded token", - is_flag=True, -) -@click.option( - "--plugins-dir", - type=click.Path(exists=True, file_okay=False, dir_okay=True), - help="Path to directory containing custom plugins", -) -def create_token( - id, secret, expires_after, alls, databases, resources, debug, plugins_dir -): - """ - Create a signed API token for the specified actor ID - - Example: - - datasette create-token root --secret mysecret - - To allow only "view-database-download" for all databases: - - \b - datasette create-token root --secret mysecret \\ - --all view-database-download - - To allow "create-table" against a specific database: - - \b - datasette create-token root --secret mysecret \\ - --database mydb create-table - - To allow "insert-row" against a specific table: - - \b - datasette create-token root --secret myscret \\ - --resource mydb mytable insert-row - - Restricted actions can be specified multiple times using - multiple --all, --database, and --resource options. - - Add --debug to see a decoded version of the token. - """ - ds = Datasette(secret=secret, plugins_dir=plugins_dir) - - # Run ds.invoke_startup() in an event loop - try: - run_sync(ds.invoke_startup) - except StartupError as e: - raise click.ClickException(e.args[0]) - - # Warn about any unknown actions - actions = [] - actions.extend(alls) - actions.extend([p[1] for p in databases]) - actions.extend([p[2] for p in resources]) - for action in actions: - if not ds.actions.get(action): - click.secho( - f" Unknown permission: {action} ", - fg="red", - err=True, - ) - - from datasette.tokens import TokenRestrictions - - restrictions = TokenRestrictions() - for action in alls: - restrictions.allow_all(action) - for database, action in databases: - restrictions.allow_database(database, action) - for database, resource, action in resources: - restrictions.allow_resource(database, resource, action) - - token = run_sync( - lambda: ds.create_token( - id, - expires_after=expires_after, - restrictions=restrictions, - handler="signed", - ) - ) - click.echo(token) - if debug: - encoded = token[len("dstok_") :] - click.echo("\nDecoded:\n") - click.echo(json.dumps(ds.unsign(encoded, namespace="token"), indent=2)) - - -pm.hook.register_commands(cli=cli) - - -async def check_databases(ds): - # Run check_connection against every connected database - # to confirm they are all usable - for database in list(ds.databases.values()): - try: - await database.execute_fn(check_connection) - except SpatialiteConnectionProblem: - suggestion = "" - try: - find_spatialite() - suggestion = "\n\nTry adding the --load-extension=spatialite option." - except SpatialiteNotFound: - pass - raise click.UsageError( - "It looks like you're trying to load a SpatiaLite" - + " database without first loading the SpatiaLite module." - + suggestion - + "\n\nRead more: https://docs.datasette.io/en/stable/spatialite.html" - ) - except ConnectionProblem as e: - raise click.UsageError( - f"Connection to {database.path} failed check: {str(e.args[0])}" - ) - # If --crossdb and more than SQLITE_LIMIT_ATTACHED show warning - if ( - ds.crossdb - and len([db for db in ds.databases.values() if not db.is_memory]) - > SQLITE_LIMIT_ATTACHED - ): - msg = ( - "Warning: --crossdb only works with the first {} attached databases".format( - SQLITE_LIMIT_ATTACHED - ) - ) - click.echo(click.style(msg, bold=True, fg="yellow"), err=True) + # Force initial hashing/table counting + ds.inspect() + ds.app().run(host=host, port=port, debug=debug) diff --git a/datasette/column_types.py b/datasette/column_types.py deleted file mode 100644 index 7320e1d6..00000000 --- a/datasette/column_types.py +++ /dev/null @@ -1,83 +0,0 @@ -from enum import Enum - - -class SQLiteType(Enum): - TEXT = "TEXT" - INTEGER = "INTEGER" - REAL = "REAL" - BLOB = "BLOB" - NULL = "NULL" - - @classmethod - def from_declared_type(cls, declared_type: str | None) -> "SQLiteType | None": - if declared_type is None: - return cls.NULL - - normalized = declared_type.strip().upper() - if not normalized: - return cls.NULL - - if normalized == cls.NULL.value: - return cls.NULL - if "INT" in normalized: - return cls.INTEGER - if any(token in normalized for token in ("CHAR", "CLOB", "TEXT")): - return cls.TEXT - if "BLOB" in normalized: - return cls.BLOB - if any( - token in normalized - for token in ("REAL", "FLOA", "DOUB") # codespell:ignore doub - ): - return cls.REAL - - return None - - -class ColumnType: - """ - Base class for column types. - - Subclasses must define ``name`` and ``description`` as class attributes: - - - ``name``: Unique identifier string. Lowercase, no spaces. - Examples: "markdown", "file", "email", "url", "point", "image". - - ``description``: Human-readable label for admin UI dropdowns. - Examples: "Markdown text", "File reference", "Email address". - - ``sqlite_types``: Optional tuple of SQLiteType values restricting - which SQLite column types this ColumnType can be assigned to. - - Instantiate with an optional ``config`` dict to bind per-column - configuration:: - - ct = MyColumnType(config={"key": "value"}) - ct.config # {"key": "value"} - """ - - name: str - description: str - sqlite_types: tuple[SQLiteType, ...] | None = None - - def __init__(self, config=None): - self.config = config - - async def render_cell(self, value, column, table, database, datasette, request): - """ - Return an HTML string to render this cell value, or None to - fall through to the default render_cell plugin hook chain. - """ - return None - - async def validate(self, value, datasette): - """ - Validate a value before it is written. Return None if valid, - or a string error message if invalid. - """ - return None - - async def transform_value(self, value, datasette): - """ - Transform a value before it appears in JSON API output. - Return the transformed value. Default: return unchanged. - """ - return value diff --git a/datasette/csrf.py b/datasette/csrf.py deleted file mode 100644 index df239aee..00000000 --- a/datasette/csrf.py +++ /dev/null @@ -1,178 +0,0 @@ -""" -Header-based CSRF (Cross-Origin) protection. - -Datasette uses the Sec-Fetch-Site + Origin header approach described in -Filippo Valsorda's article (https://words.filippo.io/csrf/) and implemented -in Go 1.25's http.CrossOriginProtection. This replaces the previous -token-based asgi-csrf mechanism. -""" - -from __future__ import annotations - -import secrets -import urllib.parse - -from .utils.asgi import asgi_send - -SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"}) - -DEFAULT_PORTS = {"http": 80, "https": 443, "ws": 80, "wss": 443} - - -def _normalize_headers(raw_headers): - """Lowercase header names; for duplicates, last value wins.""" - result = {} - for name, value in raw_headers: - if isinstance(name, str): - name = name.encode("latin-1") - if isinstance(value, str): - value = value.encode("latin-1") - result[name.lower()] = value - return result - - -def _origin_tuple(value): - """ - Parse an origin-like string into ``(scheme, host, port)`` with default - ports filled in. Raises ``ValueError`` for malformed input. - """ - parsed = urllib.parse.urlsplit(value) - scheme = (parsed.scheme or "").lower() - host = (parsed.hostname or "").lower() - if not scheme or not host: - raise ValueError("missing scheme or host in {!r}".format(value)) - port = parsed.port # may raise ValueError on bad ports - if port is None: - port = DEFAULT_PORTS.get(scheme) - if port is None: - raise ValueError("unknown default port for scheme {!r}".format(scheme)) - return scheme, host, port - - -def _install_legacy_csrftoken(scope): - """ - Populate ``scope["csrftoken"]`` with a callable returning a per-request - random token. Provided for plugin compatibility only - core no longer - uses this value for CSRF enforcement. - """ - - def csrftoken(): - if "_datasette_legacy_csrftoken" not in scope: - scope["_datasette_legacy_csrftoken"] = secrets.token_urlsafe(32) - return scope["_datasette_legacy_csrftoken"] - - scope["csrftoken"] = csrftoken - - -class CrossOriginProtectionMiddleware: - """ - Modern CSRF protection using the Sec-Fetch-Site and Origin headers. - - Based on Filippo Valsorda's algorithm, as implemented in Go 1.25's - http.CrossOriginProtection. See https://words.filippo.io/csrf/ - - Unsafe-method requests are allowed through only if they look same-origin. - Non-browser clients (curl, etc.) send neither Sec-Fetch-Site nor Origin - and are passed through unchanged - CSRF is a browser-only attack. - """ - - SAFE_METHODS = SAFE_METHODS - - def __init__(self, app, datasette): - self.app = app - self.datasette = datasette - - async def __call__(self, scope, receive, send): - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - _install_legacy_csrftoken(scope) - - if scope.get("method", "GET") in self.SAFE_METHODS: - await self.app(scope, receive, send) - return - - headers = _normalize_headers(scope.get("headers") or []) - - authorization = headers.get(b"authorization", b"").decode("latin-1") - cookie_header = headers.get(b"cookie") - # Bearer-token requests are not ambient browser credentials, so they - # are not CSRF-vulnerable. Narrowly exempt them from the header check - # before evaluating Sec-Fetch-Site / Origin. Only "Bearer" is exempt; - # schemes like Basic or Digest can be browser-managed and ambient. - # If the request also carries a Cookie header, ambient cookie auth - # could be in play, so do NOT treat it as exempt. - if authorization and not cookie_header: - parts = authorization.split(None, 1) - if parts and parts[0].lower() == "bearer": - await self.app(scope, receive, send) - return - - origin_bytes = headers.get(b"origin") - sec_fetch_site_bytes = headers.get(b"sec-fetch-site") - host_bytes = headers.get(b"host", b"") - origin = origin_bytes.decode("latin-1") if origin_bytes else None - sec_fetch_site = ( - sec_fetch_site_bytes.decode("latin-1") if sec_fetch_site_bytes else None - ) - host = host_bytes.decode("latin-1") - - # Primary defense: Sec-Fetch-Site (set by browsers, unforgeable from JS) - if sec_fetch_site is not None: - if sec_fetch_site in ("same-origin", "none"): - await self.app(scope, receive, send) - return - await self._forbid( - send, - "Sec-Fetch-Site was {!r}, expected 'same-origin' or 'none'".format( - sec_fetch_site - ), - ) - return - - # No Sec-Fetch-Site and no Origin -> non-browser client (curl, API, etc.) - if origin is None: - await self.app(scope, receive, send) - return - - # Fallback for older browsers: Origin must match the request's own - # scheme + host + port. Compare full origin tuples, not host alone. - request_scheme = self._request_scheme(scope) - try: - origin_tuple = _origin_tuple(origin) - expected_tuple = _origin_tuple("{}://{}".format(request_scheme, host)) - except ValueError: - await self._forbid( - send, - "Malformed Origin {!r} or Host {!r}".format(origin, host), - ) - return - - if origin_tuple == expected_tuple: - await self.app(scope, receive, send) - return - - await self._forbid( - send, - "Origin {!r} does not match Host {!r}".format(origin, host), - ) - - def _request_scheme(self, scope): - if self.datasette is not None: - try: - if self.datasette.setting("force_https_urls"): - return "https" - except Exception: - pass - return scope.get("scheme") or "http" - - async def _forbid(self, send, reason): - await asgi_send( - send, - content=await self.datasette.render_template( - "csrf_error.html", {"reason": reason} - ), - status=403, - content_type="text/html; charset=utf-8", - ) diff --git a/datasette/database.py b/datasette/database.py deleted file mode 100644 index e7e9527e..00000000 --- a/datasette/database.py +++ /dev/null @@ -1,985 +0,0 @@ -import asyncio -import atexit -from collections import namedtuple -import inspect -import os -from pathlib import Path -import queue -import sqlite_utils -import sys -import tempfile -import threading -import uuid - -from .tracer import trace -from .utils import ( - call_with_supported_arguments, - detect_fts, - detect_primary_keys, - detect_spatialite, - get_all_foreign_keys, - get_outbound_foreign_keys, - md5_not_usedforsecurity, - sqlite_timelimit, - sqlite3, - table_columns, - table_column_details, -) -from .utils.sql_analysis import SQLAnalysis, analyze_sql_tables -from .utils.sqlite import sqlite_version -from .inspect import inspect_hash - -connections = threading.local() - -AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file")) - - -class DatasetteClosedError(RuntimeError): - """Raised when using a Datasette or Database instance after close().""" - - -_SHUTDOWN = object() - - -class Database: - # For table counts stop at this many rows: - count_limit = 10000 - _thread_local_id_counter = 1 - - def __init__( - self, - ds, - path=None, - is_mutable=True, - is_memory=False, - memory_name=None, - mode=None, - is_temp_disk=False, - ): - self.name = None - self._thread_local_id = f"x{self._thread_local_id_counter}" - Database._thread_local_id_counter += 1 - self.route = None - self.ds = ds - self.path = path - self.is_mutable = is_mutable - self.is_memory = is_memory - self.memory_name = memory_name - self.is_temp_disk = is_temp_disk - if memory_name is not None: - self.is_memory = True - if is_temp_disk: - fd, temp_path = tempfile.mkstemp(suffix=".db", prefix="datasette_temp_") - os.close(fd) - self.path = temp_path - self.is_mutable = True - self.mode = "rwc" - self._wal_enabled = False - atexit.register(self._cleanup_temp_file) - else: - self._wal_enabled = False - self.cached_hash = None - self.cached_size = None - self._cached_table_counts = None - self._write_thread = None - self._write_queue = None - self._closed = False - self._pending_execute_futures = set() - self._pending_execute_futures_lock = threading.Lock() - # These are used when in non-threaded mode: - self._read_connection = None - self._write_connection = None - # This is used to track all file connections so they can be closed - self._all_file_connections = [] - if not is_temp_disk: - self.mode = mode - - def _check_not_closed(self): - if self._closed: - raise DatasetteClosedError( - "Database {!r} has been closed".format(self.name) - ) - - def _remove_pending_execute_future(self, future): - with self._pending_execute_futures_lock: - self._pending_execute_futures.discard(future) - - @property - def cached_table_counts(self): - if self._cached_table_counts is not None: - return self._cached_table_counts - # Maybe use self.ds.inspect_data to populate cached_table_counts - if self.ds.inspect_data and self.ds.inspect_data.get(self.name): - self._cached_table_counts = { - key: value["count"] - for key, value in self.ds.inspect_data[self.name]["tables"].items() - } - return self._cached_table_counts - - @property - def color(self): - if self.hash: - return self.hash[:6] - return md5_not_usedforsecurity(self.name)[:6] - - def suggest_name(self): - if self.is_temp_disk: - return "_temp_disk" - if self.path: - return Path(self.path).stem - elif self.memory_name: - return self.memory_name - else: - return "db" - - def connect(self, write=False): - extra_kwargs = {} - if write: - extra_kwargs["isolation_level"] = "IMMEDIATE" - if self.memory_name: - uri = "file:{}?mode=memory&cache=shared".format(self.memory_name) - conn = sqlite3.connect( - uri, uri=True, check_same_thread=False, **extra_kwargs - ) - if not write: - conn.execute("PRAGMA query_only=1") - return conn - if self.is_memory: - return sqlite3.connect(":memory:", uri=True) - - # mode=ro or immutable=1? - if self.is_mutable: - qs = "?mode=ro" - if self.ds.nolock: - qs += "&nolock=1" - else: - qs = "?immutable=1" - assert not (write and not self.is_mutable) - if write: - qs = "" - if self.mode is not None: - qs = f"?mode={self.mode}" - conn = sqlite3.connect( - f"file:{self.path}{qs}", uri=True, check_same_thread=False, **extra_kwargs - ) - self._all_file_connections.append(conn) - if self.is_temp_disk and not self._wal_enabled: - conn.execute("PRAGMA journal_mode=WAL") - self._wal_enabled = True - return conn - - def close(self): - """Release all resources held by this database. - - Idempotent. After close() further calls to execute()/execute_fn()/ - execute_write()/execute_write_fn() raise DatasetteClosedError. - """ - if self._closed: - return - with self._pending_execute_futures_lock: - if self._closed: - return - self._closed = True - pending_execute_futures = tuple(self._pending_execute_futures) - # Shut down the write thread, if any, via a sentinel. The thread - # drains any writes already queued before the sentinel and then - # closes its own write connection and returns. - write_thread = self._write_thread - if write_thread is not None and self._write_queue is not None: - self._write_queue.put(_SHUTDOWN) - write_thread.join(timeout=10) - if write_thread.is_alive(): - sys.stderr.write( - "Datasette: write thread for {!r} did not exit within 10s\n".format( - self.name - ) - ) - sys.stderr.flush() - for future in pending_execute_futures: - try: - future.result() - except Exception: - pass - # Close anything still tracked in _all_file_connections - for connection in self._all_file_connections: - try: - connection.close() - except Exception: - pass - self._all_file_connections = [] - # Drop per-thread cached read connections we can reach - try: - delattr(connections, self._thread_local_id) - except AttributeError: - pass - # Close non-threaded-mode cached connections if still open - if self._read_connection is not None: - try: - self._read_connection.close() - except Exception: - pass - self._read_connection = None - if self._write_connection is not None: - try: - self._write_connection.close() - except Exception: - pass - self._write_connection = None - if self.is_temp_disk: - self._cleanup_temp_file() - - def _cleanup_temp_file(self): - if self.is_temp_disk and self.path: - for suffix in ("", "-wal", "-shm"): - try: - os.unlink(self.path + suffix) - except OSError: - pass - - async def execute_write(self, sql, params=None, block=True, request=None): - self._check_not_closed() - - def _inner(conn): - return conn.execute(sql, params or []) - - with trace("sql", database=self.name, sql=sql.strip(), params=params): - results = await self.execute_write_fn(_inner, block=block, request=request) - return results - - async def execute_write_script(self, sql, block=True, request=None): - self._check_not_closed() - - def _inner(conn): - return conn.executescript(sql) - - with trace("sql", database=self.name, sql=sql.strip(), executescript=True): - results = await self.execute_write_fn( - _inner, block=block, transaction=False, request=request - ) - return results - - async def execute_write_many(self, sql, params_seq, block=True, request=None): - self._check_not_closed() - - def _inner(conn): - count = 0 - - def count_params(params): - nonlocal count - for param in params: - count += 1 - yield param - - return conn.executemany(sql, count_params(params_seq)), count - - with trace( - "sql", database=self.name, sql=sql.strip(), executemany=True - ) as kwargs: - results, count = await self.execute_write_fn( - _inner, block=block, request=request - ) - kwargs["count"] = count - return results - - async def execute_isolated_fn(self, fn): - self._check_not_closed() - # Open a new connection just for the duration of this function - # blocking the write queue to avoid any writes occurring during it - if self.ds.executor is None: - # non-threaded mode - isolated_connection = self.connect(write=True) - try: - result = fn(isolated_connection) - finally: - isolated_connection.close() - try: - self._all_file_connections.remove(isolated_connection) - except ValueError: - # Was probably a memory connection - pass - return result - else: - # Threaded mode - send to write thread - return await self._send_to_write_thread(fn, isolated_connection=True) - - async def analyze_sql(self, sql, params=None) -> SQLAnalysis: - self._check_not_closed() - - return await self.execute_isolated_fn( - lambda conn: analyze_sql_tables(conn, sql, params, database_name=self.name) - ) - - async def execute_write_fn(self, fn, block=True, transaction=True, request=None): - self._check_not_closed() - pending_events = [] - - def track_event(event): - pending_events.append(event) - - fn = self._wrap_fn_with_hooks(fn, request, transaction, track_event) - if self.ds.executor is None: - # non-threaded mode - if self._write_connection is None: - self._write_connection = self.connect(write=True) - self.ds._prepare_connection(self._write_connection, self.name) - if transaction: - with self._write_connection: - result = fn(self._write_connection) - else: - result = fn(self._write_connection) - else: - result = await self._send_to_write_thread( - fn, block=block, transaction=transaction - ) - if block: - for event in pending_events: - await self.ds.track_event(event) - else: - # For non-blocking writes, spawn a background task to - # dispatch events after the write thread completes - task_id, reply_future = result - - async def _dispatch_events_after_write(): - try: - await reply_future - except Exception: - # if the write failed, don't emit success events - return - for event in pending_events: - await self.ds.track_event(event) - - asyncio.ensure_future(_dispatch_events_after_write()) - result = task_id - return result - - def _wrap_fn_with_hooks(self, fn, request, transaction, track_event): - from .plugins import pm - - # Wrap fn so it receives track_event if its signature supports it. - # Historically fn was called positionally, so any single-parameter - # name (conn, connection, db, ...) worked. Preserve that by only - # switching to keyword dependency injection when the callback - # explicitly opts in by declaring a `track_event` parameter. - original_fn = fn - - if "track_event" in inspect.signature(original_fn).parameters: - - def fn_with_track_event(conn): - return call_with_supported_arguments( - original_fn, conn=conn, track_event=track_event - ) - - fn = fn_with_track_event - - wrappers = pm.hook.write_wrapper( - datasette=self.ds, - database=self.name, - request=request, - transaction=transaction, - ) - wrappers = [w for w in wrappers if w is not None] - if not wrappers: - return fn - # Build the wrapped fn by nesting context manager generators. - # The first wrapper returned by pluggy is outermost. - for wrapper_factory in reversed(wrappers): - fn = _apply_write_wrapper(fn, wrapper_factory, track_event) - return fn - - async def _send_to_write_thread( - self, fn, block=True, isolated_connection=False, transaction=True - ): - if self._write_queue is None: - self._write_queue = queue.Queue() - if self._write_thread is None: - self._write_thread = threading.Thread( - target=self._execute_writes, daemon=True - ) - self._write_thread.name = "_execute_writes for database {}".format( - self.name - ) - self._write_thread.start() - task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") - loop = asyncio.get_running_loop() - reply_future = loop.create_future() - self._write_queue.put( - WriteTask(fn, task_id, loop, reply_future, isolated_connection, transaction) - ) - if block: - return await reply_future - else: - return task_id, reply_future - - def _execute_writes(self): - # Infinite looping thread that protects the single write connection - # to this database - conn_exception = None - conn = None - try: - conn = self.connect(write=True) - self.ds._prepare_connection(conn, self.name) - except Exception as e: - conn_exception = e - while True: - task = self._write_queue.get() - if task is _SHUTDOWN: - if conn is not None: - try: - conn.close() - except Exception: - pass - return - exception = None - result = None - if conn_exception is not None: - exception = conn_exception - elif task.isolated_connection: - isolated_connection = self.connect(write=True) - try: - result = task.fn(isolated_connection) - except Exception as e: - sys.stderr.write("{}\n".format(e)) - sys.stderr.flush() - exception = e - finally: - isolated_connection.close() - try: - self._all_file_connections.remove(isolated_connection) - except ValueError: - # Was probably a memory connection - pass - else: - try: - if task.transaction: - with conn: - result = task.fn(conn) - else: - result = task.fn(conn) - except Exception as e: - sys.stderr.write("{}\n".format(e)) - sys.stderr.flush() - exception = e - _deliver_write_result(task, result, exception) - - async def execute_fn(self, fn): - self._check_not_closed() - if self.ds.executor is None: - # non-threaded mode - if self._read_connection is None: - self._read_connection = self.connect() - self.ds._prepare_connection(self._read_connection, self.name) - return fn(self._read_connection) - - # threaded mode - def in_thread(): - conn = getattr(connections, self._thread_local_id, None) - if not conn: - conn = self.connect() - self.ds._prepare_connection(conn, self.name) - setattr(connections, self._thread_local_id, conn) - return fn(conn) - - with self._pending_execute_futures_lock: - self._check_not_closed() - future = self.ds.executor.submit(in_thread) - self._pending_execute_futures.add(future) - future.add_done_callback(self._remove_pending_execute_future) - return await asyncio.wrap_future(future) - - async def execute( - self, - sql, - params=None, - truncate=False, - custom_time_limit=None, - page_size=None, - log_sql_errors=True, - ): - """Executes sql against db_name in a thread""" - self._check_not_closed() - page_size = page_size or self.ds.page_size - - def sql_operation_in_thread(conn): - time_limit_ms = self.ds.sql_time_limit_ms - if custom_time_limit and custom_time_limit < time_limit_ms: - time_limit_ms = custom_time_limit - - with sqlite_timelimit(conn, time_limit_ms): - try: - cursor = conn.cursor() - cursor.execute(sql, params if params is not None else {}) - max_returned_rows = self.ds.max_returned_rows - if max_returned_rows == page_size: - max_returned_rows += 1 - if max_returned_rows and truncate: - rows = cursor.fetchmany(max_returned_rows + 1) - truncated = len(rows) > max_returned_rows - rows = rows[:max_returned_rows] - else: - rows = cursor.fetchall() - truncated = False - except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: - if e.args == ("interrupted",): - raise QueryInterrupted(e, sql, params) - if log_sql_errors: - sys.stderr.write( - "ERROR: conn={}, sql = {}, params = {}: {}\n".format( - conn, repr(sql), params, e - ) - ) - sys.stderr.flush() - raise - - if truncate: - return Results(rows, truncated, cursor.description) - - else: - return Results(rows, False, cursor.description) - - with trace("sql", database=self.name, sql=sql.strip(), params=params): - results = await self.execute_fn(sql_operation_in_thread) - return results - - @property - def hash(self): - if self.cached_hash is not None: - return self.cached_hash - elif self.is_mutable or self.is_memory or self.is_temp_disk: - return None - elif self.ds.inspect_data and self.ds.inspect_data.get(self.name): - self.cached_hash = self.ds.inspect_data[self.name]["hash"] - return self.cached_hash - else: - p = Path(self.path) - self.cached_hash = inspect_hash(p) - return self.cached_hash - - @property - def size(self): - if self.cached_size is not None: - return self.cached_size - elif self.is_memory: - return 0 - elif self.is_mutable: - return Path(self.path).stat().st_size - elif self.ds.inspect_data and self.ds.inspect_data.get(self.name): - self.cached_size = self.ds.inspect_data[self.name]["size"] - return self.cached_size - else: - self.cached_size = Path(self.path).stat().st_size - return self.cached_size - - async def table_counts(self, limit=10): - if not self.is_mutable and self.cached_table_counts is not None: - return self.cached_table_counts - # Try to get counts for each table, $limit timeout for each count - counts = {} - for table in await self.table_names(): - try: - table_count = ( - await self.execute( - f"select count(*) from (select * from [{table}] limit {self.count_limit + 1})", - custom_time_limit=limit, - ) - ).rows[0][0] - counts[table] = table_count - # In some cases I saw "SQL Logic Error" here in addition to - # QueryInterrupted - so we catch that too: - except (QueryInterrupted, sqlite3.OperationalError, sqlite3.DatabaseError): - counts[table] = None - if not self.is_mutable: - self._cached_table_counts = counts - return counts - - @property - def mtime_ns(self): - if self.is_memory: - return None - return Path(self.path).stat().st_mtime_ns - - async def attached_databases(self): - # This used to be: - # select seq, name, file from pragma_database_list() where seq > 0 - # But SQLite prior to 3.16.0 doesn't support pragma functions - results = await self.execute("PRAGMA database_list;") - # {'seq': 0, 'name': 'main', 'file': ''} - return [ - AttachedDatabase(*row) - for row in results.rows - # Filter out the SQLite internal "temp" database, refs #2557 - if row["seq"] > 0 and row["name"] != "temp" - ] - - async def table_exists(self, table): - results = await self.execute( - "select 1 from sqlite_master where type='table' and name=?", params=(table,) - ) - return bool(results.rows) - - async def view_exists(self, table): - results = await self.execute( - "select 1 from sqlite_master where type='view' and name=?", params=(table,) - ) - return bool(results.rows) - - async def table_names(self): - results = await self.execute( - "select name from sqlite_master where type='table' order by name" - ) - return [r[0] for r in results.rows] - - async def table_columns(self, table): - return await self.execute_fn(lambda conn: table_columns(conn, table)) - - async def table_column_details(self, table): - return await self.execute_fn(lambda conn: table_column_details(conn, table)) - - async def primary_keys(self, table): - return await self.execute_fn(lambda conn: detect_primary_keys(conn, table)) - - async def fts_table(self, table): - return await self.execute_fn(lambda conn: detect_fts(conn, table)) - - async def label_column_for_table(self, table): - explicit_label_column = (await self.ds.table_config(self.name, table)).get( - "label_column" - ) - if explicit_label_column: - return explicit_label_column - - def column_details(conn): - # Returns {column_name: (type, is_unique)} - db = sqlite_utils.Database(conn) - columns = db[table].columns_dict - indexes = db[table].indexes - details = {} - for name in columns: - is_unique = any( - index - for index in indexes - if index.columns == [name] and index.unique - ) - details[name] = (columns[name], is_unique) - return details - - column_details = await self.execute_fn(column_details) - # Is there just one unique column that's text? - unique_text_columns = [ - name - for name, (type_, is_unique) in column_details.items() - if is_unique and type_ is str - ] - if len(unique_text_columns) == 1: - return unique_text_columns[0] - - column_names = list(column_details.keys()) - # Is there a name or title column? - name_or_title = [c for c in column_names if c.lower() in ("name", "title")] - if name_or_title: - return name_or_title[0] - # If a table has two columns, one of which is ID, then label_column is the other one - if ( - column_names - and len(column_names) == 2 - and ("id" in column_names or "pk" in column_names) - and not set(column_names) == {"id", "pk"} - ): - return [c for c in column_names if c not in ("id", "pk")][0] - # Couldn't find a label: - return None - - async def foreign_keys_for_table(self, table): - return await self.execute_fn( - lambda conn: get_outbound_foreign_keys(conn, table) - ) - - async def hidden_table_names(self): - hidden_tables = [] - # Add any tables marked as hidden in config - db_config = self.ds.config.get("databases", {}).get(self.name, {}) - if "tables" in db_config: - hidden_tables += [ - t for t in db_config["tables"] if db_config["tables"][t].get("hidden") - ] - - if sqlite_version()[1] >= 37: - hidden_tables += [x[0] for x in await self.execute(""" - with shadow_tables as ( - select name - from pragma_table_list - where [type] = 'shadow' - order by name - ), - core_tables as ( - select name - from sqlite_master - WHERE name in ('sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4') - OR substr(name, 1, 1) == '_' - ), - combined as ( - select name from shadow_tables - union all - select name from core_tables - ) - select name from combined order by 1 - """)] - else: - hidden_tables += [x[0] for x in await self.execute(""" - WITH base AS ( - SELECT name - FROM sqlite_master - WHERE name IN ('sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4') - OR substr(name, 1, 1) == '_' - ), - fts_suffixes AS ( - SELECT column1 AS suffix - FROM (VALUES ('_data'), ('_idx'), ('_docsize'), ('_content'), ('_config')) - ), - fts5_names AS ( - SELECT name - FROM sqlite_master - WHERE sql LIKE '%VIRTUAL TABLE%USING FTS%' - ), - fts5_shadow_tables AS ( - SELECT - printf('%s%s', fts5_names.name, fts_suffixes.suffix) AS name - FROM fts5_names - JOIN fts_suffixes - ), - fts3_suffixes AS ( - SELECT column1 AS suffix - FROM (VALUES ('_content'), ('_segdir'), ('_segments'), ('_stat'), ('_docsize')) - ), - fts3_names AS ( - SELECT name - FROM sqlite_master - WHERE sql LIKE '%VIRTUAL TABLE%USING FTS3%' - OR sql LIKE '%VIRTUAL TABLE%USING FTS4%' - ), - fts3_shadow_tables AS ( - SELECT - printf('%s%s', fts3_names.name, fts3_suffixes.suffix) AS name - FROM fts3_names - JOIN fts3_suffixes - ), - final AS ( - SELECT name FROM base - UNION ALL - SELECT name FROM fts5_shadow_tables - UNION ALL - SELECT name FROM fts3_shadow_tables - ) - SELECT name FROM final ORDER BY 1 - """)] - # Also hide any FTS tables that have a content= argument - hidden_tables += [x[0] for x in await self.execute(""" - SELECT name - FROM sqlite_master - WHERE sql LIKE '%VIRTUAL TABLE%' - AND sql LIKE '%USING FTS%' - AND sql LIKE '%content=%' - """)] - - has_spatialite = await self.execute_fn(detect_spatialite) - if has_spatialite: - # Also hide Spatialite internal tables - hidden_tables += [ - "ElementaryGeometries", - "SpatialIndex", - "geometry_columns", - "spatial_ref_sys", - "spatialite_history", - "sql_statements_log", - "sqlite_sequence", - "views_geometry_columns", - "virts_geometry_columns", - "data_licenses", - "KNN", - "KNN2", - ] + [ - r[0] for r in (await self.execute(""" - select name from sqlite_master - where name like "idx_%" - and type = "table" - """)).rows - ] - - return hidden_tables - - async def view_names(self): - results = await self.execute("select name from sqlite_master where type='view'") - return [r[0] for r in results.rows] - - async def get_all_foreign_keys(self): - return await self.execute_fn(get_all_foreign_keys) - - async def get_table_definition(self, table, type_="table"): - table_definition_rows = list( - await self.execute( - "select sql from sqlite_master where name = :n and type=:t", - {"n": table, "t": type_}, - ) - ) - if not table_definition_rows: - return None - bits = [table_definition_rows[0][0] + ";"] - # Add on any indexes - index_rows = list( - await self.execute( - "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null", - {"n": table}, - ) - ) - for index_row in index_rows: - bits.append(index_row[0] + ";") - return "\n".join(bits) - - async def get_view_definition(self, view): - return await self.get_table_definition(view, "view") - - def __repr__(self): - tags = [] - if self.is_mutable: - tags.append("mutable") - if self.is_memory: - tags.append("memory") - if self.is_temp_disk: - tags.append("temp_disk") - if self.hash: - tags.append(f"hash={self.hash}") - if self.size is not None: - tags.append(f"size={self.size}") - tags_str = "" - if tags: - tags_str = f" ({', '.join(tags)})" - return f"" - - -def _apply_write_wrapper(fn, wrapper_factory, track_event): - """Apply a single write_wrapper context manager around fn. - - ``wrapper_factory`` is a callable that takes ``(conn)`` and optionally - ``track_event``, and returns a generator that yields exactly once. - Code before the yield runs before ``fn(conn)``, code after the yield - runs after. The result of ``fn(conn)`` is sent into the generator - via ``.send()``, and any exception raised by ``fn(conn)`` is thrown - via ``.throw()``. - """ - - def wrapped(conn): - gen = call_with_supported_arguments( - wrapper_factory, conn=conn, track_event=track_event - ) - # Advance to the yield point (run "before" code) - try: - next(gen) - except StopIteration: - # Generator didn't yield — just run fn unchanged - return fn(conn) - - # Execute the actual write - try: - result = fn(conn) - except Exception: - # Throw exception into generator so it can handle it - try: - gen.throw(*sys.exc_info()) - except StopIteration: - pass - # Re-raise the original exception - raise - else: - # Send the result back through the yield - try: - gen.send(result) - except StopIteration: - pass - return result - - return wrapped - - -class WriteTask: - __slots__ = ( - "fn", - "task_id", - "loop", - "reply_future", - "isolated_connection", - "transaction", - ) - - def __init__( - self, fn, task_id, loop, reply_future, isolated_connection, transaction - ): - self.fn = fn - self.task_id = task_id - self.loop = loop - self.reply_future = reply_future - self.isolated_connection = isolated_connection - self.transaction = transaction - - -def _deliver_write_result(task, result, exception): - # Called from the write thread. Delivers the result back to the - # awaiting coroutine on its event loop via call_soon_threadsafe. - def _set(): - if task.reply_future.done(): - # Awaiter was cancelled; nothing to do. - return - if exception is not None: - task.reply_future.set_exception(exception) - else: - task.reply_future.set_result(result) - - try: - task.loop.call_soon_threadsafe(_set) - except RuntimeError: - # Event loop has been closed; the awaiter is gone. - pass - - -class QueryInterrupted(Exception): - def __init__(self, e, sql, params): - self.e = e - self.sql = sql - self.params = params - - def __str__(self): - return "QueryInterrupted: {}".format(self.e) - - -class MultipleValues(Exception): - pass - - -class Results: - def __init__(self, rows, truncated, description): - self.rows = rows - self.truncated = truncated - self.description = description - - @property - def columns(self): - return [d[0] for d in self.description] - - def first(self): - if self.rows: - return self.rows[0] - else: - return None - - def single_value(self): - if self.rows and 1 == len(self.rows) and 1 == len(self.rows[0]): - return self.rows[0][0] - else: - raise MultipleValues - - def dicts(self): - return [dict(row) for row in self.rows] - - def __iter__(self): - return iter(self.rows) - - def __len__(self): - return len(self.rows) diff --git a/datasette/default_actions.py b/datasette/default_actions.py deleted file mode 100644 index 2f78570b..00000000 --- a/datasette/default_actions.py +++ /dev/null @@ -1,133 +0,0 @@ -from datasette import hookimpl -from datasette.permissions import Action -from datasette.resources import ( - DatabaseResource, - TableResource, - QueryResource, -) - - -@hookimpl -def register_actions(): - """Register the core Datasette actions.""" - return ( - # Global actions (no resource_class) - Action( - name="view-instance", - abbr="vi", - description="View Datasette instance", - ), - Action( - name="permissions-debug", - abbr="pd", - description="Access permission debug tool", - ), - Action( - name="debug-menu", - abbr="dm", - description="View debug menu items", - ), - # Database-level actions (parent-level) - Action( - name="view-database", - abbr="vd", - description="View database", - resource_class=DatabaseResource, - ), - Action( - name="view-database-download", - abbr="vdd", - description="Download database file", - resource_class=DatabaseResource, - also_requires="view-database", - ), - Action( - name="execute-sql", - abbr="es", - description="Execute read-only SQL queries", - resource_class=DatabaseResource, - also_requires="view-database", - ), - Action( - name="execute-write-sql", - abbr="ews", - description="Execute writable SQL queries", - resource_class=DatabaseResource, - also_requires="view-database", - ), - Action( - name="create-table", - abbr="ct", - description="Create tables", - resource_class=DatabaseResource, - ), - Action( - name="store-query", - abbr="sq", - description="Create stored queries", - resource_class=DatabaseResource, - also_requires="execute-sql", - ), - # Table-level actions (child-level) - Action( - name="view-table", - abbr="vt", - description="View table", - resource_class=TableResource, - ), - Action( - name="insert-row", - abbr="ir", - description="Insert rows", - resource_class=TableResource, - ), - Action( - name="delete-row", - abbr="dr", - description="Delete rows", - resource_class=TableResource, - ), - Action( - name="update-row", - abbr="ur", - description="Update rows", - resource_class=TableResource, - ), - Action( - name="alter-table", - abbr="at", - description="Alter tables", - resource_class=TableResource, - ), - Action( - name="set-column-type", - abbr="sct", - description="Set column type", - resource_class=TableResource, - ), - Action( - name="drop-table", - abbr="dt", - description="Drop tables", - resource_class=TableResource, - ), - # Query-level actions (child-level) - Action( - name="view-query", - abbr="vq", - description="View named query results", - resource_class=QueryResource, - ), - Action( - name="update-query", - abbr="uq", - description="Update stored queries", - resource_class=QueryResource, - ), - Action( - name="delete-query", - abbr="dq", - description="Delete stored queries", - resource_class=QueryResource, - ), - ) diff --git a/datasette/default_column_types.py b/datasette/default_column_types.py deleted file mode 100644 index 24493994..00000000 --- a/datasette/default_column_types.py +++ /dev/null @@ -1,81 +0,0 @@ -import json -import re - -import markupsafe - -from datasette import hookimpl -from datasette.column_types import ColumnType, SQLiteType - - -class UrlColumnType(ColumnType): - name = "url" - description = "URL" - sqlite_types = (SQLiteType.TEXT,) - - async def render_cell(self, value, column, table, database, datasette, request): - if not value or not isinstance(value, str): - return None - escaped = markupsafe.escape(value.strip()) - return markupsafe.Markup(f'{escaped}') - - async def validate(self, value, datasette): - if value is None or value == "": - return None - if not isinstance(value, str): - return "URL must be a string" - if not re.match(r"^https?://\S+$", value.strip()): - return "Invalid URL" - return None - - -class EmailColumnType(ColumnType): - name = "email" - description = "Email address" - sqlite_types = (SQLiteType.TEXT,) - - async def render_cell(self, value, column, table, database, datasette, request): - if not value or not isinstance(value, str): - return None - escaped = markupsafe.escape(value.strip()) - return markupsafe.Markup(f'{escaped}') - - async def validate(self, value, datasette): - if value is None or value == "": - return None - if not isinstance(value, str): - return "Email must be a string" - if not re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", value.strip()): - return "Invalid email address" - return None - - -class JsonColumnType(ColumnType): - name = "json" - description = "JSON data" - sqlite_types = (SQLiteType.TEXT,) - - async def render_cell(self, value, column, table, database, datasette, request): - if value is None: - return None - try: - parsed = json.loads(value) if isinstance(value, str) else value - formatted = json.dumps(parsed, indent=2) - escaped = markupsafe.escape(formatted) - return markupsafe.Markup(f"
{escaped}
") - except (json.JSONDecodeError, TypeError): - return None - - async def validate(self, value, datasette): - if value is None or value == "": - return None - if isinstance(value, str): - try: - json.loads(value) - except json.JSONDecodeError: - return "Invalid JSON" - return None - - -@hookimpl -def register_column_types(datasette): - return [UrlColumnType, EmailColumnType, JsonColumnType] diff --git a/datasette/default_database_actions.py b/datasette/default_database_actions.py deleted file mode 100644 index e0cb3cdf..00000000 --- a/datasette/default_database_actions.py +++ /dev/null @@ -1,24 +0,0 @@ -from datasette import hookimpl -from datasette.resources import DatabaseResource - - -@hookimpl -def database_actions(datasette, actor, database, request): - async def inner(): - if not datasette.get_database(database).is_mutable: - return [] - if not await datasette.allowed( - action="execute-write-sql", - resource=DatabaseResource(database), - actor=actor, - ): - return [] - return [ - { - "href": datasette.urls.database(database) + "/-/execute-write", - "label": "Execute write SQL", - "description": "Run writable SQL with table permission checks.", - } - ] - - return inner diff --git a/datasette/default_debug_menu.py b/datasette/default_debug_menu.py deleted file mode 100644 index 6127b2a6..00000000 --- a/datasette/default_debug_menu.py +++ /dev/null @@ -1,75 +0,0 @@ -from datasette import hookimpl -from datasette.jump import JumpSQL - -DEBUG_MENU_ITEMS = ( - ( - "/-/databases", - "Databases", - "List of databases known to this Datasette instance.", - ), - ( - "/-/plugins", - "Installed plugins", - "Review loaded plugins, their versions and their registered hooks.", - ), - ( - "/-/versions", - "Version info", - "Check the Python, SQLite and dependency versions used by this server.", - ), - ( - "/-/settings", - "Settings", - "Inspect the active Datasette settings and configuration values.", - ), - ( - "/-/permissions", - "Debug permissions", - "Test permission checks for actors, actions and resources.", - ), - ( - "/-/messages", - "Debug messages", - "Try out temporary flash messages shown to users.", - ), - ( - "/-/allow-debug", - "Debug allow rules", - "Explore how allow blocks match actors against permission rules.", - ), - ( - "/-/threads", - "Debug threads", - "Inspect worker threads and database tasks.", - ), - ( - "/-/actor", - "Debug actor", - "View the actor object for the current signed-in user.", - ), - ( - "/-/patterns", - "Pattern portfolio", - "Browse Datasette UI patterns.", - ), -) - - -@hookimpl -def jump_items_sql(datasette, actor, request): - async def inner(): - if not await datasette.allowed(action="debug-menu", actor=actor): - return [] - - return [ - JumpSQL.menu_item( - label=label, - url=datasette.urls.path(path), - description=description, - search_text=f"debug {label} {description}", - item_type="debug", - ) - for path, label, description in DEBUG_MENU_ITEMS - ] - - return inner diff --git a/datasette/default_jump_items.py b/datasette/default_jump_items.py deleted file mode 100644 index d215e7ec..00000000 --- a/datasette/default_jump_items.py +++ /dev/null @@ -1,82 +0,0 @@ -from datasette import hookimpl -from datasette.jump import JumpSQL - - -@hookimpl -def jump_items_sql(datasette, actor, request): - async def inner(): - database_sql, database_params = await datasette.allowed_resources_sql( - action="view-database", actor=actor - ) - table_sql, table_params = await datasette.allowed_resources_sql( - action="view-table", actor=actor - ) - query_sql, query_params = await datasette.allowed_resources_sql( - action="view-query", actor=actor - ) - return [ - JumpSQL( - sql=f""" - WITH allowed_databases AS ( - {database_sql} - ) - SELECT - 'database' AS type, - parent AS label, - NULL AS description, - json_object( - 'method', 'database', - 'database', parent - ) AS url, - parent AS search_text, - NULL AS display_name - FROM allowed_databases - """, - params=database_params, - ), - JumpSQL( - sql=f""" - WITH allowed_tables AS ( - {table_sql} - ) - SELECT - CASE WHEN catalog_views.view_name IS NULL THEN 'table' ELSE 'view' END AS type, - allowed_tables.parent || ': ' || allowed_tables.child AS label, - NULL AS description, - json_object( - 'method', 'table', - 'database', allowed_tables.parent, - 'table', allowed_tables.child - ) AS url, - allowed_tables.parent || ' ' || allowed_tables.child AS search_text, - NULL AS display_name - FROM allowed_tables - LEFT JOIN catalog_views - ON catalog_views.database_name = allowed_tables.parent - AND catalog_views.view_name = allowed_tables.child - """, - params=table_params, - ), - JumpSQL( - sql=f""" - WITH allowed_queries AS ( - {query_sql} - ) - SELECT - 'query' AS type, - allowed_queries.parent || ': ' || allowed_queries.child AS label, - NULL AS description, - json_object( - 'method', 'query', - 'database', allowed_queries.parent, - 'query', allowed_queries.child - ) AS url, - allowed_queries.parent || ' ' || allowed_queries.child AS search_text, - NULL AS display_name - FROM allowed_queries - """, - params=query_params, - ), - ] - - return inner diff --git a/datasette/default_magic_parameters.py b/datasette/default_magic_parameters.py deleted file mode 100644 index 91c1c5aa..00000000 --- a/datasette/default_magic_parameters.py +++ /dev/null @@ -1,57 +0,0 @@ -from datasette import hookimpl -import datetime -import os -import time - - -def header(key, request): - key = key.replace("_", "-").encode("utf-8") - headers_dict = dict(request.scope["headers"]) - return headers_dict.get(key, b"").decode("utf-8") - - -def actor(key, request): - if request.actor is None: - raise KeyError - return request.actor[key] - - -def cookie(key, request): - return request.cookies[key] - - -def now(key, request): - if key == "epoch": - return int(time.time()) - elif key == "date_utc": - return datetime.datetime.now(datetime.timezone.utc).date().isoformat() - elif key == "datetime_utc": - return ( - datetime.datetime.now(datetime.timezone.utc).strftime(r"%Y-%m-%dT%H:%M:%S") - + "Z" - ) - else: - raise KeyError - - -def random(key, request): - if key.startswith("chars_") and key.split("chars_")[-1].isdigit(): - num_chars = int(key.split("chars_")[-1]) - if num_chars % 2 == 1: - urandom_len = (num_chars + 1) / 2 - else: - urandom_len = num_chars / 2 - return os.urandom(int(urandom_len)).hex()[:num_chars] - else: - raise KeyError - - -@hookimpl -def register_magic_parameters(): - return [ - ("header", header), - ("actor", actor), - ("cookie", cookie), - ("now", now), - ("random", random), - ] diff --git a/datasette/default_permissions/__init__.py b/datasette/default_permissions/__init__.py deleted file mode 100644 index 6cd46f04..00000000 --- a/datasette/default_permissions/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Default permission implementations for Datasette. - -This module provides the built-in permission checking logic through implementations -of the permission_resources_sql hook. The hooks are organized by their purpose: - -1. Actor Restrictions - Enforces _r allowlists embedded in actor tokens -2. Root User - Grants full access when --root flag is used -3. Config Rules - Applies permissions from datasette.yaml -4. Default Settings - Enforces default_allow_sql and default view permissions - -IMPORTANT: These hooks return PermissionSQL objects that are combined using SQL -UNION/INTERSECT operations. The order of evaluation is: - - restriction_sql fields are INTERSECTed (all must match) - - Regular sql fields are UNIONed and evaluated with cascading priority -""" - -from __future__ import annotations - -# Re-export all hooks and public utilities -from .restrictions import ( - actor_restrictions_sql as actor_restrictions_sql, - restrictions_allow_action as restrictions_allow_action, - ActorRestrictions as ActorRestrictions, -) -from .root import root_user_permissions_sql as root_user_permissions_sql -from .config import config_permissions_sql as config_permissions_sql -from .defaults import ( - # Avoid "datasette.default_permissions" does not explicitly export attribute - default_allow_sql_check as default_allow_sql_check, - default_action_permissions_sql as default_action_permissions_sql, - default_query_permissions_sql as default_query_permissions_sql, - DEFAULT_ALLOW_ACTIONS as DEFAULT_ALLOW_ACTIONS, -) diff --git a/datasette/default_permissions/config.py b/datasette/default_permissions/config.py deleted file mode 100644 index aab87c1c..00000000 --- a/datasette/default_permissions/config.py +++ /dev/null @@ -1,442 +0,0 @@ -""" -Config-based permission handling for Datasette. - -Applies permission rules from datasette.yaml configuration. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple - -if TYPE_CHECKING: - from datasette.app import Datasette - -from datasette import hookimpl -from datasette.permissions import PermissionSQL -from datasette.utils import actor_matches_allow - -from .helpers import PermissionRowCollector, get_action_name_variants - - -class ConfigPermissionProcessor: - """ - Processes permission rules from datasette.yaml configuration. - - Configuration structure: - - permissions: # Root-level permissions block - view-instance: - id: admin - - databases: - mydb: - permissions: # Database-level permissions - view-database: - id: admin - allow: # Database-level allow block (for view-*) - id: viewer - allow_sql: # execute-sql allow block - id: analyst - tables: - users: - permissions: # Table-level permissions - view-table: - id: admin - allow: # Table-level allow block - id: viewer - queries: - my_query: - permissions: # Query-level permissions - view-query: - id: admin - allow: # Query-level allow block - id: viewer - """ - - def __init__( - self, - datasette: "Datasette", - actor: Optional[dict], - action: str, - ): - self.datasette = datasette - self.actor = actor - self.action = action - self.config = datasette.config or {} - self.collector = PermissionRowCollector(prefix="cfg") - - # Pre-compute action variants - self.action_checks = get_action_name_variants(datasette, action) - self.action_obj = datasette.actions.get(action) - - # Parse restrictions if present - self.has_restrictions = actor and "_r" in actor if actor else False - self.restrictions = actor.get("_r", {}) if actor else {} - - # Pre-compute restriction info for efficiency - self.restricted_databases: Set[str] = set() - self.restricted_tables: Set[Tuple[str, str]] = set() - - if self.has_restrictions: - self.restricted_databases = { - db_name - for db_name, db_actions in (self.restrictions.get("d") or {}).items() - if self.action_checks.intersection(db_actions) - } - self.restricted_tables = { - (db_name, table_name) - for db_name, tables in (self.restrictions.get("r") or {}).items() - for table_name, table_actions in tables.items() - if self.action_checks.intersection(table_actions) - } - # Tables implicitly reference their parent databases - self.restricted_databases.update(db for db, _ in self.restricted_tables) - - def evaluate_allow_block(self, allow_block: Any) -> Optional[bool]: - """Evaluate an allow block against the current actor.""" - if allow_block is None: - return None - return actor_matches_allow(self.actor, allow_block) - - def is_in_restriction_allowlist( - self, - parent: Optional[str], - child: Optional[str], - ) -> bool: - """Check if resource is allowed by actor restrictions.""" - if not self.has_restrictions: - return True # No restrictions, all resources allowed - - # Check global allowlist - if self.action_checks.intersection(self.restrictions.get("a", [])): - return True - - # Check database-level allowlist - if parent and self.action_checks.intersection( - self.restrictions.get("d", {}).get(parent, []) - ): - return True - - # Check table-level allowlist - if parent: - table_restrictions = (self.restrictions.get("r", {}) or {}).get(parent, {}) - if child: - table_actions = table_restrictions.get(child, []) - if self.action_checks.intersection(table_actions): - return True - else: - # Parent query should proceed if any child in this database is allowlisted - for table_actions in table_restrictions.values(): - if self.action_checks.intersection(table_actions): - return True - - # Parent/child both None: include if any restrictions exist for this action - if parent is None and child is None: - if self.action_checks.intersection(self.restrictions.get("a", [])): - return True - if self.restricted_databases: - return True - if self.restricted_tables: - return True - - return False - - def add_permissions_rule( - self, - parent: Optional[str], - child: Optional[str], - permissions_block: Optional[dict], - scope_desc: str, - ) -> None: - """Add a rule from a permissions:{action} block.""" - if permissions_block is None: - return - - action_allow_block = permissions_block.get(self.action) - result = self.evaluate_allow_block(action_allow_block) - - self.collector.add( - parent=parent, - child=child, - allow=result, - reason=f"config {'allow' if result else 'deny'} {scope_desc}", - if_not_none=True, - ) - - def add_allow_block_rule( - self, - parent: Optional[str], - child: Optional[str], - allow_block: Any, - scope_desc: str, - ) -> None: - """ - Add rules from an allow:{} block. - - For allow blocks, if the block exists but doesn't match the actor, - this is treated as a deny. We also handle the restriction-gate logic. - """ - if allow_block is None: - return - - # Skip if resource is not in restriction allowlist - if not self.is_in_restriction_allowlist(parent, child): - return - - result = self.evaluate_allow_block(allow_block) - bool_result = bool(result) - - self.collector.add( - parent, - child, - bool_result, - f"config {'allow' if result else 'deny'} {scope_desc}", - ) - - # Handle restriction-gate: add explicit denies for restricted resources - self._add_restriction_gate_denies(parent, child, bool_result, scope_desc) - - def _add_restriction_gate_denies( - self, - parent: Optional[str], - child: Optional[str], - is_allowed: bool, - scope_desc: str, - ) -> None: - """ - When a config rule denies at a higher level, add explicit denies - for restricted resources to prevent child-level allows from - incorrectly granting access. - """ - if is_allowed or child is not None or not self.has_restrictions: - return - - if not self.action_obj: - return - - reason = f"config deny {scope_desc} (restriction gate)" - - if parent is None: - # Root-level deny: add denies for all restricted resources - if self.action_obj.takes_parent: - for db_name in self.restricted_databases: - self.collector.add(db_name, None, False, reason) - if self.action_obj.takes_child: - for db_name, table_name in self.restricted_tables: - self.collector.add(db_name, table_name, False, reason) - else: - # Database-level deny: add denies for tables in that database - if self.action_obj.takes_child: - for db_name, table_name in self.restricted_tables: - if db_name == parent: - self.collector.add(db_name, table_name, False, reason) - - def process(self) -> Optional[PermissionSQL]: - """Process all config rules and return combined PermissionSQL.""" - self._process_root_permissions() - self._process_databases() - self._process_root_allow_blocks() - - return self.collector.to_permission_sql() - - def _process_root_permissions(self) -> None: - """Process root-level permissions block.""" - root_perms = self.config.get("permissions") or {} - self.add_permissions_rule( - None, - None, - root_perms, - f"permissions for {self.action}", - ) - - def _process_databases(self) -> None: - """Process database-level and nested configurations.""" - databases = self.config.get("databases") or {} - - for db_name, db_config in databases.items(): - self._process_database(db_name, db_config or {}) - - def _process_database(self, db_name: str, db_config: dict) -> None: - """Process a single database's configuration.""" - # Database-level permissions block - db_perms = db_config.get("permissions") or {} - self.add_permissions_rule( - db_name, - None, - db_perms, - f"permissions for {self.action} on {db_name}", - ) - - # Process tables - for table_name, table_config in (db_config.get("tables") or {}).items(): - self._process_table(db_name, table_name, table_config or {}) - - # Process queries - for query_name, query_config in (db_config.get("queries") or {}).items(): - self._process_query(db_name, query_name, query_config) - - # Database-level allow blocks - self._process_database_allow_blocks(db_name, db_config) - - def _process_table( - self, - db_name: str, - table_name: str, - table_config: dict, - ) -> None: - """Process a single table's configuration.""" - # Table-level permissions block - table_perms = table_config.get("permissions") or {} - self.add_permissions_rule( - db_name, - table_name, - table_perms, - f"permissions for {self.action} on {db_name}/{table_name}", - ) - - # Table-level allow block (for view-table) - if self.action == "view-table": - self.add_allow_block_rule( - db_name, - table_name, - table_config.get("allow"), - f"allow for {self.action} on {db_name}/{table_name}", - ) - - def _process_query( - self, - db_name: str, - query_name: str, - query_config: Any, - ) -> None: - """Process a single query's configuration.""" - # Query config can be a string (just SQL) or dict - if not isinstance(query_config, dict): - return - - # Query-level permissions block - query_perms = query_config.get("permissions") or {} - self.add_permissions_rule( - db_name, - query_name, - query_perms, - f"permissions for {self.action} on {db_name}/{query_name}", - ) - - # Query-level allow block (for view-query) - if self.action == "view-query": - self.add_allow_block_rule( - db_name, - query_name, - query_config.get("allow"), - f"allow for {self.action} on {db_name}/{query_name}", - ) - - def _process_database_allow_blocks( - self, - db_name: str, - db_config: dict, - ) -> None: - """Process database-level allow/allow_sql blocks.""" - # view-database allow block - if self.action == "view-database": - self.add_allow_block_rule( - db_name, - None, - db_config.get("allow"), - f"allow for {self.action} on {db_name}", - ) - - # execute-sql allow_sql block - if self.action == "execute-sql": - self.add_allow_block_rule( - db_name, - None, - db_config.get("allow_sql"), - f"allow_sql for {db_name}", - ) - - # view-table uses database-level allow for inheritance - if self.action == "view-table": - self.add_allow_block_rule( - db_name, - None, - db_config.get("allow"), - f"allow for {self.action} on {db_name}", - ) - - # view-query uses database-level allow for inheritance - if self.action == "view-query": - self.add_allow_block_rule( - db_name, - None, - db_config.get("allow"), - f"allow for {self.action} on {db_name}", - ) - - def _process_root_allow_blocks(self) -> None: - """Process root-level allow/allow_sql blocks.""" - root_allow = self.config.get("allow") - - if self.action == "view-instance": - self.add_allow_block_rule( - None, - None, - root_allow, - "allow for view-instance", - ) - - if self.action == "view-database": - self.add_allow_block_rule( - None, - None, - root_allow, - "allow for view-database", - ) - - if self.action == "view-table": - self.add_allow_block_rule( - None, - None, - root_allow, - "allow for view-table", - ) - - if self.action == "view-query": - self.add_allow_block_rule( - None, - None, - root_allow, - "allow for view-query", - ) - - if self.action == "execute-sql": - self.add_allow_block_rule( - None, - None, - self.config.get("allow_sql"), - "allow_sql", - ) - - -@hookimpl(specname="permission_resources_sql") -async def config_permissions_sql( - datasette: "Datasette", - actor: Optional[dict], - action: str, -) -> Optional[List[PermissionSQL]]: - """ - Apply permission rules from datasette.yaml configuration. - - This processes: - - permissions: blocks at root, database, table, and query levels - - allow: blocks for view-* actions - - allow_sql: blocks for execute-sql action - """ - processor = ConfigPermissionProcessor(datasette, actor, action) - result = processor.process() - - if result is None: - return [] - - return [result] diff --git a/datasette/default_permissions/defaults.py b/datasette/default_permissions/defaults.py deleted file mode 100644 index 5bc74425..00000000 --- a/datasette/default_permissions/defaults.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -Default permission settings for Datasette. - -Provides default allow rules for standard view/execute actions. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from datasette.app import Datasette - -from datasette import hookimpl -from datasette.permissions import PermissionSQL - -# Actions that are allowed by default (unless --default-deny is used) -DEFAULT_ALLOW_ACTIONS = frozenset( - { - "view-instance", - "view-database", - "view-database-download", - "view-table", - "view-query", - "execute-sql", - } -) - - -@hookimpl(specname="permission_resources_sql") -async def default_allow_sql_check( - datasette: "Datasette", - actor: Optional[dict], - action: str, -) -> Optional[PermissionSQL]: - """ - Enforce the default_allow_sql setting. - - When default_allow_sql is false (the default), execute-sql is denied - unless explicitly allowed by config or other rules. - """ - if action == "execute-sql": - if not datasette.setting("default_allow_sql"): - return PermissionSQL.deny(reason="default_allow_sql is false") - - return None - - -@hookimpl(specname="permission_resources_sql") -async def default_action_permissions_sql( - datasette: "Datasette", - actor: Optional[dict], - action: str, -) -> Optional[PermissionSQL]: - """ - Provide default allow rules for standard view/execute actions. - - These defaults are skipped when datasette is started with --default-deny. - The restriction_sql mechanism (from actor_restrictions_sql) will still - filter these results if the actor has restrictions. - """ - if datasette.default_deny: - return None - - if action in DEFAULT_ALLOW_ACTIONS: - reason = f"default allow for {action}".replace("'", "''") - return PermissionSQL.allow(reason=reason) - - return None - - -@hookimpl(specname="permission_resources_sql") -async def default_query_permissions_sql( - datasette: "Datasette", - actor: Optional[dict], - action: str, -) -> Optional[PermissionSQL]: - actor_id = actor.get("id") if isinstance(actor, dict) else None - - if action not in {"view-query", "update-query", "delete-query"}: - return None - - params = {"query_owner_id": actor_id} - rule_sqls = [] - if actor_id is not None: - if action in {"update-query", "delete-query"}: - # Query owner can update/delete query - rule_sqls.append(""" - SELECT database_name AS parent, name AS child, 1 AS allow, - 'query owner' AS reason - FROM queries - WHERE source = 'user' - AND owner_id = :query_owner_id - """) - else: - # Query owner can view-query - rule_sqls.append(""" - SELECT database_name AS parent, name AS child, 1 AS allow, - 'query owner' AS reason - FROM queries - WHERE owner_id = :query_owner_id - """) - - # restriction_sql enforces private queries ONLY visible/mutable by owner - return PermissionSQL( - sql="\nUNION ALL\n".join(rule_sqls) if rule_sqls else None, - restriction_sql=""" - SELECT database_name AS parent, name AS child - FROM queries - WHERE is_private = 0 - OR owner_id = :query_owner_id - """, - params=params, - ) diff --git a/datasette/default_permissions/helpers.py b/datasette/default_permissions/helpers.py deleted file mode 100644 index 47e03569..00000000 --- a/datasette/default_permissions/helpers.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -Shared helper utilities for default permission implementations. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Set - -if TYPE_CHECKING: - from datasette.app import Datasette - -from datasette.permissions import PermissionSQL - - -def get_action_name_variants(datasette: "Datasette", action: str) -> Set[str]: - """ - Get all name variants for an action (full name and abbreviation). - - Example: - get_action_name_variants(ds, "view-table") -> {"view-table", "vt"} - """ - variants = {action} - action_obj = datasette.actions.get(action) - if action_obj and action_obj.abbr: - variants.add(action_obj.abbr) - return variants - - -def action_in_list(datasette: "Datasette", action: str, action_list: list) -> bool: - """Check if an action (or its abbreviation) is in a list.""" - return bool(get_action_name_variants(datasette, action).intersection(action_list)) - - -@dataclass -class PermissionRow: - """A single permission rule row.""" - - parent: Optional[str] - child: Optional[str] - allow: bool - reason: str - - -class PermissionRowCollector: - """Collects permission rows and converts them to PermissionSQL.""" - - def __init__(self, prefix: str = "row"): - self.rows: List[PermissionRow] = [] - self.prefix = prefix - - def add( - self, - parent: Optional[str], - child: Optional[str], - allow: Optional[bool], - reason: str, - if_not_none: bool = False, - ) -> None: - """Add a permission row. If if_not_none=True, only add if allow is not None.""" - if if_not_none and allow is None: - return - self.rows.append(PermissionRow(parent, child, allow, reason)) - - def to_permission_sql(self) -> Optional[PermissionSQL]: - """Convert collected rows to a PermissionSQL object.""" - if not self.rows: - return None - - parts = [] - params = {} - - for idx, row in enumerate(self.rows): - key = f"{self.prefix}_{idx}" - parts.append( - f"SELECT :{key}_parent AS parent, :{key}_child AS child, " - f":{key}_allow AS allow, :{key}_reason AS reason" - ) - params[f"{key}_parent"] = row.parent - params[f"{key}_child"] = row.child - params[f"{key}_allow"] = 1 if row.allow else 0 - params[f"{key}_reason"] = row.reason - - sql = "\nUNION ALL\n".join(parts) - return PermissionSQL(sql=sql, params=params) diff --git a/datasette/default_permissions/restrictions.py b/datasette/default_permissions/restrictions.py deleted file mode 100644 index a22cd7e5..00000000 --- a/datasette/default_permissions/restrictions.py +++ /dev/null @@ -1,195 +0,0 @@ -""" -Actor restriction handling for Datasette permissions. - -This module handles the _r (restrictions) key in actor dictionaries, which -contains allowlists of resources the actor can access. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Set, Tuple - -if TYPE_CHECKING: - from datasette.app import Datasette - -from datasette import hookimpl -from datasette.permissions import PermissionSQL - -from .helpers import action_in_list, get_action_name_variants - - -@dataclass -class ActorRestrictions: - """Parsed actor restrictions from the _r key.""" - - global_actions: List[str] # _r.a - globally allowed actions - database_actions: dict # _r.d - {db_name: [actions]} - table_actions: dict # _r.r - {db_name: {table: [actions]}} - - @classmethod - def from_actor(cls, actor: Optional[dict]) -> Optional["ActorRestrictions"]: - """Parse restrictions from actor dict. Returns None if no restrictions.""" - if not actor: - return None - assert isinstance(actor, dict), "actor must be a dictionary" - - restrictions = actor.get("_r") - if restrictions is None: - return None - - return cls( - global_actions=restrictions.get("a", []), - database_actions=restrictions.get("d", {}), - table_actions=restrictions.get("r", {}), - ) - - def is_action_globally_allowed(self, datasette: "Datasette", action: str) -> bool: - """Check if action is in the global allowlist.""" - return action_in_list(datasette, action, self.global_actions) - - def get_allowed_databases(self, datasette: "Datasette", action: str) -> Set[str]: - """Get database names where this action is allowed.""" - allowed = set() - for db_name, db_actions in self.database_actions.items(): - if action_in_list(datasette, action, db_actions): - allowed.add(db_name) - return allowed - - def get_allowed_tables( - self, datasette: "Datasette", action: str - ) -> Set[Tuple[str, str]]: - """Get (database, table) pairs where this action is allowed.""" - allowed = set() - for db_name, tables in self.table_actions.items(): - for table_name, table_actions in tables.items(): - if action_in_list(datasette, action, table_actions): - allowed.add((db_name, table_name)) - return allowed - - -@hookimpl(specname="permission_resources_sql") -async def actor_restrictions_sql( - datasette: "Datasette", - actor: Optional[dict], - action: str, -) -> Optional[List[PermissionSQL]]: - """ - Handle actor restriction-based permission rules. - - When an actor has an "_r" key, it contains an allowlist of resources they - can access. This function returns restriction_sql that filters the final - results to only include resources in that allowlist. - - The _r structure: - { - "a": ["vi", "pd"], # Global actions allowed - "d": {"mydb": ["vt", "es"]}, # Database-level actions - "r": {"mydb": {"users": ["vt"]}} # Table-level actions - } - """ - if not actor: - return None - - restrictions = ActorRestrictions.from_actor(actor) - - if restrictions is None: - # No restrictions - all resources allowed - return [] - - # If globally allowed, no filtering needed - if restrictions.is_action_globally_allowed(datasette, action): - return [] - - # Build restriction SQL - allowed_dbs = restrictions.get_allowed_databases(datasette, action) - allowed_tables = restrictions.get_allowed_tables(datasette, action) - - # If nothing is allowed for this action, return empty-set restriction - if not allowed_dbs and not allowed_tables: - return [ - PermissionSQL( - params={"deny": f"actor restrictions: {action} not in allowlist"}, - restriction_sql="SELECT NULL AS parent, NULL AS child WHERE 0", - ) - ] - - # Build UNION of allowed resources - selects = [] - params = {} - counter = 0 - - # Database-level entries (parent, NULL) - allows all children - for db_name in allowed_dbs: - key = f"restr_{counter}" - counter += 1 - selects.append(f"SELECT :{key}_parent AS parent, NULL AS child") - params[f"{key}_parent"] = db_name - - # Table-level entries (parent, child) - for db_name, table_name in allowed_tables: - key = f"restr_{counter}" - counter += 1 - selects.append(f"SELECT :{key}_parent AS parent, :{key}_child AS child") - params[f"{key}_parent"] = db_name - params[f"{key}_child"] = table_name - - restriction_sql = "\nUNION ALL\n".join(selects) - - return [PermissionSQL(params=params, restriction_sql=restriction_sql)] - - -def restrictions_allow_action( - datasette: "Datasette", - restrictions: dict, - action: str, - resource: Optional[str | Tuple[str, str]], -) -> bool: - """ - Check if restrictions allow the requested action on the requested resource. - - This is a synchronous utility function for use by other code that needs - to quickly check restriction allowlists. - - Args: - datasette: The Datasette instance - restrictions: The _r dict from an actor - action: The action name to check - resource: None for global, str for database, (db, table) tuple for table - - Returns: - True if allowed, False if denied - """ - # Does this action have an abbreviation? - to_check = get_action_name_variants(datasette, action) - - # Check global level (any resource) - all_allowed = restrictions.get("a") - if all_allowed is not None: - assert isinstance(all_allowed, list) - if to_check.intersection(all_allowed): - return True - - # Check database level - if resource: - if isinstance(resource, str): - database_name = resource - else: - database_name = resource[0] - database_allowed = restrictions.get("d", {}).get(database_name) - if database_allowed is not None: - assert isinstance(database_allowed, list) - if to_check.intersection(database_allowed): - return True - - # Check table/resource level - if resource is not None and not isinstance(resource, str) and len(resource) == 2: - database, table = resource - table_allowed = restrictions.get("r", {}).get(database, {}).get(table) - if table_allowed is not None: - assert isinstance(table_allowed, list) - if to_check.intersection(table_allowed): - return True - - # This action is not explicitly allowed, so reject it - return False diff --git a/datasette/default_permissions/root.py b/datasette/default_permissions/root.py deleted file mode 100644 index 4931f7ff..00000000 --- a/datasette/default_permissions/root.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -Root user permission handling for Datasette. - -Grants full permissions to the root user when --root flag is used. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from datasette.app import Datasette - -from datasette import hookimpl -from datasette.permissions import PermissionSQL - - -@hookimpl(specname="permission_resources_sql") -async def root_user_permissions_sql( - datasette: "Datasette", - actor: Optional[dict], -) -> Optional[PermissionSQL]: - """ - Grant root user full permissions when --root flag is used. - """ - if not datasette.root_enabled: - return None - if actor is not None and actor.get("id") == "root": - return PermissionSQL.allow(reason="root user") diff --git a/datasette/default_permissions/tokens.py b/datasette/default_permissions/tokens.py deleted file mode 100644 index 7a359dc6..00000000 --- a/datasette/default_permissions/tokens.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Token authentication for Datasette. - -Registers the default SignedTokenHandler and delegates token verification -to datasette.verify_token() so all registered handlers are tried. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from datasette.app import Datasette - -from datasette import hookimpl -from datasette.tokens import SignedTokenHandler - - -@hookimpl -def register_token_handler(datasette: "Datasette"): - """Register the default signed token handler.""" - return SignedTokenHandler() - - -@hookimpl(specname="actor_from_request") -async def actor_from_signed_api_token( - datasette: "Datasette", request -) -> Optional[dict]: - """ - Authenticate requests using API tokens by delegating to all registered - token handlers via datasette.verify_token(). - """ - authorization = request.headers.get("authorization") - if not authorization: - return None - if not authorization.startswith("Bearer "): - return None - - token = authorization[len("Bearer ") :] - return await datasette.verify_token(token) diff --git a/datasette/events.py b/datasette/events.py deleted file mode 100644 index e8786da9..00000000 --- a/datasette/events.py +++ /dev/null @@ -1,293 +0,0 @@ -from abc import ABC, abstractproperty -from dataclasses import asdict, dataclass, field -from datasette.hookspecs import hookimpl -from datetime import datetime, timezone - - -@dataclass -class Event(ABC): - @abstractproperty - def name(self): - pass - - created: datetime = field( - init=False, default_factory=lambda: datetime.now(timezone.utc) - ) - actor: dict | None - - def properties(self): - properties = asdict(self) - properties.pop("actor", None) - properties.pop("created", None) - return properties - - -@dataclass -class LoginEvent(Event): - """ - Event name: ``login`` - - A user (represented by ``event.actor``) has logged in. - """ - - name = "login" - - -@dataclass -class LogoutEvent(Event): - """ - Event name: ``logout`` - - A user (represented by ``event.actor``) has logged out. - """ - - name = "logout" - - -@dataclass -class CreateTokenEvent(Event): - """ - Event name: ``create-token`` - - A user created an API token. - - :ivar expires_after: Number of seconds after which this token will expire. - :type expires_after: int or None - :ivar restrict_all: Restricted permissions for this token. - :type restrict_all: list - :ivar restrict_database: Restricted database permissions for this token. - :type restrict_database: dict - :ivar restrict_resource: Restricted resource permissions for this token. - :type restrict_resource: dict - """ - - name = "create-token" - expires_after: int | None - restrict_all: list - restrict_database: dict - restrict_resource: dict - - -@dataclass -class CreateTableEvent(Event): - """ - Event name: ``create-table`` - - A new table has been created in the database. - - :ivar database: The name of the database where the table was created. - :type database: str - :ivar table: The name of the table that was created - :type table: str - :ivar schema: The SQL schema definition for the new table. - :type schema: str - """ - - name = "create-table" - database: str - table: str - schema: str - - -@dataclass -class DropTableEvent(Event): - """ - Event name: ``drop-table`` - - A table has been dropped from the database. - - :ivar database: The name of the database where the table was dropped. - :type database: str - :ivar table: The name of the table that was dropped - :type table: str - """ - - name = "drop-table" - database: str - table: str - - -@dataclass -class AlterTableEvent(Event): - """ - Event name: ``alter-table`` - - A table has been altered. - - :ivar database: The name of the database where the table was altered - :type database: str - :ivar table: The name of the table that was altered - :type table: str - :ivar before_schema: The table's SQL schema before the alteration - :type before_schema: str - :ivar after_schema: The table's SQL schema after the alteration - :type after_schema: str - """ - - name = "alter-table" - database: str - table: str - before_schema: str - after_schema: str - - -@dataclass -class InsertRowsEvent(Event): - """ - Event name: ``insert-rows`` - - Rows were inserted into a table. - - :ivar database: The name of the database where the rows were inserted. - :type database: str - :ivar table: The name of the table where the rows were inserted. - :type table: str - :ivar num_rows: The number of rows that were requested to be inserted. - :type num_rows: int - :ivar ignore: Was ignore set? - :type ignore: bool - :ivar replace: Was replace set? - :type replace: bool - """ - - name = "insert-rows" - database: str - table: str - num_rows: int - ignore: bool - replace: bool - - -@dataclass -class UpsertRowsEvent(Event): - """ - Event name: ``upsert-rows`` - - Rows were upserted into a table. - - :ivar database: The name of the database where the rows were inserted. - :type database: str - :ivar table: The name of the table where the rows were inserted. - :type table: str - :ivar num_rows: The number of rows that were requested to be inserted. - :type num_rows: int - """ - - name = "upsert-rows" - database: str - table: str - num_rows: int - - -@dataclass -class UpdateRowEvent(Event): - """ - Event name: ``update-row`` - - A row was updated in a table. - - :ivar database: The name of the database where the row was updated. - :type database: str - :ivar table: The name of the table where the row was updated. - :type table: str - :ivar pks: The primary key values of the updated row. - """ - - name = "update-row" - database: str - table: str - pks: list - - -@dataclass -class RenameTableEvent(Event): - """ - Event name: ``rename-table`` - - A table has been renamed. - - :ivar database: The name of the database containing the renamed table. - :type database: str - :ivar old_table: The previous name of the table. - :type old_table: str - :ivar new_table: The new name of the table. - :type new_table: str - """ - - name = "rename-table" - database: str - old_table: str - new_table: str - - -@dataclass -class DeleteRowEvent(Event): - """ - Event name: ``delete-row`` - - A row was deleted from a table. - - :ivar database: The name of the database where the row was deleted. - :type database: str - :ivar table: The name of the table where the row was deleted. - :type table: str - :ivar pks: The primary key values of the deleted row. - """ - - name = "delete-row" - database: str - table: str - pks: list - - -@hookimpl -def write_wrapper(datasette, database, request, transaction): - def wrapper(conn, track_event): - # Snapshot rootpage -> name before the write - before = { - row[1]: row[0] - for row in conn.execute( - "select name, rootpage from sqlite_master" - " where type='table' and rootpage != 0" - ).fetchall() - } - yield - # Snapshot rootpage -> name after the write - after = { - row[1]: row[0] - for row in conn.execute( - "select name, rootpage from sqlite_master" - " where type='table' and rootpage != 0" - ).fetchall() - } - # Detect renames: same rootpage, different name - for rootpage, old_name in before.items(): - new_name = after.get(rootpage) - if new_name and new_name != old_name: - track_event( - RenameTableEvent( - actor=request.actor if request else None, - database=database, - old_table=old_name, - new_table=new_name, - ) - ) - - return wrapper - - -@hookimpl -def register_events(): - return [ - LoginEvent, - LogoutEvent, - CreateTableEvent, - CreateTokenEvent, - AlterTableEvent, - RenameTableEvent, - DropTableEvent, - InsertRowsEvent, - UpsertRowsEvent, - UpdateRowEvent, - DeleteRowEvent, - ] diff --git a/datasette/facets.py b/datasette/facets.py deleted file mode 100644 index abe0605e..00000000 --- a/datasette/facets.py +++ /dev/null @@ -1,576 +0,0 @@ -import json -import urllib -from datasette import hookimpl -from datasette.database import QueryInterrupted -from datasette.utils import ( - escape_sqlite, - path_with_added_args, - path_with_removed_args, - detect_json1, - sqlite3, -) - - -def load_facet_configs(request, table_config): - # Given a request and the configuration for a table, return - # a dictionary of selected facets, their lists of configs and for each - # config whether it came from the request or the metadata. - # - # return {type: [ - # {"source": "metadata", "config": config1}, - # {"source": "request", "config": config2}]} - facet_configs = {} - table_config = table_config or {} - table_facet_configs = table_config.get("facets", []) - for facet_config in table_facet_configs: - if isinstance(facet_config, str): - type = "column" - facet_config = {"simple": facet_config} - else: - assert ( - len(facet_config.values()) == 1 - ), "Metadata config dicts should be {type: config}" - type, facet_config = list(facet_config.items())[0] - if isinstance(facet_config, str): - facet_config = {"simple": facet_config} - facet_configs.setdefault(type, []).append( - {"source": "metadata", "config": facet_config} - ) - qs_pairs = urllib.parse.parse_qs(request.query_string, keep_blank_values=True) - for key, values in qs_pairs.items(): - if key.startswith("_facet"): - # Figure out the facet type - if key == "_facet": - type = "column" - elif key.startswith("_facet_"): - type = key[len("_facet_") :] - for value in values: - # The value is the facet_config - either JSON or not - facet_config = ( - json.loads(value) if value.startswith("{") else {"simple": value} - ) - facet_configs.setdefault(type, []).append( - {"source": "request", "config": facet_config} - ) - return facet_configs - - -@hookimpl -def register_facet_classes(): - classes = [ColumnFacet, DateFacet] - if detect_json1(): - classes.append(ArrayFacet) - return classes - - -class Facet: - type = None - # How many rows to consider when suggesting facets: - suggest_consider = 1000 - - def __init__( - self, - ds, - request, - database, - sql=None, - table=None, - params=None, - table_config=None, - row_count=None, - ): - assert table or sql, "Must provide either table= or sql=" - self.ds = ds - self.request = request - self.database = database - # For foreign key expansion. Can be None for e.g. stored SQL queries: - self.table = table - self.sql = sql or f"select * from [{table}]" - self.params = params or [] - self.table_config = table_config - # row_count can be None, in which case we calculate it ourselves: - self.row_count = row_count - - def get_configs(self): - configs = load_facet_configs(self.request, self.table_config) - return configs.get(self.type) or [] - - def get_querystring_pairs(self): - # ?_foo=bar&_foo=2&empty= becomes: - # [('_foo', 'bar'), ('_foo', '2'), ('empty', '')] - return urllib.parse.parse_qsl(self.request.query_string, keep_blank_values=True) - - def get_facet_size(self): - facet_size = self.ds.setting("default_facet_size") - max_returned_rows = self.ds.setting("max_returned_rows") - table_facet_size = None - if self.table: - config_facet_size = ( - self.ds.config.get("databases", {}) - .get(self.database, {}) - .get("tables", {}) - .get(self.table, {}) - .get("facet_size") - ) - if config_facet_size: - table_facet_size = config_facet_size - custom_facet_size = self.request.args.get("_facet_size") - if custom_facet_size: - if custom_facet_size == "max": - facet_size = max_returned_rows - elif custom_facet_size.isdigit(): - facet_size = int(custom_facet_size) - else: - # Invalid value, ignore it - custom_facet_size = None - if table_facet_size and not custom_facet_size: - if table_facet_size == "max": - facet_size = max_returned_rows - else: - facet_size = table_facet_size - return min(facet_size, max_returned_rows) - - async def suggest(self): - return [] - - async def facet_results(self): - # returns ([results], [timed_out]) - # TODO: Include "hideable" with each one somehow, which indicates if it was - # defined in metadata (in which case you cannot turn it off) - raise NotImplementedError - - async def get_columns(self, sql, params=None): - # Detect column names using the "limit 0" trick - return ( - await self.ds.execute( - self.database, f"select * from ({sql}) limit 0", params or [] - ) - ).columns - - -class ColumnFacet(Facet): - type = "column" - - async def suggest(self): - row_count = await self.get_row_count() - columns = await self.get_columns(self.sql, self.params) - facet_size = self.get_facet_size() - suggested_facets = [] - already_enabled = [c["config"]["simple"] for c in self.get_configs()] - for column in columns: - if column in already_enabled: - continue - suggested_facet_sql = """ - with limited as (select * from ({sql}) limit {suggest_consider}) - select {column} as value, count(*) as n from limited - where value is not null - group by value - limit {limit} - """.format( - column=escape_sqlite(column), - sql=self.sql, - limit=facet_size + 1, - suggest_consider=self.suggest_consider, - ) - distinct_values = None - try: - distinct_values = await self.ds.execute( - self.database, - suggested_facet_sql, - self.params, - truncate=False, - custom_time_limit=self.ds.setting("facet_suggest_time_limit_ms"), - ) - num_distinct_values = len(distinct_values) - if ( - 1 < num_distinct_values < row_count - and num_distinct_values <= facet_size - # And at least one has n > 1 - and any(r["n"] > 1 for r in distinct_values) - ): - suggested_facets.append( - { - "name": column, - "toggle_url": self.ds.absolute_url( - self.request, - self.ds.urls.path( - path_with_added_args( - self.request, {"_facet": column} - ) - ), - ), - } - ) - except QueryInterrupted: - continue - return suggested_facets - - async def get_row_count(self): - if self.row_count is None: - self.row_count = ( - await self.ds.execute( - self.database, - f"select count(*) from (select * from ({self.sql}) limit {self.suggest_consider})", - self.params, - ) - ).rows[0][0] - return self.row_count - - async def facet_results(self): - facet_results = [] - facets_timed_out = [] - - qs_pairs = self.get_querystring_pairs() - - facet_size = self.get_facet_size() - for source_and_config in self.get_configs(): - config = source_and_config["config"] - source = source_and_config["source"] - column = config.get("column") or config["simple"] - facet_sql = """ - select {col} as value, count(*) as count from ( - {sql} - ) - where {col} is not null - group by {col} order by count desc, value limit {limit} - """.format(col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1) - try: - facet_rows_results = await self.ds.execute( - self.database, - facet_sql, - self.params, - truncate=False, - custom_time_limit=self.ds.setting("facet_time_limit_ms"), - ) - facet_results_values = [] - facet_results.append( - { - "name": column, - "type": self.type, - "hideable": source != "metadata", - "toggle_url": self.ds.urls.path( - path_with_removed_args(self.request, {"_facet": column}) - ), - "results": facet_results_values, - "truncated": len(facet_rows_results) > facet_size, - } - ) - facet_rows = facet_rows_results.rows[:facet_size] - if self.table: - # Attempt to expand foreign keys into labels - values = [row["value"] for row in facet_rows] - expanded = await self.ds.expand_foreign_keys( - self.request.actor, self.database, self.table, column, values - ) - else: - expanded = {} - for row in facet_rows: - column_qs = column - if column.startswith("_"): - column_qs = "{}__exact".format(column) - selected = (column_qs, str(row["value"])) in qs_pairs - if selected: - toggle_path = path_with_removed_args( - self.request, {column_qs: str(row["value"])} - ) - else: - toggle_path = path_with_added_args( - self.request, {column_qs: row["value"]} - ) - facet_results_values.append( - { - "value": row["value"], - "label": expanded.get((column, row["value"]), row["value"]), - "count": row["count"], - "toggle_url": self.ds.absolute_url( - self.request, self.ds.urls.path(toggle_path) - ), - "selected": selected, - } - ) - except QueryInterrupted: - facets_timed_out.append(column) - - return facet_results, facets_timed_out - - -class ArrayFacet(Facet): - type = "array" - - def _is_json_array_of_strings(self, json_string): - try: - array = json.loads(json_string) - except ValueError: - return False - for item in array: - if not isinstance(item, str): - return False - return True - - async def suggest(self): - columns = await self.get_columns(self.sql, self.params) - suggested_facets = [] - already_enabled = [c["config"]["simple"] for c in self.get_configs()] - for column in columns: - if column in already_enabled: - continue - # Is every value in this column either null or a JSON array? - suggested_facet_sql = """ - with limited as (select * from ({sql}) limit {suggest_consider}) - select distinct json_type({column}) - from limited - where {column} is not null and {column} != '' - """.format( - column=escape_sqlite(column), - sql=self.sql, - suggest_consider=self.suggest_consider, - ) - try: - results = await self.ds.execute( - self.database, - suggested_facet_sql, - self.params, - truncate=False, - custom_time_limit=self.ds.setting("facet_suggest_time_limit_ms"), - log_sql_errors=False, - ) - types = tuple(r[0] for r in results.rows) - if types in (("array",), ("array", None)): - # Now check that first 100 arrays contain only strings - first_100 = [ - v[0] - for v in await self.ds.execute( - self.database, - ( - "select {column} from ({sql}) " - "where {column} is not null " - "and {column} != '' " - "and json_array_length({column}) > 0 " - "limit 100" - ).format(column=escape_sqlite(column), sql=self.sql), - self.params, - truncate=False, - custom_time_limit=self.ds.setting( - "facet_suggest_time_limit_ms" - ), - log_sql_errors=False, - ) - ] - if first_100 and all( - self._is_json_array_of_strings(r) for r in first_100 - ): - suggested_facets.append( - { - "name": column, - "type": "array", - "toggle_url": self.ds.absolute_url( - self.request, - self.ds.urls.path( - path_with_added_args( - self.request, {"_facet_array": column} - ) - ), - ), - } - ) - except (QueryInterrupted, sqlite3.OperationalError): - continue - return suggested_facets - - async def facet_results(self): - # self.configs should be a plain list of columns - facet_results = [] - facets_timed_out = [] - - facet_size = self.get_facet_size() - for source_and_config in self.get_configs(): - config = source_and_config["config"] - source = source_and_config["source"] - column = config.get("column") or config["simple"] - # https://github.com/simonw/datasette/issues/448 - facet_sql = """ - with inner as ({sql}), - deduped_array_items as ( - select - distinct j.value, - inner.* - from - json_each([inner].{col}) j - join inner - ) - select - value as value, - count(*) as count - from - deduped_array_items - group by - value - order by - count(*) desc, value limit {limit} - """.format( - col=escape_sqlite(column), - sql=self.sql, - limit=facet_size + 1, - ) - try: - facet_rows_results = await self.ds.execute( - self.database, - facet_sql, - self.params, - truncate=False, - custom_time_limit=self.ds.setting("facet_time_limit_ms"), - ) - facet_results_values = [] - facet_results.append( - { - "name": column, - "type": self.type, - "results": facet_results_values, - "hideable": source != "metadata", - "toggle_url": self.ds.urls.path( - path_with_removed_args( - self.request, {"_facet_array": column} - ) - ), - "truncated": len(facet_rows_results) > facet_size, - } - ) - facet_rows = facet_rows_results.rows[:facet_size] - pairs = self.get_querystring_pairs() - for row in facet_rows: - value = str(row["value"]) - selected = (f"{column}__arraycontains", value) in pairs - if selected: - toggle_path = path_with_removed_args( - self.request, {f"{column}__arraycontains": value} - ) - else: - toggle_path = path_with_added_args( - self.request, {f"{column}__arraycontains": value} - ) - facet_results_values.append( - { - "value": value, - "label": value, - "count": row["count"], - "toggle_url": self.ds.absolute_url( - self.request, toggle_path - ), - "selected": selected, - } - ) - except QueryInterrupted: - facets_timed_out.append(column) - - return facet_results, facets_timed_out - - -class DateFacet(Facet): - type = "date" - - async def suggest(self): - columns = await self.get_columns(self.sql, self.params) - already_enabled = [c["config"]["simple"] for c in self.get_configs()] - suggested_facets = [] - for column in columns: - if column in already_enabled: - continue - # Does this column contain any dates in the first 100 rows? - suggested_facet_sql = """ - select date({column}) from ( - select * from ({sql}) limit 100 - ) where {column} glob "????-??-*" - """.format(column=escape_sqlite(column), sql=self.sql) - try: - results = await self.ds.execute( - self.database, - suggested_facet_sql, - self.params, - truncate=False, - custom_time_limit=self.ds.setting("facet_suggest_time_limit_ms"), - log_sql_errors=False, - ) - values = tuple(r[0] for r in results.rows) - if any(values): - suggested_facets.append( - { - "name": column, - "type": "date", - "toggle_url": self.ds.absolute_url( - self.request, - self.ds.urls.path( - path_with_added_args( - self.request, {"_facet_date": column} - ) - ), - ), - } - ) - except (QueryInterrupted, sqlite3.OperationalError): - continue - return suggested_facets - - async def facet_results(self): - facet_results = [] - facets_timed_out = [] - args = dict(self.get_querystring_pairs()) - facet_size = self.get_facet_size() - for source_and_config in self.get_configs(): - config = source_and_config["config"] - source = source_and_config["source"] - column = config.get("column") or config["simple"] - # TODO: does this query break if inner sql produces value or count columns? - facet_sql = """ - select date({col}) as value, count(*) as count from ( - {sql} - ) - where date({col}) is not null - group by date({col}) order by count desc, value limit {limit} - """.format(col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1) - try: - facet_rows_results = await self.ds.execute( - self.database, - facet_sql, - self.params, - truncate=False, - custom_time_limit=self.ds.setting("facet_time_limit_ms"), - ) - facet_results_values = [] - facet_results.append( - { - "name": column, - "type": self.type, - "results": facet_results_values, - "hideable": source != "metadata", - "toggle_url": path_with_removed_args( - self.request, {"_facet_date": column} - ), - "truncated": len(facet_rows_results) > facet_size, - } - ) - facet_rows = facet_rows_results.rows[:facet_size] - for row in facet_rows: - selected = str(args.get(f"{column}__date")) == str(row["value"]) - if selected: - toggle_path = path_with_removed_args( - self.request, {f"{column}__date": str(row["value"])} - ) - else: - toggle_path = path_with_added_args( - self.request, {f"{column}__date": row["value"]} - ) - facet_results_values.append( - { - "value": row["value"], - "label": row["value"], - "count": row["count"], - "toggle_url": self.ds.absolute_url( - self.request, toggle_path - ), - "selected": selected, - } - ) - except QueryInterrupted: - facets_timed_out.append(column) - - return facet_results, facets_timed_out diff --git a/datasette/filters.py b/datasette/filters.py deleted file mode 100644 index 95cc5f37..00000000 --- a/datasette/filters.py +++ /dev/null @@ -1,427 +0,0 @@ -from datasette import hookimpl -from datasette.resources import DatabaseResource -from datasette.views.base import DatasetteError -from datasette.utils.asgi import BadRequest -import json -from .utils import detect_json1, escape_sqlite, path_with_removed_args - - -@hookimpl(specname="filters_from_request") -def where_filters(request, database, datasette): - # This one deals with ?_where= - async def inner(): - where_clauses = [] - extra_wheres_for_ui = [] - if "_where" in request.args: - if not await datasette.allowed( - action="execute-sql", - resource=DatabaseResource(database=database), - actor=request.actor, - ): - raise DatasetteError("_where= is not allowed", status=403) - else: - where_clauses.extend(request.args.getlist("_where")) - extra_wheres_for_ui = [ - { - "text": text, - "remove_url": path_with_removed_args(request, {"_where": text}), - } - for text in request.args.getlist("_where") - ] - - return FilterArguments( - where_clauses, - extra_context={ - "extra_wheres_for_ui": extra_wheres_for_ui, - }, - ) - - return inner - - -@hookimpl(specname="filters_from_request") -def search_filters(request, database, table, datasette): - # ?_search= and _search_colname= - async def inner(): - where_clauses = [] - params = {} - human_descriptions = [] - extra_context = {} - - # Figure out which fts_table to use - table_metadata = await datasette.table_config(database, table) - db = datasette.get_database(database) - fts_table = request.args.get("_fts_table") - fts_table = fts_table or table_metadata.get("fts_table") - fts_table = fts_table or await db.fts_table(table) - fts_pk = request.args.get("_fts_pk", table_metadata.get("fts_pk", "rowid")) - search_args = { - key: request.args[key] - for key in request.args - if key.startswith("_search") and key != "_searchmode" - } - search = "" - search_mode_raw = table_metadata.get("searchmode") == "raw" - # Or set search mode from the querystring - qs_searchmode = request.args.get("_searchmode") - if qs_searchmode == "escaped": - search_mode_raw = False - if qs_searchmode == "raw": - search_mode_raw = True - - extra_context["supports_search"] = bool(fts_table) - - if fts_table and search_args: - if "_search" in search_args: - # Simple ?_search=xxx - search = search_args["_search"] - where_clauses.append( - "{fts_pk} in (select rowid from {fts_table} where {fts_table} match {match_clause})".format( - fts_table=escape_sqlite(fts_table), - fts_pk=escape_sqlite(fts_pk), - match_clause=( - ":search" if search_mode_raw else "escape_fts(:search)" - ), - ) - ) - human_descriptions.append(f'search matches "{search}"') - params["search"] = search - extra_context["search"] = search - else: - # More complex: search against specific columns - for i, (key, search_text) in enumerate(search_args.items()): - search_col = key.split("_search_", 1)[1] - if search_col not in await db.table_columns(fts_table): - raise BadRequest("Cannot search by that column") - - where_clauses.append( - "rowid in (select rowid from {fts_table} where {search_col} match {match_clause})".format( - fts_table=escape_sqlite(fts_table), - search_col=escape_sqlite(search_col), - match_clause=( - ":search_{}".format(i) - if search_mode_raw - else "escape_fts(:search_{})".format(i) - ), - ) - ) - human_descriptions.append( - f'search column "{search_col}" matches "{search_text}"' - ) - params[f"search_{i}"] = search_text - extra_context["search"] = search_text - - return FilterArguments(where_clauses, params, human_descriptions, extra_context) - - return inner - - -@hookimpl(specname="filters_from_request") -def through_filters(request, database, table, datasette): - # ?_search= and _search_colname= - async def inner(): - where_clauses = [] - params = {} - human_descriptions = [] - extra_context = {} - - # Support for ?_through={table, column, value} - if "_through" in request.args: - for through in request.args.getlist("_through"): - through_data = json.loads(through) - through_table = through_data["table"] - other_column = through_data["column"] - value = through_data["value"] - db = datasette.get_database(database) - outgoing_foreign_keys = await db.foreign_keys_for_table(through_table) - try: - fk_to_us = [ - fk for fk in outgoing_foreign_keys if fk["other_table"] == table - ][0] - except IndexError: - raise DatasetteError( - "Invalid _through - could not find corresponding foreign key" - ) - param = f"p{len(params)}" - where_clauses.append( - "{our_pk} in (select {our_column} from {through_table} where {other_column} = :{param})".format( - through_table=escape_sqlite(through_table), - our_pk=escape_sqlite(fk_to_us["other_column"]), - our_column=escape_sqlite(fk_to_us["column"]), - other_column=escape_sqlite(other_column), - param=param, - ) - ) - params[param] = value - human_descriptions.append(f'{through_table}.{other_column} = "{value}"') - - return FilterArguments(where_clauses, params, human_descriptions, extra_context) - - return inner - - -class FilterArguments: - def __init__( - self, where_clauses, params=None, human_descriptions=None, extra_context=None - ): - self.where_clauses = where_clauses - self.params = params or {} - self.human_descriptions = human_descriptions or [] - self.extra_context = extra_context or {} - - -class Filter: - key = None - display = None - no_argument = False - - def where_clause(self, table, column, value, param_counter): - raise NotImplementedError - - def human_clause(self, column, value): - raise NotImplementedError - - -class TemplatedFilter(Filter): - def __init__( - self, - key, - display, - sql_template, - human_template, - format="{}", - numeric=False, - no_argument=False, - ): - self.key = key - self.display = display - self.sql_template = sql_template - self.human_template = human_template - self.format = format - self.numeric = numeric - self.no_argument = no_argument - - def where_clause(self, table, column, value, param_counter): - converted = self.format.format(value) - if self.numeric and converted.isdigit(): - converted = int(converted) - if self.no_argument: - kwargs = {"c": column} - converted = None - else: - kwargs = {"c": column, "p": f"p{param_counter}", "t": table} - return self.sql_template.format(**kwargs), converted - - def human_clause(self, column, value): - if callable(self.human_template): - template = self.human_template(column, value) - else: - template = self.human_template - if self.no_argument: - return template.format(c=column) - else: - return template.format(c=column, v=value) - - -class InFilter(Filter): - key = "in" - display = "in" - - def split_value(self, value): - if value.startswith("["): - return json.loads(value) - else: - return [v.strip() for v in value.split(",")] - - def where_clause(self, table, column, value, param_counter): - values = self.split_value(value) - params = [f":p{param_counter + i}" for i in range(len(values))] - sql = f"{escape_sqlite(column)} in ({', '.join(params)})" - return sql, values - - def human_clause(self, column, value): - return f"{column} in {json.dumps(self.split_value(value))}" - - -class NotInFilter(InFilter): - key = "notin" - display = "not in" - - def where_clause(self, table, column, value, param_counter): - values = self.split_value(value) - params = [f":p{param_counter + i}" for i in range(len(values))] - sql = f"{escape_sqlite(column)} not in ({', '.join(params)})" - return sql, values - - def human_clause(self, column, value): - return f"{column} not in {json.dumps(self.split_value(value))}" - - -class Filters: - _filters = ( - [ - # key, display, sql_template, human_template, format=, numeric=, no_argument= - TemplatedFilter( - "exact", - "=", - '"{c}" = :{p}', - lambda c, v: "{c} = {v}" if v.isdigit() else '{c} = "{v}"', - ), - TemplatedFilter( - "not", - "!=", - '"{c}" != :{p}', - lambda c, v: "{c} != {v}" if v.isdigit() else '{c} != "{v}"', - ), - TemplatedFilter( - "contains", - "contains", - '"{c}" like :{p}', - '{c} contains "{v}"', - format="%{}%", - ), - TemplatedFilter( - "notcontains", - "does not contain", - '"{c}" not like :{p}', - '{c} does not contain "{v}"', - format="%{}%", - ), - TemplatedFilter( - "endswith", - "ends with", - '"{c}" like :{p}', - '{c} ends with "{v}"', - format="%{}", - ), - TemplatedFilter( - "startswith", - "starts with", - '"{c}" like :{p}', - '{c} starts with "{v}"', - format="{}%", - ), - TemplatedFilter("gt", ">", '"{c}" > :{p}', "{c} > {v}", numeric=True), - TemplatedFilter( - "gte", "\u2265", '"{c}" >= :{p}', "{c} \u2265 {v}", numeric=True - ), - TemplatedFilter("lt", "<", '"{c}" < :{p}', "{c} < {v}", numeric=True), - TemplatedFilter( - "lte", "\u2264", '"{c}" <= :{p}', "{c} \u2264 {v}", numeric=True - ), - TemplatedFilter("like", "like", '"{c}" like :{p}', '{c} like "{v}"'), - TemplatedFilter( - "notlike", "not like", '"{c}" not like :{p}', '{c} not like "{v}"' - ), - TemplatedFilter("glob", "glob", '"{c}" glob :{p}', '{c} glob "{v}"'), - InFilter(), - NotInFilter(), - ] - + ( - [ - TemplatedFilter( - "arraycontains", - "array contains", - """:{p} in (select value from json_each([{t}].[{c}]))""", - '{c} contains "{v}"', - ), - TemplatedFilter( - "arraynotcontains", - "array does not contain", - """:{p} not in (select value from json_each([{t}].[{c}]))""", - '{c} does not contain "{v}"', - ), - ] - if detect_json1() - else [] - ) - + [ - TemplatedFilter( - "date", "date", 'date("{c}") = :{p}', '"{c}" is on date {v}' - ), - TemplatedFilter( - "isnull", "is null", '"{c}" is null', "{c} is null", no_argument=True - ), - TemplatedFilter( - "notnull", - "is not null", - '"{c}" is not null', - "{c} is not null", - no_argument=True, - ), - TemplatedFilter( - "isblank", - "is blank", - '("{c}" is null or "{c}" = "")', - "{c} is blank", - no_argument=True, - ), - TemplatedFilter( - "notblank", - "is not blank", - '("{c}" is not null and "{c}" != "")', - "{c} is not blank", - no_argument=True, - ), - ] - ) - _filters_by_key = {f.key: f for f in _filters} - - def __init__(self, pairs): - self.pairs = pairs - - def lookups(self): - """Yields (lookup, display, no_argument) pairs""" - for filter in self._filters: - yield filter.key, filter.display, filter.no_argument - - def human_description_en(self, extra=None): - bits = [] - if extra: - bits.extend(extra) - for column, lookup, value in self.selections(): - filter = self._filters_by_key.get(lookup, None) - if filter: - bits.append(filter.human_clause(column, value)) - # Comma separated, with an ' and ' at the end - and_bits = [] - commas, tail = bits[:-1], bits[-1:] - if commas: - and_bits.append(", ".join(commas)) - if tail: - and_bits.append(tail[0]) - s = " and ".join(and_bits) - if not s: - return "" - return f"where {s}" - - def selections(self): - """Yields (column, lookup, value) tuples""" - for key, value in self.pairs: - if "__" in key: - column, lookup = key.rsplit("__", 1) - else: - column = key - lookup = "exact" - yield column, lookup, value - - def has_selections(self): - return bool(self.pairs) - - def build_where_clauses(self, table): - sql_bits = [] - params = {} - i = 0 - for column, lookup, value in self.selections(): - filter = self._filters_by_key.get(lookup, None) - if filter: - sql_bit, param = filter.where_clause(table, column, value, i) - sql_bits.append(sql_bit) - if param is not None: - if not isinstance(param, list): - param = [param] - for individual_param in param: - param_id = f"p{i}" - params[param_id] = individual_param - i += 1 - return sql_bits, params diff --git a/datasette/fixtures.py b/datasette/fixtures.py deleted file mode 100644 index 7c85e16a..00000000 --- a/datasette/fixtures.py +++ /dev/null @@ -1,415 +0,0 @@ -from datasette.utils.sqlite import sqlite3 -from datasette.utils import documented -import itertools -import random -import string - -__all__ = [ - "EXTRA_DATABASE_SQL", - "TABLES", - "TABLE_PARAMETERIZED_SQL", - "generate_compound_rows", - "generate_sortable_rows", - "populate_extra_database", - "populate_fixture_database", - "write_extra_database", - "write_fixture_database", -] - - -def generate_compound_rows(num): - """Generate rows for the compound_three_primary_keys fixture table.""" - for a, b, c in itertools.islice( - itertools.product(string.ascii_lowercase, repeat=3), num - ): - yield a, b, c, f"{a}-{b}-{c}" - - -def generate_sortable_rows(num): - """Generate rows for the sortable fixture table.""" - rand = random.Random(42) - for a, b in itertools.islice( - itertools.product(string.ascii_lowercase, repeat=2), num - ): - yield { - "pk1": a, - "pk2": b, - "content": f"{a}-{b}", - "sortable": rand.randint(-100, 100), - "sortable_with_nulls": rand.choice([None, rand.random(), rand.random()]), - "sortable_with_nulls_2": rand.choice([None, rand.random(), rand.random()]), - "text": rand.choice(["$null", "$blah"]), - } - - -TABLES = ( - """ -CREATE TABLE simple_primary_key ( - id integer primary key, - content text -); - -CREATE TABLE primary_key_multiple_columns ( - id varchar(30) primary key, - content text, - content2 text -); - -CREATE TABLE primary_key_multiple_columns_explicit_label ( - id varchar(30) primary key, - content text, - content2 text -); - -CREATE TABLE compound_primary_key ( - pk1 varchar(30), - pk2 varchar(30), - content text, - PRIMARY KEY (pk1, pk2) -); - -INSERT INTO compound_primary_key VALUES ('a', 'b', 'c'); -INSERT INTO compound_primary_key VALUES ('a/b', '.c-d', 'c'); -INSERT INTO compound_primary_key VALUES ('d', 'e', 'RENDER_CELL_DEMO'); - -CREATE TABLE compound_three_primary_keys ( - pk1 varchar(30), - pk2 varchar(30), - pk3 varchar(30), - content text, - PRIMARY KEY (pk1, pk2, pk3) -); -CREATE INDEX idx_compound_three_primary_keys_content ON compound_three_primary_keys(content); - -CREATE TABLE foreign_key_references ( - pk varchar(30) primary key, - foreign_key_with_label integer, - foreign_key_with_blank_label integer, - foreign_key_with_no_label varchar(30), - foreign_key_compound_pk1 varchar(30), - foreign_key_compound_pk2 varchar(30), - FOREIGN KEY (foreign_key_with_label) REFERENCES simple_primary_key(id), - FOREIGN KEY (foreign_key_with_blank_label) REFERENCES simple_primary_key(id), - FOREIGN KEY (foreign_key_with_no_label) REFERENCES primary_key_multiple_columns(id) - FOREIGN KEY (foreign_key_compound_pk1, foreign_key_compound_pk2) REFERENCES compound_primary_key(pk1, pk2) -); - -CREATE TABLE sortable ( - pk1 varchar(30), - pk2 varchar(30), - content text, - sortable integer, - sortable_with_nulls real, - sortable_with_nulls_2 real, - text text, - PRIMARY KEY (pk1, pk2) -); - -CREATE TABLE no_primary_key ( - content text, - a text, - b text, - c text -); - -CREATE TABLE [123_starts_with_digits] ( - content text -); - -CREATE VIEW paginated_view AS - SELECT - content, - '- ' || content || ' -' AS content_extra - FROM no_primary_key; - -CREATE TABLE "Table With Space In Name" ( - pk varchar(30) primary key, - content text -); - -CREATE TABLE "table/with/slashes.csv" ( - pk varchar(30) primary key, - content text -); - -CREATE TABLE "complex_foreign_keys" ( - pk varchar(30) primary key, - f1 integer, - f2 integer, - f3 integer, - FOREIGN KEY ("f1") REFERENCES [simple_primary_key](id), - FOREIGN KEY ("f2") REFERENCES [simple_primary_key](id), - FOREIGN KEY ("f3") REFERENCES [simple_primary_key](id) -); - -CREATE TABLE "custom_foreign_key_label" ( - pk varchar(30) primary key, - foreign_key_with_custom_label text, - FOREIGN KEY ("foreign_key_with_custom_label") REFERENCES [primary_key_multiple_columns_explicit_label](id) -); - -CREATE TABLE tags ( - tag TEXT PRIMARY KEY -); - -CREATE TABLE searchable ( - pk integer primary key, - text1 text, - text2 text, - [name with . and spaces] text -); - -CREATE TABLE searchable_tags ( - searchable_id integer, - tag text, - PRIMARY KEY (searchable_id, tag), - FOREIGN KEY (searchable_id) REFERENCES searchable(pk), - FOREIGN KEY (tag) REFERENCES tags(tag) -); - -INSERT INTO searchable VALUES (1, 'barry cat', 'terry dog', 'panther'); -INSERT INTO searchable VALUES (2, 'terry dog', 'sara weasel', 'puma'); - -INSERT INTO tags VALUES ("canine"); -INSERT INTO tags VALUES ("feline"); - -INSERT INTO searchable_tags (searchable_id, tag) VALUES - (1, "feline"), - (2, "canine") -; - -CREATE VIRTUAL TABLE "searchable_fts" - USING FTS5 (text1, text2, [name with . and spaces], content="searchable", content_rowid="pk"); -INSERT INTO "searchable_fts" (searchable_fts) VALUES ('rebuild'); - -CREATE TABLE [select] ( - [group] text, - [having] text, - [and] text, - [json] text -); -INSERT INTO [select] VALUES ('group', 'having', 'and', - '{"href": "http://example.com/", "label":"Example"}' -); - -CREATE TABLE infinity ( - value REAL -); -INSERT INTO infinity VALUES - (1e999), - (-1e999), - (1.5) -; - -CREATE TABLE facet_cities ( - id integer primary key, - name text -); -INSERT INTO facet_cities (id, name) VALUES - (1, 'San Francisco'), - (2, 'Los Angeles'), - (3, 'Detroit'), - (4, 'Memnonia') -; - -CREATE TABLE facetable ( - pk integer primary key, - created text, - planet_int integer, - on_earth integer, - state text, - _city_id integer, - _neighborhood text, - tags text, - complex_array text, - distinct_some_null, - n text, - FOREIGN KEY ("_city_id") REFERENCES [facet_cities](id) -); -INSERT INTO facetable - (created, planet_int, on_earth, state, _city_id, _neighborhood, tags, complex_array, distinct_some_null, n) -VALUES - ("2019-01-14 08:00:00", 1, 1, 'CA', 1, 'Mission', '["tag1", "tag2"]', '[{"foo": "bar"}]', 'one', 'n1'), - ("2019-01-14 08:00:00", 1, 1, 'CA', 1, 'Dogpatch', '["tag1", "tag3"]', '[]', 'two', 'n2'), - ("2019-01-14 08:00:00", 1, 1, 'CA', 1, 'SOMA', '[]', '[]', null, null), - ("2019-01-14 08:00:00", 1, 1, 'CA', 1, 'Tenderloin', '[]', '[]', null, null), - ("2019-01-15 08:00:00", 1, 1, 'CA', 1, 'Bernal Heights', '[]', '[]', null, null), - ("2019-01-15 08:00:00", 1, 1, 'CA', 1, 'Hayes Valley', '[]', '[]', null, null), - ("2019-01-15 08:00:00", 1, 1, 'CA', 2, 'Hollywood', '[]', '[]', null, null), - ("2019-01-15 08:00:00", 1, 1, 'CA', 2, 'Downtown', '[]', '[]', null, null), - ("2019-01-16 08:00:00", 1, 1, 'CA', 2, 'Los Feliz', '[]', '[]', null, null), - ("2019-01-16 08:00:00", 1, 1, 'CA', 2, 'Koreatown', '[]', '[]', null, null), - ("2019-01-16 08:00:00", 1, 1, 'MI', 3, 'Downtown', '[]', '[]', null, null), - ("2019-01-17 08:00:00", 1, 1, 'MI', 3, 'Greektown', '[]', '[]', null, null), - ("2019-01-17 08:00:00", 1, 1, 'MI', 3, 'Corktown', '[]', '[]', null, null), - ("2019-01-17 08:00:00", 1, 1, 'MI', 3, 'Mexicantown', '[]', '[]', null, null), - ("2019-01-17 08:00:00", 2, 0, 'MC', 4, 'Arcadia Planitia', '[]', '[]', null, null) -; - -CREATE TABLE binary_data ( - data BLOB -); - --- Many 2 Many demo: roadside attractions! - -CREATE TABLE roadside_attractions ( - pk integer primary key, - name text, - address text, - url text, - latitude real, - longitude real -); -INSERT INTO roadside_attractions VALUES ( - 1, "The Mystery Spot", "465 Mystery Spot Road, Santa Cruz, CA 95065", "https://www.mysteryspot.com/", - 37.0167, -122.0024 -); -INSERT INTO roadside_attractions VALUES ( - 2, "Winchester Mystery House", "525 South Winchester Boulevard, San Jose, CA 95128", "https://winchestermysteryhouse.com/", - 37.3184, -121.9511 -); -INSERT INTO roadside_attractions VALUES ( - 3, "Burlingame Museum of PEZ Memorabilia", "214 California Drive, Burlingame, CA 94010", null, - 37.5793, -122.3442 -); -INSERT INTO roadside_attractions VALUES ( - 4, "Bigfoot Discovery Museum", "5497 Highway 9, Felton, CA 95018", "https://www.bigfootdiscoveryproject.com/", - 37.0414, -122.0725 -); - -CREATE TABLE attraction_characteristic ( - pk integer primary key, - name text -); -INSERT INTO attraction_characteristic VALUES ( - 1, "Museum" -); -INSERT INTO attraction_characteristic VALUES ( - 2, "Paranormal" -); - -CREATE TABLE roadside_attraction_characteristics ( - attraction_id INTEGER REFERENCES roadside_attractions(pk), - characteristic_id INTEGER REFERENCES attraction_characteristic(pk) -); -INSERT INTO roadside_attraction_characteristics VALUES ( - 1, 2 -); -INSERT INTO roadside_attraction_characteristics VALUES ( - 2, 2 -); -INSERT INTO roadside_attraction_characteristics VALUES ( - 4, 2 -); -INSERT INTO roadside_attraction_characteristics VALUES ( - 3, 1 -); -INSERT INTO roadside_attraction_characteristics VALUES ( - 4, 1 -); - -INSERT INTO simple_primary_key VALUES (1, 'hello'); -INSERT INTO simple_primary_key VALUES (2, 'world'); -INSERT INTO simple_primary_key VALUES (3, ''); -INSERT INTO simple_primary_key VALUES (4, 'RENDER_CELL_DEMO'); -INSERT INTO simple_primary_key VALUES (5, 'RENDER_CELL_ASYNC'); - -INSERT INTO primary_key_multiple_columns VALUES (1, 'hey', 'world'); -INSERT INTO primary_key_multiple_columns_explicit_label VALUES (1, 'hey', 'world2'); - -INSERT INTO foreign_key_references VALUES (1, 1, 3, 1, 'a', 'b'); -INSERT INTO foreign_key_references VALUES (2, null, null, null, null, null); - -INSERT INTO complex_foreign_keys VALUES (1, 1, 2, 1); -INSERT INTO custom_foreign_key_label VALUES (1, 1); - -INSERT INTO [table/with/slashes.csv] VALUES (3, 'hey'); - -CREATE VIEW simple_view AS - SELECT content, upper(content) AS upper_content FROM simple_primary_key; - -CREATE VIEW searchable_view AS - SELECT * from searchable; - -CREATE VIEW searchable_view_configured_by_metadata AS - SELECT * from searchable; - -""" - + "\n".join( - [ - 'INSERT INTO no_primary_key VALUES ({i}, "a{i}", "b{i}", "c{i}");'.format( - i=i + 1 - ) - for i in range(201) - ] - ) - + '\nINSERT INTO no_primary_key VALUES ("RENDER_CELL_DEMO", "a202", "b202", "c202");\n' - + "\n".join( - [ - 'INSERT INTO compound_three_primary_keys VALUES ("{a}", "{b}", "{c}", "{content}");'.format( - a=a, b=b, c=c, content=content - ) - for a, b, c, content in generate_compound_rows(1001) - ] - ) - + "\n".join(["""INSERT INTO sortable VALUES ( - "{pk1}", "{pk2}", "{content}", {sortable}, - {sortable_with_nulls}, {sortable_with_nulls_2}, "{text}"); - """.format(**row).replace("None", "null") for row in generate_sortable_rows(201)]) -) - -TABLE_PARAMETERIZED_SQL = [ - ("insert into binary_data (data) values (?);", [b"\x15\x1c\x02\xc7\xad\x05\xfe"]), - ("insert into binary_data (data) values (?);", [b"\x15\x1c\x03\xc7\xad\x05\xfe"]), - ("insert into binary_data (data) values (null);", []), -] - -EXTRA_DATABASE_SQL = """ -CREATE TABLE searchable ( - pk integer primary key, - text1 text, - text2 text -); - -CREATE VIEW searchable_view AS SELECT * FROM searchable; - -INSERT INTO searchable VALUES (1, 'barry cat', 'terry dog'); -INSERT INTO searchable VALUES (2, 'terry dog', 'sara weasel'); - -CREATE VIRTUAL TABLE "searchable_fts" - USING FTS3 (text1, text2, content="searchable"); -INSERT INTO "searchable_fts" (rowid, text1, text2) - SELECT rowid, text1, text2 FROM searchable; -""" - - -@documented(label="datasette_fixtures_populate_fixture_database") -def populate_fixture_database(conn): - """Populate a SQLite connection with Datasette's test fixture tables.""" - conn.executescript(TABLES) - for sql, params in TABLE_PARAMETERIZED_SQL: - with conn: - conn.execute(sql, params) - - -def populate_extra_database(conn): - """Populate a SQLite connection with the extra database used in tests.""" - conn.executescript(EXTRA_DATABASE_SQL) - - -def write_fixture_database(db_filename): - """Write Datasette's test fixture tables to a SQLite database file.""" - conn = sqlite3.connect(db_filename) - try: - populate_fixture_database(conn) - finally: - conn.close() - - -def write_extra_database(db_filename): - """Write the extra test database tables to a SQLite database file.""" - conn = sqlite3.connect(db_filename) - try: - populate_extra_database(conn) - finally: - conn.close() diff --git a/datasette/forbidden.py b/datasette/forbidden.py deleted file mode 100644 index 41c48396..00000000 --- a/datasette/forbidden.py +++ /dev/null @@ -1,19 +0,0 @@ -from datasette import hookimpl, Response - - -@hookimpl(trylast=True) -def forbidden(datasette, request, message): - async def inner(): - return Response.html( - await datasette.render_template( - "error.html", - { - "title": "Forbidden", - "error": message, - }, - request=request, - ), - status=403, - ) - - return inner diff --git a/datasette/handle_exception.py b/datasette/handle_exception.py deleted file mode 100644 index 96398a4c..00000000 --- a/datasette/handle_exception.py +++ /dev/null @@ -1,77 +0,0 @@ -from datasette import hookimpl, Response -from .utils import add_cors_headers -from .utils.asgi import ( - Base400, -) -from .views.base import DatasetteError -from markupsafe import Markup -import traceback - -try: - import ipdb as pdb -except ImportError: - import pdb - -try: - import rich -except ImportError: - rich = None - - -@hookimpl(trylast=True) -def handle_exception(datasette, request, exception): - async def inner(): - if datasette.pdb: - pdb.post_mortem(exception.__traceback__) - - if rich is not None: - rich.get_console().print_exception(show_locals=True) - - title = None - if isinstance(exception, Base400): - status = exception.status - info = {} - message = exception.args[0] - elif isinstance(exception, DatasetteError): - status = exception.status - info = exception.error_dict - message = exception.message - if exception.message_is_html: - message = Markup(message) - title = exception.title - else: - status = 500 - info = {} - message = str(exception) - traceback.print_exc() - templates = [f"{status}.html", "error.html"] - info.update( - { - "ok": False, - "error": message, - "status": status, - "title": title, - } - ) - headers = {} - if datasette.cors: - add_cors_headers(headers) - if request.path.split("?")[0].endswith(".json"): - return Response.json(info, status=status, headers=headers) - else: - environment = datasette.get_jinja_environment(request) - template = environment.select_template(templates) - return Response.html( - await template.render_async( - dict( - info, - urls=datasette.urls, - app_css_hash=datasette.app_css_hash(), - menu_links=lambda: [], - ) - ), - status=status, - headers=headers, - ) - - return inner diff --git a/datasette/hookspecs.py b/datasette/hookspecs.py index dcd502af..240b58db 100644 --- a/datasette/hookspecs.py +++ b/datasette/hookspecs.py @@ -1,265 +1,25 @@ from pluggy import HookimplMarker from pluggy import HookspecMarker -hookspec = HookspecMarker("datasette") -hookimpl = HookimplMarker("datasette") +hookspec = HookspecMarker('datasette') +hookimpl = HookimplMarker('datasette') @hookspec -def startup(datasette): - """Fires directly after Datasette first starts running""" +def prepare_connection(conn): + "Modify SQLite connection in some way e.g. register custom SQL functions" @hookspec -def asgi_wrapper(datasette): - """Returns an ASGI middleware callable to wrap our ASGI application with""" +def prepare_jinja2_environment(env): + "Modify Jinja2 template environment e.g. register custom template tags" @hookspec -def prepare_connection(conn, database, datasette): - """Modify SQLite connection in some way e.g. register custom SQL functions""" +def extra_css_urls(): + "Extra CSS URLs added by this plugin" @hookspec -def prepare_jinja2_environment(env, datasette): - """Modify Jinja2 template environment e.g. register custom template tags""" - - -@hookspec -def extra_css_urls(template, database, table, columns, view_name, request, datasette): - """Extra CSS URLs added by this plugin""" - - -@hookspec -def extra_js_urls(template, database, table, columns, view_name, request, datasette): - """Extra JavaScript URLs added by this plugin""" - - -@hookspec -def extra_body_script( - template, database, table, columns, view_name, request, datasette -): - """Extra JavaScript code to be included in diff --git a/datasette/templates/_codemirror.html b/datasette/templates/_codemirror.html index c4629aeb..237d6907 100644 --- a/datasette/templates/_codemirror.html +++ b/datasette/templates/_codemirror.html @@ -1,16 +1,7 @@ - - + + + diff --git a/datasette/templates/_codemirror_foot.html b/datasette/templates/_codemirror_foot.html index a624c8a4..1e07fc72 100644 --- a/datasette/templates/_codemirror_foot.html +++ b/datasette/templates/_codemirror_foot.html @@ -1,42 +1,13 @@ diff --git a/datasette/templates/_crumbs.html b/datasette/templates/_crumbs.html deleted file mode 100644 index bd1ff0da..00000000 --- a/datasette/templates/_crumbs.html +++ /dev/null @@ -1,15 +0,0 @@ -{% macro nav(request, database=None, table=None) -%} -{% if crumb_items is defined %} - {% set items=crumb_items(request=request, database=database, table=table) %} - {% if items %} -

- {% for item in items %} - {{ item.label }} - {% if not loop.last %} - / - {% endif %} - {% endfor %} -

- {% endif %} -{% endif %} -{%- endmacro %} diff --git a/datasette/templates/_debug_common_functions.html b/datasette/templates/_debug_common_functions.html deleted file mode 100644 index d988a2f3..00000000 --- a/datasette/templates/_debug_common_functions.html +++ /dev/null @@ -1,50 +0,0 @@ - diff --git a/datasette/templates/_description_source_license.html b/datasette/templates/_description_source_license.html index f852268f..eba4eb1a 100644 --- a/datasette/templates/_description_source_license.html +++ b/datasette/templates/_description_source_license.html @@ -1,6 +1,6 @@ -{% if metadata.get("description_html") or metadata.get("description") %} +{% if metadata.description_html or metadata.description %}
+ + + {% for column in display_columns %} + + {% endfor %} + + + + {% for row in display_rows %} + + {% for cell in row %} + + {% endfor %} + + {% endfor %} + +
+ {% if not column.sortable %} + {{ column.name }} + {% else %} + {% if column.name == sort %} + {{ column.name }} ▼ + {% else %} + {{ column.name }}{% if column.name == sort_desc %} ▲{% endif %} + {% endif %} + {% endif %} +
{{ cell.value }}
diff --git a/datasette/templates/_sql_parameter_scripts.html b/datasette/templates/_sql_parameter_scripts.html deleted file mode 100644 index 159a141c..00000000 --- a/datasette/templates/_sql_parameter_scripts.html +++ /dev/null @@ -1,293 +0,0 @@ - diff --git a/datasette/templates/_sql_parameter_styles.html b/datasette/templates/_sql_parameter_styles.html deleted file mode 100644 index bc6838f5..00000000 --- a/datasette/templates/_sql_parameter_styles.html +++ /dev/null @@ -1,58 +0,0 @@ - diff --git a/datasette/templates/_sql_parameters.html b/datasette/templates/_sql_parameters.html deleted file mode 100644 index 58801d40..00000000 --- a/datasette/templates/_sql_parameters.html +++ /dev/null @@ -1,9 +0,0 @@ -
- {% if parameter_names %} -

Parameters

- {% for parameter in parameter_names %} - {% set parameter_id = (sql_parameter_id_prefix|default("qp")) ~ loop.index %} -

{% if sql_parameters_allow_expand|default(false) %} {% endif %}

- {% endfor %} - {% endif %} -
diff --git a/datasette/templates/_suggested_facets.html b/datasette/templates/_suggested_facets.html deleted file mode 100644 index b80208c3..00000000 --- a/datasette/templates/_suggested_facets.html +++ /dev/null @@ -1,3 +0,0 @@ -

- Suggested facets: {% for facet in suggested_facets %}{{ facet.name }}{% if facet.get("type") %} ({{ facet.type }}){% endif %}{% if not loop.last %}, {% endif %}{% endfor %} -

diff --git a/datasette/templates/_table.html b/datasette/templates/_table.html deleted file mode 100644 index f47a325f..00000000 --- a/datasette/templates/_table.html +++ /dev/null @@ -1,37 +0,0 @@ - -
-{% if display_columns %} -
- - - - {% for column in display_columns %} - - {% endfor %} - - - - {% for row in display_rows %} - - {% for cell in row %} - - {% endfor %} - - {% endfor %} - -
- {% if not column.sortable %} - {{ column.name }} - {% else %} - {% if column.name == sort %} - {{ column.name }} ▼ - {% else %} - {{ column.name }}{% if column.name == sort_desc %} ▲{% endif %} - {% endif %} - {% endif %} -
{{ cell.value }}
-
-{% endif %} -{% if not display_rows %} -

0 records

-{% endif %} diff --git a/datasette/templates/allow_debug.html b/datasette/templates/allow_debug.html deleted file mode 100644 index 1ecc92df..00000000 --- a/datasette/templates/allow_debug.html +++ /dev/null @@ -1,61 +0,0 @@ -{% extends "base.html" %} - -{% block title %}Debug allow rules{% endblock %} - -{% block extra_head %} - -{% endblock %} - -{% block content %} - -

Debug allow rules

- -{% set current_tab = "allow_debug" %} -{% include "_permissions_debug_tabs.html" %} - -

Use this tool to try out different actor and allow combinations. See Defining permissions with "allow" blocks for documentation.

- -
-
-

- -
-
-

- -
-
- -
-
- -{% if error %}

{{ error }}

{% endif %} - -{% if result == "True" %}

Result: allow

{% endif %} - -{% if result == "False" %}

Result: deny

{% endif %} - -{% endblock %} diff --git a/datasette/templates/api_explorer.html b/datasette/templates/api_explorer.html deleted file mode 100644 index dc393c20..00000000 --- a/datasette/templates/api_explorer.html +++ /dev/null @@ -1,208 +0,0 @@ -{% extends "base.html" %} - -{% block title %}API Explorer{% endblock %} - -{% block extra_head %} - -{% endblock %} - -{% block content %} - -

API Explorer{% if private %} 🔒{% endif %}

- -

Use this tool to try out the - {% if datasette_version %} - Datasette API. - {% else %} - Datasette API. - {% endif %} -

-
- GET -
-
- - - -
-
-
-
- POST -
-
- - -
-
- - -
-

-
-
- - - - - -{% if example_links %} -

API endpoints

-
    - {% for database in example_links %} -
  • Database: {{ database.name }}
  • -
      - {% for link in database.links %} -
    • {{ link.path }} - {{ link.label }}
    • - {% endfor %} - {% for table in database.tables %} -
    • {{ table.name }} -
        - {% for link in table.links %} -
      • {{ link.path }} - {{ link.label }}
      • - {% endfor %} -
      -
    • - {% endfor %} -
    - {% endfor %} -
-{% endif %} - -{% endblock %} diff --git a/datasette/templates/base.html b/datasette/templates/base.html index e1767deb..382c8e92 100644 --- a/datasette/templates/base.html +++ b/datasette/templates/base.html @@ -1,76 +1,40 @@ -{% import "_crumbs.html" as crumbs with context %} - + + {% block title %}{% endblock %} - + {% for url in extra_css_urls %} - + {% endfor %} - - {% for url in extra_js_urls %} - + {% endfor %} -{%- if alternate_url_json -%} - -{%- endif -%} -{%- block extra_head %}{% endblock -%} +{% block extra_head %}{% endblock %} -
@@ -102,32 +81,23 @@ {% endif %} - {% for key, value in form_hidden_args %} - + {% for facet in sorted_facet_results %} + {% endfor %} -{% if extra_wheres_for_ui %} -
-

{{ extra_wheres_for_ui|length }} extra where clause{% if extra_wheres_for_ui|length != 1 %}s{% endif %}

-
    - {% for extra_where in extra_wheres_for_ui %} -
  • {{ extra_where.text }} [remove]
  • - {% endfor %} -
-
+{% if query.sql %} +

View and edit SQL

{% endif %} -{% if query.sql and allow_execute_sql %} -

View and edit SQL

-{% endif %} - - + {% if suggested_facets %} - {% include "_suggested_facets.html" %} +

+ Suggested facets: {% for facet in suggested_facets %}{{ facet.name }}{% if not loop.last %}, {% endif %}{% endfor %} +

{% endif %} {% if facets_timed_out %} @@ -135,103 +105,44 @@ {% endif %} {% if facet_results %} - {% include "_facet_results.html" %} +
+ {% for facet_info in sorted_facet_results %} +
+

+ {{ facet_info.name }} + {% if facet_hideable(facet_info.name) %} + + {% endif %} +

+
    + {% for facet_value in facet_info.results %} + {% if not facet_value.selected %} +
  • {{ facet_value.label or "_" }} {{ "{:,}".format(facet_value.count) }}
  • + {% else %} +
  • {{ facet_value.label }} · {{ "{:,}".format(facet_value.count) }}
  • + {% endif %} + {% endfor %} + {% if facet_info.truncated %} +
  • ...
  • + {% endif %} +
+
+ {% endfor %} +
{% endif %} -{% if all_columns %} - - - - -{% endif %} -{% if set_column_type_ui %} - -{% endif %} - -{% include custom_table_templates %} +{% include custom_rows_and_columns_templates %} {% if next_url %}

Next page

{% endif %} -{% if display_rows %} -
-

Advanced export

-

JSON shape: - default, - array, - newline-delimited{% if primary_keys %}, - object - {% endif %} -

-
-

- CSV options: - - {% if expandable_columns %}{% endif %} - {% if next_url and settings.allow_csv_stream %}{% endif %} - - {% for key, value in url_csv_hidden_args %} - - {% endfor %} -

-
-
-{% endif %} - {% if table_definition %} -
{{ table_definition }}
+
{{ table_definition }}
{% endif %} {% if view_definition %} -
{{ view_definition }}
-{% endif %} - -{% if allow_execute_sql and query.sql %} - +
{{ view_definition }}
{% endif %} {% endblock %} diff --git a/datasette/tokens.py b/datasette/tokens.py deleted file mode 100644 index 38a55529..00000000 --- a/datasette/tokens.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -Token handler system for Datasette. - -Provides a base class for token handlers and the default signed token handler. -Plugins can implement register_token_handler to provide custom token backends -(e.g. database-backed tokens that can be revoked and audited). -""" - -from __future__ import annotations - -import dataclasses -import time -from typing import TYPE_CHECKING, Optional - -import itsdangerous - -if TYPE_CHECKING: - from datasette.app import Datasette - - -@dataclasses.dataclass -class TokenRestrictions: - """ - Restrictions to apply to a token, limiting which actions it can perform. - - Use the builder methods to construct restrictions:: - - restrictions = (TokenRestrictions() - .allow_all("view-instance") - .allow_database("mydb", "create-table") - .allow_resource("mydb", "mytable", "insert-row")) - """ - - all: list[str] = dataclasses.field(default_factory=list) - database: dict[str, list[str]] = dataclasses.field(default_factory=dict) - resource: dict[str, dict[str, list[str]]] = dataclasses.field(default_factory=dict) - - def allow_all(self, action: str) -> "TokenRestrictions": - """Allow an action across all databases and resources.""" - self.all.append(action) - return self - - def allow_database(self, database: str, action: str) -> "TokenRestrictions": - """Allow an action on a specific database.""" - self.database.setdefault(database, []).append(action) - return self - - def allow_resource( - self, database: str, resource: str, action: str - ) -> "TokenRestrictions": - """Allow an action on a specific resource within a database.""" - self.resource.setdefault(database, {}).setdefault(resource, []).append(action) - return self - - def abbreviated(self, datasette: "Datasette") -> Optional[dict]: - """ - Return the abbreviated ``_r`` dictionary shape for this set of - restrictions, using action abbreviations registered with ``datasette``. - Returns ``None`` if no restrictions are set. - """ - if not (self.all or self.database or self.resource): - return None - - def abbreviate_action(action): - action_obj = datasette.actions.get(action) - if not action_obj: - return action - return action_obj.abbr or action - - result: dict = {} - if self.all: - result["a"] = [abbreviate_action(a) for a in self.all] - if self.database: - result["d"] = { - database: [abbreviate_action(a) for a in actions] - for database, actions in self.database.items() - } - if self.resource: - result["r"] = {} - for database, resources in self.resource.items(): - for resource, actions in resources.items(): - result["r"].setdefault(database, {})[resource] = [ - abbreviate_action(a) for a in actions - ] - return result - - -class TokenHandler: - """ - Base class for token handlers. - - Subclass this and implement create_token() and verify_token() to provide - a custom token backend. Return an instance from the register_token_handler hook. - """ - - name: str = "" - - async def create_token( - self, - datasette: "Datasette", - actor_id: str, - *, - expires_after: Optional[int] = None, - restrictions: Optional[TokenRestrictions] = None, - ) -> str: - """Create and return a token string for the given actor.""" - raise NotImplementedError - - async def verify_token(self, datasette: "Datasette", token: str) -> Optional[dict]: - """ - Verify a token and return an actor dict, or None if this handler - does not recognize the token. - """ - raise NotImplementedError - - -class SignedTokenHandler(TokenHandler): - """ - Default token handler using itsdangerous signed tokens (dstok_ prefix). - """ - - name = "signed" - - async def create_token( - self, - datasette: "Datasette", - actor_id: str, - *, - expires_after: Optional[int] = None, - restrictions: Optional[TokenRestrictions] = None, - ) -> str: - if not datasette.setting("allow_signed_tokens"): - raise ValueError( - "Signed tokens are not enabled for this Datasette instance" - ) - - token = {"a": actor_id, "t": int(time.time())} - - if expires_after: - token["d"] = expires_after - if restrictions is not None: - abbreviated = restrictions.abbreviated(datasette) - if abbreviated is not None: - token["_r"] = abbreviated - return "dstok_{}".format(datasette.sign(token, namespace="token")) - - async def verify_token(self, datasette: "Datasette", token: str) -> Optional[dict]: - prefix = "dstok_" - - if not datasette.setting("allow_signed_tokens"): - return None - - max_signed_tokens_ttl = datasette.setting("max_signed_tokens_ttl") - - if not token.startswith(prefix): - return None - - raw = token[len(prefix) :] - try: - decoded = datasette.unsign(raw, namespace="token") - except itsdangerous.BadSignature: - return None - - if "t" not in decoded: - return None - created = decoded["t"] - if not isinstance(created, int): - return None - - duration = decoded.get("d") - if duration is not None and not isinstance(duration, int): - return None - - if (duration is None and max_signed_tokens_ttl) or ( - duration is not None - and max_signed_tokens_ttl - and duration > max_signed_tokens_ttl - ): - duration = max_signed_tokens_ttl - - if duration: - if time.time() - created > duration: - return None - - actor = {"id": decoded["a"], "token": "dstok"} - - if "_r" in decoded: - actor["_r"] = decoded["_r"] - - if duration: - actor["token_expires"] = created + duration - - return actor diff --git a/datasette/tracer.py b/datasette/tracer.py deleted file mode 100644 index 9e66613b..00000000 --- a/datasette/tracer.py +++ /dev/null @@ -1,153 +0,0 @@ -import asyncio -from contextlib import contextmanager -from contextvars import ContextVar -from markupsafe import escape -import time -import json -import traceback - -tracers = {} - -TRACE_RESERVED_KEYS = {"type", "start", "end", "duration_ms", "traceback"} - -trace_task_id = ContextVar("trace_task_id", default=None) - - -def get_task_id(): - current = trace_task_id.get(None) - if current is not None: - return current - try: - loop = asyncio.get_event_loop() - except RuntimeError: - return None - return id(asyncio.current_task(loop=loop)) - - -@contextmanager -def trace_child_tasks(): - token = trace_task_id.set(get_task_id()) - yield - trace_task_id.reset(token) - - -@contextmanager -def trace(trace_type, **kwargs): - assert not TRACE_RESERVED_KEYS.intersection( - kwargs.keys() - ), f".trace() keyword parameters cannot include {TRACE_RESERVED_KEYS}" - task_id = get_task_id() - if task_id is None: - yield kwargs - return - tracer = tracers.get(task_id) - if tracer is None: - yield kwargs - return - start = time.perf_counter() - captured_error = None - try: - yield kwargs - except Exception as ex: - captured_error = ex - raise - finally: - end = time.perf_counter() - trace_info = { - "type": trace_type, - "start": start, - "end": end, - "duration_ms": (end - start) * 1000, - "traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]), - "error": str(captured_error) if captured_error else None, - } - trace_info.update(kwargs) - tracer.append(trace_info) - - -@contextmanager -def capture_traces(tracer): - # tracer is a list - task_id = get_task_id() - if task_id is None: - yield - return - tracers[task_id] = tracer - yield - del tracers[task_id] - - -class AsgiTracer: - # If the body is larger than this we don't attempt to append the trace - max_body_bytes = 1024 * 256 # 256 KB - - def __init__(self, app): - self.app = app - - async def __call__(self, scope, receive, send): - if b"_trace=1" not in scope.get("query_string", b"").split(b"&"): - await self.app(scope, receive, send) - return - trace_start = time.perf_counter() - traces = [] - - accumulated_body = b"" - size_limit_exceeded = False - response_headers = [] - - async def wrapped_send(message): - nonlocal accumulated_body, size_limit_exceeded, response_headers - - if message["type"] == "http.response.start": - response_headers = message["headers"] - await send(message) - return - - if message["type"] != "http.response.body" or size_limit_exceeded: - await send(message) - return - - # Accumulate body until the end or until size is exceeded - accumulated_body += message["body"] - if len(accumulated_body) > self.max_body_bytes: - # Send what we have accumulated so far - await send( - { - "type": "http.response.body", - "body": accumulated_body, - "more_body": bool(message.get("more_body")), - } - ) - size_limit_exceeded = True - return - - if not message.get("more_body"): - # We have all the body - modify it and send the result - # TODO: What to do about Content-Type or other cases? - trace_info = { - "request_duration_ms": 1000 * (time.perf_counter() - trace_start), - "sum_trace_duration_ms": sum(t["duration_ms"] for t in traces), - "num_traces": len(traces), - "traces": traces, - } - try: - content_type = [ - v.decode("utf8") - for k, v in response_headers - if k.lower() == b"content-type" - ][0] - except IndexError: - content_type = "" - if "text/html" in content_type and b"" in accumulated_body: - extra = escape(json.dumps(trace_info, indent=2)) - extra_html = f"
{extra}
".encode("utf8") - accumulated_body = accumulated_body.replace(b"", extra_html) - elif "json" in content_type and accumulated_body.startswith(b"{"): - data = json.loads(accumulated_body.decode("utf8")) - if "_trace" not in data: - data["_trace"] = trace_info - accumulated_body = json.dumps(data).encode("utf8") - await send({"type": "http.response.body", "body": accumulated_body}) - - with capture_traces(traces): - await self.app(scope, receive, wrapped_send) diff --git a/datasette/url_builder.py b/datasette/url_builder.py deleted file mode 100644 index 16b3d42b..00000000 --- a/datasette/url_builder.py +++ /dev/null @@ -1,61 +0,0 @@ -from .utils import tilde_encode, path_with_format, PrefixedUrlString -import urllib - - -class Urls: - def __init__(self, ds): - self.ds = ds - - def path(self, path, format=None): - if not isinstance(path, PrefixedUrlString): - if path.startswith("/"): - path = path[1:] - path = self.ds.setting("base_url") + path - if format is not None: - path = path_with_format(path=path, format=format) - return PrefixedUrlString(path) - - def instance(self, format=None): - return self.path("", format=format) - - def static(self, path): - return self.path(f"-/static/{path}") - - def static_plugins(self, plugin, path): - return self.path(f"-/static-plugins/{plugin}/{path}") - - def logout(self): - return self.path("-/logout") - - def database(self, database, format=None): - db = self.ds.get_database(database) - return self.path(tilde_encode(db.route), format=format) - - def database_query(self, database, sql, format=None): - path = f"{self.database(database)}/-/query?" + urllib.parse.urlencode( - {"sql": sql} - ) - return self.path(path, format=format) - - def table(self, database, table, format=None): - path = f"{self.database(database)}/{tilde_encode(table)}" - if format is not None: - path = path_with_format(path=path, format=format) - return PrefixedUrlString(path) - - def query(self, database, query, format=None): - path = f"{self.database(database)}/{tilde_encode(query)}" - if format is not None: - path = path_with_format(path=path, format=format) - return PrefixedUrlString(path) - - def row(self, database, table, row_path, format=None): - path = f"{self.table(database, table)}/{row_path}" - if format is not None: - path = path_with_format(path=path, format=format) - return PrefixedUrlString(path) - - def row_blob(self, database, table, row_path, column): - return self.table(database, table) + "/{}.blob?_blob_column={}".format( - row_path, urllib.parse.quote_plus(column) - ) diff --git a/datasette/utils.py b/datasette/utils.py new file mode 100644 index 00000000..61dbe910 --- /dev/null +++ b/datasette/utils.py @@ -0,0 +1,891 @@ +from contextlib import contextmanager +from collections import OrderedDict +import base64 +import hashlib +import imp +import json +import os +import pkg_resources +import re +import shlex +import sqlite3 +import tempfile +import time +import shutil +import urllib +import numbers + + +# From https://www.sqlite.org/lang_keywords.html +reserved_words = set(( + 'abort action add after all alter analyze and as asc attach autoincrement ' + 'before begin between by cascade case cast check collate column commit ' + 'conflict constraint create cross current_date current_time ' + 'current_timestamp database default deferrable deferred delete desc detach ' + 'distinct drop each else end escape except exclusive exists explain fail ' + 'for foreign from full glob group having if ignore immediate in index ' + 'indexed initially inner insert instead intersect into is isnull join key ' + 'left like limit match natural no not notnull null of offset on or order ' + 'outer plan pragma primary query raise recursive references regexp reindex ' + 'release rename replace restrict right rollback row savepoint select set ' + 'table temp temporary then to transaction trigger union unique update using ' + 'vacuum values view virtual when where with without' +).split()) + +SPATIALITE_DOCKERFILE_EXTRAS = r''' +RUN apt-get update && \ + apt-get install -y python3-dev gcc libsqlite3-mod-spatialite && \ + rm -rf /var/lib/apt/lists/* +ENV SQLITE_EXTENSIONS /usr/lib/x86_64-linux-gnu/mod_spatialite.so +''' + + +class InterruptedError(Exception): + pass + + +class Results: + def __init__(self, rows, truncated, description): + self.rows = rows + self.truncated = truncated + self.description = description + + def __iter__(self): + return iter(self.rows) + + def __len__(self): + return len(self.rows) + + +def urlsafe_components(token): + "Splits token on commas and URL decodes each component" + return [ + urllib.parse.unquote_plus(b) for b in token.split(',') + ] + + +def path_from_row_pks(row, pks, use_rowid, quote=True): + """ Generate an optionally URL-quoted unique identifier + for a row from its primary keys.""" + if use_rowid: + bits = [row['rowid']] + else: + bits = [row[pk] for pk in pks] + + if quote: + bits = [urllib.parse.quote_plus(str(bit)) for bit in bits] + else: + bits = [str(bit) for bit in bits] + + return ','.join(bits) + + +def compound_keys_after_sql(pks, start_index=0): + # Implementation of keyset pagination + # See https://github.com/simonw/datasette/issues/190 + # For pk1/pk2/pk3 returns: + # + # ([pk1] > :p0) + # or + # ([pk1] = :p0 and [pk2] > :p1) + # or + # ([pk1] = :p0 and [pk2] = :p1 and [pk3] > :p2) + or_clauses = [] + pks_left = pks[:] + while pks_left: + and_clauses = [] + last = pks_left[-1] + rest = pks_left[:-1] + and_clauses = ['{} = :p{}'.format( + escape_sqlite(pk), (i + start_index) + ) for i, pk in enumerate(rest)] + and_clauses.append('{} > :p{}'.format( + escape_sqlite(last), (len(rest) + start_index) + )) + or_clauses.append('({})'.format(' and '.join(and_clauses))) + pks_left.pop() + or_clauses.reverse() + return '({})'.format('\n or\n'.join(or_clauses)) + + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, sqlite3.Row): + return tuple(obj) + if isinstance(obj, sqlite3.Cursor): + return list(obj) + if isinstance(obj, bytes): + # Does it encode to utf8? + try: + return obj.decode('utf8') + except UnicodeDecodeError: + return { + '$base64': True, + 'encoded': base64.b64encode(obj).decode('latin1'), + } + return json.JSONEncoder.default(self, obj) + + +@contextmanager +def sqlite_timelimit(conn, ms): + deadline = time.time() + (ms / 1000) + # n is the number of SQLite virtual machine instructions that will be + # executed between each check. It's hard to know what to pick here. + # After some experimentation, I've decided to go with 1000 by default and + # 1 for time limits that are less than 50ms + n = 1000 + if ms < 50: + n = 1 + + def handler(): + if time.time() >= deadline: + return 1 + + conn.set_progress_handler(handler, n) + yield + conn.set_progress_handler(None, n) + + +class InvalidSql(Exception): + pass + + +allowed_sql_res = [ + re.compile(r'^select\b'), + re.compile(r'^explain select\b'), + re.compile(r'^explain query plan select\b'), + re.compile(r'^with\b'), +] +disallawed_sql_res = [ + (re.compile('pragma'), 'Statement may not contain PRAGMA'), +] + + +def validate_sql_select(sql): + sql = sql.strip().lower() + if not any(r.match(sql) for r in allowed_sql_res): + raise InvalidSql('Statement must be a SELECT') + for r, msg in disallawed_sql_res: + if r.search(sql): + raise InvalidSql(msg) + + +def path_with_added_args(qs, args, path=None): + path = path or qs.path + if isinstance(args, dict): + args = args.items() + args_to_remove = {k for k, v in args if v is None} + current = [] + for key, value in urllib.parse.parse_qsl(str(qs)): + if key not in args_to_remove: + current.append((key, value)) + current.extend([ + (key, value) + for key, value in args + if value is not None + ]) + query_string = urllib.parse.urlencode(current) + if query_string: + query_string = '?{}'.format(query_string) + return path + query_string + + +def path_with_removed_args(qs, args, path=None): + # args can be a dict or a set + path = path or qs.path + current = [] + if isinstance(args, set): + def should_remove(key, value): + return key in args + elif isinstance(args, dict): + # Must match key AND value + def should_remove(key, value): + return args.get(key) == value + for key, value in urllib.parse.parse_qsl(str(qs)): + if not should_remove(key, value): + current.append((key, value)) + query_string = urllib.parse.urlencode(current) + if query_string: + query_string = '?{}'.format(query_string) + return path + query_string + + +def path_with_replaced_args(qs, args, path=None): + path = path or qs.path + if isinstance(args, dict): + args = args.items() + keys_to_replace = {p[0] for p in args} + current = [] + for key, value in urllib.parse.parse_qsl(str(qs)): + if key not in keys_to_replace: + current.append((key, value)) + current.extend([p for p in args if p[1] is not None]) + query_string = urllib.parse.urlencode(current) + if query_string: + query_string = '?{}'.format(query_string) + return path + query_string + + +_css_re = re.compile(r'''['"\n\\]''') +_boring_keyword_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') + + +def escape_css_string(s): + return _css_re.sub(lambda m: '\\{:X}'.format(ord(m.group())), s) + + +def escape_sqlite(s): + if _boring_keyword_re.match(s) and (s.lower() not in reserved_words): + return s + else: + return '[{}]'.format(s) + + +def make_dockerfile(files, metadata_file, extra_options, branch, template_dir, plugins_dir, static, install, spatialite): + cmd = ['"datasette"', '"serve"', '"--host"', '"0.0.0.0"'] + cmd.append('"' + '", "'.join(files) + '"') + cmd.extend(['"--cors"', '"--port"', '"8001"', '"--inspect-file"', '"inspect-data.json"']) + if metadata_file: + cmd.extend(['"--metadata"', '"{}"'.format(metadata_file)]) + if template_dir: + cmd.extend(['"--template-dir"', '"templates/"']) + if plugins_dir: + cmd.extend(['"--plugins-dir"', '"plugins/"']) + if static: + for mount_point, _ in static: + cmd.extend(['"--static"', '"{}:{}"'.format(mount_point, mount_point)]) + if extra_options: + for opt in extra_options.split(): + cmd.append('"{}"'.format(opt)) + + if branch: + install = ['https://github.com/simonw/datasette/archive/{}.zip'.format( + branch + )] + list(install) + else: + install = ['datasette'] + list(install) + + return ''' +FROM python:3.6 +COPY . /app +WORKDIR /app +{spatialite_extras} +RUN pip install {install_from} +RUN datasette inspect {files} --inspect-file inspect-data.json +EXPOSE 8001 +CMD [{cmd}]'''.format( + files=' '.join(files), + cmd=', '.join(cmd), + install_from=' '.join(install), + spatialite_extras=SPATIALITE_DOCKERFILE_EXTRAS if spatialite else '', + ).strip() + + +@contextmanager +def temporary_docker_directory( + files, + name, + metadata, + extra_options, + branch, + template_dir, + plugins_dir, + static, + install, + spatialite, + extra_metadata=None +): + extra_metadata = extra_metadata or {} + tmp = tempfile.TemporaryDirectory() + # We create a datasette folder in there to get a nicer now deploy name + datasette_dir = os.path.join(tmp.name, name) + os.mkdir(datasette_dir) + saved_cwd = os.getcwd() + file_paths = [ + os.path.join(saved_cwd, file_path) + for file_path in files + ] + file_names = [os.path.split(f)[-1] for f in files] + if metadata: + metadata_content = json.load(metadata) + else: + metadata_content = {} + for key, value in extra_metadata.items(): + if value: + metadata_content[key] = value + try: + dockerfile = make_dockerfile( + file_names, + metadata_content and 'metadata.json', + extra_options, + branch, + template_dir, + plugins_dir, + static, + install, + spatialite, + ) + os.chdir(datasette_dir) + if metadata_content: + open('metadata.json', 'w').write(json.dumps(metadata_content, indent=2)) + open('Dockerfile', 'w').write(dockerfile) + for path, filename in zip(file_paths, file_names): + link_or_copy(path, os.path.join(datasette_dir, filename)) + if template_dir: + link_or_copy_directory( + os.path.join(saved_cwd, template_dir), + os.path.join(datasette_dir, 'templates') + ) + if plugins_dir: + link_or_copy_directory( + os.path.join(saved_cwd, plugins_dir), + os.path.join(datasette_dir, 'plugins') + ) + for mount_point, path in static: + link_or_copy_directory( + os.path.join(saved_cwd, path), + os.path.join(datasette_dir, mount_point) + ) + yield datasette_dir + finally: + tmp.cleanup() + os.chdir(saved_cwd) + + +@contextmanager +def temporary_heroku_directory( + files, + name, + metadata, + extra_options, + branch, + template_dir, + plugins_dir, + static, + install, + extra_metadata=None +): + # FIXME: lots of duplicated code from above + + extra_metadata = extra_metadata or {} + tmp = tempfile.TemporaryDirectory() + saved_cwd = os.getcwd() + + file_paths = [ + os.path.join(saved_cwd, file_path) + for file_path in files + ] + file_names = [os.path.split(f)[-1] for f in files] + + if metadata: + metadata_content = json.load(metadata) + else: + metadata_content = {} + for key, value in extra_metadata.items(): + if value: + metadata_content[key] = value + + try: + os.chdir(tmp.name) + + if metadata_content: + open('metadata.json', 'w').write(json.dumps(metadata_content, indent=2)) + + open('runtime.txt', 'w').write('python-3.6.3') + + if branch: + install = ['https://github.com/simonw/datasette/archive/{branch}.zip'.format( + branch=branch + )] + list(install) + else: + install = ['datasette'] + list(install) + + open('requirements.txt', 'w').write('\n'.join(install)) + os.mkdir('bin') + open('bin/post_compile', 'w').write('datasette inspect --inspect-file inspect-data.json') + + extras = [] + if template_dir: + link_or_copy_directory( + os.path.join(saved_cwd, template_dir), + os.path.join(tmp.name, 'templates') + ) + extras.extend(['--template-dir', 'templates/']) + if plugins_dir: + link_or_copy_directory( + os.path.join(saved_cwd, plugins_dir), + os.path.join(tmp.name, 'plugins') + ) + extras.extend(['--plugins-dir', 'plugins/']) + + if metadata: + extras.extend(['--metadata', 'metadata.json']) + for mount_point, path in static: + link_or_copy_directory( + os.path.join(saved_cwd, path), + os.path.join(tmp.name, mount_point) + ) + extras.extend(['--static', '{}:{}'.format(mount_point, mount_point)]) + + quoted_files = " ".join(map(shlex.quote, file_names)) + procfile_cmd = 'web: datasette serve --host 0.0.0.0 {quoted_files} --cors --port $PORT --inspect-file inspect-data.json {extras}'.format( + quoted_files=quoted_files, + extras=' '.join(extras), + ) + open('Procfile', 'w').write(procfile_cmd) + + for path, filename in zip(file_paths, file_names): + link_or_copy(path, os.path.join(tmp.name, filename)) + + yield + + finally: + tmp.cleanup() + os.chdir(saved_cwd) + + +def get_all_foreign_keys(conn): + tables = [r[0] for r in conn.execute('select name from sqlite_master where type="table"')] + table_to_foreign_keys = {} + for table in tables: + table_to_foreign_keys[table] = { + 'incoming': [], + 'outgoing': [], + } + for table in tables: + infos = conn.execute( + 'PRAGMA foreign_key_list([{}])'.format(table) + ).fetchall() + for info in infos: + if info is not None: + id, seq, table_name, from_, to_, on_update, on_delete, match = info + if table_name not in table_to_foreign_keys: + # Weird edge case where something refers to a table that does + # not actually exist + continue + table_to_foreign_keys[table_name]['incoming'].append({ + 'other_table': table, + 'column': to_, + 'other_column': from_ + }) + table_to_foreign_keys[table]['outgoing'].append({ + 'other_table': table_name, + 'column': from_, + 'other_column': to_ + }) + + return table_to_foreign_keys + + +def detect_spatialite(conn): + rows = conn.execute('select 1 from sqlite_master where tbl_name = "geometry_columns"').fetchall() + return len(rows) > 0 + + +def detect_fts(conn, table): + "Detect if table has a corresponding FTS virtual table and return it" + rows = conn.execute(detect_fts_sql(table)).fetchall() + if len(rows) == 0: + return None + else: + return rows[0][0] + + +def detect_fts_sql(table): + return r''' + select name from sqlite_master + where rootpage = 0 + and ( + sql like '%VIRTUAL TABLE%USING FTS%content="{table}"%' + or ( + tbl_name = "{table}" + and sql like '%VIRTUAL TABLE%USING FTS%' + ) + ) + '''.format(table=table) + + +class Filter: + def __init__(self, key, display, sql_template, human_template, format='{}', numeric=False, no_argument=False): + self.key = key + self.display = display + self.sql_template = sql_template + self.human_template = human_template + self.format = format + self.numeric = numeric + self.no_argument = no_argument + + def where_clause(self, column, value, param_counter): + converted = self.format.format(value) + if self.numeric and converted.isdigit(): + converted = int(converted) + if self.no_argument: + kwargs = { + 'c': column, + } + converted = None + else: + kwargs = { + 'c': column, + 'p': 'p{}'.format(param_counter), + } + return self.sql_template.format(**kwargs), converted + + def human_clause(self, column, value): + if callable(self.human_template): + template = self.human_template(column, value) + else: + template = self.human_template + if self.no_argument: + return template.format(c=column) + else: + return template.format(c=column, v=value) + + +class Filters: + _filters = [ + Filter('exact', '=', '"{c}" = :{p}', lambda c, v: '{c} = {v}' if v.isdigit() else '{c} = "{v}"'), + Filter('not', '!=', '"{c}" != :{p}', lambda c, v: '{c} != {v}' if v.isdigit() else '{c} != "{v}"'), + Filter('contains', 'contains', '"{c}" like :{p}', '{c} contains "{v}"', format='%{}%'), + Filter('endswith', 'ends with', '"{c}" like :{p}', '{c} ends with "{v}"', format='%{}'), + Filter('startswith', 'starts with', '"{c}" like :{p}', '{c} starts with "{v}"', format='{}%'), + Filter('gt', '>', '"{c}" > :{p}', '{c} > {v}', numeric=True), + Filter('gte', '\u2265', '"{c}" >= :{p}', '{c} \u2265 {v}', numeric=True), + Filter('lt', '<', '"{c}" < :{p}', '{c} < {v}', numeric=True), + Filter('lte', '\u2264', '"{c}" <= :{p}', '{c} \u2264 {v}', numeric=True), + Filter('glob', 'glob', '"{c}" glob :{p}', '{c} glob "{v}"'), + Filter('like', 'like', '"{c}" like :{p}', '{c} like "{v}"'), + Filter('isnull', 'is null', '"{c}" is null', '{c} is null', no_argument=True), + Filter('notnull', 'is not null', '"{c}" is not null', '{c} is not null', no_argument=True), + Filter('isblank', 'is blank', '("{c}" is null or "{c}" = "")', '{c} is blank', no_argument=True), + Filter('notblank', 'is not blank', '("{c}" is not null and "{c}" != "")', '{c} is not blank', no_argument=True), + ] + _filters_by_key = { + f.key: f for f in _filters + } + + def __init__(self, pairs, units={}, ureg=None): + self.pairs = pairs + self.units = units + self.ureg = ureg + + def lookups(self): + "Yields (lookup, display, no_argument) pairs" + for filter in self._filters: + yield filter.key, filter.display, filter.no_argument + + def human_description_en(self, extra=None): + bits = [] + if extra: + bits.extend(extra) + for column, lookup, value in self.selections(): + filter = self._filters_by_key.get(lookup, None) + if filter: + bits.append(filter.human_clause(column, value)) + # Comma separated, with an ' and ' at the end + and_bits = [] + commas, tail = bits[:-1], bits[-1:] + if commas: + and_bits.append(', '.join(commas)) + if tail: + and_bits.append(tail[0]) + s = ' and '.join(and_bits) + if not s: + return '' + return 'where {}'.format(s) + + def selections(self): + "Yields (column, lookup, value) tuples" + for key, value in self.pairs: + if '__' in key: + column, lookup = key.rsplit('__', 1) + else: + column = key + lookup = 'exact' + yield column, lookup, value + + def has_selections(self): + return bool(self.pairs) + + def convert_unit(self, column, value): + "If the user has provided a unit in the quey, convert it into the column unit, if present." + if column not in self.units: + return value + + # Try to interpret the value as a unit + value = self.ureg(value) + if isinstance(value, numbers.Number): + # It's just a bare number, assume it's the column unit + return value + + column_unit = self.ureg(self.units[column]) + return value.to(column_unit).magnitude + + def build_where_clauses(self): + sql_bits = [] + params = {} + for i, (column, lookup, value) in enumerate(self.selections()): + filter = self._filters_by_key.get(lookup, None) + if filter: + sql_bit, param = filter.where_clause(column, self.convert_unit(column, value), i) + sql_bits.append(sql_bit) + if param is not None: + param_id = 'p{}'.format(i) + params[param_id] = param + return sql_bits, params + + +filter_column_re = re.compile(r'^_filter_column_\d+$') + + +def filters_should_redirect(special_args): + redirect_params = [] + # Handle _filter_column=foo&_filter_op=exact&_filter_value=... + filter_column = special_args.get('_filter_column') + filter_op = special_args.get('_filter_op') or '' + filter_value = special_args.get('_filter_value') or '' + if '__' in filter_op: + filter_op, filter_value = filter_op.split('__', 1) + if filter_column: + redirect_params.append( + ('{}__{}'.format(filter_column, filter_op), filter_value) + ) + for key in ('_filter_column', '_filter_op', '_filter_value'): + if key in special_args: + redirect_params.append((key, None)) + # Now handle _filter_column_1=name&_filter_op_1=contains&_filter_value_1=hello + column_keys = [k for k in special_args if filter_column_re.match(k)] + for column_key in column_keys: + number = column_key.split('_')[-1] + column = special_args[column_key] + op = special_args.get('_filter_op_{}'.format(number)) or 'exact' + value = special_args.get('_filter_value_{}'.format(number)) or '' + if '__' in op: + op, value = op.split('__', 1) + if column: + redirect_params.append(('{}__{}'.format(column, op), value)) + redirect_params.extend([ + ('_filter_column_{}'.format(number), None), + ('_filter_op_{}'.format(number), None), + ('_filter_value_{}'.format(number), None), + ]) + return redirect_params + + +whitespace_re = re.compile(r'\s') + + +def is_url(value): + "Must start with http:// or https:// and contain JUST a URL" + if not isinstance(value, str): + return False + if not value.startswith('http://') and not value.startswith('https://'): + return False + # Any whitespace at all is invalid + if whitespace_re.search(value): + return False + return True + + +css_class_re = re.compile(r'^[a-zA-Z]+[_a-zA-Z0-9-]*$') +css_invalid_chars_re = re.compile(r'[^a-zA-Z0-9_\-]') + + +def to_css_class(s): + """ + Given a string (e.g. a table name) returns a valid unique CSS class. + For simple cases, just returns the string again. If the string is not a + valid CSS class (we disallow - and _ prefixes even though they are valid + as they may be confused with browser prefixes) we strip invalid characters + and add a 6 char md5 sum suffix, to make sure two tables with identical + names after stripping characters don't end up with the same CSS class. + """ + if css_class_re.match(s): + return s + md5_suffix = hashlib.md5(s.encode('utf8')).hexdigest()[:6] + # Strip leading _, - + s = s.lstrip('_').lstrip('-') + # Replace any whitespace with hyphens + s = '-'.join(s.split()) + # Remove any remaining invalid characters + s = css_invalid_chars_re.sub('', s) + # Attach the md5 suffix + bits = [b for b in (s, md5_suffix) if b] + return '-'.join(bits) + + +def link_or_copy(src, dst): + # Intended for use in populating a temp directory. We link if possible, + # but fall back to copying if the temp directory is on a different device + # https://github.com/simonw/datasette/issues/141 + try: + os.link(src, dst) + except OSError: + shutil.copyfile(src, dst) + + +def link_or_copy_directory(src, dst): + try: + shutil.copytree(src, dst, copy_function=os.link) + except OSError: + shutil.copytree(src, dst) + + +def module_from_path(path, name): + # Adapted from http://sayspy.blogspot.com/2011/07/how-to-import-module-from-just-file.html + mod = imp.new_module(name) + mod.__file__ = path + with open(path, 'r') as file: + code = compile(file.read(), path, 'exec', dont_inherit=True) + exec(code, mod.__dict__) + return mod + + +def get_plugins(pm): + plugins = [] + plugin_to_distinfo = dict(pm.list_plugin_distinfo()) + for plugin in pm.get_plugins(): + static_path = None + templates_path = None + try: + if pkg_resources.resource_isdir(plugin.__name__, 'static'): + static_path = pkg_resources.resource_filename(plugin.__name__, 'static') + if pkg_resources.resource_isdir(plugin.__name__, 'templates'): + templates_path = pkg_resources.resource_filename(plugin.__name__, 'templates') + except (KeyError, ImportError): + # Caused by --plugins_dir= plugins - KeyError/ImportError thrown in Py3.5 + pass + plugin_info = { + 'name': plugin.__name__, + 'static_path': static_path, + 'templates_path': templates_path, + } + distinfo = plugin_to_distinfo.get(plugin) + if distinfo: + plugin_info['version'] = distinfo.version + plugins.append(plugin_info) + return plugins + + +FORMATS = ('csv', 'json', 'jsono') + + +def resolve_table_and_format(table_and_format, table_exists): + if '.' in table_and_format: + # Check if a table exists with this exact name + if table_exists(table_and_format): + return table_and_format, None + # Check if table ends with a known format + for _format in FORMATS: + if table_and_format.endswith(".{}".format(_format)): + table = table_and_format[:-(len(_format) + 1)] + return table, _format + return table_and_format, None + + +def path_with_format(qs, format, extra_qs=None): + new_qs = extra_qs or {} + path = qs.path + if "." in qs.path: + new_qs["_format"] = format + else: + path = "{}.{}".format(path, format) + if new_qs: + extra = urllib.parse.urlencode(sorted(new_qs.items())) + if qs.data: + path = "{}?{}&{}".format( + path, str(qs), extra + ) + else: + path = "{}?{}".format(path, extra) + elif qs.data: + path = "{}?{}".format(path, str(qs)) + return path + + +class CustomRow(OrderedDict): + # Loose imitation of sqlite3.Row which offers + # both index-based AND key-based lookups + def __init__(self, columns): + self.columns = columns + + def __getitem__(self, key): + if isinstance(key, int): + return super().__getitem__(self.columns[key]) + else: + return super().__getitem__(key) + + def __iter__(self): + for column in self.columns: + yield self[column] + + +def value_as_boolean(value): + if value.lower() not in ('on', 'off', 'true', 'false', '1', '0'): + raise ValueAsBooleanError + return value.lower() in ('on', 'true', '1') + + +class ValueAsBooleanError(ValueError): + pass + + +class Querystring: + def __init__(self, path, qs=None): + self.path = path + self.prev = None + self.data = [] + if qs: + self.data = urllib.parse.parse_qsl(qs, keep_blank_values=True) + + def first(self, key): + for item in self.data: + if item[0] == key: + return item[1] + raise KeyError + + def first_or_none(self, key): + try: + return self.first(key) + except KeyError: + return None + + def last(self, key): + for item in reversed(self.data): + if item[0] == key: + return item[1] + raise KeyError + + def getlist(self, key): + result = [] + for item in self.data: + if item[0] == key: + result.append(item[1]) + return result + + def first_dict(self): + return {k: v[0] for k, v in self.data} + + def append(self, key, value): + self.data.append((key, value)) + + def remove(self, key): + self.data = [item for item in self.data if item[0] != key] + + def replace(self, **kwargs): + for key, values in kwargs.items(): + if not isinstance(values, list): + kwargs[key] = [values] + new_data = [] + for key, value in self.data: + if key in kwargs: + new_data.append((key, kwargs[key])) + else: + new_data.append((key, value)) + self.data = new_data + + def __str__(self): + return urllib.parse.urlencode(self.data) + + def __repr__(self): + return str(self) diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py deleted file mode 100644 index 9d189459..00000000 --- a/datasette/utils/__init__.py +++ /dev/null @@ -1,1570 +0,0 @@ -import asyncio -from contextlib import contextmanager -import aiofiles -import click -from collections import OrderedDict, namedtuple, Counter -import copy -import dataclasses -import base64 -import hashlib -import inspect -import json -import markupsafe -import mergedeep -import os -import re -import shlex -import tempfile -import typing -import time -import types -import secrets -import shutil -from typing import Iterable, List, Tuple -import urllib -import yaml -from .shutil_backport import copytree -from .sqlite import sqlite3, supports_table_xinfo - -if typing.TYPE_CHECKING: - from datasette.database import Database - from datasette.permissions import Resource - - -@dataclasses.dataclass -class PaginatedResources: - """Paginated results from allowed_resources query.""" - - resources: List["Resource"] - next: str | None # Keyset token for next page (None if no more results) - _datasette: typing.Any = dataclasses.field(default=None, repr=False) - _action: str = dataclasses.field(default=None, repr=False) - _actor: typing.Any = dataclasses.field(default=None, repr=False) - _parent: str | None = dataclasses.field(default=None, repr=False) - _include_is_private: bool = dataclasses.field(default=False, repr=False) - _include_reasons: bool = dataclasses.field(default=False, repr=False) - _limit: int = dataclasses.field(default=100, repr=False) - - async def all(self): - """ - Async generator that yields all resources across all pages. - - Automatically handles pagination under the hood. This is useful when you need - to iterate through all results without manually managing pagination tokens. - - Yields: - Resource objects one at a time - - Example: - page = await datasette.allowed_resources("view-table", actor) - async for table in page.all(): - print(f"{table.parent}/{table.child}") - """ - # Yield all resources from current page - for resource in self.resources: - yield resource - - # Continue fetching subsequent pages if there are more - next_token = self.next - while next_token: - page = await self._datasette.allowed_resources( - self._action, - self._actor, - parent=self._parent, - include_is_private=self._include_is_private, - include_reasons=self._include_reasons, - limit=self._limit, - next=next_token, - ) - for resource in page.resources: - yield resource - next_token = page.next - - -# From https://www.sqlite.org/lang_keywords.html -reserved_words = set( - ( - "abort action add after all alter analyze and as asc attach autoincrement " - "before begin between by cascade case cast check collate column commit " - "conflict constraint create cross current_date current_time " - "current_timestamp database default deferrable deferred delete desc detach " - "distinct drop each else end escape except exclusive exists explain fail " - "for foreign from full glob group having if ignore immediate in index " - "indexed initially inner insert instead intersect into is isnull join key " - "left like limit match natural no not notnull null of offset on or order " - "outer plan pragma primary query raise recursive references regexp reindex " - "release rename replace restrict right rollback row savepoint select set " - "table temp temporary then to transaction trigger union unique update using " - "vacuum values view virtual when where with without" - ).split() -) - -APT_GET_DOCKERFILE_EXTRAS = r""" -RUN apt-get update && \ - apt-get install -y {} && \ - rm -rf /var/lib/apt/lists/* -""" - -# Can replace with sqlite-utils when I add that dependency -SPATIALITE_PATHS = ( - "/usr/lib/x86_64-linux-gnu/mod_spatialite.so", - "/usr/local/lib/mod_spatialite.dylib", - "/usr/local/lib/mod_spatialite.so", - "/opt/homebrew/lib/mod_spatialite.dylib", -) -# Used to display /-/versions.json SpatiaLite information -SPATIALITE_FUNCTIONS = ( - "spatialite_version", - "spatialite_target_cpu", - "check_strict_sql_quoting", - "freexl_version", - "proj_version", - "geos_version", - "rttopo_version", - "libxml2_version", - "HasIconv", - "HasMathSQL", - "HasGeoCallbacks", - "HasProj", - "HasProj6", - "HasGeos", - "HasGeosAdvanced", - "HasGeosTrunk", - "HasGeosReentrant", - "HasGeosOnlyReentrant", - "HasMiniZip", - "HasRtTopo", - "HasLibXML2", - "HasEpsg", - "HasFreeXL", - "HasGeoPackage", - "HasGCP", - "HasTopology", - "HasKNN", - "HasRouting", -) -# Length of hash subset used in hashed URLs: -HASH_LENGTH = 7 - - -# Can replace this with Column from sqlite_utils when I add that dependency -Column = namedtuple( - "Column", ("cid", "name", "type", "notnull", "default_value", "is_pk", "hidden") -) - -functions_marked_as_documented = [] - - -def documented(fn=None, *, label=None): - def decorate(fn): - fn._datasette_docs_label = label or "internals_utils_{}".format(fn.__name__) - functions_marked_as_documented.append(fn) - return fn - - if fn is None: - return decorate - return decorate(fn) - - -@documented -async def await_me_maybe(value: typing.Any) -> typing.Any: - "If value is callable, call it. If awaitable, await it. Otherwise return it." - if callable(value): - value = value() - if asyncio.iscoroutine(value): - value = await value - return value - - -def urlsafe_components(token): - """Splits token on commas and tilde-decodes each component""" - return [tilde_decode(b) for b in token.split(",")] - - -def path_from_row_pks(row, pks, use_rowid, quote=True): - """Generate an optionally tilde-encoded unique identifier - for a row from its primary keys.""" - if use_rowid: - bits = [row["rowid"]] - else: - bits = [ - row[pk]["value"] if isinstance(row[pk], dict) else row[pk] for pk in pks - ] - if quote: - bits = [tilde_encode(str(bit)) for bit in bits] - else: - bits = [str(bit) for bit in bits] - - return ",".join(bits) - - -def compound_keys_after_sql(pks, start_index=0): - # Implementation of keyset pagination - # See https://github.com/simonw/datasette/issues/190 - # For pk1/pk2/pk3 returns: - # - # ([pk1] > :p0) - # or - # ([pk1] = :p0 and [pk2] > :p1) - # or - # ([pk1] = :p0 and [pk2] = :p1 and [pk3] > :p2) - or_clauses = [] - pks_left = pks[:] - while pks_left: - and_clauses = [] - last = pks_left[-1] - rest = pks_left[:-1] - and_clauses = [ - f"{escape_sqlite(pk)} = :p{i + start_index}" for i, pk in enumerate(rest) - ] - and_clauses.append(f"{escape_sqlite(last)} > :p{len(rest) + start_index}") - or_clauses.append(f"({' and '.join(and_clauses)})") - pks_left.pop() - or_clauses.reverse() - return "({})".format("\n or\n".join(or_clauses)) - - -class CustomJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, sqlite3.Row): - return tuple(obj) - if isinstance(obj, sqlite3.Cursor): - return list(obj) - if isinstance(obj, bytes): - # Does it encode to utf8? - try: - return obj.decode("utf8") - except UnicodeDecodeError: - return { - "$base64": True, - "encoded": base64.b64encode(obj).decode("latin1"), - } - return json.JSONEncoder.default(self, obj) - - -@contextmanager -def sqlite_timelimit(conn, ms): - deadline = time.perf_counter() + (ms / 1000) - # n is the number of SQLite virtual machine instructions that will be - # executed between each check. It takes about 0.08ms to execute 1000. - # https://github.com/simonw/datasette/issues/1679 - n = 1000 - if ms <= 20: - # This mainly happens while executing our test suite - n = 1 - - def handler(): - if time.perf_counter() >= deadline: - # Returning 1 terminates the query with an error - return 1 - - conn.set_progress_handler(handler, n) - try: - yield - finally: - conn.set_progress_handler(None, n) - - -class InvalidSql(Exception): - pass - - -# Allow SQL to start with a /* */ or -- comment -comment_re = ( - # Start of string, then any amount of whitespace - r"^\s*(" - + - # Comment that starts with -- and ends at a newline - r"(?:\-\-.*?\n\s*)" - + - # Comment that starts with /* and ends with */ - but does not have */ in it - r"|(?:\/\*((?!\*\/)[\s\S])*\*\/)" - + - # Whitespace - r"\s*)*\s*" -) - -allowed_sql_res = [ - re.compile(comment_re + r"select\b"), - re.compile(comment_re + r"explain\s+select\b"), - re.compile(comment_re + r"explain\s+query\s+plan\s+select\b"), - re.compile(comment_re + r"with\b"), - re.compile(comment_re + r"explain\s+with\b"), - re.compile(comment_re + r"explain\s+query\s+plan\s+with\b"), -] - -allowed_pragmas = ( - "database_list", - "foreign_key_list", - "function_list", - "index_info", - "index_list", - "index_xinfo", - "page_count", - "max_page_count", - "page_size", - "schema_version", - "table_info", - "table_xinfo", - "table_list", -) -disallawed_sql_res = [ - ( - re.compile(f"pragma(?!_({'|'.join(allowed_pragmas)}))"), - "Statement contained a disallowed PRAGMA. Allowed pragma functions are {}".format( - ", ".join("pragma_{}()".format(pragma) for pragma in allowed_pragmas) - ), - ) -] - - -def validate_sql_select(sql): - sql = "\n".join( - line for line in sql.split("\n") if not line.strip().startswith("--") - ) - sql = sql.strip().lower() - if not any(r.match(sql) for r in allowed_sql_res): - raise InvalidSql("Statement must be a SELECT") - for r, msg in disallawed_sql_res: - if r.search(sql): - raise InvalidSql(msg) - - -def append_querystring(url, querystring): - op = "&" if ("?" in url) else "?" - return f"{url}{op}{querystring}" - - -def path_with_added_args(request, args, path=None): - path = path or request.path - if isinstance(args, dict): - args = args.items() - args_to_remove = {k for k, v in args if v is None} - current = [] - for key, value in urllib.parse.parse_qsl(request.query_string): - if key not in args_to_remove: - current.append((key, value)) - current.extend([(key, value) for key, value in args if value is not None]) - query_string = urllib.parse.urlencode(current) - if query_string: - query_string = f"?{query_string}" - return path + query_string - - -def path_with_removed_args(request, args, path=None): - query_string = request.query_string - if path is None: - path = request.path - else: - if "?" in path: - bits = path.split("?", 1) - path, query_string = bits - # args can be a dict or a set - current = [] - if isinstance(args, set): - - def should_remove(key, value): - return key in args - - elif isinstance(args, dict): - # Must match key AND value - def should_remove(key, value): - return args.get(key) == value - - for key, value in urllib.parse.parse_qsl(query_string): - if not should_remove(key, value): - current.append((key, value)) - query_string = urllib.parse.urlencode(current) - if query_string: - query_string = f"?{query_string}" - return path + query_string - - -def path_with_replaced_args(request, args, path=None): - path = path or request.path - if isinstance(args, dict): - args = args.items() - keys_to_replace = {p[0] for p in args} - current = [] - for key, value in urllib.parse.parse_qsl(request.query_string): - if key not in keys_to_replace: - current.append((key, value)) - current.extend([p for p in args if p[1] is not None]) - query_string = urllib.parse.urlencode(current) - if query_string: - query_string = f"?{query_string}" - return path + query_string - - -_css_re = re.compile(r"""['"\n\\]""") -_boring_keyword_re = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") - - -def escape_css_string(s): - return _css_re.sub( - lambda m: "\\" + (f"{ord(m.group()):X}".zfill(6)), - s.replace("\r\n", "\n"), - ) - - -def escape_sqlite(s): - if _boring_keyword_re.match(s) and (s.lower() not in reserved_words): - return s - else: - return f"[{s}]" - - -def make_dockerfile( - files, - metadata_file, - extra_options, - branch, - template_dir, - plugins_dir, - static, - install, - spatialite, - version_note, - secret, - environment_variables=None, - port=8001, - apt_get_extras=None, -): - cmd = ["datasette", "serve", "--host", "0.0.0.0"] - environment_variables = environment_variables or {} - environment_variables["DATASETTE_SECRET"] = secret - apt_get_extras = apt_get_extras or [] - for filename in files: - cmd.extend(["-i", filename]) - cmd.extend(["--cors", "--inspect-file", "inspect-data.json"]) - if metadata_file: - cmd.extend(["--metadata", f"{metadata_file}"]) - if template_dir: - cmd.extend(["--template-dir", "templates/"]) - if plugins_dir: - cmd.extend(["--plugins-dir", "plugins/"]) - if version_note: - cmd.extend(["--version-note", f"{version_note}"]) - if static: - for mount_point, _ in static: - cmd.extend(["--static", f"{mount_point}:{mount_point}"]) - if extra_options: - for opt in extra_options.split(): - cmd.append(f"{opt}") - cmd = [shlex.quote(part) for part in cmd] - # port attribute is a (fixed) env variable and should not be quoted - cmd.extend(["--port", "$PORT"]) - cmd = " ".join(cmd) - if branch: - install = [f"https://github.com/simonw/datasette/archive/{branch}.zip"] + list( - install - ) - else: - install = ["datasette"] + list(install) - - apt_get_extras_ = [] - apt_get_extras_.extend(apt_get_extras) - apt_get_extras = apt_get_extras_ - if spatialite: - apt_get_extras.extend(["python3-dev", "gcc", "libsqlite3-mod-spatialite"]) - environment_variables["SQLITE_EXTENSIONS"] = ( - "/usr/lib/x86_64-linux-gnu/mod_spatialite.so" - ) - return """ -FROM python:3.11.0-slim-bullseye -COPY . /app -WORKDIR /app -{apt_get_extras} -{environment_variables} -RUN pip install -U {install_from} -RUN datasette inspect {files} --inspect-file inspect-data.json -ENV PORT {port} -EXPOSE {port} -CMD {cmd}""".format( - apt_get_extras=( - APT_GET_DOCKERFILE_EXTRAS.format(" ".join(apt_get_extras)) - if apt_get_extras - else "" - ), - environment_variables="\n".join( - [ - "ENV {} '{}'".format(key, value) - for key, value in environment_variables.items() - ] - ), - install_from=" ".join(install), - files=" ".join(files), - port=port, - cmd=cmd, - ).strip() - - -@contextmanager -def temporary_docker_directory( - files, - name, - metadata, - extra_options, - branch, - template_dir, - plugins_dir, - static, - install, - spatialite, - version_note, - secret, - extra_metadata=None, - environment_variables=None, - port=8001, - apt_get_extras=None, -): - extra_metadata = extra_metadata or {} - tmp = tempfile.TemporaryDirectory() - # We create a datasette folder in there to get a nicer now deploy name - datasette_dir = os.path.join(tmp.name, name) - os.mkdir(datasette_dir) - saved_cwd = os.getcwd() - file_paths = [os.path.join(saved_cwd, file_path) for file_path in files] - file_names = [os.path.split(f)[-1] for f in files] - if metadata: - metadata_content = parse_metadata(metadata.read()) - else: - metadata_content = {} - # Merge in the non-null values in extra_metadata - mergedeep.merge( - metadata_content, - {key: value for key, value in extra_metadata.items() if value is not None}, - ) - try: - dockerfile = make_dockerfile( - file_names, - metadata_content and "metadata.json", - extra_options, - branch, - template_dir, - plugins_dir, - static, - install, - spatialite, - version_note, - secret, - environment_variables, - port=port, - apt_get_extras=apt_get_extras, - ) - os.chdir(datasette_dir) - if metadata_content: - with open("metadata.json", "w") as fp: - fp.write(json.dumps(metadata_content, indent=2)) - with open("Dockerfile", "w") as fp: - fp.write(dockerfile) - for path, filename in zip(file_paths, file_names): - link_or_copy(path, os.path.join(datasette_dir, filename)) - if template_dir: - link_or_copy_directory( - os.path.join(saved_cwd, template_dir), - os.path.join(datasette_dir, "templates"), - ) - if plugins_dir: - link_or_copy_directory( - os.path.join(saved_cwd, plugins_dir), - os.path.join(datasette_dir, "plugins"), - ) - for mount_point, path in static: - link_or_copy_directory( - os.path.join(saved_cwd, path), os.path.join(datasette_dir, mount_point) - ) - yield datasette_dir - finally: - tmp.cleanup() - os.chdir(saved_cwd) - - -def detect_primary_keys(conn, table): - """Figure out primary keys for a table.""" - columns = table_column_details(conn, table) - pks = [column for column in columns if column.is_pk] - pks.sort(key=lambda column: column.is_pk) - return [column.name for column in pks] - - -def get_outbound_foreign_keys(conn, table): - infos = conn.execute(f"PRAGMA foreign_key_list([{table}])").fetchall() - fks = [] - for info in infos: - if info is not None: - id, seq, table_name, from_, to_, on_update, on_delete, match = info - fks.append( - { - "column": from_, - "other_table": table_name, - "other_column": to_, - "id": id, - "seq": seq, - } - ) - # Filter out compound foreign keys by removing any where "id" is not unique - id_counts = Counter(fk["id"] for fk in fks) - return [ - { - "column": fk["column"], - "other_table": fk["other_table"], - "other_column": fk["other_column"], - } - for fk in fks - if id_counts[fk["id"]] == 1 - ] - - -def get_all_foreign_keys(conn): - tables = [ - r[0] - for r in conn.execute( - 'select name from sqlite_master where type="table" order by name' - ) - ] - table_to_foreign_keys = {} - for table in tables: - table_to_foreign_keys[table] = {"incoming": [], "outgoing": []} - for table in tables: - fks = get_outbound_foreign_keys(conn, table) - for fk in fks: - table_name = fk["other_table"] - from_ = fk["column"] - to_ = fk["other_column"] - if table_name not in table_to_foreign_keys: - # Weird edge case where something refers to a table that does - # not actually exist - continue - table_to_foreign_keys[table_name]["incoming"].append( - {"other_table": table, "column": to_, "other_column": from_} - ) - table_to_foreign_keys[table]["outgoing"].append( - {"other_table": table_name, "column": from_, "other_column": to_} - ) - - # Sort foreign keys for deterministic ordering - for table in table_to_foreign_keys: - table_to_foreign_keys[table]["incoming"].sort( - key=lambda fk: (fk["other_table"], fk["column"], fk["other_column"]) - ) - table_to_foreign_keys[table]["outgoing"].sort( - key=lambda fk: (fk["other_table"], fk["column"], fk["other_column"]) - ) - - return table_to_foreign_keys - - -def detect_spatialite(conn): - rows = conn.execute( - 'select 1 from sqlite_master where tbl_name = "geometry_columns"' - ).fetchall() - return len(rows) > 0 - - -def detect_fts(conn, table): - """Detect if table has a corresponding FTS virtual table and return it""" - rows = conn.execute(detect_fts_sql(table)).fetchall() - if len(rows) == 0: - return None - else: - return rows[0][0] - - -def detect_fts_sql(table): - return r""" - select name from sqlite_master - where rootpage = 0 - and ( - sql like '%VIRTUAL TABLE%USING FTS%content="{table}"%' - or sql like '%VIRTUAL TABLE%USING FTS%content=[{table}]%' - or ( - tbl_name = "{table}" - and sql like '%VIRTUAL TABLE%USING FTS%' - ) - ) - """.format(table=table.replace("'", "''")) - - -def detect_json1(conn=None): - close_conn = False - if conn is None: - conn = sqlite3.connect(":memory:") - close_conn = True - try: - conn.execute("SELECT json('{}')") - return True - except Exception: - return False - finally: - if close_conn: - conn.close() - - -def table_columns(conn, table): - return [column.name for column in table_column_details(conn, table)] - - -def table_column_details(conn, table): - if supports_table_xinfo(): - # table_xinfo was added in 3.26.0 - return [ - Column(*r) - for r in conn.execute( - f"PRAGMA table_xinfo({escape_sqlite(table)});" - ).fetchall() - ] - else: - # First trigger a query against sqlite_master to fix an intermittent - # test failure, see https://github.com/simonw/datasette/issues/2632 - conn.execute("select 1 from sqlite_master limit 1").fetchall() - return [ - # Treat hidden as 0 for all columns. - Column(*(list(r) + [0])) - for r in conn.execute( - f"PRAGMA table_info({escape_sqlite(table)});" - ).fetchall() - ] - - -filter_column_re = re.compile(r"^_filter_column_\d+$") - - -def filters_should_redirect(special_args): - redirect_params = [] - # Handle _filter_column=foo&_filter_op=exact&_filter_value=... - filter_column = special_args.get("_filter_column") - filter_op = special_args.get("_filter_op") or "" - filter_value = special_args.get("_filter_value") or "" - if "__" in filter_op: - filter_op, filter_value = filter_op.split("__", 1) - if filter_column: - redirect_params.append((f"{filter_column}__{filter_op}", filter_value)) - for key in ("_filter_column", "_filter_op", "_filter_value"): - if key in special_args: - redirect_params.append((key, None)) - # Now handle _filter_column_1=name&_filter_op_1=contains&_filter_value_1=hello - column_keys = [k for k in special_args if filter_column_re.match(k)] - for column_key in column_keys: - number = column_key.split("_")[-1] - column = special_args[column_key] - op = special_args.get(f"_filter_op_{number}") or "exact" - value = special_args.get(f"_filter_value_{number}") or "" - if "__" in op: - op, value = op.split("__", 1) - if column: - redirect_params.append((f"{column}__{op}", value)) - redirect_params.extend( - [ - (f"_filter_column_{number}", None), - (f"_filter_op_{number}", None), - (f"_filter_value_{number}", None), - ] - ) - return redirect_params - - -whitespace_re = re.compile(r"\s") - - -def is_url(value): - """Must start with http:// or https:// and contain JUST a URL""" - if not isinstance(value, str): - return False - if not value.startswith("http://") and not value.startswith("https://"): - return False - # Any whitespace at all is invalid - if whitespace_re.search(value): - return False - return True - - -css_class_re = re.compile(r"^[a-zA-Z]+[_a-zA-Z0-9-]*$") -css_invalid_chars_re = re.compile(r"[^a-zA-Z0-9_\-]") - - -def to_css_class(s): - """ - Given a string (e.g. a table name) returns a valid unique CSS class. - For simple cases, just returns the string again. If the string is not a - valid CSS class (we disallow - and _ prefixes even though they are valid - as they may be confused with browser prefixes) we strip invalid characters - and add a 6 char md5 sum suffix, to make sure two tables with identical - names after stripping characters don't end up with the same CSS class. - """ - if css_class_re.match(s): - return s - md5_suffix = md5_not_usedforsecurity(s)[:6] - # Strip leading _, - - s = s.lstrip("_").lstrip("-") - # Replace any whitespace with hyphens - s = "-".join(s.split()) - # Remove any remaining invalid characters - s = css_invalid_chars_re.sub("", s) - # Attach the md5 suffix - bits = [b for b in (s, md5_suffix) if b] - return "-".join(bits) - - -def link_or_copy(src, dst): - # Intended for use in populating a temp directory. We link if possible, - # but fall back to copying if the temp directory is on a different device - # https://github.com/simonw/datasette/issues/141 - try: - os.link(src, dst) - except OSError: - shutil.copyfile(src, dst) - - -def link_or_copy_directory(src, dst): - try: - copytree(src, dst, copy_function=os.link, dirs_exist_ok=True) - except OSError: - copytree(src, dst, dirs_exist_ok=True) - - -def module_from_path(path, name): - # Adapted from http://sayspy.blogspot.com/2011/07/how-to-import-module-from-just-file.html - mod = types.ModuleType(name) - mod.__file__ = path - with open(path, "r") as file: - code = compile(file.read(), path, "exec", dont_inherit=True) - exec(code, mod.__dict__) - return mod - - -def path_with_format( - *, request=None, path=None, format=None, extra_qs=None, replace_format=None -): - qs = extra_qs or {} - path = request.path if request else path - if replace_format and path.endswith(f".{replace_format}"): - path = path[: -(1 + len(replace_format))] - if "." in path: - qs["_format"] = format - else: - path = f"{path}.{format}" - if qs: - extra = urllib.parse.urlencode(sorted(qs.items())) - if request and request.query_string: - path = f"{path}?{request.query_string}&{extra}" - else: - path = f"{path}?{extra}" - elif request and request.query_string: - path = f"{path}?{request.query_string}" - return path - - -class CustomRow(OrderedDict): - # Loose imitation of sqlite3.Row which offers - # both index-based AND key-based lookups - def __init__(self, columns, values=None): - self.columns = columns - if values: - self.update(values) - - def __getitem__(self, key): - if isinstance(key, int): - return super().__getitem__(self.columns[key]) - else: - return super().__getitem__(key) - - def __iter__(self): - for column in self.columns: - yield self[column] - - -def value_as_boolean(value): - if value.lower() not in ("on", "off", "true", "false", "1", "0"): - raise ValueAsBooleanError - return value.lower() in ("on", "true", "1") - - -class ValueAsBooleanError(ValueError): - pass - - -class WriteLimitExceeded(Exception): - pass - - -class LimitedWriter: - def __init__(self, writer, limit_mb): - self.writer = writer - self.limit_bytes = limit_mb * 1024 * 1024 - self.bytes_count = 0 - - async def write(self, bytes): - self.bytes_count += len(bytes) - if self.limit_bytes and (self.bytes_count > self.limit_bytes): - raise WriteLimitExceeded(f"CSV contains more than {self.limit_bytes} bytes") - await self.writer.write(bytes) - - -class EscapeHtmlWriter: - def __init__(self, writer): - self.writer = writer - - async def write(self, content): - await self.writer.write(markupsafe.escape(content)) - - -_infinities = {float("inf"), float("-inf")} - - -def remove_infinites(row): - """ - Replace float('inf') and float('-inf') with None in a row. - - Returns the original row object unchanged if no infinities are found. - """ - if isinstance(row, dict): - for v in row.values(): - if isinstance(v, float) and v in _infinities: - return { - k: (None if isinstance(v2, float) and v2 in _infinities else v2) - for k, v2 in row.items() - } - else: - for v in row: - if isinstance(v, float) and v in _infinities: - return [ - None if isinstance(v2, float) and v2 in _infinities else v2 - for v2 in row - ] - return row - - -class StaticMount(click.ParamType): - name = "mount:directory" - - def convert(self, value, param, ctx): - if ":" not in value: - self.fail( - f'"{value}" should be of format mountpoint:directory', - param, - ctx, - ) - path, dirpath = value.split(":", 1) - dirpath = os.path.abspath(dirpath) - if not os.path.exists(dirpath) or not os.path.isdir(dirpath): - self.fail(f"{value} is not a valid directory path", param, ctx) - return path, dirpath - - -# The --load-extension parameter can optionally include a specific entrypoint. -# This is done by appending ":entrypoint_name" after supplying the path to the extension -class LoadExtension(click.ParamType): - name = "path:entrypoint?" - - def convert(self, value, param, ctx): - if ":" not in value: - return value - path, entrypoint = value.split(":", 1) - return path, entrypoint - - -def format_bytes(bytes): - current = float(bytes) - for unit in ("bytes", "KB", "MB", "GB", "TB"): - if current < 1024: - break - current = current / 1024 - if unit == "bytes": - return f"{int(current)} {unit}" - else: - return f"{current:.1f} {unit}" - - -_escape_fts_re = re.compile(r'\s+|(".*?")') - - -def escape_fts(query): - # If query has unbalanced ", add one at end - if query.count('"') % 2: - query += '"' - bits = _escape_fts_re.split(query) - bits = [b for b in bits if b and b != '""'] - return " ".join( - '"{}"'.format(bit) if not bit.startswith('"') else bit for bit in bits - ) - - -class MultiParams: - def __init__(self, data): - # data is a dictionary of key => [list, of, values] or a list of [["key", "value"]] pairs - if isinstance(data, dict): - for key in data: - assert isinstance( - data[key], (list, tuple) - ), "dictionary data should be a dictionary of key => [list]" - self._data = data - elif isinstance(data, list) or isinstance(data, tuple): - new_data = {} - for item in data: - assert ( - isinstance(item, (list, tuple)) and len(item) == 2 - ), "list data should be a list of [key, value] pairs" - key, value = item - new_data.setdefault(key, []).append(value) - self._data = new_data - - def __repr__(self): - return f"" - - def __contains__(self, key): - return key in self._data - - def __getitem__(self, key): - return self._data[key][0] - - def keys(self): - return self._data.keys() - - def __iter__(self): - yield from self._data.keys() - - def __len__(self): - return len(self._data) - - def get(self, name, default=None): - """Return first value in the list, if available""" - try: - return self._data.get(name)[0] - except (KeyError, TypeError): - return default - - def getlist(self, name): - """Return full list""" - return self._data.get(name) or [] - - -class ConnectionProblem(Exception): - pass - - -class SpatialiteConnectionProblem(ConnectionProblem): - pass - - -def check_connection(conn): - tables = [ - r[0] - for r in conn.execute( - "select name from sqlite_master where type='table'" - ).fetchall() - ] - for table in tables: - try: - conn.execute( - f"PRAGMA table_info({escape_sqlite(table)});", - ) - except sqlite3.OperationalError as e: - if e.args[0] == "no such module: VirtualSpatialIndex": - raise SpatialiteConnectionProblem(e) - else: - raise ConnectionProblem(e) - - -class BadMetadataError(Exception): - pass - - -@documented -def parse_metadata(content: str) -> dict: - "Detects if content is JSON or YAML and parses it appropriately." - # content can be JSON or YAML - try: - return json.loads(content) - except json.JSONDecodeError: - try: - return yaml.safe_load(content) - except yaml.YAMLError: - raise BadMetadataError("Metadata is not valid JSON or YAML") - - -def _gather_arguments(fn, kwargs): - parameters = inspect.signature(fn).parameters.keys() - call_with = [] - for parameter in parameters: - if parameter not in kwargs: - raise TypeError( - "{} requires parameters {}, missing: {}".format( - fn, tuple(parameters), set(parameters) - set(kwargs.keys()) - ) - ) - call_with.append(kwargs[parameter]) - return call_with - - -@documented -def call_with_supported_arguments(fn, **kwargs): - """ - Call ``fn`` with the subset of ``**kwargs`` matching its signature. - - This implements dependency injection: the caller provides all available - keyword arguments and the function receives only the ones it declares - as parameters. - - :param fn: A callable (sync function) - :param kwargs: All available keyword arguments - :returns: The return value of ``fn`` - """ - call_with = _gather_arguments(fn, kwargs) - return fn(*call_with) - - -@documented -async def async_call_with_supported_arguments(fn, **kwargs): - """ - Async version of :func:`call_with_supported_arguments`. - - Calls ``await fn(...)`` with the subset of ``**kwargs`` matching its - signature. - - :param fn: An async callable - :param kwargs: All available keyword arguments - :returns: The return value of ``await fn(...)`` - """ - call_with = _gather_arguments(fn, kwargs) - return await fn(*call_with) - - -def actor_matches_allow(actor, allow): - if allow is True: - return True - if allow is False: - return False - if actor is None and allow and allow.get("unauthenticated") is True: - return True - if allow is None: - return True - actor = actor or {} - for key, values in allow.items(): - if values == "*" and key in actor: - return True - if not isinstance(values, list): - values = [values] - actor_values = actor.get(key) - if actor_values is None: - continue - if not isinstance(actor_values, list): - actor_values = [actor_values] - actor_values = set(actor_values) - if actor_values.intersection(values): - return True - return False - - -def resolve_env_secrets(config, environ): - """Create copy that recursively replaces {"$env": "NAME"} with values from environ""" - if isinstance(config, dict): - if list(config.keys()) == ["$env"]: - return environ.get(list(config.values())[0]) - elif list(config.keys()) == ["$file"]: - with open(list(config.values())[0]) as fp: - return fp.read() - else: - return { - key: resolve_env_secrets(value, environ) - for key, value in config.items() - } - elif isinstance(config, list): - return [resolve_env_secrets(value, environ) for value in config] - else: - return config - - -def display_actor(actor): - for key in ("display", "name", "username", "login", "id"): - if actor.get(key): - return actor[key] - return str(actor) - - -class SpatialiteNotFound(Exception): - pass - - -# Can replace with sqlite-utils when I add that dependency -def find_spatialite(): - for path in SPATIALITE_PATHS: - if os.path.exists(path): - return path - raise SpatialiteNotFound - - -async def initial_path_for_datasette(datasette): - """Return suggested path for opening this Datasette, based on number of DBs and tables""" - databases = dict([p for p in datasette.databases.items() if p[0] != "_internal"]) - if len(databases) == 1: - db_name = next(iter(databases.keys())) - path = datasette.urls.database(db_name) - # Does this DB only have one table? - db = next(iter(databases.values())) - tables = await db.table_names() - if len(tables) == 1: - path = datasette.urls.table(db_name, tables[0]) - else: - path = datasette.urls.instance() - return path - - -class PrefixedUrlString(str): - def __add__(self, other): - return type(self)(super().__add__(other)) - - def __str__(self): - return super().__str__() - - def __getattribute__(self, name): - if not name.startswith("__") and name in dir(str): - - def method(self, *args, **kwargs): - value = getattr(super(), name)(*args, **kwargs) - if isinstance(value, str): - return type(self)(value) - elif isinstance(value, list): - return [type(self)(i) for i in value] - elif isinstance(value, tuple): - return tuple(type(self)(i) for i in value) - else: - return value - - return method.__get__(self) - else: - return super().__getattribute__(name) - - -class StartupError(Exception): - pass - - -_single_line_comment_re = re.compile(r"--.*") -_multi_line_comment_re = re.compile(r"/\*.*?\*/", re.DOTALL) -_single_quote_re = re.compile(r"'(?:''|[^'])*'") -_double_quote_re = re.compile(r'"(?:\"\"|[^"])*"') -_named_param_re = re.compile(r":(\w+)") - - -@documented -def named_parameters(sql: str) -> List[str]: - """ - Given a SQL statement, return a list of named parameters that are used in the statement - - e.g. for ``select * from foo where id=:id`` this would return ``["id"]`` - """ - sql = _single_line_comment_re.sub("", sql) - sql = _multi_line_comment_re.sub("", sql) - sql = _single_quote_re.sub("", sql) - sql = _double_quote_re.sub("", sql) - # Extract parameters from what is left - return _named_param_re.findall(sql) - - -async def derive_named_parameters(db: "Database", sql: str) -> List[str]: - """ - This undocumented but stable method exists for backwards compatibility - with plugins that were using it before it switched to named_parameters() - """ - return named_parameters(sql) - - -def add_cors_headers(headers): - headers["Access-Control-Allow-Origin"] = "*" - headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type" - headers["Access-Control-Expose-Headers"] = "Link" - headers["Access-Control-Allow-Methods"] = "GET, POST, HEAD, OPTIONS" - headers["Access-Control-Max-Age"] = "3600" - - -_TILDE_ENCODING_SAFE = frozenset( - b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" - b"abcdefghijklmnopqrstuvwxyz" - b"0123456789_-" - # This is the same as Python percent-encoding but I removed - # '.' and '~' -) - -_space = ord(" ") - - -class TildeEncoder(dict): - # Keeps a cache internally, via __missing__ - def __missing__(self, b): - # Handle a cache miss, store encoded string in cache and return. - if b in _TILDE_ENCODING_SAFE: - res = chr(b) - elif b == _space: - res = "+" - else: - res = "~{:02X}".format(b) - self[b] = res - return res - - -_tilde_encoder = TildeEncoder().__getitem__ - - -@documented -def tilde_encode(s: str) -> str: - "Returns tilde-encoded string - for example ``/foo/bar`` -> ``~2Ffoo~2Fbar``" - return "".join(_tilde_encoder(char) for char in s.encode("utf-8")) - - -@documented -def tilde_decode(s: str) -> str: - "Decodes a tilde-encoded string, so ``~2Ffoo~2Fbar`` -> ``/foo/bar``" - # Avoid accidentally decoding a %2f style sequence - temp = secrets.token_hex(16) - s = s.replace("%", temp) - decoded = urllib.parse.unquote_plus(s.replace("~", "%")) - return decoded.replace(temp, "%") - - -def resolve_routes(routes, path): - for regex, view in routes: - match = regex.match(path) - if match is not None: - return match, view - return None, None - - -def truncate_url(url, length): - if (not length) or (len(url) <= length): - return url - bits = url.rsplit(".", 1) - if len(bits) == 2 and 1 <= len(bits[1]) <= 4 and "/" not in bits[1]: - rest, ext = bits - return rest[: length - 1 - len(ext)] + "…." + ext - return url[: length - 1] + "…" - - -async def row_sql_params_pks(db, table, pk_values): - pks = await db.primary_keys(table) - use_rowid = not pks - select = "*" - if use_rowid: - select = "rowid, *" - pks = ["rowid"] - wheres = [f'"{pk}"=:p{i}' for i, pk in enumerate(pks)] - sql = f"select {select} from {escape_sqlite(table)} where {' AND '.join(wheres)}" - params = {} - for i, pk_value in enumerate(pk_values): - params[f"p{i}"] = pk_value - return sql, params, pks - - -def _handle_pair(key: str, value: str) -> dict: - """ - Turn a key-value pair into a nested dictionary. - foo, bar => {'foo': 'bar'} - foo.bar, baz => {'foo': {'bar': 'baz'}} - foo.bar, [1, 2, 3] => {'foo': {'bar': [1, 2, 3]}} - foo.bar, "baz" => {'foo': {'bar': 'baz'}} - foo.bar, '{"baz": "qux"}' => {'foo': {'bar': "{'baz': 'qux'}"}} - """ - try: - value = json.loads(value) - except json.JSONDecodeError: - # If it doesn't parse as JSON, treat it as a string - pass - - keys = key.split(".") - result = current_dict = {} - - for k in keys[:-1]: - current_dict[k] = {} - current_dict = current_dict[k] - - current_dict[keys[-1]] = value - return result - - -def _combine(base: dict, update: dict) -> dict: - """ - Recursively merge two dictionaries. - """ - for key, value in update.items(): - if isinstance(value, dict) and key in base and isinstance(base[key], dict): - base[key] = _combine(base[key], value) - else: - base[key] = value - return base - - -def pairs_to_nested_config(pairs: typing.List[typing.Tuple[str, typing.Any]]) -> dict: - """ - Parse a list of key-value pairs into a nested dictionary. - """ - result = {} - for key, value in pairs: - parsed_pair = _handle_pair(key, value) - result = _combine(result, parsed_pair) - return result - - -def make_slot_function(name, datasette, request, **kwargs): - from datasette.plugins import pm - - method = getattr(pm.hook, name, None) - assert method is not None, "No hook found for {}".format(name) - - async def inner(): - html_bits = [] - for hook in method(datasette=datasette, request=request, **kwargs): - html = await await_me_maybe(hook) - if html is not None: - html_bits.append(html) - return markupsafe.Markup("".join(html_bits)) - - return inner - - -def prune_empty_dicts(d: dict): - """ - Recursively prune all empty dictionaries from a given dictionary. - """ - for key, value in list(d.items()): - if isinstance(value, dict): - prune_empty_dicts(value) - if value == {}: - d.pop(key, None) - - -def move_plugins_and_allow(source: dict, destination: dict) -> Tuple[dict, dict]: - """ - Move 'plugins' and 'allow' keys from source to destination dictionary. Creates - hierarchy in destination if needed. After moving, recursively remove any keys - in the source that are left empty. - """ - source = copy.deepcopy(source) - destination = copy.deepcopy(destination) - - def recursive_move(src, dest, path=None): - if path is None: - path = [] - for key, value in list(src.items()): - new_path = path + [key] - if key in ("plugins", "allow"): - # Navigate and create the hierarchy in destination if needed - d = dest - for step in path: - d = d.setdefault(step, {}) - # Move the plugins - d[key] = value - # Remove the plugins from source - src.pop(key, None) - elif isinstance(value, dict): - recursive_move(value, dest, new_path) - # After moving, check if the current dictionary is empty and remove it if so - if not value: - src.pop(key, None) - - recursive_move(source, destination) - prune_empty_dicts(source) - return source, destination - - -_table_config_keys = ( - "hidden", - "sort", - "sort_desc", - "size", - "sortable_columns", - "label_column", - "facets", - "fts_table", - "fts_pk", - "searchmode", -) - - -def move_table_config(metadata: dict, config: dict): - """ - Move all known table configuration keys from metadata to config. - """ - if "databases" not in metadata: - return metadata, config - metadata = copy.deepcopy(metadata) - config = copy.deepcopy(config) - for database_name, database in metadata["databases"].items(): - if "tables" not in database: - continue - for table_name, table in database["tables"].items(): - for key in _table_config_keys: - if key in table: - config.setdefault("databases", {}).setdefault( - database_name, {} - ).setdefault("tables", {}).setdefault(table_name, {})[ - key - ] = table.pop( - key - ) - prune_empty_dicts(metadata) - return metadata, config - - -def redact_keys(original: dict, key_patterns: Iterable) -> dict: - """ - Recursively redact sensitive keys in a dictionary based on given patterns - - :param original: The original dictionary - :param key_patterns: A list of substring patterns to redact - :return: A copy of the original dictionary with sensitive values redacted - """ - - def redact(data): - if isinstance(data, dict): - return { - k: ( - redact(v) - if not any(pattern in k for pattern in key_patterns) - else "***" - ) - for k, v in data.items() - } - elif isinstance(data, list): - return [redact(item) for item in data] - else: - return data - - return redact(original) - - -def md5_not_usedforsecurity(s): - try: - return hashlib.md5(s.encode("utf8"), usedforsecurity=False).hexdigest() - except TypeError: - # For Python 3.8 which does not support usedforsecurity=False - return hashlib.md5(s.encode("utf8")).hexdigest() - - -_etag_cache = {} - - -async def calculate_etag(filepath, chunk_size=4096): - if filepath in _etag_cache: - return _etag_cache[filepath] - - hasher = hashlib.md5() - async with aiofiles.open(filepath, "rb") as f: - while True: - chunk = await f.read(chunk_size) - if not chunk: - break - hasher.update(chunk) - - etag = f'"{hasher.hexdigest()}"' - _etag_cache[filepath] = etag - - return etag - - -def deep_dict_update(dict1, dict2): - for key, value in dict2.items(): - if isinstance(value, dict): - dict1[key] = deep_dict_update(dict1.get(key, type(value)()), value) - else: - dict1[key] = value - return dict1 diff --git a/datasette/utils/actions_sql.py b/datasette/utils/actions_sql.py deleted file mode 100644 index 891ee913..00000000 --- a/datasette/utils/actions_sql.py +++ /dev/null @@ -1,591 +0,0 @@ -""" -SQL query builder for hierarchical permission checking. - -This module implements a cascading permission system based on the pattern -from https://github.com/simonw/research/tree/main/sqlite-permissions-poc - -It builds SQL queries that: - -1. Start with all resources of a given type (from resource_type.resources_sql()) -2. Gather permission rules from plugins (via permission_resources_sql hook) -3. Apply cascading logic: child → parent → global -4. Apply DENY-beats-ALLOW at each level - -The core pattern is: -- Resources are identified by (parent, child) tuples -- Rules are evaluated at three levels: - - child: exact match on (parent, child) - - parent: match on (parent, NULL) - - global: match on (NULL, NULL) -- At the same level, DENY (allow=0) beats ALLOW (allow=1) -- Across levels, child beats parent beats global -""" - -from typing import TYPE_CHECKING - -from datasette.utils.permissions import gather_permission_sql_from_hooks - -if TYPE_CHECKING: - from datasette.app import Datasette - - -async def build_allowed_resources_sql( - datasette: "Datasette", - actor: dict | None, - action: str, - *, - parent: str | None = None, - include_is_private: bool = False, -) -> tuple[str, dict]: - """ - Build a SQL query that returns all resources the actor can access for this action. - - Args: - datasette: The Datasette instance - actor: The actor dict (or None for unauthenticated) - action: The action name (e.g., "view-table", "view-database") - parent: Optional parent filter to limit results (e.g., database name) - include_is_private: If True, add is_private column showing if anonymous cannot access - - Returns: - A tuple of (sql_query, params_dict) - - The returned SQL query will have three columns (or four with include_is_private): - - parent: The parent resource identifier (or NULL) - - child: The child resource identifier (or NULL) - - reason: The reason from the rule that granted access - - is_private: (if include_is_private) 1 if anonymous cannot access, 0 otherwise - - Example: - For action="view-table", this might return: - SELECT parent, child, reason FROM ... WHERE is_allowed = 1 - - Results would be like: - ('analytics', 'users', 'role-based: analysts can access analytics DB') - ('analytics', 'events', 'role-based: analysts can access analytics DB') - ('production', 'orders', 'business-exception: allow production.orders for carol') - """ - # Get the Action object - action_obj = datasette.actions.get(action) - if not action_obj: - raise ValueError(f"Unknown action: {action}") - - # If this action also_requires another action, we need to combine the queries - if action_obj.also_requires: - # Build both queries - main_sql, main_params = await _build_single_action_sql( - datasette, - actor, - action, - parent=parent, - include_is_private=include_is_private, - ) - required_sql, required_params = await _build_single_action_sql( - datasette, - actor, - action_obj.also_requires, - parent=parent, - include_is_private=False, - ) - - # Merge parameters - they should have identical values for :actor, :actor_id, etc. - all_params = {**main_params, **required_params} - if parent is not None: - all_params["filter_parent"] = parent - - # Combine with INNER JOIN - only resources allowed by both actions - combined_sql = f""" -WITH -main_allowed AS ( -{main_sql} -), -required_allowed AS ( -{required_sql} -) -SELECT m.parent, m.child, m.reason""" - - if include_is_private: - combined_sql += ", m.is_private" - - combined_sql += """ -FROM main_allowed m -INNER JOIN required_allowed r - ON ((m.parent = r.parent) OR (m.parent IS NULL AND r.parent IS NULL)) - AND ((m.child = r.child) OR (m.child IS NULL AND r.child IS NULL)) -""" - - if parent is not None: - combined_sql += "WHERE m.parent = :filter_parent\n" - - combined_sql += "ORDER BY m.parent, m.child" - - return combined_sql, all_params - - # No also_requires, build single action query - return await _build_single_action_sql( - datasette, actor, action, parent=parent, include_is_private=include_is_private - ) - - -async def _build_single_action_sql( - datasette: "Datasette", - actor: dict | None, - action: str, - *, - parent: str | None = None, - include_is_private: bool = False, -) -> tuple[str, dict]: - """ - Build SQL for a single action (internal helper for build_allowed_resources_sql). - - This contains the original logic from build_allowed_resources_sql, extracted - to allow combining multiple actions when also_requires is used. - """ - # Get the Action object - action_obj = datasette.actions.get(action) - if not action_obj: - raise ValueError(f"Unknown action: {action}") - - # Get base resources SQL from the resource class - base_resources_sql = await action_obj.resource_class.resources_sql( - datasette, actor=actor - ) - - permission_sqls = await gather_permission_sql_from_hooks( - datasette=datasette, - actor=actor, - action=action, - ) - - # If permission_sqls is the sentinel, skip all permission checks - # Return SQL that allows all resources - from datasette.utils.permissions import SKIP_PERMISSION_CHECKS - - if permission_sqls is SKIP_PERMISSION_CHECKS: - cols = "parent, child, 'skip_permission_checks' AS reason" - if include_is_private: - cols += ", 0 AS is_private" - return f"SELECT {cols} FROM ({base_resources_sql})", {} - - all_params = {} - rule_sqls = [] - restriction_sqls = [] - - for permission_sql in permission_sqls: - # Always collect params (even from restriction-only plugins) - all_params.update(permission_sql.params or {}) - - # Collect restriction SQL filters - if permission_sql.restriction_sql: - restriction_sqls.append(permission_sql.restriction_sql) - - # Skip plugins that only provide restriction_sql (no permission rules) - if permission_sql.sql is None: - continue - rule_sqls.append(f""" - SELECT parent, child, allow, reason, '{permission_sql.source}' AS source_plugin FROM ( - {permission_sql.sql} - ) - """.strip()) - - # If no rules, return empty result (deny all) - if not rule_sqls: - empty_cols = "NULL AS parent, NULL AS child, NULL AS reason" - if include_is_private: - empty_cols += ", NULL AS is_private" - return f"SELECT {empty_cols} WHERE 0", {} - - # Build the cascading permission query - rules_union = " UNION ALL ".join(rule_sqls) - - # Build the main query - query_parts = [ - "WITH", - "base AS (", - f" {base_resources_sql}", - "),", - "all_rules AS (", - f" {rules_union}", - "),", - ] - - # If include_is_private, we need to build anonymous permissions too - if include_is_private: - anon_permission_sqls = await gather_permission_sql_from_hooks( - datasette=datasette, - actor=None, - action=action, - ) - anon_sqls_rewritten = [] - anon_params = {} - - for permission_sql in anon_permission_sqls: - # Skip plugins that only provide restriction_sql (no permission rules) - if permission_sql.sql is None: - continue - rewritten_sql = permission_sql.sql - for key, value in (permission_sql.params or {}).items(): - anon_key = f"anon_{key}" - anon_params[anon_key] = value - rewritten_sql = rewritten_sql.replace(f":{key}", f":{anon_key}") - anon_sqls_rewritten.append(rewritten_sql) - - all_params.update(anon_params) - - if anon_sqls_rewritten: - anon_rules_union = " UNION ALL ".join(anon_sqls_rewritten) - query_parts.extend( - [ - "anon_rules AS (", - f" {anon_rules_union}", - "),", - ] - ) - else: - query_parts.extend( - [ - "anon_rules AS (", - " SELECT NULL AS parent, NULL AS child, 0 AS allow, NULL AS reason WHERE 0", - "),", - ] - ) - - # Continue with the cascading logic - query_parts.extend( - [ - "child_lvl AS (", - " SELECT b.parent, b.child,", - " MAX(CASE WHEN ar.allow = 0 THEN 1 ELSE 0 END) AS any_deny,", - " MAX(CASE WHEN ar.allow = 1 THEN 1 ELSE 0 END) AS any_allow,", - " json_group_array(CASE WHEN ar.allow = 0 THEN ar.source_plugin || ': ' || ar.reason END) AS deny_reasons,", - " json_group_array(CASE WHEN ar.allow = 1 THEN ar.source_plugin || ': ' || ar.reason END) AS allow_reasons", - " FROM base b", - " LEFT JOIN all_rules ar ON ar.parent = b.parent AND ar.child = b.child", - " GROUP BY b.parent, b.child", - "),", - "parent_lvl AS (", - " SELECT b.parent, b.child,", - " MAX(CASE WHEN ar.allow = 0 THEN 1 ELSE 0 END) AS any_deny,", - " MAX(CASE WHEN ar.allow = 1 THEN 1 ELSE 0 END) AS any_allow,", - " json_group_array(CASE WHEN ar.allow = 0 THEN ar.source_plugin || ': ' || ar.reason END) AS deny_reasons,", - " json_group_array(CASE WHEN ar.allow = 1 THEN ar.source_plugin || ': ' || ar.reason END) AS allow_reasons", - " FROM base b", - " LEFT JOIN all_rules ar ON ar.parent = b.parent AND ar.child IS NULL", - " GROUP BY b.parent, b.child", - "),", - "global_lvl AS (", - " SELECT b.parent, b.child,", - " MAX(CASE WHEN ar.allow = 0 THEN 1 ELSE 0 END) AS any_deny,", - " MAX(CASE WHEN ar.allow = 1 THEN 1 ELSE 0 END) AS any_allow,", - " json_group_array(CASE WHEN ar.allow = 0 THEN ar.source_plugin || ': ' || ar.reason END) AS deny_reasons,", - " json_group_array(CASE WHEN ar.allow = 1 THEN ar.source_plugin || ': ' || ar.reason END) AS allow_reasons", - " FROM base b", - " LEFT JOIN all_rules ar ON ar.parent IS NULL AND ar.child IS NULL", - " GROUP BY b.parent, b.child", - "),", - ] - ) - - # Add anonymous decision logic if needed - if include_is_private: - query_parts.extend( - [ - "anon_child_lvl AS (", - " SELECT b.parent, b.child,", - " MAX(CASE WHEN ar.allow = 0 THEN 1 ELSE 0 END) AS any_deny,", - " MAX(CASE WHEN ar.allow = 1 THEN 1 ELSE 0 END) AS any_allow", - " FROM base b", - " LEFT JOIN anon_rules ar ON ar.parent = b.parent AND ar.child = b.child", - " GROUP BY b.parent, b.child", - "),", - "anon_parent_lvl AS (", - " SELECT b.parent, b.child,", - " MAX(CASE WHEN ar.allow = 0 THEN 1 ELSE 0 END) AS any_deny,", - " MAX(CASE WHEN ar.allow = 1 THEN 1 ELSE 0 END) AS any_allow", - " FROM base b", - " LEFT JOIN anon_rules ar ON ar.parent = b.parent AND ar.child IS NULL", - " GROUP BY b.parent, b.child", - "),", - "anon_global_lvl AS (", - " SELECT b.parent, b.child,", - " MAX(CASE WHEN ar.allow = 0 THEN 1 ELSE 0 END) AS any_deny,", - " MAX(CASE WHEN ar.allow = 1 THEN 1 ELSE 0 END) AS any_allow", - " FROM base b", - " LEFT JOIN anon_rules ar ON ar.parent IS NULL AND ar.child IS NULL", - " GROUP BY b.parent, b.child", - "),", - "anon_decisions AS (", - " SELECT", - " b.parent, b.child,", - " CASE", - " WHEN acl.any_deny = 1 THEN 0", - " WHEN acl.any_allow = 1 THEN 1", - " WHEN apl.any_deny = 1 THEN 0", - " WHEN apl.any_allow = 1 THEN 1", - " WHEN agl.any_deny = 1 THEN 0", - " WHEN agl.any_allow = 1 THEN 1", - " ELSE 0", - " END AS anon_is_allowed", - " FROM base b", - " JOIN anon_child_lvl acl ON b.parent = acl.parent AND (b.child = acl.child OR (b.child IS NULL AND acl.child IS NULL))", - " JOIN anon_parent_lvl apl ON b.parent = apl.parent AND (b.child = apl.child OR (b.child IS NULL AND apl.child IS NULL))", - " JOIN anon_global_lvl agl ON b.parent = agl.parent AND (b.child = agl.child OR (b.child IS NULL AND agl.child IS NULL))", - "),", - ] - ) - - # Final decisions - query_parts.extend( - [ - "decisions AS (", - " SELECT", - " b.parent, b.child,", - " -- Cascading permission logic: child → parent → global, DENY beats ALLOW at each level", - " -- Priority order:", - " -- 1. Child-level deny (most specific, blocks access)", - " -- 2. Child-level allow (most specific, grants access)", - " -- 3. Parent-level deny (intermediate, blocks access)", - " -- 4. Parent-level allow (intermediate, grants access)", - " -- 5. Global-level deny (least specific, blocks access)", - " -- 6. Global-level allow (least specific, grants access)", - " -- 7. Default deny (no rules match)", - " CASE", - " WHEN cl.any_deny = 1 THEN 0", - " WHEN cl.any_allow = 1 THEN 1", - " WHEN pl.any_deny = 1 THEN 0", - " WHEN pl.any_allow = 1 THEN 1", - " WHEN gl.any_deny = 1 THEN 0", - " WHEN gl.any_allow = 1 THEN 1", - " ELSE 0", - " END AS is_allowed,", - " CASE", - " WHEN cl.any_deny = 1 THEN cl.deny_reasons", - " WHEN cl.any_allow = 1 THEN cl.allow_reasons", - " WHEN pl.any_deny = 1 THEN pl.deny_reasons", - " WHEN pl.any_allow = 1 THEN pl.allow_reasons", - " WHEN gl.any_deny = 1 THEN gl.deny_reasons", - " WHEN gl.any_allow = 1 THEN gl.allow_reasons", - " ELSE '[]'", - " END AS reason", - ] - ) - - if include_is_private: - query_parts.append( - " , CASE WHEN ad.anon_is_allowed = 0 THEN 1 ELSE 0 END AS is_private" - ) - - query_parts.extend( - [ - " FROM base b", - " JOIN child_lvl cl ON b.parent = cl.parent AND (b.child = cl.child OR (b.child IS NULL AND cl.child IS NULL))", - " JOIN parent_lvl pl ON b.parent = pl.parent AND (b.child = pl.child OR (b.child IS NULL AND pl.child IS NULL))", - " JOIN global_lvl gl ON b.parent = gl.parent AND (b.child = gl.child OR (b.child IS NULL AND gl.child IS NULL))", - ] - ) - - if include_is_private: - query_parts.append( - " JOIN anon_decisions ad ON b.parent = ad.parent AND (b.child = ad.child OR (b.child IS NULL AND ad.child IS NULL))" - ) - - query_parts.append(")") - - # Add restriction list CTE if there are restrictions - if restriction_sqls: - # Wrap each restriction_sql in a subquery to avoid operator precedence issues - # with UNION ALL inside the restriction SQL statements - restriction_intersect = "\nINTERSECT\n".join( - f"SELECT * FROM ({sql})" for sql in restriction_sqls - ) - query_parts.extend( - [",", "restriction_list AS (", f" {restriction_intersect}", ")"] - ) - - # Final SELECT - select_cols = "parent, child, reason" - if include_is_private: - select_cols += ", is_private" - - query_parts.append(f"SELECT {select_cols}") - query_parts.append("FROM decisions") - query_parts.append("WHERE is_allowed = 1") - - # Add restriction filter if there are restrictions - if restriction_sqls: - query_parts.append(""" - AND EXISTS ( - SELECT 1 FROM restriction_list r - WHERE (r.parent = decisions.parent OR r.parent IS NULL) - AND (r.child = decisions.child OR r.child IS NULL) - )""") - - # Add parent filter if specified - if parent is not None: - query_parts.append(" AND parent = :filter_parent") - all_params["filter_parent"] = parent - - query_parts.append("ORDER BY parent, child") - - query = "\n".join(query_parts) - return query, all_params - - -async def build_permission_rules_sql( - datasette: "Datasette", actor: dict | None, action: str -) -> tuple[str, dict]: - """ - Build the UNION SQL and params for all permission rules for a given actor and action. - - Returns: - A tuple of (sql, params) where sql is a UNION ALL query that returns - (parent, child, allow, reason, source_plugin) rows. - """ - # Get the Action object - action_obj = datasette.actions.get(action) - if not action_obj: - raise ValueError(f"Unknown action: {action}") - - permission_sqls = await gather_permission_sql_from_hooks( - datasette=datasette, - actor=actor, - action=action, - ) - - # If permission_sqls is the sentinel, skip all permission checks - # Return SQL that allows everything - from datasette.utils.permissions import SKIP_PERMISSION_CHECKS - - if permission_sqls is SKIP_PERMISSION_CHECKS: - return ( - "SELECT NULL AS parent, NULL AS child, 1 AS allow, 'skip_permission_checks' AS reason, 'skip' AS source_plugin", - {}, - [], - ) - - if not permission_sqls: - return ( - "SELECT NULL AS parent, NULL AS child, 0 AS allow, NULL AS reason, NULL AS source_plugin WHERE 0", - {}, - [], - ) - - union_parts = [] - all_params = {} - restriction_sqls = [] - - for permission_sql in permission_sqls: - all_params.update(permission_sql.params or {}) - - # Collect restriction SQL filters - if permission_sql.restriction_sql: - restriction_sqls.append(permission_sql.restriction_sql) - - # Skip plugins that only provide restriction_sql (no permission rules) - if permission_sql.sql is None: - continue - - union_parts.append(f""" - SELECT parent, child, allow, reason, '{permission_sql.source}' AS source_plugin FROM ( - {permission_sql.sql} - ) - """.strip()) - - rules_union = " UNION ALL ".join(union_parts) - return rules_union, all_params, restriction_sqls - - -async def check_permission_for_resource( - *, - datasette: "Datasette", - actor: dict | None, - action: str, - parent: str | None, - child: str | None, -) -> bool: - """ - Check if an actor has permission for a specific action on a specific resource. - - Args: - datasette: The Datasette instance - actor: The actor dict (or None) - action: The action name - parent: The parent resource identifier (e.g., database name, or None) - child: The child resource identifier (e.g., table name, or None) - - Returns: - True if the actor is allowed, False otherwise - - This builds the cascading permission query and checks if the specific - resource is in the allowed set. - """ - rules_union, all_params, restriction_sqls = await build_permission_rules_sql( - datasette, actor, action - ) - - # If no rules (empty SQL), default deny - if not rules_union: - return False - - # Add parameters for the resource we're checking - all_params["_check_parent"] = parent - all_params["_check_child"] = child - - # If there are restriction filters, check if the resource passes them first - if restriction_sqls: - # Check if resource is in restriction allowlist - # Database-level restrictions (parent, NULL) should match all children (parent, *) - # Wrap each restriction_sql in a subquery to avoid operator precedence issues - restriction_check = "\nINTERSECT\n".join( - f"SELECT * FROM ({sql})" for sql in restriction_sqls - ) - restriction_query = f""" -WITH restriction_list AS ( - {restriction_check} -) -SELECT EXISTS ( - SELECT 1 FROM restriction_list - WHERE (parent = :_check_parent OR parent IS NULL) - AND (child = :_check_child OR child IS NULL) -) AS in_allowlist -""" - result = await datasette.get_internal_database().execute( - restriction_query, all_params - ) - if result.rows and not result.rows[0][0]: - # Resource not in restriction allowlist - deny - return False - - query = f""" -WITH -all_rules AS ( - {rules_union} -), -matched_rules AS ( - SELECT ar.*, - CASE - WHEN ar.child IS NOT NULL THEN 2 -- child-level (most specific) - WHEN ar.parent IS NOT NULL THEN 1 -- parent-level - ELSE 0 -- root/global - END AS depth - FROM all_rules ar - WHERE (ar.parent IS NULL OR ar.parent = :_check_parent) - AND (ar.child IS NULL OR ar.child = :_check_child) -), -winner AS ( - SELECT * - FROM matched_rules - ORDER BY - depth DESC, -- specificity first (higher depth wins) - CASE WHEN allow=0 THEN 0 ELSE 1 END, -- then deny over allow - source_plugin -- stable tie-break - LIMIT 1 -) -SELECT COALESCE((SELECT allow FROM winner), 0) AS is_allowed -""" - - # Execute the query against the internal database - result = await datasette.get_internal_database().execute(query, all_params) - if result.rows: - return bool(result.rows[0][0]) - return False diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py deleted file mode 100644 index 35f243b6..00000000 --- a/datasette/utils/asgi.py +++ /dev/null @@ -1,565 +0,0 @@ -import json -from typing import Optional -from datasette.utils import MultiParams, calculate_etag -from datasette.utils.multipart import ( - parse_form_data, - MultipartParseError, - FormData, - DEFAULT_MAX_FILE_SIZE, - DEFAULT_MAX_REQUEST_SIZE, - DEFAULT_MAX_FIELDS, - DEFAULT_MAX_FILES, - DEFAULT_MAX_PARTS, - DEFAULT_MAX_FIELD_SIZE, - DEFAULT_MAX_MEMORY_FILE_SIZE, - DEFAULT_MAX_PART_HEADER_BYTES, - DEFAULT_MAX_PART_HEADER_LINES, - DEFAULT_MIN_FREE_DISK_BYTES, -) -from mimetypes import guess_type -from urllib.parse import parse_qs, urlunparse, parse_qsl -from pathlib import Path -from http.cookies import SimpleCookie, Morsel -import aiofiles -import aiofiles.os -import re - -# Workaround for adding samesite support to pre 3.8 python -Morsel._reserved["samesite"] = "SameSite" -# Thanks, Starlette: -# https://github.com/encode/starlette/blob/519f575/starlette/responses.py#L17 - - -class Base400(Exception): - status = 400 - - -class NotFound(Base400): - status = 404 - - -class DatabaseNotFound(NotFound): - def __init__(self, database_name): - self.database_name = database_name - super().__init__("Database not found") - - -class TableNotFound(NotFound): - def __init__(self, database_name, table): - super().__init__("Table not found") - self.database_name = database_name - self.table = table - - -class RowNotFound(NotFound): - def __init__(self, database_name, table, pk_values): - super().__init__("Row not found") - self.database_name = database_name - self.table_name = table - self.pk_values = pk_values - - -class Forbidden(Base400): - status = 403 - - -class BadRequest(Base400): - status = 400 - - -SAMESITE_VALUES = ("strict", "lax", "none") - - -class Request: - def __init__(self, scope, receive): - self.scope = scope - self.receive = receive - - def __repr__(self): - return ''.format(self.method, self.url) - - @property - def method(self): - return self.scope["method"] - - @property - def url(self): - return urlunparse( - (self.scheme, self.host, self.path, None, self.query_string, None) - ) - - @property - def url_vars(self): - return (self.scope.get("url_route") or {}).get("kwargs") or {} - - @property - def scheme(self): - return self.scope.get("scheme") or "http" - - @property - def headers(self): - return { - k.decode("latin-1").lower(): v.decode("latin-1") - for k, v in self.scope.get("headers") or [] - } - - @property - def host(self): - return self.headers.get("host") or "localhost" - - @property - def cookies(self): - cookies = SimpleCookie() - cookies.load(self.headers.get("cookie", "")) - return {key: value.value for key, value in cookies.items()} - - @property - def path(self): - if self.scope.get("raw_path") is not None: - return self.scope["raw_path"].decode("latin-1").partition("?")[0] - else: - path = self.scope["path"] - if isinstance(path, str): - return path - else: - return path.decode("utf-8") - - @property - def query_string(self): - return (self.scope.get("query_string") or b"").decode("latin-1") - - @property - def full_path(self): - qs = self.query_string - return "{}{}".format(self.path, ("?" + qs) if qs else "") - - @property - def args(self): - return MultiParams(parse_qs(qs=self.query_string, keep_blank_values=True)) - - @property - def actor(self): - return self.scope.get("actor", None) - - async def post_body(self): - body = b"" - more_body = True - while more_body: - message = await self.receive() - assert message["type"] == "http.request", message - body += message.get("body", b"") - more_body = message.get("more_body", False) - return body - - async def post_vars(self): - body = await self.post_body() - return dict(parse_qsl(body.decode("utf-8"), keep_blank_values=True)) - - async def form( - self, - files: bool = False, - max_file_size: int = DEFAULT_MAX_FILE_SIZE, - max_request_size: int = DEFAULT_MAX_REQUEST_SIZE, - max_fields: int = DEFAULT_MAX_FIELDS, - max_files: int = DEFAULT_MAX_FILES, - max_parts: Optional[int] = DEFAULT_MAX_PARTS, - max_field_size: int = DEFAULT_MAX_FIELD_SIZE, - max_memory_file_size: int = DEFAULT_MAX_MEMORY_FILE_SIZE, - max_part_header_bytes: int = DEFAULT_MAX_PART_HEADER_BYTES, - max_part_header_lines: int = DEFAULT_MAX_PART_HEADER_LINES, - min_free_disk_bytes: int = DEFAULT_MIN_FREE_DISK_BYTES, - ) -> FormData: - """ - Parse form data from the request body. - - Supports both application/x-www-form-urlencoded and multipart/form-data. - - Args: - files: If True, store file uploads; if False (default), discard them - max_file_size: Maximum size per file in bytes (default 50MB) - max_request_size: Maximum total request size in bytes (default 100MB) - max_fields: Maximum number of form fields (default 1000) - max_files: Maximum number of file uploads (default 100) - max_parts: Maximum number of multipart parts (default max_fields + max_files) - max_field_size: Maximum size of a text field value in bytes (default 100KB) - max_memory_file_size: Threshold before files spill to disk (default 1MB) - max_part_header_bytes: Maximum bytes allowed in part headers (default 16KB) - max_part_header_lines: Maximum header lines per part (default 100) - min_free_disk_bytes: Minimum free bytes required in temp dir (default 50MB) - - Returns: - FormData object with dict-like access to fields and files. - Use form["key"] for first value, form.getlist("key") for all values. - - Raises: - BadRequest: If content-type is missing, unsupported, or parsing fails - """ - content_type = self.headers.get("content-type", "") - if not content_type: - raise BadRequest( - "Missing Content-Type header; expected application/x-www-form-urlencoded " - "or multipart/form-data" - ) - - try: - return await parse_form_data( - receive=self.receive, - content_type=content_type, - files=files, - max_file_size=max_file_size, - max_request_size=max_request_size, - max_fields=max_fields, - max_files=max_files, - max_parts=max_parts, - max_field_size=max_field_size, - max_memory_file_size=max_memory_file_size, - max_part_header_bytes=max_part_header_bytes, - max_part_header_lines=max_part_header_lines, - min_free_disk_bytes=min_free_disk_bytes, - ) - except MultipartParseError as e: - raise BadRequest(str(e)) - - @classmethod - def fake(cls, path_with_query_string, method="GET", scheme="http", url_vars=None): - """Useful for constructing Request objects for tests""" - path, _, query_string = path_with_query_string.partition("?") - scope = { - "http_version": "1.1", - "method": method, - "path": path, - "raw_path": path_with_query_string.encode("latin-1"), - "query_string": query_string.encode("latin-1"), - "scheme": scheme, - "type": "http", - } - if url_vars: - scope["url_route"] = {"kwargs": url_vars} - return cls(scope, None) - - -class AsgiLifespan: - def __init__(self, app, on_startup=None, on_shutdown=None): - self.app = app - on_startup = on_startup or [] - on_shutdown = on_shutdown or [] - if not isinstance(on_startup or [], list): - on_startup = [on_startup] - if not isinstance(on_shutdown or [], list): - on_shutdown = [on_shutdown] - self.on_startup = on_startup - self.on_shutdown = on_shutdown - - async def __call__(self, scope, receive, send): - if scope["type"] == "lifespan": - while True: - message = await receive() - if message["type"] == "lifespan.startup": - for fn in self.on_startup: - await fn() - await send({"type": "lifespan.startup.complete"}) - elif message["type"] == "lifespan.shutdown": - for fn in self.on_shutdown: - await fn() - await send({"type": "lifespan.shutdown.complete"}) - return - else: - await self.app(scope, receive, send) - - -class AsgiStream: - def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): - self.stream_fn = stream_fn - self.status = status - self.headers = headers or {} - self.content_type = content_type - - async def asgi_send(self, send): - # Remove any existing content-type header - headers = {k: v for k, v in self.headers.items() if k.lower() != "content-type"} - headers["content-type"] = self.content_type - await send( - { - "type": "http.response.start", - "status": self.status, - "headers": [ - [key.encode("utf-8"), value.encode("utf-8")] - for key, value in headers.items() - ], - } - ) - w = AsgiWriter(send) - await self.stream_fn(w) - await send({"type": "http.response.body", "body": b""}) - - -class AsgiWriter: - def __init__(self, send): - self.send = send - - async def write(self, chunk): - await self.send( - { - "type": "http.response.body", - "body": chunk.encode("utf-8"), - "more_body": True, - } - ) - - -async def asgi_send_json(send, info, status=200, headers=None): - headers = headers or {} - await asgi_send( - send, - json.dumps(info), - status=status, - headers=headers, - content_type="application/json; charset=utf-8", - ) - - -async def asgi_send_html(send, html, status=200, headers=None): - headers = headers or {} - await asgi_send( - send, - html, - status=status, - headers=headers, - content_type="text/html; charset=utf-8", - ) - - -async def asgi_send_redirect(send, location, status=302): - # Prevent open redirect vulnerability: strip multiple leading slashes - # //example.com would be interpreted as a protocol-relative URL (e.g., https://example.com/) - location = re.sub(r"^/+", "/", location) - await asgi_send( - send, - "", - status=status, - headers={"Location": location}, - content_type="text/html; charset=utf-8", - ) - - -async def asgi_send(send, content, status, headers=None, content_type="text/plain"): - await asgi_start(send, status, headers, content_type) - await send({"type": "http.response.body", "body": content.encode("utf-8")}) - - -async def asgi_start(send, status, headers=None, content_type="text/plain"): - headers = headers or {} - # Remove any existing content-type header - headers = {k: v for k, v in headers.items() if k.lower() != "content-type"} - headers["content-type"] = content_type - await send( - { - "type": "http.response.start", - "status": status, - "headers": [ - [key.encode("latin1"), value.encode("latin1")] - for key, value in headers.items() - ], - } - ) - - -async def asgi_send_file( - send, filepath, filename=None, content_type=None, chunk_size=4096, headers=None -): - headers = headers or {} - if filename: - headers["content-disposition"] = f'attachment; filename="{filename}"' - - first = True - headers["content-length"] = str((await aiofiles.os.stat(str(filepath))).st_size) - async with aiofiles.open(str(filepath), mode="rb") as fp: - if first: - await asgi_start( - send, - 200, - headers, - content_type or guess_type(str(filepath))[0] or "text/plain", - ) - first = False - more_body = True - while more_body: - chunk = await fp.read(chunk_size) - more_body = len(chunk) == chunk_size - await send( - {"type": "http.response.body", "body": chunk, "more_body": more_body} - ) - - -def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None): - root_path = Path(root_path) - static_headers = {} - - if headers: - static_headers = headers.copy() - - async def inner_static(request, send): - path = request.scope["url_route"]["kwargs"]["path"] - headers = static_headers.copy() - try: - full_path = (root_path / path).resolve().absolute() - except FileNotFoundError: - await asgi_send_html(send, "404: Directory not found", 404) - return - if full_path.is_dir(): - await asgi_send_html(send, "403: Directory listing is not allowed", 403) - return - # Ensure full_path is within root_path to avoid weird "../" tricks - try: - full_path.relative_to(root_path.resolve()) - except ValueError: - await asgi_send_html(send, "404: Path not inside root path", 404) - return - try: - # Calculate ETag for filepath - etag = await calculate_etag(full_path, chunk_size=chunk_size) - headers["ETag"] = etag - if_none_match = request.headers.get("if-none-match") - if if_none_match and if_none_match == etag: - return await asgi_send(send, "", 304) - await asgi_send_file( - send, full_path, chunk_size=chunk_size, headers=headers - ) - except FileNotFoundError: - await asgi_send_html(send, "404: File not found", 404) - return - - return inner_static - - -class Response: - def __init__(self, body=None, status=200, headers=None, content_type="text/plain"): - self.body = body - self.status = status - self.headers = headers or {} - self._set_cookie_headers = [] - self.content_type = content_type - - async def asgi_send(self, send): - headers = {} - headers.update(self.headers) - headers["content-type"] = self.content_type - raw_headers = [ - [key.encode("utf-8"), value.encode("utf-8")] - for key, value in headers.items() - ] - for set_cookie in self._set_cookie_headers: - raw_headers.append([b"set-cookie", set_cookie.encode("utf-8")]) - await send( - { - "type": "http.response.start", - "status": self.status, - "headers": raw_headers, - } - ) - body = self.body - if not isinstance(body, bytes): - body = body.encode("utf-8") - await send({"type": "http.response.body", "body": body}) - - def set_cookie( - self, - key, - value="", - max_age=None, - expires=None, - path="/", - domain=None, - secure=False, - httponly=False, - samesite="lax", - ): - assert samesite in SAMESITE_VALUES, "samesite should be one of {}".format( - SAMESITE_VALUES - ) - cookie = SimpleCookie() - cookie[key] = value - for prop_name, prop_value in ( - ("max_age", max_age), - ("expires", expires), - ("path", path), - ("domain", domain), - ("samesite", samesite), - ): - if prop_value is not None: - cookie[key][prop_name.replace("_", "-")] = prop_value - for prop_name, prop_value in (("secure", secure), ("httponly", httponly)): - if prop_value: - cookie[key][prop_name] = True - self._set_cookie_headers.append(cookie.output(header="").strip()) - - @classmethod - def html(cls, body, status=200, headers=None): - return cls( - body, - status=status, - headers=headers, - content_type="text/html; charset=utf-8", - ) - - @classmethod - def text(cls, body, status=200, headers=None): - return cls( - str(body), - status=status, - headers=headers, - content_type="text/plain; charset=utf-8", - ) - - @classmethod - def json(cls, body, status=200, headers=None, default=None): - return cls( - json.dumps(body, default=default), - status=status, - headers=headers, - content_type="application/json; charset=utf-8", - ) - - @classmethod - def redirect(cls, path, status=302, headers=None): - headers = headers or {} - headers["Location"] = path - return cls("", status=status, headers=headers) - - -class AsgiFileDownload: - def __init__( - self, - filepath, - filename=None, - content_type="application/octet-stream", - headers=None, - ): - self.headers = headers or {} - self.filepath = filepath - self.filename = filename - self.content_type = content_type - - async def asgi_send(self, send): - return await asgi_send_file( - send, - self.filepath, - filename=self.filename, - content_type=self.content_type, - headers=self.headers, - ) - - -class AsgiRunOnFirstRequest: - def __init__(self, asgi, on_startup): - assert isinstance(on_startup, list) - self.asgi = asgi - self.on_startup = on_startup - self._started = False - - async def __call__(self, scope, receive, send): - if not self._started: - self._started = True - for hook in self.on_startup: - await hook() - return await self.asgi(scope, receive, send) diff --git a/datasette/utils/baseconv.py b/datasette/utils/baseconv.py deleted file mode 100644 index c4b64908..00000000 --- a/datasette/utils/baseconv.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -Convert numbers from base 10 integers to base X strings and back again. - -Sample usage: - ->>> base20 = BaseConverter('0123456789abcdefghij') ->>> base20.from_decimal(1234) -'31e' ->>> base20.to_decimal('31e') -1234 - -Originally shared here: https://www.djangosnippets.org/snippets/1431/ -""" - - -class BaseConverter(object): - decimal_digits = "0123456789" - - def __init__(self, digits): - self.digits = digits - - def encode(self, i): - return self.convert(i, self.decimal_digits, self.digits) - - def decode(self, s): - return int(self.convert(s, self.digits, self.decimal_digits)) - - def convert(number, fromdigits, todigits): - # Based on http://code.activestate.com/recipes/111286/ - if str(number)[0] == "-": - number = str(number)[1:] - neg = 1 - else: - neg = 0 - - # make an integer out of the number - x = 0 - for digit in str(number): - x = x * len(fromdigits) + fromdigits.index(digit) - - # create the result in base 'len(todigits)' - if x == 0: - res = todigits[0] - else: - res = "" - while x > 0: - digit = x % len(todigits) - res = todigits[digit] + res - x = int(x / len(todigits)) - if neg: - res = "-" + res - return res - - convert = staticmethod(convert) - - -bin = BaseConverter("01") -hexconv = BaseConverter("0123456789ABCDEF") -base62 = BaseConverter("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789abcdefghijklmnopqrstuvwxyz") diff --git a/datasette/utils/check_callable.py b/datasette/utils/check_callable.py deleted file mode 100644 index a0997d20..00000000 --- a/datasette/utils/check_callable.py +++ /dev/null @@ -1,25 +0,0 @@ -import inspect -import types -from typing import NamedTuple, Any - - -class CallableStatus(NamedTuple): - is_callable: bool - is_async_callable: bool - - -def check_callable(obj: Any) -> CallableStatus: - if not callable(obj): - return CallableStatus(False, False) - - if isinstance(obj, type): - # It's a class - return CallableStatus(True, False) - - if isinstance(obj, types.FunctionType): - return CallableStatus(True, inspect.iscoroutinefunction(obj)) - - if hasattr(obj, "__call__"): - return CallableStatus(True, inspect.iscoroutinefunction(obj.__call__)) - - assert False, "obj {} is somehow callable with no __call__ method".format(repr(obj)) diff --git a/datasette/utils/internal_db.py b/datasette/utils/internal_db.py deleted file mode 100644 index bf172667..00000000 --- a/datasette/utils/internal_db.py +++ /dev/null @@ -1,269 +0,0 @@ -import textwrap -from datasette.utils import table_column_details - - -async def init_internal_db(db): - create_tables_sql = textwrap.dedent(""" - CREATE TABLE IF NOT EXISTS catalog_databases ( - database_name TEXT PRIMARY KEY, - path TEXT, - is_memory INTEGER, - schema_version INTEGER - ); - CREATE TABLE IF NOT EXISTS catalog_tables ( - database_name TEXT, - table_name TEXT, - rootpage INTEGER, - sql TEXT, - PRIMARY KEY (database_name, table_name), - FOREIGN KEY (database_name) REFERENCES catalog_databases(database_name) - ); - CREATE TABLE IF NOT EXISTS catalog_views ( - database_name TEXT, - view_name TEXT, - rootpage INTEGER, - sql TEXT, - PRIMARY KEY (database_name, view_name), - FOREIGN KEY (database_name) REFERENCES catalog_databases(database_name) - ); - CREATE TABLE IF NOT EXISTS catalog_columns ( - database_name TEXT, - table_name TEXT, - cid INTEGER, - name TEXT, - type TEXT, - "notnull" INTEGER, - default_value TEXT, -- renamed from dflt_value - is_pk INTEGER, -- renamed from pk - hidden INTEGER, - PRIMARY KEY (database_name, table_name, name), - FOREIGN KEY (database_name) REFERENCES catalog_databases(database_name), - FOREIGN KEY (database_name, table_name) REFERENCES catalog_tables(database_name, table_name) - ); - CREATE TABLE IF NOT EXISTS catalog_indexes ( - database_name TEXT, - table_name TEXT, - seq INTEGER, - name TEXT, - "unique" INTEGER, - origin TEXT, - partial INTEGER, - PRIMARY KEY (database_name, table_name, name), - FOREIGN KEY (database_name) REFERENCES catalog_databases(database_name), - FOREIGN KEY (database_name, table_name) REFERENCES catalog_tables(database_name, table_name) - ); - CREATE TABLE IF NOT EXISTS catalog_foreign_keys ( - database_name TEXT, - table_name TEXT, - id INTEGER, - seq INTEGER, - "table" TEXT, - "from" TEXT, - "to" TEXT, - on_update TEXT, - on_delete TEXT, - match TEXT, - PRIMARY KEY (database_name, table_name, id, seq), - FOREIGN KEY (database_name) REFERENCES catalog_databases(database_name), - FOREIGN KEY (database_name, table_name) REFERENCES catalog_tables(database_name, table_name) - ); - """).strip() - await db.execute_write_script(create_tables_sql) - await initialize_metadata_tables(db) - - -async def initialize_metadata_tables(db): - await db.execute_write_script(textwrap.dedent(""" - CREATE TABLE IF NOT EXISTS metadata_instance ( - key text, - value text, - unique(key) - ); - - CREATE TABLE IF NOT EXISTS metadata_databases ( - database_name text, - key text, - value text, - unique(database_name, key) - ); - - CREATE TABLE IF NOT EXISTS metadata_resources ( - database_name text, - resource_name text, - key text, - value text, - unique(database_name, resource_name, key) - ); - - CREATE TABLE IF NOT EXISTS metadata_columns ( - database_name text, - resource_name text, - column_name text, - key text, - value text, - unique(database_name, resource_name, column_name, key) - ); - - CREATE TABLE IF NOT EXISTS column_types ( - database_name TEXT NOT NULL, - resource_name TEXT NOT NULL, - column_name TEXT NOT NULL, - column_type TEXT NOT NULL, - config TEXT, - PRIMARY KEY (database_name, resource_name, column_name) - ); - - CREATE TABLE IF NOT EXISTS queries ( - database_name TEXT NOT NULL, - name TEXT NOT NULL, - sql TEXT NOT NULL, - title TEXT, - description TEXT, - description_html TEXT, - options TEXT NOT NULL DEFAULT '{}', - parameters TEXT NOT NULL DEFAULT '[]', - is_write INTEGER NOT NULL DEFAULT 0 CHECK (is_write IN (0, 1)), - is_private INTEGER NOT NULL DEFAULT 0 CHECK (is_private IN (0, 1)), - is_trusted INTEGER NOT NULL DEFAULT 0 CHECK (is_trusted IN (0, 1)), - source TEXT NOT NULL DEFAULT 'user', - owner_id TEXT, - created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (database_name, name) - ); - - CREATE INDEX IF NOT EXISTS queries_owner_idx - ON queries(owner_id); - """)) - - -async def populate_schema_tables(internal_db, db): - database_name = db.name - - def delete_everything(conn): - conn.execute( - "DELETE FROM catalog_tables WHERE database_name = ?", [database_name] - ) - conn.execute( - "DELETE FROM catalog_views WHERE database_name = ?", [database_name] - ) - conn.execute( - "DELETE FROM catalog_columns WHERE database_name = ?", [database_name] - ) - conn.execute( - "DELETE FROM catalog_foreign_keys WHERE database_name = ?", - [database_name], - ) - conn.execute( - "DELETE FROM catalog_indexes WHERE database_name = ?", [database_name] - ) - - await internal_db.execute_write_fn(delete_everything) - - tables = (await db.execute("select * from sqlite_master WHERE type = 'table'")).rows - views = (await db.execute("select * from sqlite_master WHERE type = 'view'")).rows - - def collect_info(conn): - tables_to_insert = [] - views_to_insert = [] - columns_to_insert = [] - foreign_keys_to_insert = [] - indexes_to_insert = [] - - for view in views: - view_name = view["name"] - views_to_insert.append( - (database_name, view_name, view["rootpage"], view["sql"]) - ) - - for table in tables: - table_name = table["name"] - tables_to_insert.append( - (database_name, table_name, table["rootpage"], table["sql"]) - ) - columns = table_column_details(conn, table_name) - columns_to_insert.extend( - { - **{"database_name": database_name, "table_name": table_name}, - **column._asdict(), - } - for column in columns - ) - foreign_keys = conn.execute( - f"PRAGMA foreign_key_list([{table_name}])" - ).fetchall() - foreign_keys_to_insert.extend( - { - **{"database_name": database_name, "table_name": table_name}, - **dict(foreign_key), - } - for foreign_key in foreign_keys - ) - indexes = conn.execute(f"PRAGMA index_list([{table_name}])").fetchall() - indexes_to_insert.extend( - { - **{"database_name": database_name, "table_name": table_name}, - **dict(index), - } - for index in indexes - ) - return ( - tables_to_insert, - views_to_insert, - columns_to_insert, - foreign_keys_to_insert, - indexes_to_insert, - ) - - ( - tables_to_insert, - views_to_insert, - columns_to_insert, - foreign_keys_to_insert, - indexes_to_insert, - ) = await db.execute_fn(collect_info) - - await internal_db.execute_write_many( - """ - INSERT INTO catalog_tables (database_name, table_name, rootpage, sql) - values (?, ?, ?, ?) - """, - tables_to_insert, - ) - await internal_db.execute_write_many( - """ - INSERT INTO catalog_views (database_name, view_name, rootpage, sql) - values (?, ?, ?, ?) - """, - views_to_insert, - ) - await internal_db.execute_write_many( - """ - INSERT INTO catalog_columns ( - database_name, table_name, cid, name, type, "notnull", default_value, is_pk, hidden - ) VALUES ( - :database_name, :table_name, :cid, :name, :type, :notnull, :default_value, :is_pk, :hidden - ) - """, - columns_to_insert, - ) - await internal_db.execute_write_many( - """ - INSERT INTO catalog_foreign_keys ( - database_name, table_name, "id", seq, "table", "from", "to", on_update, on_delete, match - ) VALUES ( - :database_name, :table_name, :id, :seq, :table, :from, :to, :on_update, :on_delete, :match - ) - """, - foreign_keys_to_insert, - ) - await internal_db.execute_write_many( - """ - INSERT INTO catalog_indexes ( - database_name, table_name, seq, name, "unique", origin, partial - ) VALUES ( - :database_name, :table_name, :seq, :name, :unique, :origin, :partial - ) - """, - indexes_to_insert, - ) diff --git a/datasette/utils/multipart.py b/datasette/utils/multipart.py deleted file mode 100644 index cfa77486..00000000 --- a/datasette/utils/multipart.py +++ /dev/null @@ -1,757 +0,0 @@ -""" -Streaming multipart/form-data parser for ASGI applications. - -Supports: -- Streaming parsing without buffering entire body in memory -- Files spill to disk above configurable threshold -- Security limits on request size, file size, field count -- Both multipart/form-data and application/x-www-form-urlencoded -""" - -import asyncio -import shutil -import tempfile -from dataclasses import dataclass, field -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Union, -) -from urllib.parse import parse_qsl - -# Centralized defaults for multipart/form-data parsing -DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB -DEFAULT_MAX_REQUEST_SIZE = 100 * 1024 * 1024 # 100MB -DEFAULT_MAX_FIELDS = 1000 -DEFAULT_MAX_FILES = 100 -# If max_parts is not specified, it defaults to max_fields + max_files -DEFAULT_MAX_PARTS: Optional[int] = None -DEFAULT_MAX_FIELD_SIZE = 100 * 1024 # 100KB -DEFAULT_MAX_MEMORY_FILE_SIZE = 1024 * 1024 # 1MB -DEFAULT_MAX_PART_HEADER_BYTES = 16 * 1024 # 16KB -DEFAULT_MAX_PART_HEADER_LINES = 100 -DEFAULT_MIN_FREE_DISK_BYTES = 50 * 1024 * 1024 # 50MB - - -class MultipartParseError(Exception): - """Raised when multipart parsing fails.""" - - pass - - -@dataclass -class UploadedFile: - """ - Represents an uploaded file from a multipart form. - - Attributes: - name: The form field name - filename: The original filename from the upload - content_type: The MIME type of the file - size: Size in bytes - """ - - name: str - filename: str - content_type: Optional[str] - size: int - _file: tempfile.SpooledTemporaryFile = field(repr=False) - - async def read(self, size: int = -1) -> bytes: - """Read file contents.""" - return await asyncio.to_thread(self._file.read, size) - - async def seek(self, offset: int, whence: int = 0) -> int: - """Seek to position in file.""" - return await asyncio.to_thread(self._file.seek, offset, whence) - - async def close(self) -> None: - """Close the underlying file.""" - await asyncio.to_thread(self._file.close) - - def close_sync(self) -> None: - """Close the underlying file synchronously.""" - self._file.close() - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.close() - - def __del__(self): - try: - self._file.close() - except Exception: - pass - - -class FormData: - """ - Container for parsed form data, supporting both fields and files. - - Provides dict-like access with support for multiple values per key. - """ - - def __init__(self): - self._data: List[Tuple[str, Union[str, UploadedFile]]] = [] - - def append(self, key: str, value: Union[str, UploadedFile]) -> None: - """Add a key-value pair.""" - self._data.append((key, value)) - - def __getitem__(self, key: str) -> Union[str, UploadedFile]: - """Get the first value for a key.""" - for k, v in self._data: - if k == key: - return v - raise KeyError(key) - - def get(self, key: str, default: Any = None) -> Optional[Union[str, UploadedFile]]: - """Get the first value for a key, or default if not found.""" - try: - return self[key] - except KeyError: - return default - - def getlist(self, key: str) -> List[Union[str, UploadedFile]]: - """Get all values for a key.""" - return [v for k, v in self._data if k == key] - - def __contains__(self, key: str) -> bool: - """Check if key exists.""" - return any(k == key for k, _ in self._data) - - def __len__(self) -> int: - """Return number of items.""" - return len(self._data) - - def __iter__(self): - """Iterate over unique keys.""" - seen = set() - for k, _ in self._data: - if k not in seen: - seen.add(k) - yield k - - def keys(self): - """Return unique keys.""" - return list(self) - - def items(self) -> List[Tuple[str, Union[str, UploadedFile]]]: - """Return all key-value pairs.""" - return list(self._data) - - def values(self) -> List[Union[str, UploadedFile]]: - """Return all values.""" - return [v for _, v in self._data] - - def _uploaded_files(self) -> List[UploadedFile]: - """Return UploadedFile instances contained in this form.""" - return [v for _, v in self._data if isinstance(v, UploadedFile)] - - def close(self) -> None: - """ - Close any uploaded files. - - This provides deterministic cleanup for spooled temp files. - """ - for uploaded in self._uploaded_files(): - try: - uploaded.close_sync() - except Exception: - # Best-effort cleanup; ignore close errors - pass - - async def aclose(self) -> None: - """Asynchronously close any uploaded files.""" - for uploaded in self._uploaded_files(): - try: - await uploaded.close() - except Exception: - # Best-effort cleanup; ignore close errors - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - self.close() - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.aclose() - - -def parse_content_disposition(header: str) -> Dict[str, Optional[str]]: - """ - Parse Content-Disposition header value. - - Returns dict with 'name', 'filename' keys (filename may be None). - """ - result: Dict[str, Optional[str]] = {"name": None, "filename": None} - - # Split on semicolons, handling quoted strings - parts = [] - current = "" - in_quotes = False - i = 0 - while i < len(header): - char = header[i] - if char == '"' and (i == 0 or header[i - 1] != "\\"): - in_quotes = not in_quotes - current += char - elif char == ";" and not in_quotes: - parts.append(current.strip()) - current = "" - else: - current += char - i += 1 - if current.strip(): - parts.append(current.strip()) - - for part in parts[1:]: # Skip the "form-data" part - if "=" not in part: - continue - - key, _, value = part.partition("=") - key = key.strip().lower() - value = value.strip() - - # Handle filename* (RFC 5987 encoding) - if key == "filename*": - # Format: utf-8''encoded_filename or charset'language'encoded_filename - if "'" in value: - parts_star = value.split("'", 2) - if len(parts_star) >= 3: - # charset = parts_star[0] - # language = parts_star[1] - encoded = parts_star[2] - # URL decode - try: - from urllib.parse import unquote - - result["filename"] = unquote(encoded, encoding="utf-8") - except Exception: - pass - continue - - # Remove quotes if present - if value.startswith('"') and value.endswith('"'): - value = value[1:-1] - # Unescape backslash sequences - value = value.replace('\\"', '"').replace("\\\\", "\\") - - if key == "name": - result["name"] = value - elif key == "filename": - # Only set if filename* hasn't already set it - if result["filename"] is None: - # Strip path components (security) - # Handle both Unix and Windows paths - value = value.replace("\\", "/") - if "/" in value: - value = value.rsplit("/", 1)[-1] - result["filename"] = value - - return result - - -def parse_content_type(header: str) -> Tuple[str, Dict[str, str]]: - """ - Parse Content-Type header value. - - Returns (media_type, parameters_dict). - """ - parts = header.split(";") - media_type = parts[0].strip().lower() - params = {} - - for part in parts[1:]: - part = part.strip() - if "=" in part: - key, _, value = part.partition("=") - key = key.strip().lower() - value = value.strip() - # Remove quotes if present - if value.startswith('"') and value.endswith('"'): - value = value[1:-1] - params[key] = value - - return media_type, params - - -class MultipartParser: - """ - Streaming multipart/form-data parser. - - Processes the body chunk by chunk without loading everything into memory. - """ - - # Parser states - STATE_PREAMBLE = 0 - STATE_HEADER = 1 - STATE_BODY = 2 - STATE_DONE = 3 - - def __init__( - self, - boundary: bytes, - max_file_size: int = DEFAULT_MAX_FILE_SIZE, - max_request_size: int = DEFAULT_MAX_REQUEST_SIZE, - max_fields: int = DEFAULT_MAX_FIELDS, - max_files: int = DEFAULT_MAX_FILES, - max_parts: Optional[int] = DEFAULT_MAX_PARTS, - max_field_size: int = DEFAULT_MAX_FIELD_SIZE, - max_memory_file_size: int = DEFAULT_MAX_MEMORY_FILE_SIZE, - max_part_header_bytes: int = DEFAULT_MAX_PART_HEADER_BYTES, - max_part_header_lines: int = DEFAULT_MAX_PART_HEADER_LINES, - min_free_disk_bytes: int = DEFAULT_MIN_FREE_DISK_BYTES, - handle_files: bool = False, - ): - self.boundary = b"--" + boundary - self.end_boundary = self.boundary + b"--" - self.max_file_size = max_file_size - self.max_request_size = max_request_size - self.max_fields = max_fields - self.max_files = max_files - # If not specified, tie max_parts to the other cardinality limits - if max_parts is None: - max_parts = max_fields + max_files - self.max_parts = max_parts - self.max_field_size = max_field_size - self.max_memory_file_size = max_memory_file_size - self.max_part_header_bytes = max_part_header_bytes - self.max_part_header_lines = max_part_header_lines - self.min_free_disk_bytes = min_free_disk_bytes - self.handle_files = handle_files - - self.state = self.STATE_PREAMBLE - self.buffer = bytearray() - self.total_bytes = 0 - self.field_count = 0 - self.file_count = 0 - self.part_count = 0 - self.current_part_size = 0 - self.current_header_bytes = 0 - self.current_header_lines = 0 - - self.form_data = FormData() - self._disk_check_interval_bytes = 1024 * 1024 # 1MB between disk checks - self._bytes_since_disk_check = 0 - self._tempdir = tempfile.gettempdir() - - # Current part state - self.current_headers: Dict[str, str] = {} - self.current_file: Optional[tempfile.SpooledTemporaryFile] = None - self.current_body = bytearray() - self.current_name: Optional[str] = None - self.current_filename: Optional[str] = None - self.current_content_type: Optional[str] = None - - def feed(self, chunk: bytes) -> None: - """Feed a chunk of data to the parser.""" - self.total_bytes += len(chunk) - if self.total_bytes > self.max_request_size: - raise MultipartParseError("Request body too large") - - self.buffer.extend(chunk) - self._process() - - def _process(self) -> None: - """Process buffered data.""" - while True: - if self.state == self.STATE_PREAMBLE: - if not self._process_preamble(): - break - elif self.state == self.STATE_HEADER: - if not self._process_header(): - break - elif self.state == self.STATE_BODY: - if not self._process_body(): - break - elif self.state == self.STATE_DONE: - break - - def _process_preamble(self) -> bool: - """Skip preamble and find first boundary.""" - # Look for boundary (could be at start or after preamble) - # Try both \r\n prefixed and bare boundary at start - idx = self.buffer.find(self.boundary) - if idx == -1: - # Keep potential partial boundary at end - keep = len(self.boundary) - 1 - if len(self.buffer) > keep: - self.buffer = self.buffer[-keep:] - return False - - # Found boundary, skip to after it - after_boundary = idx + len(self.boundary) - - # Check for end boundary - if self.buffer[idx : idx + len(self.end_boundary)] == self.end_boundary: - self.state = self.STATE_DONE - return False - - # Skip CRLF or LF after boundary - if after_boundary < len(self.buffer): - if self.buffer[after_boundary : after_boundary + 2] == b"\r\n": - after_boundary += 2 - elif self.buffer[after_boundary : after_boundary + 1] == b"\n": - after_boundary += 1 - - self.buffer = self.buffer[after_boundary:] - self.state = self.STATE_HEADER - self.current_headers = {} - self.current_header_bytes = 0 - self.current_header_lines = 0 - return True - - def _process_header(self) -> bool: - """Parse part headers.""" - while True: - # Look for end of header line - crlf_idx = self.buffer.find(b"\r\n") - lf_idx = self.buffer.find(b"\n") - - if crlf_idx == -1 and lf_idx == -1: - # Guard against unbounded header buffering if no newline is ever sent - if len(self.buffer) > self.max_part_header_bytes: - raise MultipartParseError("Part headers too large") - return False # Need more data - - # Use whichever comes first - if crlf_idx != -1 and (lf_idx == -1 or crlf_idx < lf_idx): - idx = crlf_idx - line_end_len = 2 - else: - idx = lf_idx - line_end_len = 1 - - line = self.buffer[:idx] - self.buffer = self.buffer[idx + line_end_len :] - - self.current_header_lines += 1 - self.current_header_bytes += idx + line_end_len - if ( - self.current_header_lines > self.max_part_header_lines - or self.current_header_bytes > self.max_part_header_bytes - ): - raise MultipartParseError("Part headers too large") - - if not line: - # Empty line = end of headers - self._start_body() - self.state = self.STATE_BODY - return True - - # Parse header - try: - line_str = line.decode("utf-8", errors="replace") - except Exception: - line_str = line.decode("latin-1") - - if ":" in line_str: - name, _, value = line_str.partition(":") - self.current_headers[name.strip().lower()] = value.strip() - - def _start_body(self) -> None: - """Initialize body parsing for current part.""" - self.part_count += 1 - if self.part_count > self.max_parts: - raise MultipartParseError("Too many parts") - - # Parse Content-Disposition - cd = self.current_headers.get("content-disposition", "") - parsed = parse_content_disposition(cd) - self.current_name = parsed.get("name") - self.current_filename = parsed.get("filename") - self.current_content_type = self.current_headers.get("content-type") - self.current_part_size = 0 - - if self.current_filename is not None: - # It's a file - self.file_count += 1 - if self.file_count > self.max_files: - raise MultipartParseError("Too many files") - if self.handle_files: - self.current_file = tempfile.SpooledTemporaryFile( - max_size=self.max_memory_file_size - ) - else: - # Will discard file content - self.current_file = None - else: - # It's a text field - self.field_count += 1 - if self.field_count > self.max_fields: - raise MultipartParseError("Too many fields") - self.current_body = bytearray() - self.current_file = None - - # Check disk space before allocating a spooled temp file - if self.current_filename is not None and self.handle_files: - self._ensure_disk_space() - - def _process_body(self) -> bool: - """Process body data for current part.""" - # Look for boundary in buffer - # Need to handle boundary potentially split across chunks - - # The boundary is preceded by \r\n (or \n for lenient parsing) - search_boundary = b"\r\n" + self.boundary - - idx = self.buffer.find(search_boundary) - if idx == -1: - # Try LF-only boundary (lenient) - search_boundary_lf = b"\n" + self.boundary - idx = self.buffer.find(search_boundary_lf) - if idx != -1: - search_boundary = search_boundary_lf - - if idx == -1: - # No boundary found yet - # Keep potential partial boundary at end of buffer - safe_len = len(self.buffer) - len(search_boundary) - 1 - if safe_len > 0: - safe_data = self.buffer[:safe_len] - self._write_body_data(bytes(safe_data)) - self.buffer = self.buffer[safe_len:] - return False - - # Found boundary - write remaining body data - body_data = self.buffer[:idx] - self._write_body_data(bytes(body_data)) - - # Move past the boundary - after_boundary = idx + len(search_boundary) - - # Check for end boundary - remaining = self.buffer[after_boundary:] - if remaining.startswith(b"--"): - # End boundary - self._finish_part() - self.state = self.STATE_DONE - return False - - # Skip CRLF or LF after boundary - if remaining.startswith(b"\r\n"): - after_boundary += 2 - elif remaining.startswith(b"\n"): - after_boundary += 1 - - self.buffer = self.buffer[after_boundary:] - self._finish_part() - self.state = self.STATE_HEADER - self.current_headers = {} - self.current_header_bytes = 0 - self.current_header_lines = 0 - return True - - def _write_body_data(self, data: bytes) -> None: - """Write data to current part body.""" - if not data: - return - - self.current_part_size += len(data) - - if self.current_filename is not None: - # File data - if self.current_part_size > self.max_file_size: - raise MultipartParseError("File too large") - if self.handle_files and self.current_file: - self._bytes_since_disk_check += len(data) - if self._bytes_since_disk_check >= self._disk_check_interval_bytes: - self._ensure_disk_space() - self._bytes_since_disk_check = 0 - self.current_file.write(data) - # else: discard file data - else: - # Field data - if self.current_part_size > self.max_field_size: - raise MultipartParseError("Field value too large") - self.current_body.extend(data) - - def _finish_part(self) -> None: - """Finalize current part and add to form data.""" - if self.current_name is None: - return - - if self.current_filename is not None: - # File - if self.handle_files and self.current_file: - self.current_file.seek(0) - uploaded = UploadedFile( - name=self.current_name, - filename=self.current_filename, - content_type=self.current_content_type, - size=self.current_part_size, - _file=self.current_file, - ) - self.form_data.append(self.current_name, uploaded) - # else: file was discarded - else: - # Text field - try: - value = bytes(self.current_body).decode("utf-8") - except UnicodeDecodeError: - value = bytes(self.current_body).decode("latin-1") - self.form_data.append(self.current_name, value) - - # Reset part state - self.current_file = None - self.current_body = bytearray() - self.current_name = None - self.current_filename = None - self.current_content_type = None - - def finalize(self) -> FormData: - """Finalize parsing and return form data.""" - # Process any remaining data - self._process() - if self.state != self.STATE_DONE: - raise MultipartParseError( - "Truncated multipart body (missing closing boundary)" - ) - return self.form_data - - def _ensure_disk_space(self) -> None: - """ - Ensure there is enough free space on the temp filesystem. - - This is a best-effort guard against filling the disk with uploads. - """ - if not self.handle_files: - return - if self.min_free_disk_bytes <= 0: - return - free_bytes = shutil.disk_usage(self._tempdir).free - if free_bytes < self.min_free_disk_bytes: - raise MultipartParseError("Insufficient disk space for uploads") - - -async def parse_form_data( - receive: Callable, - content_type: str, - files: bool = False, - max_file_size: int = DEFAULT_MAX_FILE_SIZE, - max_request_size: int = DEFAULT_MAX_REQUEST_SIZE, - max_fields: int = DEFAULT_MAX_FIELDS, - max_files: int = DEFAULT_MAX_FILES, - max_parts: Optional[int] = DEFAULT_MAX_PARTS, - max_field_size: int = DEFAULT_MAX_FIELD_SIZE, - max_memory_file_size: int = DEFAULT_MAX_MEMORY_FILE_SIZE, - max_part_header_bytes: int = DEFAULT_MAX_PART_HEADER_BYTES, - max_part_header_lines: int = DEFAULT_MAX_PART_HEADER_LINES, - min_free_disk_bytes: int = DEFAULT_MIN_FREE_DISK_BYTES, -) -> FormData: - """ - Parse form data from an ASGI receive callable. - - Supports both application/x-www-form-urlencoded and multipart/form-data. - - Args: - receive: ASGI receive callable - content_type: Content-Type header value - files: If True, store file uploads; if False, discard them - max_file_size: Maximum size per file in bytes - max_request_size: Maximum total request size in bytes - max_fields: Maximum number of form fields - max_files: Maximum number of file uploads - max_field_size: Maximum size of a text field value - max_memory_file_size: File size threshold before spilling to disk - - Returns: - FormData object containing parsed fields and files - """ - media_type, params = parse_content_type(content_type) - - if media_type == "application/x-www-form-urlencoded": - # Read entire body for URL-encoded forms (they're typically small) - body = bytearray() - total = 0 - while True: - message = await receive() - message_type = message.get("type") - if message_type == "http.disconnect": - raise MultipartParseError("Client disconnected during request body") - if message_type is not None and message_type != "http.request": - continue - chunk = message.get("body", b"") - total += len(chunk) - if total > max_request_size: - raise MultipartParseError("Request body too large") - body.extend(chunk) - if not message.get("more_body", False): - break - - form_data = FormData() - try: - pairs = parse_qsl(bytes(body).decode("utf-8"), keep_blank_values=True) - except UnicodeDecodeError: - pairs = parse_qsl(bytes(body).decode("latin-1"), keep_blank_values=True) - - for key, value in pairs: - form_data.append(key, value) - - return form_data - - elif media_type == "multipart/form-data": - boundary = params.get("boundary") - if not boundary: - raise MultipartParseError("Missing boundary in Content-Type") - - parser = MultipartParser( - boundary=boundary.encode("utf-8"), - max_file_size=max_file_size, - max_request_size=max_request_size, - max_fields=max_fields, - max_files=max_files, - max_parts=max_parts, - max_field_size=max_field_size, - max_memory_file_size=max_memory_file_size, - max_part_header_bytes=max_part_header_bytes, - max_part_header_lines=max_part_header_lines, - min_free_disk_bytes=min_free_disk_bytes, - handle_files=files, - ) - - # Stream body through parser - batch_target = 64 * 1024 - batch = bytearray() - - async def flush_batch() -> None: - if batch: - data = bytes(batch) - batch.clear() - await asyncio.to_thread(parser.feed, data) - - while True: - message = await receive() - message_type = message.get("type") - if message_type == "http.disconnect": - raise MultipartParseError("Client disconnected during request body") - if message_type is not None and message_type != "http.request": - continue - chunk = message.get("body", b"") - if chunk: - batch.extend(chunk) - if len(batch) >= batch_target: - await flush_batch() - if not message.get("more_body", False): - break - - await flush_batch() - return await asyncio.to_thread(parser.finalize) - - else: - raise MultipartParseError( - f"Unsupported Content-Type: {media_type}. " - "Expected application/x-www-form-urlencoded or multipart/form-data" - ) diff --git a/datasette/utils/permissions.py b/datasette/utils/permissions.py deleted file mode 100644 index fd1e41a1..00000000 --- a/datasette/utils/permissions.py +++ /dev/null @@ -1,436 +0,0 @@ -# perm_utils.py -from __future__ import annotations - -import json -from typing import Any, Dict, Iterable, List, Sequence, Tuple -import sqlite3 - -from datasette.permissions import PermissionSQL -from datasette.plugins import pm -from datasette.utils import await_me_maybe - -# Sentinel object to indicate permission checks should be skipped -SKIP_PERMISSION_CHECKS = object() - - -async def gather_permission_sql_from_hooks( - *, datasette, actor: dict | None, action: str -) -> List[PermissionSQL] | object: - """Collect PermissionSQL objects from the permission_resources_sql hook. - - Ensures that each returned PermissionSQL has a populated ``source``. - - Returns SKIP_PERMISSION_CHECKS sentinel if skip_permission_checks context variable - is set, signaling that all permission checks should be bypassed. - """ - from datasette.permissions import _skip_permission_checks - - # Check if we should skip permission checks BEFORE calling hooks - # This avoids creating unawaited coroutines - if _skip_permission_checks.get(): - return SKIP_PERMISSION_CHECKS - - hook_caller = pm.hook.permission_resources_sql - hookimpls = hook_caller.get_hookimpls() - hook_results = list(hook_caller(datasette=datasette, actor=actor, action=action)) - - collected: List[PermissionSQL] = [] - actor_json = json.dumps(actor) if actor is not None else None - actor_id = actor.get("id") if isinstance(actor, dict) else None - - for index, result in enumerate(hook_results): - hookimpl = hookimpls[index] - resolved = await await_me_maybe(result) - default_source = _plugin_name_from_hookimpl(hookimpl) - for permission_sql in _iter_permission_sql_from_result(resolved, action=action): - if not permission_sql.source: - permission_sql.source = default_source - params = permission_sql.params or {} - params.setdefault("action", action) - params.setdefault("actor", actor_json) - params.setdefault("actor_id", actor_id) - collected.append(permission_sql) - - return collected - - -def _plugin_name_from_hookimpl(hookimpl) -> str: - if getattr(hookimpl, "plugin_name", None): - return hookimpl.plugin_name - plugin = getattr(hookimpl, "plugin", None) - if hasattr(plugin, "__name__"): - return plugin.__name__ - return repr(plugin) - - -def _iter_permission_sql_from_result( - result: Any, *, action: str -) -> Iterable[PermissionSQL]: - if result is None: - return [] - if isinstance(result, PermissionSQL): - return [result] - if isinstance(result, (list, tuple)): - collected: List[PermissionSQL] = [] - for item in result: - collected.extend(_iter_permission_sql_from_result(item, action=action)) - return collected - if callable(result): - permission_sql = result(action) # type: ignore[call-arg] - return _iter_permission_sql_from_result(permission_sql, action=action) - raise TypeError( - "Plugin providers must return PermissionSQL instances, sequences, or callables" - ) - - -# ----------------------------- -# Plugin interface & utilities -# ----------------------------- - - -def build_rules_union( - actor: dict | None, plugins: Sequence[PermissionSQL] -) -> Tuple[str, Dict[str, Any]]: - """ - Compose plugin SQL into a UNION ALL. - - Returns: - union_sql: a SELECT with columns (parent, child, allow, reason, source_plugin) - params: dict of bound parameters including :actor (JSON), :actor_id, and plugin params - - Note: Plugins are responsible for ensuring their parameter names don't conflict. - The system reserves these parameter names: :actor, :actor_id, :action, :filter_parent - Plugin parameters should be prefixed with a unique identifier (e.g., source name). - """ - parts: List[str] = [] - actor_json = json.dumps(actor) if actor else None - actor_id = actor.get("id") if actor else None - params: Dict[str, Any] = {"actor": actor_json, "actor_id": actor_id} - - for p in plugins: - # No namespacing - just use plugin params as-is - params.update(p.params or {}) - - # Skip plugins that only provide restriction_sql (no permission rules) - if p.sql is None: - continue - - parts.append(f""" - SELECT parent, child, allow, reason, '{p.source}' AS source_plugin FROM ( - {p.sql} - ) - """.strip()) - - if not parts: - # Empty UNION that returns no rows - union_sql = "SELECT NULL parent, NULL child, NULL allow, NULL reason, 'none' source_plugin WHERE 0" - else: - union_sql = "\nUNION ALL\n".join(parts) - - return union_sql, params - - -# ----------------------------------------------- -# Core resolvers (no temp tables, no custom UDFs) -# ----------------------------------------------- - - -async def resolve_permissions_from_catalog( - db, - actor: dict | None, - plugins: Sequence[Any], - action: str, - candidate_sql: str, - candidate_params: Dict[str, Any] | None = None, - *, - implicit_deny: bool = True, -) -> List[Dict[str, Any]]: - """ - Resolve permissions by embedding the provided *candidate_sql* in a CTE. - - Expectations: - - candidate_sql SELECTs: parent TEXT, child TEXT - (Use child=NULL for parent-scoped actions like "execute-sql".) - - *db* exposes: rows = await db.execute(sql, params) - where rows is an iterable of sqlite3.Row - - plugins: hook results handled by await_me_maybe - can be sync/async, - single PermissionSQL, list, or callable returning PermissionSQL - - actor is the actor dict (or None), made available as :actor (JSON), :actor_id, and :action - - Decision policy: - 1) Specificity first: child (depth=2) > parent (depth=1) > root (depth=0) - 2) Within the same depth: deny (0) beats allow (1) - 3) If no matching rule: - - implicit_deny=True -> treat as allow=0, reason='implicit deny' - - implicit_deny=False -> allow=None, reason=None - - Returns: list of dict rows - - parent, child, allow, reason, source_plugin, depth - - resource (rendered "/parent/child" or "/parent" or "/") - """ - resolved_plugins: List[PermissionSQL] = [] - restriction_sqls: List[str] = [] - - for plugin in plugins: - if callable(plugin) and not isinstance(plugin, PermissionSQL): - resolved = plugin(action) # type: ignore[arg-type] - else: - resolved = plugin # type: ignore[assignment] - if not isinstance(resolved, PermissionSQL): - raise TypeError("Plugin providers must return PermissionSQL instances") - resolved_plugins.append(resolved) - - # Collect restriction SQL filters - if resolved.restriction_sql: - restriction_sqls.append(resolved.restriction_sql) - - union_sql, rule_params = build_rules_union(actor, resolved_plugins) - all_params = { - **(candidate_params or {}), - **rule_params, - "action": action, - } - - sql = f""" - WITH - cands AS ( - {candidate_sql} - ), - rules AS ( - {union_sql} - ), - matched AS ( - SELECT - c.parent, c.child, - r.allow, r.reason, r.source_plugin, - CASE - WHEN r.child IS NOT NULL THEN 2 -- child-level (most specific) - WHEN r.parent IS NOT NULL THEN 1 -- parent-level - ELSE 0 -- root/global - END AS depth - FROM cands c - JOIN rules r - ON (r.parent IS NULL OR r.parent = c.parent) - AND (r.child IS NULL OR r.child = c.child) - ), - ranked AS ( - SELECT *, - ROW_NUMBER() OVER ( - PARTITION BY parent, child - ORDER BY - depth DESC, -- specificity first - CASE WHEN allow=0 THEN 0 ELSE 1 END, -- then deny over allow at same depth - source_plugin -- stable tie-break - ) AS rn - FROM matched - ), - winner AS ( - SELECT parent, child, - allow, reason, source_plugin, depth - FROM ranked WHERE rn = 1 - ) - SELECT - c.parent, c.child, - COALESCE(w.allow, CASE WHEN :implicit_deny THEN 0 ELSE NULL END) AS allow, - COALESCE(w.reason, CASE WHEN :implicit_deny THEN 'implicit deny' ELSE NULL END) AS reason, - w.source_plugin, - COALESCE(w.depth, -1) AS depth, - :action AS action, - CASE - WHEN c.parent IS NULL THEN '/' - WHEN c.child IS NULL THEN '/' || c.parent - ELSE '/' || c.parent || '/' || c.child - END AS resource - FROM cands c - LEFT JOIN winner w - ON ((w.parent = c.parent) OR (w.parent IS NULL AND c.parent IS NULL)) - AND ((w.child = c.child ) OR (w.child IS NULL AND c.child IS NULL)) - ORDER BY c.parent, c.child - """ - - # If there are restriction filters, wrap the query with INTERSECT - # This ensures only resources in the restriction allowlist are returned - if restriction_sqls: - # Start with the main query, but select only parent/child for the INTERSECT - main_query_for_intersect = f""" - WITH - cands AS ( - {candidate_sql} - ), - rules AS ( - {union_sql} - ), - matched AS ( - SELECT - c.parent, c.child, - r.allow, r.reason, r.source_plugin, - CASE - WHEN r.child IS NOT NULL THEN 2 -- child-level (most specific) - WHEN r.parent IS NOT NULL THEN 1 -- parent-level - ELSE 0 -- root/global - END AS depth - FROM cands c - JOIN rules r - ON (r.parent IS NULL OR r.parent = c.parent) - AND (r.child IS NULL OR r.child = c.child) - ), - ranked AS ( - SELECT *, - ROW_NUMBER() OVER ( - PARTITION BY parent, child - ORDER BY - depth DESC, -- specificity first - CASE WHEN allow=0 THEN 0 ELSE 1 END, -- then deny over allow at same depth - source_plugin -- stable tie-break - ) AS rn - FROM matched - ), - winner AS ( - SELECT parent, child, - allow, reason, source_plugin, depth - FROM ranked WHERE rn = 1 - ), - permitted_resources AS ( - SELECT c.parent, c.child - FROM cands c - LEFT JOIN winner w - ON ((w.parent = c.parent) OR (w.parent IS NULL AND c.parent IS NULL)) - AND ((w.child = c.child ) OR (w.child IS NULL AND c.child IS NULL)) - WHERE COALESCE(w.allow, CASE WHEN :implicit_deny THEN 0 ELSE NULL END) = 1 - ) - SELECT parent, child FROM permitted_resources - """ - - # Build restriction list with INTERSECT (all must match) - # Then filter to resources that match hierarchically - # Wrap each restriction_sql in a subquery to avoid operator precedence issues - # with UNION ALL inside the restriction SQL statements - restriction_intersect = "\nINTERSECT\n".join( - f"SELECT * FROM ({sql})" for sql in restriction_sqls - ) - - # Combine: resources allowed by permissions AND in restriction allowlist - # Database-level restrictions (parent, NULL) should match all children (parent, *) - filtered_resources = f""" - WITH restriction_list AS ( - {restriction_intersect} - ), - permitted AS ( - {main_query_for_intersect} - ), - filtered AS ( - SELECT p.parent, p.child - FROM permitted p - WHERE EXISTS ( - SELECT 1 FROM restriction_list r - WHERE (r.parent = p.parent OR r.parent IS NULL) - AND (r.child = p.child OR r.child IS NULL) - ) - ) - """ - - # Now join back to get full results for only the filtered resources - sql = f""" - {filtered_resources} - , cands AS ( - {candidate_sql} - ), - rules AS ( - {union_sql} - ), - matched AS ( - SELECT - c.parent, c.child, - r.allow, r.reason, r.source_plugin, - CASE - WHEN r.child IS NOT NULL THEN 2 -- child-level (most specific) - WHEN r.parent IS NOT NULL THEN 1 -- parent-level - ELSE 0 -- root/global - END AS depth - FROM cands c - JOIN rules r - ON (r.parent IS NULL OR r.parent = c.parent) - AND (r.child IS NULL OR r.child = c.child) - ), - ranked AS ( - SELECT *, - ROW_NUMBER() OVER ( - PARTITION BY parent, child - ORDER BY - depth DESC, -- specificity first - CASE WHEN allow=0 THEN 0 ELSE 1 END, -- then deny over allow at same depth - source_plugin -- stable tie-break - ) AS rn - FROM matched - ), - winner AS ( - SELECT parent, child, - allow, reason, source_plugin, depth - FROM ranked WHERE rn = 1 - ) - SELECT - c.parent, c.child, - COALESCE(w.allow, CASE WHEN :implicit_deny THEN 0 ELSE NULL END) AS allow, - COALESCE(w.reason, CASE WHEN :implicit_deny THEN 'implicit deny' ELSE NULL END) AS reason, - w.source_plugin, - COALESCE(w.depth, -1) AS depth, - :action AS action, - CASE - WHEN c.parent IS NULL THEN '/' - WHEN c.child IS NULL THEN '/' || c.parent - ELSE '/' || c.parent || '/' || c.child - END AS resource - FROM filtered c - LEFT JOIN winner w - ON ((w.parent = c.parent) OR (w.parent IS NULL AND c.parent IS NULL)) - AND ((w.child = c.child ) OR (w.child IS NULL AND c.child IS NULL)) - ORDER BY c.parent, c.child - """ - - rows_iter: Iterable[sqlite3.Row] = await db.execute( - sql, - {**all_params, "implicit_deny": 1 if implicit_deny else 0}, - ) - return [dict(r) for r in rows_iter] - - -async def resolve_permissions_with_candidates( - db, - actor: dict | None, - plugins: Sequence[Any], - candidates: List[Tuple[str, str | None]], - action: str, - *, - implicit_deny: bool = True, -) -> List[Dict[str, Any]]: - """ - Resolve permissions without any external candidate table by embedding - the candidates as a UNION of parameterized SELECTs in a CTE. - - candidates: list of (parent, child) where child can be None for parent-scoped actions. - actor: actor dict (or None), made available as :actor (JSON), :actor_id, and :action - """ - # Build a small CTE for candidates. - cand_rows_sql: List[str] = [] - cand_params: Dict[str, Any] = {} - for i, (parent, child) in enumerate(candidates): - pkey = f"cand_p_{i}" - ckey = f"cand_c_{i}" - cand_params[pkey] = parent - cand_params[ckey] = child - cand_rows_sql.append(f"SELECT :{pkey} AS parent, :{ckey} AS child") - candidate_sql = ( - "\nUNION ALL\n".join(cand_rows_sql) - if cand_rows_sql - else "SELECT NULL AS parent, NULL AS child WHERE 0" - ) - - return await resolve_permissions_from_catalog( - db, - actor, - plugins, - action, - candidate_sql=candidate_sql, - candidate_params=cand_params, - implicit_deny=implicit_deny, - ) diff --git a/datasette/utils/shutil_backport.py b/datasette/utils/shutil_backport.py deleted file mode 100644 index d1fd1bd7..00000000 --- a/datasette/utils/shutil_backport.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Backported from Python 3.8. - -This code is licensed under the Python License: -https://github.com/python/cpython/blob/v3.8.3/LICENSE -""" - -import os -from shutil import copy, copy2, copystat, Error - - -def _copytree( - entries, - src, - dst, - symlinks, - ignore, - copy_function, - ignore_dangling_symlinks, - dirs_exist_ok=False, -): - if ignore is not None: - ignored_names = ignore(src, set(os.listdir(src))) - else: - ignored_names = set() - - os.makedirs(dst, exist_ok=dirs_exist_ok) - errors = [] - use_srcentry = copy_function is copy2 or copy_function is copy - - for srcentry in entries: - if srcentry.name in ignored_names: - continue - srcname = os.path.join(src, srcentry.name) - dstname = os.path.join(dst, srcentry.name) - srcobj = srcentry if use_srcentry else srcname - try: - if srcentry.is_symlink(): - linkto = os.readlink(srcname) - if symlinks: - os.symlink(linkto, dstname) - copystat(srcobj, dstname, follow_symlinks=not symlinks) - else: - if not os.path.exists(linkto) and ignore_dangling_symlinks: - continue - if srcentry.is_dir(): - copytree( - srcobj, - dstname, - symlinks, - ignore, - copy_function, - dirs_exist_ok=dirs_exist_ok, - ) - else: - copy_function(srcobj, dstname) - elif srcentry.is_dir(): - copytree( - srcobj, - dstname, - symlinks, - ignore, - copy_function, - dirs_exist_ok=dirs_exist_ok, - ) - else: - copy_function(srcentry, dstname) - except Error as err: - errors.extend(err.args[0]) - except OSError as why: - errors.append((srcname, dstname, str(why))) - try: - copystat(src, dst) - except OSError as why: - # Copying file access times may fail on Windows - if getattr(why, "winerror", None) is None: - errors.append((src, dst, str(why))) - if errors: - raise Error(errors) - return dst - - -def copytree( - src, - dst, - symlinks=False, - ignore=None, - copy_function=copy2, - ignore_dangling_symlinks=False, - dirs_exist_ok=False, -): - with os.scandir(src) as entries: - return _copytree( - entries=entries, - src=src, - dst=dst, - symlinks=symlinks, - ignore=ignore, - copy_function=copy_function, - ignore_dangling_symlinks=ignore_dangling_symlinks, - dirs_exist_ok=dirs_exist_ok, - ) diff --git a/datasette/utils/sql_analysis.py b/datasette/utils/sql_analysis.py deleted file mode 100644 index b5317b62..00000000 --- a/datasette/utils/sql_analysis.py +++ /dev/null @@ -1,99 +0,0 @@ -from dataclasses import dataclass -from typing import Literal - -from datasette.utils.sqlite import sqlite3 - -SQLTableOperation = Literal["read", "insert", "update", "delete"] - - -@dataclass(frozen=True) -class SQLTableAccess: - operation: SQLTableOperation - database: str | None - table: str - sqlite_schema: str | None - columns: tuple[str, ...] = () - source: str | None = None - - -@dataclass(frozen=True) -class SQLAnalysis: - table_accesses: tuple[SQLTableAccess, ...] - - -_ACTION_TO_OPERATION: dict[int, SQLTableOperation] = { - sqlite3.SQLITE_READ: "read", - sqlite3.SQLITE_INSERT: "insert", - sqlite3.SQLITE_UPDATE: "update", - sqlite3.SQLITE_DELETE: "delete", -} - - -def analyze_sql_tables( - conn, - sql: str, - params=None, - *, - database_name: str | None = None, - schema_to_database: dict[str, str] | None = None, -) -> SQLAnalysis: - """ - Return tables accessed by a SQL statement according to SQLite's authorizer. - - This function is synchronous and connection-based. It temporarily installs a - SQLite authorizer, prepares ``EXPLAIN ``, and returns the table access - callbacks observed while SQLite compiles the statement. - """ - accesses: dict[ - tuple[SQLTableOperation, str | None, str, str | None, str | None], set[str] - ] = {} - - def database_for_schema(sqlite_schema): - if schema_to_database and sqlite_schema in schema_to_database: - return schema_to_database[sqlite_schema] - if sqlite_schema == "main" and database_name is not None: - return database_name - return sqlite_schema - - def authorizer(action, arg1, arg2, sqlite_schema, source): - operation = _ACTION_TO_OPERATION.get(action) - if operation is None or arg1 is None: - return sqlite3.SQLITE_OK - - key = ( - operation, - database_for_schema(sqlite_schema), - arg1, - sqlite_schema, - source, - ) - columns = accesses.setdefault(key, set()) - if operation in ("read", "update") and arg2 is not None: - columns.add(arg2) - return sqlite3.SQLITE_OK - - conn.set_authorizer(authorizer) - try: - conn.execute("EXPLAIN " + sql, params if params is not None else {}).fetchall() - finally: - conn.set_authorizer(None) - - return SQLAnalysis( - table_accesses=tuple( - SQLTableAccess( - operation=operation, - database=database, - table=table, - sqlite_schema=sqlite_schema, - columns=tuple(sorted(columns)), - source=source, - ) - for ( - operation, - database, - table, - sqlite_schema, - source, - ), columns in accesses.items() - ) - ) diff --git a/datasette/utils/sqlite.py b/datasette/utils/sqlite.py deleted file mode 100644 index d0a2d783..00000000 --- a/datasette/utils/sqlite.py +++ /dev/null @@ -1,40 +0,0 @@ -using_pysqlite3 = False -try: - import pysqlite3 as sqlite3 - - using_pysqlite3 = True -except ImportError: - import sqlite3 - -if hasattr(sqlite3, "enable_callback_tracebacks"): - sqlite3.enable_callback_tracebacks(True) - -_cached_sqlite_version = None - - -def sqlite_version(): - global _cached_sqlite_version - if _cached_sqlite_version is None: - _cached_sqlite_version = _sqlite_version() - return _cached_sqlite_version - - -def _sqlite_version(): - conn = sqlite3.connect(":memory:") - try: - return tuple( - map( - int, - conn.execute("select sqlite_version()").fetchone()[0].split("."), - ) - ) - finally: - conn.close() - - -def supports_table_xinfo(): - return sqlite_version() >= (3, 26, 0) - - -def supports_generated_columns(): - return sqlite_version() >= (3, 31, 0) diff --git a/datasette/utils/testing.py b/datasette/utils/testing.py deleted file mode 100644 index de7e94af..00000000 --- a/datasette/utils/testing.py +++ /dev/null @@ -1,174 +0,0 @@ -from asgiref.sync import async_to_sync -from urllib.parse import urlencode -import json - -# These wrapper classes pre-date the introduction of -# datasette.client and httpx to Datasette. They could -# be removed if the Datasette tests are modified to -# call datasette.client directly. - - -class TestResponse: - def __init__(self, httpx_response): - self.httpx_response = httpx_response - - @property - def status(self): - return self.httpx_response.status_code - - # Supports both for test-writing convenience - @property - def status_code(self): - return self.status - - @property - def headers(self): - return self.httpx_response.headers - - @property - def body(self): - return self.httpx_response.content - - @property - def content(self): - return self.body - - @property - def cookies(self): - return dict(self.httpx_response.cookies) - - @property - def json(self): - return json.loads(self.text) - - @property - def text(self): - return self.body.decode("utf8") - - -class TestClient: - max_redirects = 5 - - def __init__(self, ds): - self.ds = ds - - def actor_cookie(self, actor): - return self.ds.sign({"a": actor}, "actor") - - @async_to_sync - async def get( - self, - path, - follow_redirects=False, - redirect_count=0, - method="GET", - params=None, - cookies=None, - if_none_match=None, - headers=None, - ): - if params: - path += "?" + urlencode(params, doseq=True) - return await self._request( - path=path, - follow_redirects=follow_redirects, - redirect_count=redirect_count, - method=method, - cookies=cookies, - if_none_match=if_none_match, - headers=headers, - ) - - @async_to_sync - async def post( - self, - path, - post_data=None, - body=None, - follow_redirects=False, - redirect_count=0, - content_type="application/x-www-form-urlencoded", - cookies=None, - headers=None, - csrftoken_from=None, - ): - cookies = cookies or {} - post_data = post_data or {} - assert not (post_data and body), "Provide one or other of body= or post_data=" - # csrftoken_from is accepted for backward compatibility but is now a no-op. - # Datasette no longer uses CSRF tokens - see CrossOriginProtectionMiddleware. - if post_data: - body = urlencode(post_data, doseq=True) - return await self._request( - path=path, - follow_redirects=follow_redirects, - redirect_count=redirect_count, - method="POST", - cookies=cookies, - headers=headers, - post_body=body, - content_type=content_type, - ) - - @async_to_sync - async def request( - self, - path, - follow_redirects=True, - redirect_count=0, - method="GET", - cookies=None, - headers=None, - post_body=None, - content_type=None, - if_none_match=None, - ): - return await self._request( - path, - follow_redirects=follow_redirects, - redirect_count=redirect_count, - method=method, - cookies=cookies, - headers=headers, - post_body=post_body, - content_type=content_type, - if_none_match=if_none_match, - ) - - async def _request( - self, - path, - follow_redirects=True, - redirect_count=0, - method="GET", - cookies=None, - headers=None, - post_body=None, - content_type=None, - if_none_match=None, - ): - await self.ds.invoke_startup() - headers = headers or {} - if content_type: - headers["content-type"] = content_type - if if_none_match: - headers["if-none-match"] = if_none_match - httpx_response = await self.ds.client.request( - method, - path, - follow_redirects=follow_redirects, - avoid_path_rewrites=True, - cookies=cookies, - headers=headers, - content=post_body, - ) - response = TestResponse(httpx_response) - if follow_redirects and response.status in (301, 302): - assert ( - redirect_count < self.max_redirects - ), f"Redirected {redirect_count} times, max_redirects={self.max_redirects}" - location = response.headers["Location"] - return await self._request( - location, follow_redirects=True, redirect_count=redirect_count + 1 - ) - return response diff --git a/datasette/version.py b/datasette/version.py index 494d2dc0..b4033f08 100644 --- a/datasette/version.py +++ b/datasette/version.py @@ -1,2 +1,6 @@ -__version__ = "1.0a30" +from ._version import get_versions + +__version__ = get_versions()['version'] +del get_versions + __version_info__ = tuple(__version__.split(".")) diff --git a/datasette/views/__init__.py b/datasette/views/__init__.py index 88106737..e69de29b 100644 --- a/datasette/views/__init__.py +++ b/datasette/views/__init__.py @@ -1,2 +0,0 @@ -class Context: - "Base class for all documented contexts" diff --git a/datasette/views/base.py b/datasette/views/base.py index e4c1c738..ae1df581 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -1,322 +1,371 @@ import asyncio import csv -import hashlib -import sys -import textwrap +import json +import re +import sqlite3 import time import urllib -from markupsafe import escape +import pint +from sanic import response +from sanic.exceptions import NotFound +from sanic.views import HTTPMethodView -from datasette.database import QueryInterrupted -from datasette.utils.asgi import Request +from datasette import __version__ from datasette.utils import ( - add_cors_headers, - await_me_maybe, - EscapeHtmlWriter, + CustomJSONEncoder, + InterruptedError, InvalidSql, - LimitedWriter, - call_with_supported_arguments, + Querystring, path_from_row_pks, path_with_added_args, - path_with_removed_args, path_with_format, - sqlite3, -) -from datasette.utils.asgi import ( - AsgiStream, - NotFound, - Response, - BadRequest, + resolve_table_and_format, + to_css_class ) +ureg = pint.UnitRegistry() + +HASH_LENGTH = 7 + class DatasetteError(Exception): - def __init__( - self, - message, - title=None, - error_dict=None, - status=500, - template=None, - message_is_html=False, - ): + + def __init__(self, message, title=None, error_dict=None, status=500, template=None, messagge_is_html=False): self.message = message self.title = title self.error_dict = error_dict or {} self.status = status - self.message_is_html = message_is_html + self.messagge_is_html = messagge_is_html -class View: - async def head(self, request, datasette): - if not hasattr(self, "get"): - return await self.method_not_allowed(request) - response = await self.get(request, datasette) - response.body = "" - return response +class RenderMixin(HTTPMethodView): - async def method_not_allowed(self, request): - if ( - request.path.endswith(".json") - or request.headers.get("content-type") == "application/json" - ): - response = Response.json( - {"ok": False, "error": "Method not allowed"}, status=405 + def render(self, templates, **context): + template = self.jinja_env.select_template(templates) + select_templates = [ + "{}{}".format("*" if template_name == template.name else "", template_name) + for template_name in templates + ] + return response.html( + template.render( + { + **context, + **{ + "app_css_hash": self.ds.app_css_hash(), + "select_templates": select_templates, + "zip": zip, + } + } ) - else: - response = Response.text("Method not allowed", status=405) - return response - - async def options(self, request, datasette): - response = Response.text("ok") - response.headers["allow"] = ", ".join( - method.upper() - for method in ("head", "get", "post", "put", "patch", "delete") - if hasattr(self, method) ) - return response - - async def __call__(self, request, datasette): - try: - handler = getattr(self, request.method.lower()) - except AttributeError: - return await self.method_not_allowed(request) - return await handler(request, datasette) -class BaseView: - ds = None - has_json_alternate = True +class BaseView(RenderMixin): + re_named_parameter = re.compile(":([a-zA-Z0-9_]+)") def __init__(self, datasette): self.ds = datasette + self.files = datasette.files + self.jinja_env = datasette.jinja_env + self.executor = datasette.executor + self.page_size = datasette.page_size + self.max_returned_rows = datasette.max_returned_rows - async def head(self, *args, **kwargs): - response = await self.get(*args, **kwargs) - response.body = b"" - return response - - async def method_not_allowed(self, request): - if ( - request.path.endswith(".json") - or request.headers.get("content-type") == "application/json" - ): - response = Response.json( - {"ok": False, "error": "Method not allowed"}, status=405 - ) - else: - response = Response.text("Method not allowed", status=405) - return response - - async def options(self, request, *args, **kwargs): - return Response.text("ok") - - async def get(self, request, *args, **kwargs): - return await self.method_not_allowed(request) - - async def post(self, request, *args, **kwargs): - return await self.method_not_allowed(request) - - async def put(self, request, *args, **kwargs): - return await self.method_not_allowed(request) - - async def patch(self, request, *args, **kwargs): - return await self.method_not_allowed(request) - - async def delete(self, request, *args, **kwargs): - return await self.method_not_allowed(request) - - async def dispatch_request(self, request): - if self.ds: - await self.ds.refresh_schemas() - handler = getattr(self, request.method.lower(), None) - response = await handler(request) - if self.ds.cors: - add_cors_headers(response.headers) - return response - - async def render(self, templates, request, context=None): - context = context or {} - environment = self.ds.get_jinja_environment(request) - template = environment.select_template(templates) - template_context = { - **context, - **{ - "select_templates": [ - f"{'*' if template_name == template.name else ''}{template_name}" - for template_name in templates - ], - }, - } - headers = {} - if self.has_json_alternate: - alternate_url_json = self.ds.absolute_url( - request, - self.ds.urls.path(path_with_format(request=request, format="json")), - ) - template_context["alternate_url_json"] = alternate_url_json - headers.update( - { - "Link": '<{}>; rel="alternate"; type="application/json+datasette"'.format( - alternate_url_json - ) - } - ) - return Response.html( - await self.ds.render_template( - template, - template_context, - request=request, - view_name=self.name, - ), - headers=headers, + def table_metadata(self, database, table): + "Fetch table-specific metadata." + return self.ds.metadata.get("databases", {}).get(database, {}).get( + "tables", {} + ).get( + table, {} ) - @classmethod - def as_view(cls, *class_args, **class_kwargs): - async def view(request, send): - self = view.view_class(*class_args, **class_kwargs) - return await self.dispatch_request(request) - - view.view_class = cls - view.__doc__ = cls.__doc__ - view.__module__ = cls.__module__ - view.__name__ = cls.__name__ - return view - - -class DataView(BaseView): - name = "" - - def redirect(self, request, path, forward_querystring=True, remove_args=None): - if request.query_string and "?" not in path and forward_querystring: - path = f"{path}?{request.query_string}" - if remove_args: - path = path_with_removed_args(request, remove_args, path=path) - r = Response.redirect(path) - r.headers["Link"] = f"<{path}>; rel=preload" + def options(self, request, *args, **kwargs): + r = response.text("ok") if self.ds.cors: - add_cors_headers(r.headers) + r.headers["Access-Control-Allow-Origin"] = "*" return r - async def data(self, request): - raise NotImplementedError + def redirect(self, qs, path, forward_querystring=True): + if qs.data and "?" not in qs.path and forward_querystring: + path = "{}?{}".format(path, str(qs)) + r = response.redirect(path) + r.headers["Link"] = "<{}>; rel=preload".format(path) + if self.ds.cors: + r.headers["Access-Control-Allow-Origin"] = "*" + return r - async def as_csv(self, request, database): - return await stream_csv(self.ds, self.data, request, database) - - async def get(self, request): - db = await self.ds.resolve_database(request) - database = db.name - database_route = db.route - - _format = request.url_vars["format"] - data_kwargs = {} - - if _format == "csv": - return await self.as_csv(request, database_route) - - if _format is None: - # HTML views default to expanding all foreign key labels - data_kwargs["default_labels"] = True - - extra_template_data = {} - start = time.perf_counter() - status_code = None - templates = [] + def resolve_db_name(self, db_name, **kwargs): + databases = self.ds.inspect() + hash = None + name = None + if "-" in db_name: + # Might be name-and-hash, or might just be + # a name with a hyphen in it + name, hash = db_name.rsplit("-", 1) + if name not in databases: + # Try the whole name + name = db_name + hash = None + else: + name = db_name + # Verify the hash try: - response_or_template_contexts = await self.data(request, **data_kwargs) - if isinstance(response_or_template_contexts, Response): + info = databases[name] + except KeyError: + raise NotFound("Database not found: {}".format(name)) + + expected = info["hash"][:HASH_LENGTH] + if expected != hash: + if "table_and_format" in kwargs: + table, _format = resolve_table_and_format( + table_and_format=urllib.parse.unquote_plus( + kwargs["table_and_format"] + ), + table_exists=lambda t: self.ds.table_exists(name, t) + ) + kwargs["table"] = table + if _format: + kwargs["as_format"] = ".{}".format(_format) + should_redirect = "/{}-{}".format(name, expected) + if "table" in kwargs: + should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"]) + if "pk_path" in kwargs: + should_redirect += "/" + kwargs["pk_path"] + if "as_format" in kwargs: + should_redirect += kwargs["as_format"] + if "as_db" in kwargs: + should_redirect += kwargs["as_db"] + return name, expected, should_redirect + + return name, expected, None + + def get_templates(self, database, table=None): + assert NotImplemented + + async def asgi_get(self, receive, send): + kwargs = self.scope["url_route"]["kwargs"] + db_name = kwargs.pop("db_name") + name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs) + qs = Querystring( + self.scope["path"], self.scope["query_string"].decode("utf-8") + ) + if should_redirect: + response = self.redirect(qs, should_redirect) + else: + response = await self.view_get(qs, name, hash, **kwargs) + # Send response over send() channel + await send({ + 'type': 'http.response.start', + 'status': 200, + 'headers': [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in response.headers.items() + ], + }) + await send({ + 'type': 'http.response.body', + 'body': response.body, + }) + + async def get(self, request, db_name, **kwargs): + name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs) + qs = Querystring(request.path, request.query_string) + if should_redirect: + return self.redirect(qs, should_redirect) + return await self.view_get(qs, name, hash, **kwargs) + + async def as_csv(self, qs, name, hash, **kwargs): + try: + response_or_template_contexts = await self.data( + qs, name, hash, **kwargs + ) + if isinstance(response_or_template_contexts, response.HTTPResponse): return response_or_template_contexts - # If it has four items, it includes an HTTP status code - if len(response_or_template_contexts) == 4: - ( - data, - extra_template_data, - templates, - status_code, - ) = response_or_template_contexts + else: data, extra_template_data, templates = response_or_template_contexts - except QueryInterrupted as ex: - raise DatasetteError( - textwrap.dedent(""" -

SQL query took too long. The time limit is controlled by the - sql_time_limit_ms - configuration option.

- - - """.format(escape(ex.sql))).strip(), - title="SQL Interrupted", - status=400, - message_is_html=True, - ) except (sqlite3.OperationalError, InvalidSql) as e: raise DatasetteError(str(e), title="Invalid SQL", status=400) - except sqlite3.OperationalError as e: + except (sqlite3.OperationalError) as e: + raise DatasetteError(str(e)) + + except DatasetteError: + raise + # Convert rows and columns to CSV + headings = data["columns"] + # if there are columns_expanded we need to add additional headings + columns_expanded = set(data.get("columns_expanded") or []) + if columns_expanded: + headings = [] + for column in data["columns"]: + headings.append(column) + if column in columns_expanded: + headings.append("{}_label".format(column)) + + async def stream_fn(r): + writer = csv.writer(r) + writer.writerow(headings) + for row in data["rows"]: + if not columns_expanded: + # Simple path + writer.writerow(row) + else: + # Look for {"value": "label": } dicts and expand + new_row = [] + for cell in row: + if isinstance(cell, dict): + new_row.append(cell["value"]) + new_row.append(cell["label"]) + else: + new_row.append(cell) + writer.writerow(new_row) + + content_type = "text/plain; charset=utf-8" + headers = {} + if qs.first_or_none("_dl"): + content_type = "text/csv; charset=utf-8" + disposition = 'attachment; filename="{}.csv"'.format( + kwargs.get('table', name) + ) + headers["Content-Disposition"] = disposition + + return response.stream( + stream_fn, + headers=headers, + content_type=content_type + ) + + async def view_get(self, qs, name, hash, **kwargs): + # If ?_format= is provided, use that as the format + _format = qs.first_or_none("_format") + if not _format: + _format = (kwargs.pop("as_format", None) or "").lstrip(".") + if "table_and_format" in kwargs: + table, _ext_format = resolve_table_and_format( + table_and_format=urllib.parse.unquote_plus( + kwargs["table_and_format"] + ), + table_exists=lambda t: self.ds.table_exists(name, t) + ) + _format = _format or _ext_format + kwargs["table"] = table + del kwargs["table_and_format"] + + if _format == "csv": + return await self.as_csv(qs, name, hash, **kwargs) + + if _format is None: + # HTML views default to expanding all forign key labels + kwargs['default_labels'] = True + + extra_template_data = {} + start = time.time() + status_code = 200 + templates = [] + try: + response_or_template_contexts = await self.data( + qs, name, hash, **kwargs + ) + if isinstance(response_or_template_contexts, response.HTTPResponse): + return response_or_template_contexts + + else: + data, extra_template_data, templates = response_or_template_contexts + except InterruptedError as e: + raise DatasetteError(""" + SQL query took too long. The time limit is controlled by the + sql_time_limit_ms + configuration option. + """, title="SQL Interrupted", status=400, messagge_is_html=True) + except (sqlite3.OperationalError, InvalidSql) as e: + raise DatasetteError(str(e), title="Invalid SQL", status=400) + + except (sqlite3.OperationalError) as e: raise DatasetteError(str(e)) except DatasetteError: raise - end = time.perf_counter() + end = time.time() data["query_ms"] = (end - start) * 1000 - - # Special case for .jsono extension - redirect to _shape=objects - if _format == "jsono": - return self.redirect( - request, - path_with_added_args( - request, - {"_shape": "objects"}, - path=request.path.rsplit(".jsono", 1)[0] + ".json", - ), - forward_querystring=False, - ) - - if _format in self.ds.renderers.keys(): - # Dispatch request to the correct output format renderer - # (CSV is not handled here due to streaming) - result = call_with_supported_arguments( - self.ds.renderers[_format][0], - datasette=self.ds, - columns=data.get("columns") or [], - rows=data.get("rows") or [], - sql=data.get("query", {}).get("sql", None), - query_name=data.get("query_name"), - database=database, - table=data.get("table"), - request=request, - view_name=self.name, - truncated=False, # TODO: support this - error=data.get("error"), - # These will be deprecated in Datasette 1.0: - args=request.args, - data=data, - ) - if asyncio.iscoroutine(result): - result = await result - if result is None: - raise NotFound("No data") - if isinstance(result, dict): - r = Response( - body=result.get("body"), - status=result.get("status_code", status_code or 200), - content_type=result.get("content_type", "text/plain"), - headers=result.get("headers"), + for key in ("source", "source_url", "license", "license_url"): + value = self.ds.metadata.get(key) + if value: + data[key] = value + if _format in ("json", "jsono"): + # Special case for .jsono extension - redirect to _shape=objects + if _format == "jsono": + return self.redirect( + qs, + path_with_added_args( + qs, + {"_shape": "objects"}, + path=qs.path.rsplit(".jsono", 1)[0] + ".json", + ), + forward_querystring=False, ) - elif isinstance(result, Response): - r = result - if status_code is not None: - # Over-ride the status code - r.status = status_code + + # Handle the _json= parameter which may modify data["rows"] + json_cols = [] + json_cols = qs.getlist("_json") + if json_cols and "rows" in data and "columns" in data: + data["rows"] = convert_specific_columns_to_json( + data["rows"], data["columns"], json_cols, + ) + + # Deal with the _shape option + shape = qs.first_or_none("_shape") or "arrays" + if shape == "arrayfirst": + data = [row[0] for row in data["rows"]] + elif shape in ("objects", "object", "array"): + columns = data.get("columns") + rows = data.get("rows") + if rows and columns: + data["rows"] = [dict(zip(columns, row)) for row in rows] + if shape == "object": + error = None + if "primary_keys" not in data: + error = "_shape=object is only available on tables" + else: + pks = data["primary_keys"] + if not pks: + error = "_shape=object not available for tables with no primary keys" + else: + object_rows = {} + for row in data["rows"]: + pk_string = path_from_row_pks(row, pks, not pks) + object_rows[pk_string] = row + data = object_rows + if error: + data = { + "ok": False, + "error": error, + "database": name, + "database_hash": hash, + } + elif shape == "array": + data = data["rows"] + elif shape == "arrays": + pass else: - assert False, f"{result} should be dict or Response" + status_code = 400 + data = { + "ok": False, + "error": "Invalid _shape: {}".format(shape), + "status": 400, + "title": None, + } + headers = {} + if self.ds.cors: + headers["Access-Control-Allow-Origin"] = "*" + r = response.HTTPResponse( + json.dumps(data, cls=CustomJSONEncoder), + status=status_code, + content_type="application/json", + headers=headers, + ) else: extras = {} if callable(extra_template_data): @@ -325,246 +374,109 @@ class DataView(BaseView): extras = await extras else: extras = extra_template_data - url_labels_extra = {} - if data.get("expandable_columns"): - url_labels_extra = {"_labels": "on"} - - renderers = {} - for key, (_, can_render) in self.ds.renderers.items(): - it_can_render = call_with_supported_arguments( - can_render, - datasette=self.ds, - columns=data.get("columns") or [], - rows=data.get("rows") or [], - sql=data.get("query", {}).get("sql", None), - query_name=data.get("query_name"), - database=database, - table=data.get("table"), - request=request, - view_name=self.name, - ) - it_can_render = await await_me_maybe(it_can_render) - if it_can_render: - renderers[key] = self.ds.urls.path( - path_with_format( - request=request, format=key, extra_qs={**url_labels_extra} - ) - ) - - url_csv_args = {"_size": "max", **url_labels_extra} - url_csv = self.ds.urls.path( - path_with_format(request=request, format="csv", extra_qs=url_csv_args) - ) - url_csv_path = url_csv.split("?")[0] context = { **data, **extras, **{ - "renderers": renderers, - "url_csv": url_csv, - "url_csv_path": url_csv_path, - "url_csv_hidden_args": [ - (key, value) - for key, value in urllib.parse.parse_qsl(request.query_string) - if key not in ("_labels", "_facet", "_size") - ] - + [("_size", "max")], - "settings": self.ds.settings_dict(), - }, + "url_json": path_with_format(qs, "json"), + "url_csv": path_with_format(qs, "csv", { + "_size": "max" + }), + "url_csv_dl": path_with_format(qs, "csv", { + "_dl": "1", + "_size": "max" + }), + "extra_css_urls": self.ds.extra_css_urls(), + "extra_js_urls": self.ds.extra_js_urls(), + "datasette_version": __version__, + } } if "metadata" not in context: - context["metadata"] = await self.ds.get_instance_metadata() - r = await self.render(templates, request=request, context=context) - if status_code is not None: - r.status = status_code - - ttl = request.args.get("_ttl", None) - if ttl is None or not ttl.isdigit(): - ttl = self.ds.setting("default_cache_ttl") - - return self.set_response_headers(r, ttl) - - def set_response_headers(self, response, ttl): + context["metadata"] = self.ds.metadata + r = self.render(templates, **context) + r.status = status_code # Set far-future cache expiry - if self.ds.cache_headers and response.status == 200: - ttl = int(ttl) - if ttl == 0: - ttl_header = "no-cache" + if self.ds.cache_headers: + ttl = qs.first_or_none("_ttl") + if ttl is None or not ttl.isdigit(): + ttl = self.ds.config["default_cache_ttl"] else: - ttl_header = f"max-age={ttl}" - response.headers["Cache-Control"] = ttl_header - response.headers["Referrer-Policy"] = "no-referrer" - if self.ds.cors: - add_cors_headers(response.headers) - return response + ttl = int(ttl) + if ttl == 0: + ttl_header = 'no-cache' + else: + ttl_header = 'max-age={}'.format(ttl) + r.headers["Cache-Control"] = ttl_header + r.headers["Referrer-Policy"] = "no-referrer" + return r + async def custom_sql( + self, qs, name, hash, sql, editable=True, canned_query=None + ): + params = qs.first_dict() + if "sql" in params: + params.pop("sql") + if "_shape" in params: + params.pop("_shape") + # Extract any :named parameters + named_parameters = self.re_named_parameter.findall(sql) + named_parameter_values = { + named_parameter: params.get(named_parameter) or "" + for named_parameter in named_parameters + } -def _error(messages, status=400): - return Response.json({"ok": False, "errors": messages}, status=status) + # Set to blank string if missing from params + for named_parameter in named_parameters: + if named_parameter not in params: + params[named_parameter] = "" - -async def stream_csv(datasette, fetch_data, request, database): - kwargs = {} - stream = request.args.get("_stream") - # Do not calculate facets or counts: - extra_parameters = [ - "{}=1".format(key) - for key in ("_nofacet", "_nocount") - if not request.args.get(key) - ] - if extra_parameters: - # Replace request object with a new one with modified scope - if not request.query_string: - new_query_string = "&".join(extra_parameters) - else: - new_query_string = request.query_string + "&" + "&".join(extra_parameters) - new_scope = dict(request.scope, query_string=new_query_string.encode("latin-1")) - receive = request.receive - request = Request(new_scope, receive) - if stream: - # Some quick soundness checks - if not datasette.setting("allow_csv_stream"): - raise BadRequest("CSV streaming is disabled") - if request.args.get("_next"): - raise BadRequest("_next not allowed for CSV streaming") - kwargs["_size"] = "max" - # Fetch the first page - try: - response_or_template_contexts = await fetch_data(request) - if isinstance(response_or_template_contexts, Response): - return response_or_template_contexts - elif len(response_or_template_contexts) == 4: - data, _, _, _ = response_or_template_contexts - else: - data, _, _ = response_or_template_contexts - except (sqlite3.OperationalError, InvalidSql) as e: - raise DatasetteError(str(e), title="Invalid SQL", status=400) - - except sqlite3.OperationalError as e: - raise DatasetteError(str(e)) - - except DatasetteError: - raise - - # Convert rows and columns to CSV - headings = data["columns"] - # if there are expanded_columns we need to add additional headings - expanded_columns = set(data.get("expanded_columns") or []) - if expanded_columns: - headings = [] - for column in data["columns"]: - headings.append(column) - if column in expanded_columns: - headings.append(f"{column}_label") - - content_type = "text/plain; charset=utf-8" - preamble = "" - postamble = "" - - trace = request.args.get("_trace") - if trace: - content_type = "text/html; charset=utf-8" - preamble = ( - "CSV debug" - '" + columns = [r[0] for r in results.description] - async def stream_fn(r): - nonlocal data, trace - limited_writer = LimitedWriter(r, datasette.setting("max_csv_mb")) - if trace: - await limited_writer.write(preamble) - writer = csv.writer(EscapeHtmlWriter(limited_writer)) - else: - writer = csv.writer(limited_writer) - first = True - next = None - while first or (next and stream): - try: - kwargs = {} - if next: - kwargs["_next"] = next - if not first: - data, _, _ = await fetch_data(request, **kwargs) - if first: - if request.args.get("_header") != "off": - await writer.writerow(headings) - first = False - next = data.get("next") - for row in data["rows"]: - if any(isinstance(r, bytes) for r in row): - new_row = [] - for column, cell in zip(headings, row): - if isinstance(cell, bytes): - # If this is a table page, use .urls.row_blob() - if data.get("table"): - pks = data.get("primary_keys") or [] - cell = datasette.absolute_url( - request, - datasette.urls.row_blob( - database, - data["table"], - path_from_row_pks(row, pks, not pks), - column, - ), - ) - else: - # Otherwise generate URL for this query - url = datasette.absolute_url( - request, - path_with_format( - request=request, - format="blob", - extra_qs={ - "_blob_column": column, - "_blob_hash": hashlib.sha256( - cell - ).hexdigest(), - }, - replace_format="csv", - ), - ) - cell = url.replace("&_nocount=1", "").replace( - "&_nofacet=1", "" - ) - new_row.append(cell) - row = new_row - if not expanded_columns: - # Simple path - await writer.writerow(row) - else: - # Look for {"value": "label": } dicts and expand - new_row = [] - for heading, cell in zip(data["columns"], row): - if heading in expanded_columns: - if cell is None: - new_row.extend(("", "")) - else: - if not isinstance(cell, dict): - new_row.extend((cell, "")) - else: - new_row.append(cell["value"]) - new_row.append(cell["label"]) - else: - new_row.append(cell) - await writer.writerow(new_row) - except Exception as ex: - sys.stderr.write("Caught this error: {}\n".format(ex)) - sys.stderr.flush() - await r.write(str(ex)) - return - await limited_writer.write(postamble) + templates = ["query-{}.html".format(to_css_class(name)), "query.html"] + if canned_query: + templates.insert( + 0, + "query-{}-{}.html".format( + to_css_class(name), to_css_class(canned_query) + ), + ) - headers = {} - if datasette.cors: - add_cors_headers(headers) - if request.args.get("_dl", None): - if not trace: - content_type = "text/csv; charset=utf-8" - disposition = 'attachment; filename="{}.csv"'.format( - request.url_vars.get("table", database) - ) - headers["content-disposition"] = disposition + return { + "database": name, + "rows": results.rows, + "truncated": results.truncated, + "columns": columns, + "query": {"sql": sql, "params": params}, + }, { + "database_hash": hash, + "custom_sql": True, + "named_parameter_values": named_parameter_values, + "editable": editable, + "canned_query": canned_query, + "config": self.ds.config, + }, templates - return AsgiStream(stream_fn, headers=headers, content_type=content_type) + +def convert_specific_columns_to_json(rows, columns, json_cols): + json_cols = set(json_cols) + if not json_cols.intersection(columns): + return rows + new_rows = [] + for row in rows: + new_row = [] + for value, column in zip(row, columns): + if column in json_cols: + try: + value = json.loads(value) + except (TypeError, ValueError) as e: + print(e) + pass + new_row.append(value) + new_rows.append(new_row) + return new_rows diff --git a/datasette/views/database.py b/datasette/views/database.py index b558b002..b52b20d6 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -1,1290 +1,55 @@ -from dataclasses import dataclass, field -from urllib.parse import parse_qsl, urlencode -import asyncio -import hashlib -import itertools -import json -import markupsafe import os -import re -import sqlite_utils -import textwrap -from datasette.events import AlterTableEvent, CreateTableEvent, InsertRowsEvent -from datasette.database import QueryInterrupted -from datasette.resources import DatabaseResource, QueryResource -from datasette.stored_queries import stored_query_to_dict -from datasette.utils import ( - add_cors_headers, - await_me_maybe, - call_with_supported_arguments, - named_parameters as derive_named_parameters, - format_bytes, - make_slot_function, - tilde_decode, - to_css_class, - validate_sql_select, - is_url, - path_with_added_args, - path_with_format, - path_with_removed_args, - sqlite3, - truncate_url, - InvalidSql, -) -from datasette.utils.asgi import AsgiFileDownload, NotFound, Response, Forbidden -from datasette.plugins import pm +from sanic import response -from .base import BaseView, DatasetteError, View, _error, stream_csv -from .query_helpers import _ensure_stored_query_execution_permissions, _table_columns -from . import Context +from datasette.utils import to_css_class, validate_sql_select + +from .base import BaseView, DatasetteError -class DatabaseView(View): - async def get(self, request, datasette): - format_ = request.url_vars.get("format") or "html" +class DatabaseView(BaseView): - await datasette.refresh_schemas() + async def data(self, qs, name, hash, default_labels=False): + if qs.first_or_none("sql"): + if not self.ds.config["allow_sql"]: + raise DatasetteError("sql= is not allowed", status=400) + sql = qs.first("sql") + validate_sql_select(sql) + return await self.custom_sql(qs, name, hash, sql) - db = await datasette.resolve_database(request) - database = db.name - - visible, private = await datasette.check_visibility( - request.actor, - action="view-database", - resource=DatabaseResource(database=database), - ) - if not visible: - raise Forbidden("You do not have permission to view this database") - - sql = (request.args.get("sql") or "").strip() - if sql: - redirect_url = "/" + request.url_vars.get("database") + "/-/query" - if request.url_vars.get("format"): - redirect_url += "." + request.url_vars.get("format") - redirect_url += "?" + request.query_string - response = Response.redirect(redirect_url) - if datasette.cors: - add_cors_headers(response.headers) - return response - - if format_ not in ("html", "json"): - raise NotFound("Invalid format: {}".format(format_)) - - metadata = await datasette.get_database_metadata(database) - - # Get all tables/views this actor can see in bulk with private flag - allowed_tables_page = await datasette.allowed_resources( - "view-table", - request.actor, - parent=database, - include_is_private=True, - limit=1000, - ) - # Create lookup dict for quick access - allowed_dict = {r.child: r for r in allowed_tables_page.resources} - - # Filter to just views - view_names_set = set(await db.view_names()) - sql_views = [ - {"name": name, "private": allowed_dict[name].private} - for name in allowed_dict - if name in view_names_set - ] - - tables = await get_tables(datasette, request, db, allowed_dict) - - queries_page = await datasette.list_queries( - database, - actor=request.actor, - limit=5, - include_private=True, - ) - stored_queries = queries_page.queries - queries_more = queries_page.has_more - queries_count = ( - await datasette.count_queries(database, actor=request.actor) - if queries_more - else len(stored_queries) - ) - - async def database_actions(): - links = [] - for hook in pm.hook.database_actions( - datasette=datasette, - database=database, - actor=request.actor, - request=request, - ): - extra_links = await await_me_maybe(hook) - if extra_links: - links.extend(extra_links) - return links - - attached_databases = [d.name for d in await db.attached_databases()] - - allow_execute_sql = await datasette.allowed( - action="execute-sql", - resource=DatabaseResource(database=database), - actor=request.actor, - ) - json_data = { - "ok": True, - "database": database, - "private": private, - "path": datasette.urls.database(database), - "size": db.size, + info = self.ds.inspect()[name] + metadata = self.ds.metadata.get("databases", {}).get(name, {}) + self.ds.update_with_inherited_metadata(metadata) + tables = list(info["tables"].values()) + tables.sort(key=lambda t: (t["hidden"], t["name"])) + return { + "database": name, "tables": tables, "hidden_count": len([t for t in tables if t["hidden"]]), - "views": sql_views, - "queries": [stored_query_to_dict(query) for query in stored_queries], - "queries_more": queries_more, - "queries_count": queries_count, - "allow_execute_sql": allow_execute_sql, - "table_columns": ( - await _table_columns(datasette, database) if allow_execute_sql else {} - ), - "metadata": await datasette.get_database_metadata(database), - } - - if format_ == "json": - response = Response.json(json_data) - if datasette.cors: - add_cors_headers(response.headers) - return response - - assert format_ == "html" - alternate_url_json = datasette.absolute_url( - request, - datasette.urls.path(path_with_format(request=request, format="json")), - ) - templates = (f"database-{to_css_class(database)}.html", "database.html") - environment = datasette.get_jinja_environment(request) - template = environment.select_template(templates) - return Response.html( - await datasette.render_template( - templates, - DatabaseContext( - database=database, - private=private, - path=datasette.urls.database(database), - size=db.size, - tables=tables, - hidden_count=len([t for t in tables if t["hidden"]]), - views=sql_views, - queries=stored_queries, - queries_more=queries_more, - queries_count=queries_count, - allow_execute_sql=allow_execute_sql, - table_columns=( - await _table_columns(datasette, database) - if allow_execute_sql - else {} - ), - metadata=metadata, - database_color=db.color, - database_actions=database_actions, - show_hidden=request.args.get("_show_hidden"), - editable=True, - count_limit=db.count_limit, - allow_download=datasette.setting("allow_download") - and not db.is_mutable - and not db.is_memory, - attached_databases=attached_databases, - alternate_url_json=alternate_url_json, - select_templates=[ - f"{'*' if template_name == template.name else ''}{template_name}" - for template_name in templates - ], - top_database=make_slot_function( - "top_database", datasette, request, database=database - ), - ), - request=request, - view_name="database", - ), - headers={ - "Link": '<{}>; rel="alternate"; type="application/json+datasette"'.format( - alternate_url_json - ) - }, + "views": info["views"], + "queries": [ + {"name": query_name, "sql": query_sql} + for query_name, query_sql in (metadata.get("queries") or {}).items() + ], + "config": self.ds.config, + }, { + "database_hash": hash, + "show_hidden": qs.first_or_none("_show_hidden"), + "editable": True, + "metadata": metadata, + }, ( + "database-{}.html".format(to_css_class(name)), "database.html" ) -@dataclass -class DatabaseContext(Context): - database: str = field(metadata={"help": "The name of the database"}) - private: bool = field( - metadata={"help": "Boolean indicating if this is a private database"} - ) - path: str = field(metadata={"help": "The URL path to this database"}) - size: int = field(metadata={"help": "The size of the database in bytes"}) - tables: list = field(metadata={"help": "List of table objects in the database"}) - hidden_count: int = field(metadata={"help": "Count of hidden tables"}) - views: list = field(metadata={"help": "List of view objects in the database"}) - queries: list = field(metadata={"help": "List of stored query objects"}) - queries_more: bool = field( - metadata={"help": "Boolean indicating if more stored queries are available"} - ) - queries_count: int = field(metadata={"help": "Count of visible stored queries"}) - allow_execute_sql: bool = field( - metadata={"help": "Boolean indicating if custom SQL can be executed"} - ) - table_columns: dict = field( - metadata={"help": "Dictionary mapping table names to their column lists"} - ) - metadata: dict = field(metadata={"help": "Metadata for the database"}) - database_color: str = field(metadata={"help": "The color assigned to the database"}) - database_actions: callable = field( - metadata={ - "help": "Callable returning list of action links for the database menu" - } - ) - show_hidden: str = field(metadata={"help": "Value of _show_hidden query parameter"}) - editable: bool = field( - metadata={"help": "Boolean indicating if the database is editable"} - ) - count_limit: int = field(metadata={"help": "The maximum number of rows to count"}) - allow_download: bool = field( - metadata={"help": "Boolean indicating if database download is allowed"} - ) - attached_databases: list = field( - metadata={"help": "List of names of attached databases"} - ) - alternate_url_json: str = field( - metadata={"help": "URL for the alternate JSON version of this page"} - ) - select_templates: list = field( - metadata={ - "help": "List of templates that were considered for rendering this page" - } - ) - top_database: callable = field( - metadata={"help": "Callable to render the top_database slot"} - ) +class DatabaseDownload(BaseView): - -@dataclass -class QueryContext(Context): - database: str = field(metadata={"help": "The name of the database being queried"}) - database_color: str = field(metadata={"help": "The color of the database"}) - query: dict = field( - metadata={"help": "The SQL query object containing the `sql` string"} - ) - stored_query: str = field( - metadata={"help": "The name of the stored query if this is a stored query"} - ) - private: bool = field( - metadata={"help": "Boolean indicating if this is a private database"} - ) - # urls: dict = field( - # metadata={"help": "Object containing URL helpers like `database()`"} - # ) - stored_query_write: bool = field( - metadata={ - "help": "Boolean indicating if this is a stored query that allows writes" - } - ) - metadata: dict = field( - metadata={"help": "Metadata about the database or the stored query"} - ) - db_is_immutable: bool = field( - metadata={"help": "Boolean indicating if this database is immutable"} - ) - error: str = field(metadata={"help": "Any query error message"}) - hide_sql: bool = field( - metadata={"help": "Boolean indicating if the SQL should be hidden"} - ) - show_hide_link: str = field( - metadata={"help": "The URL to toggle showing/hiding the SQL"} - ) - show_hide_text: str = field( - metadata={"help": "The text for the show/hide SQL link"} - ) - editable: bool = field( - metadata={"help": "Boolean indicating if the SQL can be edited"} - ) - allow_execute_sql: bool = field( - metadata={"help": "Boolean indicating if custom SQL can be executed"} - ) - save_query_url: str = field( - metadata={"help": "URL to save the current arbitrary SQL as a query"} - ) - tables: list = field(metadata={"help": "List of table objects in the database"}) - named_parameter_values: dict = field( - metadata={"help": "Dictionary of parameter names/values"} - ) - edit_sql_url: str = field( - metadata={"help": "URL to edit the SQL for a stored query"} - ) - display_rows: list = field(metadata={"help": "List of result rows to display"}) - columns: list = field(metadata={"help": "List of column names"}) - renderers: dict = field(metadata={"help": "Dictionary of renderer name to URL"}) - url_csv: str = field(metadata={"help": "URL for CSV export"}) - show_hide_hidden: str = field( - metadata={"help": "Hidden input field for the _show_sql parameter"} - ) - table_columns: dict = field( - metadata={"help": "Dictionary of table name to list of column names"} - ) - alternate_url_json: str = field( - metadata={"help": "URL for alternate JSON version of this page"} - ) - # TODO: refactor this to somewhere else, probably ds.render_template() - select_templates: list = field( - metadata={ - "help": "List of templates that were considered for rendering this page" - } - ) - top_query: callable = field( - metadata={"help": "Callable to render the top_query slot"} - ) - top_stored_query: callable = field( - metadata={"help": "Callable to render the top_stored_query slot"} - ) - query_actions: callable = field( - metadata={ - "help": "Callable returning a list of links for the query action menu" - } - ) - - -async def get_tables(datasette, request, db, allowed_dict): - """ - Get list of tables with metadata for the database view. - - Args: - datasette: The Datasette instance - request: The current request - db: The database - allowed_dict: Dict mapping table name -> Resource object with .private attribute - """ - tables = [] - table_counts = await db.table_counts(100) - hidden_table_names = set(await db.hidden_table_names()) - all_foreign_keys = await db.get_all_foreign_keys() - - for table in table_counts: - if table not in allowed_dict: - continue - - table_columns = await db.table_columns(table) - tables.append( - { - "name": table, - "columns": table_columns, - "primary_keys": await db.primary_keys(table), - "count": table_counts[table], - "hidden": table in hidden_table_names, - "fts_table": await db.fts_table(table), - "foreign_keys": all_foreign_keys[table], - "private": allowed_dict[table].private, - } + async def view_get(self, qs, name, hash, **kwargs): + if not self.ds.config["allow_download"]: + raise DatasetteError("Database download is forbidden", status=403) + filepath = self.ds.inspect()[name]["file"] + return await response.file_stream( + filepath, + filename=os.path.basename(filepath), + mime_type="application/octet-stream", ) - tables.sort(key=lambda t: (t["hidden"], t["name"])) - return tables - - -async def database_download(request, datasette): - from datasette.resources import DatabaseResource - - database = tilde_decode(request.url_vars["database"]) - await datasette.ensure_permission( - action="view-database-download", - resource=DatabaseResource(database=database), - actor=request.actor, - ) - try: - db = datasette.get_database(route=database) - except KeyError: - raise DatasetteError("Invalid database", status=404) - - if db.is_memory: - raise DatasetteError("Cannot download in-memory databases", status=404) - if not datasette.setting("allow_download") or db.is_mutable: - raise Forbidden("Database download is forbidden") - if not db.path: - raise DatasetteError("Cannot download database", status=404) - filepath = db.path - headers = {} - if datasette.cors: - add_cors_headers(headers) - if db.hash: - etag = '"{}"'.format(db.hash) - headers["Etag"] = etag - # Has user seen this already? - if_none_match = request.headers.get("if-none-match") - if if_none_match and if_none_match == etag: - return Response("", status=304) - headers["Transfer-Encoding"] = "chunked" - return AsgiFileDownload( - filepath, - filename=os.path.basename(filepath), - content_type="application/octet-stream", - headers=headers, - ) - - -class QueryView(View): - async def post(self, request, datasette): - from datasette.app import TableNotFound - - db = await datasette.resolve_database(request) - - # We must be a stored query - table_found = False - try: - await datasette.resolve_table(request) - table_found = True - except TableNotFound as table_not_found: - stored_query = await datasette.get_query( - table_not_found.database_name, table_not_found.table - ) - if stored_query is None: - raise - if table_found: - # That should not have happened - raise DatasetteError("Unexpected table found on POST", status=404) - - if not await datasette.allowed( - action="view-query", - resource=QueryResource(database=db.name, query=stored_query.name), - actor=request.actor, - ): - raise Forbidden("You do not have permission to view this query") - - await _ensure_stored_query_execution_permissions( - datasette, db, stored_query, request.actor - ) - - # If database is immutable, return an error - if not db.is_mutable: - raise Forbidden("Database is immutable") - - # Process the POST - body = await request.post_body() - body = body.decode("utf-8").strip() - if body.startswith("{") and body.endswith("}"): - params = json.loads(body) - # But we want key=value strings - for key, value in params.items(): - params[key] = str(value) - else: - params = dict(parse_qsl(body, keep_blank_values=True)) - - # Don't ever send csrftoken as a SQL parameter - params.pop("csrftoken", None) - - # Should we return JSON? - should_return_json = ( - request.headers.get("accept") == "application/json" - or request.args.get("_json") - or params.get("_json") - ) - params_for_query = MagicParameters(stored_query.sql, params, request, datasette) - await params_for_query.execute_params() - ok = None - redirect_url = None - try: - cursor = await db.execute_write( - stored_query.sql, params_for_query, request=request - ) - # success message can come from on_success_message or on_success_message_sql - message = None - message_type = datasette.INFO - on_success_message_sql = stored_query.on_success_message_sql - if on_success_message_sql: - try: - message_result = ( - await db.execute(on_success_message_sql, params_for_query) - ).first() - if message_result: - message = message_result[0] - except Exception as ex: - message = "Error running on_success_message_sql: {}".format(ex) - message_type = datasette.ERROR - if not message: - message = ( - stored_query.on_success_message - or "Query executed, {} row{} affected".format( - cursor.rowcount, "" if cursor.rowcount == 1 else "s" - ) - ) - - redirect_url = stored_query.on_success_redirect - ok = True - except Exception as ex: - message = stored_query.on_error_message or str(ex) - message_type = datasette.ERROR - redirect_url = stored_query.on_error_redirect - ok = False - if should_return_json: - return Response.json( - { - "ok": ok, - "message": message, - "redirect": redirect_url, - } - ) - else: - datasette.add_message(request, message, message_type) - return Response.redirect(redirect_url or request.path) - - async def get(self, request, datasette): - from datasette.app import TableNotFound - - await datasette.refresh_schemas() - - db = await datasette.resolve_database(request) - database = db.name - - # Get all tables/views this actor can see in bulk with private flag - allowed_tables_page = await datasette.allowed_resources( - "view-table", - request.actor, - parent=database, - include_is_private=True, - limit=1000, - ) - # Create lookup dict for quick access - allowed_dict = {r.child: r for r in allowed_tables_page.resources} - - # Are we a stored query? - stored_query = None - stored_query_write = False - if "table" in request.url_vars: - try: - await datasette.resolve_table(request) - except TableNotFound as table_not_found: - # Was this actually a stored query? - stored_query = await datasette.get_query( - table_not_found.database_name, table_not_found.table - ) - if stored_query is None: - raise - stored_query_write = stored_query.is_write - - private = False - if stored_query: - # Respect stored query permissions - visible, private = await datasette.check_visibility( - request.actor, - action="view-query", - resource=QueryResource(database=database, query=stored_query.name), - ) - if not visible: - raise Forbidden("You do not have permission to view this query") - if not stored_query_write: - await _ensure_stored_query_execution_permissions( - datasette, db, stored_query, request.actor - ) - - else: - await datasette.ensure_permission( - action="execute-sql", - resource=DatabaseResource(database=database), - actor=request.actor, - ) - - # Flattened because of ?sql=&name1=value1&name2=value2 feature - params = {key: request.args.get(key) for key in request.args} - sql = None - - if stored_query: - sql = stored_query.sql - elif "sql" in params: - sql = params.pop("sql") - - # Extract any :named parameters - named_parameters = [] - if stored_query and stored_query.parameters: - named_parameters = stored_query.parameters - if not named_parameters and sql: - named_parameters = derive_named_parameters(sql) - named_parameter_values = { - named_parameter: params.get(named_parameter) or "" - for named_parameter in named_parameters - if not named_parameter.startswith("_") - } - # Set to blank string if missing from params - for named_parameter in named_parameters: - if named_parameter not in params and not named_parameter.startswith("_"): - params[named_parameter] = "" - - extra_args = {} - if params.get("_timelimit"): - extra_args["custom_time_limit"] = int(params["_timelimit"]) - - format_ = request.url_vars.get("format") or "html" - - query_error = None - results = None - rows = [] - columns = [] - - params_for_query = params - - if sql and not stored_query_write: - try: - if not stored_query: - # For regular queries we only allow SELECT, plus other rules - validate_sql_select(sql) - else: - # Stored queries can run magic parameters - params_for_query = MagicParameters(sql, params, request, datasette) - await params_for_query.execute_params() - results = await datasette.execute( - database, sql, params_for_query, truncate=True, **extra_args - ) - columns = results.columns - rows = results.rows - except QueryInterrupted as ex: - raise DatasetteError( - textwrap.dedent(""" -

SQL query took too long. The time limit is controlled by the - sql_time_limit_ms - configuration option.

- - - """.format(markupsafe.escape(ex.sql))).strip(), - title="SQL Interrupted", - status=400, - message_is_html=True, - ) - except sqlite3.DatabaseError as ex: - query_error = str(ex) - results = None - rows = [] - columns = [] - except (sqlite3.OperationalError, InvalidSql) as ex: - raise DatasetteError(str(ex), title="Invalid SQL", status=400) - except sqlite3.OperationalError as ex: - raise DatasetteError(str(ex)) - except DatasetteError: - raise - - # Handle formats from plugins - if format_ == "csv": - if not sql: - raise DatasetteError("?sql= is required", status=400) - - async def fetch_data_for_csv(request, _next=None): - results = await db.execute(sql, params, truncate=True) - data = {"rows": results.rows, "columns": results.columns} - return data, None, None - - return await stream_csv(datasette, fetch_data_for_csv, request, db.name) - elif format_ in datasette.renderers.keys(): - # Dispatch request to the correct output format renderer - # (CSV is not handled here due to streaming) - result = call_with_supported_arguments( - datasette.renderers[format_][0], - datasette=datasette, - columns=columns, - rows=rows, - sql=sql, - query_name=stored_query.name if stored_query else None, - database=database, - table=None, - request=request, - view_name="table", - truncated=results.truncated if results else False, - error=query_error, - # These will be deprecated in Datasette 1.0: - args=request.args, - data={"ok": True, "rows": rows, "columns": columns}, - ) - if asyncio.iscoroutine(result): - result = await result - if result is None: - raise NotFound("No data") - if isinstance(result, dict): - r = Response( - body=result.get("body"), - status=result.get("status_code") or 200, - content_type=result.get("content_type", "text/plain"), - headers=result.get("headers"), - ) - elif isinstance(result, Response): - r = result - # if status_code is not None: - # # Over-ride the status code - # r.status = status_code - else: - assert False, f"{result} should be dict or Response" - elif format_ == "html": - headers = {} - templates = [f"query-{to_css_class(database)}.html", "query.html"] - if stored_query: - templates.insert( - 0, - f"query-{to_css_class(database)}-{to_css_class(stored_query.name)}.html", - ) - - environment = datasette.get_jinja_environment(request) - template = environment.select_template(templates) - alternate_url_json = datasette.absolute_url( - request, - datasette.urls.path(path_with_format(request=request, format="json")), - ) - data = {} - headers.update( - { - "Link": '<{}>; rel="alternate"; type="application/json+datasette"'.format( - alternate_url_json - ) - } - ) - metadata = await datasette.get_database_metadata(database) - if stored_query: - metadata = stored_query_to_dict(stored_query) - metadata.pop("source", None) - - renderers = {} - for key, (_, can_render) in datasette.renderers.items(): - it_can_render = call_with_supported_arguments( - can_render, - datasette=datasette, - columns=data.get("columns") or [], - rows=data.get("rows") or [], - sql=data.get("query", {}).get("sql", None), - query_name=data.get("query_name"), - database=database, - table=data.get("table"), - request=request, - view_name="database", - ) - it_can_render = await await_me_maybe(it_can_render) - if it_can_render: - renderers[key] = datasette.urls.path( - path_with_format(request=request, format=key) - ) - - allow_execute_sql = await datasette.allowed( - action="execute-sql", - resource=DatabaseResource(database=database), - actor=request.actor, - ) - allow_store_query = await datasette.allowed( - action="store-query", - resource=DatabaseResource(database=database), - actor=request.actor, - ) - - show_hide_hidden = "" - if stored_query and stored_query.hide_sql: - if bool(params.get("_show_sql")): - show_hide_link = path_with_removed_args(request, {"_show_sql"}) - show_hide_text = "hide" - show_hide_hidden = ( - '' - ) - else: - show_hide_link = path_with_added_args(request, {"_show_sql": 1}) - show_hide_text = "show" - else: - if bool(params.get("_hide_sql")): - show_hide_link = path_with_removed_args(request, {"_hide_sql"}) - show_hide_text = "show" - show_hide_hidden = ( - '' - ) - else: - show_hide_link = path_with_added_args(request, {"_hide_sql": 1}) - show_hide_text = "hide" - hide_sql = show_hide_text == "show" - - # Show 'Edit SQL' button only if: - # - User is allowed to execute SQL - # - SQL is an approved SELECT statement - # - No magic parameters, so no :_ in the SQL string - edit_sql_url = None - is_validated_sql = False - if sql: - try: - validate_sql_select(sql) - is_validated_sql = True - except InvalidSql: - pass - if allow_execute_sql and is_validated_sql and ":_" not in sql: - edit_sql_url = ( - datasette.urls.database(database) - + "/-/query" - + "?" - + urlencode( - { - **{ - "sql": sql, - }, - **named_parameter_values, - } - ) - ) - save_query_url = None - if ( - not stored_query - and allow_execute_sql - and allow_store_query - and is_validated_sql - and ":_" not in sql - ): - save_query_url = ( - datasette.urls.database(database) - + "/-/queries/store?" - + urlencode({"sql": sql}) - ) - - async def query_actions(): - query_actions = [] - for hook in pm.hook.query_actions( - datasette=datasette, - actor=request.actor, - database=database, - query_name=stored_query.name if stored_query else None, - request=request, - sql=sql, - params=params, - ): - extra_links = await await_me_maybe(hook) - if extra_links: - query_actions.extend(extra_links) - return query_actions - - r = Response.html( - await datasette.render_template( - template, - QueryContext( - database=database, - database_color=db.color, - query={ - "sql": sql, - "params": params, - }, - stored_query=stored_query.name if stored_query else None, - private=private, - stored_query_write=stored_query_write, - db_is_immutable=not db.is_mutable, - error=query_error, - hide_sql=hide_sql, - show_hide_link=datasette.urls.path(show_hide_link), - show_hide_text=show_hide_text, - editable=not stored_query, - allow_execute_sql=allow_execute_sql, - save_query_url=save_query_url, - tables=await get_tables(datasette, request, db, allowed_dict), - named_parameter_values=named_parameter_values, - edit_sql_url=edit_sql_url, - display_rows=await display_rows( - datasette, database, request, rows, columns - ), - table_columns=( - await _table_columns(datasette, database) - if allow_execute_sql - else {} - ), - columns=columns, - renderers=renderers, - url_csv=datasette.urls.path( - path_with_format( - request=request, format="csv", extra_qs={"_size": "max"} - ) - ), - show_hide_hidden=markupsafe.Markup(show_hide_hidden), - metadata=metadata, - alternate_url_json=alternate_url_json, - select_templates=[ - f"{'*' if template_name == template.name else ''}{template_name}" - for template_name in templates - ], - top_query=make_slot_function( - "top_query", datasette, request, database=database, sql=sql - ), - top_stored_query=make_slot_function( - "top_stored_query", - datasette, - request, - database=database, - query_name=stored_query.name if stored_query else None, - ), - query_actions=query_actions, - ), - request=request, - view_name="database", - ), - headers=headers, - ) - else: - assert False, "Invalid format: {}".format(format_) - if datasette.cors: - add_cors_headers(r.headers) - return r - - -class MagicParameters(dict): - def __init__(self, sql, data, request, datasette): - super().__init__(data) - self._sql = sql - self._request = request - self._magics = dict( - itertools.chain.from_iterable( - pm.hook.register_magic_parameters(datasette=datasette) - ) - ) - self._prepared = {} - - async def execute_params(self): - for key in derive_named_parameters(self._sql): - if key.startswith("_") and key.count("_") >= 2: - prefix, suffix = key[1:].split("_", 1) - if prefix in self._magics: - result = await await_me_maybe( - self._magics[prefix](suffix, self._request) - ) - self._prepared[key] = result - - def __len__(self): - # Workaround for 'Incorrect number of bindings' error - # https://github.com/simonw/datasette/issues/967#issuecomment-692951144 - return super().__len__() or 1 - - def __getitem__(self, key): - if key.startswith("_") and key.count("_") >= 2: - if key in self._prepared: - return self._prepared[key] - # Try the other route - prefix, suffix = key[1:].split("_", 1) - if prefix in self._magics: - try: - return self._magics[prefix](suffix, self._request) - except KeyError: - return super().__getitem__(key) - else: - return super().__getitem__(key) - - -class TableCreateView(BaseView): - name = "table-create" - - _valid_keys = { - "table", - "rows", - "row", - "columns", - "pk", - "pks", - "ignore", - "replace", - "alter", - } - _supported_column_types = { - "text", - "integer", - "float", - "blob", - } - # Any string that does not contain a newline or start with sqlite_ - _table_name_re = re.compile(r"^(?!sqlite_)[^\n]+$") - - def __init__(self, datasette): - self.ds = datasette - - async def post(self, request): - db = await self.ds.resolve_database(request) - database_name = db.name - - # Must have create-table permission - if not await self.ds.allowed( - action="create-table", - resource=DatabaseResource(database=database_name), - actor=request.actor, - ): - return _error(["Permission denied"], 403) - - body = await request.post_body() - try: - data = json.loads(body) - except json.JSONDecodeError as e: - return _error(["Invalid JSON: {}".format(e)]) - - if not isinstance(data, dict): - return _error(["JSON must be an object"]) - - invalid_keys = set(data.keys()) - self._valid_keys - if invalid_keys: - return _error(["Invalid keys: {}".format(", ".join(invalid_keys))]) - - # ignore and replace are mutually exclusive - if data.get("ignore") and data.get("replace"): - return _error(["ignore and replace are mutually exclusive"]) - - # ignore and replace only allowed with row or rows - if "ignore" in data or "replace" in data: - if not data.get("row") and not data.get("rows"): - return _error(["ignore and replace require row or rows"]) - - # ignore and replace require pk or pks - if "ignore" in data or "replace" in data: - if not data.get("pk") and not data.get("pks"): - return _error(["ignore and replace require pk or pks"]) - - ignore = data.get("ignore") - replace = data.get("replace") - - if replace: - # Must have update-row permission - if not await self.ds.allowed( - action="update-row", - resource=DatabaseResource(database=database_name), - actor=request.actor, - ): - return _error(["Permission denied: need update-row"], 403) - - table_name = data.get("table") - if not table_name: - return _error(["Table is required"]) - - if not self._table_name_re.match(table_name): - return _error(["Invalid table name"]) - - table_exists = await db.table_exists(data["table"]) - columns = data.get("columns") - rows = data.get("rows") - row = data.get("row") - if not columns and not rows and not row: - return _error(["columns, rows or row is required"]) - - if rows and row: - return _error(["Cannot specify both rows and row"]) - - if rows or row: - # Must have insert-row permission - if not await self.ds.allowed( - action="insert-row", - resource=DatabaseResource(database=database_name), - actor=request.actor, - ): - return _error(["Permission denied: need insert-row"], 403) - - alter = False - if rows or row: - if not table_exists: - # if table is being created for the first time, alter=True - alter = True - else: - # alter=True only if they request it AND they have permission - if data.get("alter"): - if not await self.ds.allowed( - action="alter-table", - resource=DatabaseResource(database=database_name), - actor=request.actor, - ): - return _error(["Permission denied: need alter-table"], 403) - alter = True - - if columns: - if rows or row: - return _error(["Cannot specify columns with rows or row"]) - if not isinstance(columns, list): - return _error(["columns must be a list"]) - for column in columns: - if not isinstance(column, dict): - return _error(["columns must be a list of objects"]) - if not column.get("name") or not isinstance(column.get("name"), str): - return _error(["Column name is required"]) - if not column.get("type"): - column["type"] = "text" - if column["type"] not in self._supported_column_types: - return _error( - ["Unsupported column type: {}".format(column["type"])] - ) - # No duplicate column names - dupes = {c["name"] for c in columns if columns.count(c) > 1} - if dupes: - return _error(["Duplicate column name: {}".format(", ".join(dupes))]) - - if row: - rows = [row] - - if rows: - if not isinstance(rows, list): - return _error(["rows must be a list"]) - for row in rows: - if not isinstance(row, dict): - return _error(["rows must be a list of objects"]) - - pk = data.get("pk") - pks = data.get("pks") - - if pk and pks: - return _error(["Cannot specify both pk and pks"]) - if pk: - if not isinstance(pk, str): - return _error(["pk must be a string"]) - if pks: - if not isinstance(pks, list): - return _error(["pks must be a list"]) - for pk in pks: - if not isinstance(pk, str): - return _error(["pks must be a list of strings"]) - - # If table exists already, read pks from that instead - if table_exists: - actual_pks = await db.primary_keys(table_name) - # if pk passed and table already exists check it does not change - bad_pks = False - if len(actual_pks) == 1 and data.get("pk") and data["pk"] != actual_pks[0]: - bad_pks = True - elif ( - len(actual_pks) > 1 - and data.get("pks") - and set(data["pks"]) != set(actual_pks) - ): - bad_pks = True - if bad_pks: - return _error(["pk cannot be changed for existing table"]) - pks = actual_pks - - initial_schema = None - if table_exists: - initial_schema = await db.execute_fn( - lambda conn: sqlite_utils.Database(conn)[table_name].schema - ) - - def create_table(conn): - table = sqlite_utils.Database(conn)[table_name] - if rows: - table.insert_all( - rows, pk=pks or pk, ignore=ignore, replace=replace, alter=alter - ) - else: - table.create( - {c["name"]: c["type"] for c in columns}, - pk=pks or pk, - ) - return table.schema - - try: - schema = await db.execute_write_fn(create_table, request=request) - except Exception as e: - return _error([str(e)]) - - if initial_schema is not None and initial_schema != schema: - await self.ds.track_event( - AlterTableEvent( - request.actor, - database=database_name, - table=table_name, - before_schema=initial_schema, - after_schema=schema, - ) - ) - - table_url = self.ds.absolute_url( - request, self.ds.urls.table(db.name, table_name) - ) - table_api_url = self.ds.absolute_url( - request, self.ds.urls.table(db.name, table_name, format="json") - ) - details = { - "ok": True, - "database": db.name, - "table": table_name, - "table_url": table_url, - "table_api_url": table_api_url, - "schema": schema, - } - if rows: - details["row_count"] = len(rows) - - if not table_exists: - # Only log creation if we created a table - await self.ds.track_event( - CreateTableEvent( - request.actor, database=db.name, table=table_name, schema=schema - ) - ) - if rows: - await self.ds.track_event( - InsertRowsEvent( - request.actor, - database=db.name, - table=table_name, - num_rows=len(rows), - ignore=ignore, - replace=replace, - ) - ) - return Response.json(details, status=201) - - -async def display_rows(datasette, database, request, rows, columns): - display_rows = [] - truncate_cells = datasette.setting("truncate_cells_html") - for row in rows: - display_row = [] - for column, value in zip(columns, row): - display_value = value - # Let the plugins have a go - # pylint: disable=no-member - plugin_display_value = None - for candidate in pm.hook.render_cell( - row=row, - value=value, - column=column, - table=None, - pks=[], - database=database, - datasette=datasette, - request=request, - column_type=None, - ): - candidate = await await_me_maybe(candidate) - if candidate is not None: - plugin_display_value = candidate - break - if plugin_display_value is not None: - display_value = plugin_display_value - else: - if value in ("", None): - display_value = markupsafe.Markup(" ") - elif is_url(str(display_value).strip()): - display_value = markupsafe.Markup( - '{truncated_url}'.format( - url=markupsafe.escape(value.strip()), - truncated_url=markupsafe.escape( - truncate_url(value.strip(), truncate_cells) - ), - ) - ) - elif isinstance(display_value, bytes): - blob_url = path_with_format( - request=request, - format="blob", - extra_qs={ - "_blob_column": column, - "_blob_hash": hashlib.sha256(display_value).hexdigest(), - }, - ) - formatted = format_bytes(len(value)) - display_value = markupsafe.Markup( - '<Binary: {:,} byte{}>'.format( - blob_url, - ( - ' title="{}"'.format(formatted) - if "bytes" not in formatted - else "" - ), - len(value), - "" if len(value) == 1 else "s", - ) - ) - else: - display_value = str(value) - if truncate_cells and len(display_value) > truncate_cells: - display_value = display_value[:truncate_cells] + "\u2026" - display_row.append(display_value) - display_rows.append(display_row) - return display_rows diff --git a/datasette/views/execute_write.py b/datasette/views/execute_write.py deleted file mode 100644 index 0054300c..00000000 --- a/datasette/views/execute_write.py +++ /dev/null @@ -1,257 +0,0 @@ -from urllib.parse import urlencode - -from datasette.resources import DatabaseResource -from datasette.utils import sqlite3 -from datasette.utils.asgi import Response - -from .base import BaseView, _error -from .query_helpers import ( - QueryValidationError, - _analysis_is_write, - _analysis_rows, - _analysis_rows_with_permissions, - _block_framing, - _coerce_execute_write_payload, - _derived_query_parameters, - _execute_write_analysis_data, - _inserted_row_url, - _json_or_form_payload, - _prepare_execute_write, - _table_columns, - _wants_json, -) - - -class ExecuteWriteView(BaseView): - name = "execute-write" - has_json_alternate = False - - async def _render_form( - self, - request, - db, - *, - sql="", - parameter_values=None, - analysis=None, - analysis_error=None, - execution_message=None, - execution_links=None, - execution_ok=None, - status=200, - ): - parameter_values = parameter_values or {} - execution_links = execution_links or [] - parameter_names = [] - analysis_rows = [] - table_columns = await _table_columns(self.ds, db.name) - hidden_table_names = set(await db.hidden_table_names()) - write_template_tables = { - table: columns - for table, columns in table_columns.items() - if columns and table not in hidden_table_names - } - if sql and analysis_error is None: - try: - parameter_names = _derived_query_parameters(sql) - if analysis is None: - params = {parameter: "" for parameter in parameter_names} - analysis = await db.analyze_sql(sql, params) - if _analysis_is_write(analysis): - analysis_rows = await _analysis_rows_with_permissions( - self.ds, analysis, request.actor - ) - else: - analysis_error = ( - "Use /-/query for read-only SQL; " - "this endpoint only executes writes" - ) - except (QueryValidationError, sqlite3.DatabaseError) as ex: - analysis_error = getattr(ex, "message", str(ex)) - - allow_save_query = await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ) and await self.ds.allowed( - action="store-query", - resource=DatabaseResource(db.name), - actor=request.actor, - ) - save_query_base_url = None - save_query_url = None - if allow_save_query: - save_query_base_url = self.ds.urls.database(db.name) + "/-/queries/store" - if ( - sql - and analysis_error is None - and not any(row["allowed"] is False for row in analysis_rows) - ): - save_query_url = save_query_base_url + "?" + urlencode({"sql": sql}) - - response = await self.render( - ["execute_write.html"], - request, - { - "database": db.name, - "database_color": db.color, - "sql": sql, - "parameter_names": parameter_names, - "parameter_values": parameter_values, - "analysis_error": analysis_error, - "analysis_rows": [ - row for row in analysis_rows if row["operation"] != "read" - ], - "execution_message": execution_message, - "execution_links": execution_links, - "execution_ok": execution_ok, - "execute_disabled": bool( - (not sql) - or analysis_error - or any(row["allowed"] is False for row in analysis_rows) - ), - "table_columns": table_columns, - "write_template_tables": write_template_tables, - "save_query_url": save_query_url, - "save_query_base_url": save_query_base_url, - }, - ) - response.status = status - return _block_framing(response) - - async def get(self, request): - db = await self.ds.resolve_database(request) - await self.ds.ensure_permission( - action="execute-write-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ) - if not db.is_mutable: - return _block_framing( - _error( - ["Cannot execute write SQL because this database is immutable."], - 403, - ) - ) - return await self._render_form( - request, - db, - sql=request.args.get("sql") or "", - ) - - async def post(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-write-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing( - _error(["Permission denied: need execute-write-sql"], 403) - ) - if not db.is_mutable: - return _block_framing(_error(["Database is immutable"], 403)) - - data = {} - is_json = request.headers.get("content-type", "").startswith("application/json") - sql = "" - provided_params = {} - try: - data, is_json = await _json_or_form_payload(request) - sql, provided_params = _coerce_execute_write_payload(data, is_json) - parameter_names, params, analysis = await _prepare_execute_write( - self.ds, db, sql, provided_params, request.actor - ) - except QueryValidationError as ex: - if _wants_json(request, is_json, data): - return _block_framing(_error([ex.message], ex.status)) - return await self._render_form( - request, - db, - sql=sql or "", - parameter_values=provided_params, - analysis_error=ex.message, - execution_message=ex.message, - execution_ok=False, - status=ex.status, - ) - - try: - cursor = await db.execute_write(sql, params, request=request) - except sqlite3.DatabaseError as ex: - message = str(ex) - if _wants_json(request, is_json, data): - return _block_framing(_error([message], 400)) - return await self._render_form( - request, - db, - sql=sql, - parameter_values=params, - analysis=analysis, - execution_message=message, - execution_ok=False, - status=400, - ) - - message = "Query executed, {} row{} affected".format( - cursor.rowcount, "" if cursor.rowcount == 1 else "s" - ) - if _wants_json(request, is_json, data): - return _block_framing( - Response.json( - { - "ok": True, - "message": message, - "rowcount": cursor.rowcount, - "analysis": _analysis_rows(analysis), - } - ) - ) - - inserted_row_url = await _inserted_row_url(self.ds, db, analysis, cursor) - execution_links = ( - [{"href": inserted_row_url, "label": "View row"}] - if inserted_row_url - else [] - ) - return await self._render_form( - request, - db, - sql=sql, - parameter_values={name: params.get(name, "") for name in parameter_names}, - analysis=analysis, - execution_message=message, - execution_links=execution_links, - execution_ok=True, - ) - - -class ExecuteWriteAnalyzeView(BaseView): - name = "execute-write-analyze" - has_json_alternate = False - - async def get(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-write-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing( - _error(["Permission denied: need execute-write-sql"], 403) - ) - - invalid_keys = set(request.args) - {"sql"} - if invalid_keys: - return _block_framing( - _error( - ["Invalid keys: {}".format(", ".join(sorted(invalid_keys)))], - 400, - ) - ) - sql = request.args.get("sql") or "" - return _block_framing( - Response.json( - await _execute_write_analysis_data(self.ds, db, sql, request.actor) - ) - ) diff --git a/datasette/views/index.py b/datasette/views/index.py index 6a9462ac..66776e1c 100644 --- a/datasette/views/index.py +++ b/datasette/views/index.py @@ -1,189 +1,59 @@ import json -from datasette.plugins import pm -from datasette.utils import ( - add_cors_headers, - await_me_maybe, - make_slot_function, - CustomJSONEncoder, -) -from datasette.utils.asgi import Response +from sanic import response + +from datasette.utils import CustomJSONEncoder from datasette.version import __version__ -from .base import BaseView - -# Truncate table list on homepage at: -TRUNCATE_AT = 5 - -# Only attempt counts if database less than this size in bytes: -COUNT_DB_SIZE_LIMIT = 100 * 1024 * 1024 +from .base import HASH_LENGTH, RenderMixin -class IndexView(BaseView): - name = "index" +class IndexView(RenderMixin): - async def get(self, request): - as_format = request.url_vars["format"] - await self.ds.ensure_permission(action="view-instance", actor=request.actor) - - # Get all allowed databases and tables in bulk - db_page = await self.ds.allowed_resources( - "view-database", request.actor, include_is_private=True - ) - allowed_databases = [r async for r in db_page.all()] - allowed_db_dict = {r.parent: r for r in allowed_databases} - - # Group tables by database - tables_by_db = {} - table_page = await self.ds.allowed_resources( - "view-table", request.actor, include_is_private=True - ) - async for t in table_page.all(): - if t.parent not in tables_by_db: - tables_by_db[t.parent] = {} - tables_by_db[t.parent][t.child] = t + def __init__(self, datasette): + self.ds = datasette + self.files = datasette.files + self.jinja_env = datasette.jinja_env + self.executor = datasette.executor + async def get(self, request, as_format): databases = [] - # Iterate over allowed databases instead of all databases - for name in allowed_db_dict.keys(): - db = self.ds.databases[name] - database_private = allowed_db_dict[name].private - - # Get allowed tables/views for this database - allowed_for_db = tables_by_db.get(name, {}) - - # Get table names from allowed set instead of db.table_names() - table_names = [child_name for child_name in allowed_for_db.keys()] - - hidden_table_names = set(await db.hidden_table_names()) - - # Determine which allowed items are views - view_names_set = set(await db.view_names()) - views = [ - {"name": child_name, "private": resource.private} - for child_name, resource in allowed_for_db.items() - if child_name in view_names_set - ] - - # Filter to just tables (not views) for table processing - table_names = [name for name in table_names if name not in view_names_set] - - # Perform counts only for immutable or DBS with <= COUNT_TABLE_LIMIT tables - table_counts = {} - if not db.is_mutable or db.size < COUNT_DB_SIZE_LIMIT: - table_counts = await db.table_counts(10) - # If any of these are None it means at least one timed out - ignore them all - if any(v is None for v in table_counts.values()): - table_counts = {} - - tables = {} - for table in table_names: - # Check if table is in allowed set - if table not in allowed_for_db: - continue - - table_columns = await db.table_columns(table) - tables[table] = { - "name": table, - "columns": table_columns, - "primary_keys": await db.primary_keys(table), - "count": table_counts.get(table), - "hidden": table in hidden_table_names, - "fts_table": await db.fts_table(table), - "num_relationships_for_sorting": 0, - "private": allowed_for_db[table].private, - } - - if request.args.get("_sort") == "relationships" or not table_counts: - # We will be sorting by number of relationships, so populate that field - all_foreign_keys = await db.get_all_foreign_keys() - for table, foreign_keys in all_foreign_keys.items(): - if table in tables.keys(): - count = len(foreign_keys["incoming"] + foreign_keys["outgoing"]) - tables[table]["num_relationships_for_sorting"] = count - - hidden_tables = [t for t in tables.values() if t["hidden"]] - visible_tables = [t for t in tables.values() if not t["hidden"]] - - tables_and_views_truncated = list( - sorted( - (t for t in tables.values() if t not in hidden_tables), - key=lambda t: ( - t["num_relationships_for_sorting"], - t["count"] or 0, - t["name"], - ), - reverse=True, - )[:TRUNCATE_AT] - ) - - # Only add views if this is less than TRUNCATE_AT - if len(tables_and_views_truncated) < TRUNCATE_AT: - num_views_to_add = TRUNCATE_AT - len(tables_and_views_truncated) - for view in views[:num_views_to_add]: - tables_and_views_truncated.append(view) - - databases.append( - { - "name": name, - "hash": db.hash, - "color": db.color, - "path": self.ds.urls.database(name), - "tables_and_views_truncated": tables_and_views_truncated, - "tables_and_views_more": (len(visible_tables) + len(views)) - > TRUNCATE_AT, - "tables_count": len(visible_tables), - "table_rows_sum": sum((t["count"] or 0) for t in visible_tables), - "show_table_row_counts": bool(table_counts), - "hidden_table_rows_sum": sum( - t["count"] for t in hidden_tables if t["count"] is not None - ), - "hidden_tables_count": len(hidden_tables), - "views_count": len(views), - "private": database_private, - } - ) - + for key, info in sorted(self.ds.inspect().items()): + tables = [t for t in info["tables"].values() if not t["hidden"]] + hidden_tables = [t for t in info["tables"].values() if t["hidden"]] + database = { + "name": key, + "hash": info["hash"], + "path": "{}-{}".format(key, info["hash"][:HASH_LENGTH]), + "tables_truncated": sorted( + tables, key=lambda t: t["count"], reverse=True + )[ + :5 + ], + "tables_count": len(tables), + "tables_more": len(tables) > 5, + "table_rows_sum": sum(t["count"] for t in tables), + "hidden_table_rows_sum": sum(t["count"] for t in hidden_tables), + "hidden_tables_count": len(hidden_tables), + "views_count": len(info["views"]), + } + databases.append(database) if as_format: headers = {} if self.ds.cors: - add_cors_headers(headers) - return Response( - json.dumps( - { - "databases": {db["name"]: db for db in databases}, - "metadata": await self.ds.get_instance_metadata(), - }, - cls=CustomJSONEncoder, - ), - content_type="application/json; charset=utf-8", + headers["Access-Control-Allow-Origin"] = "*" + return response.HTTPResponse( + json.dumps({db["name"]: db for db in databases}, cls=CustomJSONEncoder), + content_type="application/json", headers=headers, ) + else: - homepage_actions = [] - for hook in pm.hook.homepage_actions( - datasette=self.ds, - actor=request.actor, - request=request, - ): - extra_links = await await_me_maybe(hook) - if extra_links: - homepage_actions.extend(extra_links) - alternative_homepage = request.path == "/-/" - return await self.render( - ["default:index.html" if alternative_homepage else "index.html"], - request=request, - context={ - "databases": databases, - "metadata": await self.ds.get_instance_metadata(), - "datasette_version": __version__, - "private": not await self.ds.allowed( - action="view-instance", actor=None - ), - "top_homepage": make_slot_function( - "top_homepage", self.ds, request - ), - "homepage_actions": homepage_actions, - "noindex": request.path == "/-/", - }, + return self.render( + ["index.html"], + databases=databases, + metadata=self.ds.metadata, + datasette_version=__version__, + extra_css_urls=self.ds.extra_css_urls(), + extra_js_urls=self.ds.extra_js_urls(), ) diff --git a/datasette/views/query_helpers.py b/datasette/views/query_helpers.py deleted file mode 100644 index 46d71b8e..00000000 --- a/datasette/views/query_helpers.py +++ /dev/null @@ -1,556 +0,0 @@ -import json -import re - -from datasette.resources import DatabaseResource, TableResource -from datasette.stored_queries import StoredQuery -from datasette.utils import ( - named_parameters as derive_named_parameters, - escape_sqlite, - path_from_row_pks, - sqlite3, - validate_sql_select, - InvalidSql, -) -from datasette.utils.asgi import Forbidden - -_query_name_re = re.compile(r"^[^/\.\n]+$") - -_query_fields = { - "sql", - "title", - "description", - "hide_sql", - "fragment", - "parameters", - "params", - "is_private", - "on_success_message", - "on_success_redirect", - "on_error_message", - "on_error_redirect", -} - -_query_create_fields = _query_fields | {"name", "mode", "csrftoken"} -_query_update_fields = _query_fields -_query_write_fields = { - "on_success_message", - "on_success_redirect", - "on_error_message", - "on_error_redirect", -} - - -class QueryValidationError(Exception): - def __init__(self, message, status=400): - self.message = message - self.status = status - - -def _actor_id(actor): - if isinstance(actor, dict): - return actor.get("id") - return None - - -def _as_bool(value): - if isinstance(value, bool): - return value - if value is None: - return False - if isinstance(value, int): - return bool(value) - if isinstance(value, str): - return value.lower() in {"1", "true", "t", "yes", "on"} - return bool(value) - - -def _as_optional_bool(value, name): - if value is None or value == "": - return None - if isinstance(value, bool): - return value - if isinstance(value, int): - return bool(value) - if isinstance(value, str): - lowered = value.lower() - if lowered in {"1", "true", "t", "yes", "on"}: - return True - if lowered in {"0", "false", "f", "no", "off"}: - return False - raise QueryValidationError("{} must be 0 or 1".format(name)) - - -def _query_list_limit(value, default=50): - if value in (None, ""): - return default - try: - return min(max(1, int(value)), 1000) - except ValueError as ex: - raise QueryValidationError("_size must be an integer") from ex - - -def _derived_query_parameters(sql): - parameters = [] - seen = set() - for parameter in derive_named_parameters(sql): - if parameter.startswith("_"): - raise QueryValidationError("Magic parameters are not allowed") - if parameter not in seen: - parameters.append(parameter) - seen.add(parameter) - return parameters - - -def _coerce_query_parameters(value, derived): - if value is None: - return derived - if isinstance(value, str): - parameters = [ - parameter.strip() - for parameter in re.split(r"[\s,]+", value) - if parameter.strip() - ] - elif isinstance(value, list): - parameters = value - else: - raise QueryValidationError("parameters must be a list of strings") - if not all(isinstance(parameter, str) for parameter in parameters): - raise QueryValidationError("parameters must be a list of strings") - if any(parameter.startswith("_") for parameter in parameters): - raise QueryValidationError("Magic parameters are not allowed") - if set(parameters) != set(derived): - raise QueryValidationError("parameters must match SQL named parameters") - return parameters - - -def _analysis_is_write(analysis): - return any( - access.operation in {"insert", "update", "delete"} - for access in analysis.table_accesses - ) - - -def _block_framing(response): - response.headers["Content-Security-Policy"] = "frame-ancestors 'none'" - response.headers["X-Frame-Options"] = "DENY" - return response - - -def _wants_json(request, is_json, data): - return ( - is_json - or request.headers.get("accept") == "application/json" - or (isinstance(data, dict) and data.get("_json")) - ) - - -def _query_create_form_error_message(message): - return { - "Query name is required": "URL is required", - "Invalid query name": "Invalid URL", - "Query name conflicts with a table or view": ( - "URL conflicts with an existing table or view" - ), - "Query already exists": "A query already exists at that URL", - }.get(message, message) - - -async def _json_or_form_payload(request): - content_type = request.headers.get("content-type", "") - if content_type.startswith("application/json"): - body = await request.post_body() - try: - return json.loads(body or b"{}"), True - except json.JSONDecodeError as e: - raise QueryValidationError("Invalid JSON: {}".format(e)) - return await request.post_vars(), False - - -async def _check_query_name(db, name, *, existing=False): - if not name or not isinstance(name, str): - raise QueryValidationError("Query name is required") - if not _query_name_re.match(name): - raise QueryValidationError("Invalid query name") - if not existing and (await db.table_exists(name) or await db.view_exists(name)): - raise QueryValidationError("Query name conflicts with a table or view") - - -async def _analyze_user_query(datasette, db, sql, *, actor): - if not sql or not isinstance(sql, str): - raise QueryValidationError("SQL is required") - derived = _derived_query_parameters(sql) - params = {parameter: "" for parameter in derived} - try: - analysis = await db.analyze_sql(sql, params) - except sqlite3.DatabaseError as ex: - raise QueryValidationError("Could not analyze query: {}".format(ex)) from ex - - is_write = _analysis_is_write(analysis) - if is_write: - try: - await datasette.ensure_query_write_permissions( - db.name, sql, actor=actor, analysis=analysis - ) - except Forbidden as ex: - raise QueryValidationError(str(ex), status=403) from ex - else: - try: - validate_sql_select(sql) - except InvalidSql as ex: - raise QueryValidationError(str(ex)) from ex - return is_write, derived, analysis - - -def _analysis_rows(analysis): - write_actions = { - "insert": "insert-row", - "update": "update-row", - "delete": "delete-row", - } - return [ - { - "operation": access.operation, - "database": access.database, - "table": access.table, - "required_permission": write_actions.get(access.operation, ""), - "source": access.source, - } - for access in analysis.table_accesses - ] - - -async def _analysis_rows_with_permissions(datasette, analysis, actor): - rows = _analysis_rows(analysis) - for row in rows: - permission = row["required_permission"] - if permission: - row["allowed"] = await datasette.allowed( - action=permission, - resource=TableResource(row["database"], row["table"]), - actor=actor, - ) - else: - row["allowed"] = None - return rows - - -def _coerce_execute_write_payload(data, is_json): - if not isinstance(data, dict): - raise QueryValidationError("JSON must be a dictionary") - if is_json: - invalid_keys = set(data) - {"sql", "params"} - if invalid_keys: - raise QueryValidationError( - "Invalid keys: {}".format(", ".join(sorted(invalid_keys))) - ) - params = data.get("params") or {} - else: - params = { - key: value - for key, value in data.items() - if key not in {"sql", "csrftoken", "_json"} - } - if not isinstance(params, dict): - raise QueryValidationError("params must be a dictionary") - return data.get("sql"), params - - -async def _prepare_execute_write(datasette, db, sql, params, actor): - if not sql or not isinstance(sql, str): - raise QueryValidationError("SQL is required") - parameter_names = _derived_query_parameters(sql) - extra_params = set(params) - set(parameter_names) - if extra_params: - raise QueryValidationError( - "Unknown parameters: {}".format(", ".join(sorted(extra_params))) - ) - params = {name: params.get(name, "") for name in parameter_names} - try: - analysis = await db.analyze_sql(sql, params) - except sqlite3.DatabaseError as ex: - raise QueryValidationError("Could not analyze query: {}".format(ex)) from ex - if not _analysis_is_write(analysis): - raise QueryValidationError( - "Use /-/query for read-only SQL; this endpoint only executes writes" - ) - try: - await datasette.ensure_query_write_permissions( - db.name, sql, actor=actor, analysis=analysis - ) - except Forbidden as ex: - raise QueryValidationError(str(ex), status=403) from ex - return parameter_names, params, analysis - - -async def _ensure_stored_query_execution_permissions( - datasette, db, query: StoredQuery, actor -): - if query.is_trusted: - return - if query.is_write: - await datasette.ensure_permission( - action="execute-write-sql", - resource=DatabaseResource(db.name), - actor=actor, - ) - await datasette.ensure_query_write_permissions(db.name, query.sql, actor=actor) - else: - await datasette.ensure_permission( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=actor, - ) - - -async def _execute_write_analysis_data(datasette, db, sql, actor): - parameter_names = [] - analysis_rows = [] - analysis_error = None - if sql: - try: - parameter_names = _derived_query_parameters(sql) - params = {parameter: "" for parameter in parameter_names} - analysis = await db.analyze_sql(sql, params) - if _analysis_is_write(analysis): - analysis_rows = await _analysis_rows_with_permissions( - datasette, analysis, actor - ) - else: - analysis_error = ( - "Use /-/query for read-only SQL; " - "this endpoint only executes writes" - ) - except (QueryValidationError, sqlite3.DatabaseError) as ex: - analysis_error = getattr(ex, "message", str(ex)) - return { - "ok": analysis_error is None, - "parameters": parameter_names, - "analysis_error": analysis_error, - "analysis_rows": [row for row in analysis_rows if row["operation"] != "read"], - "execute_disabled": bool( - (not sql) - or analysis_error - or any(row["allowed"] is False for row in analysis_rows) - ), - } - - -async def _query_create_analysis_data(datasette, db, sql, actor): - has_sql = bool(sql and sql.strip()) - parameter_names = [] - analysis_rows = [] - analysis_error = None - if has_sql: - try: - parameter_names = _derived_query_parameters(sql) - params = {parameter: "" for parameter in parameter_names} - analysis = await db.analyze_sql(sql, params) - analysis_rows = await _analysis_rows_with_permissions( - datasette, analysis, actor - ) - except (QueryValidationError, sqlite3.DatabaseError) as ex: - analysis_error = getattr(ex, "message", str(ex)) - return { - "ok": analysis_error is None, - "parameters": parameter_names, - "analysis_error": analysis_error, - "analysis_rows": analysis_rows, - "has_sql": has_sql, - "analysis_is_write": bool( - analysis_rows and any(row["required_permission"] for row in analysis_rows) - ), - "save_disabled": bool( - (not has_sql) - or analysis_error - or any(row["allowed"] is False for row in analysis_rows) - ), - } - - -async def _query_create_form_context( - datasette, - request, - db, - *, - sql="", - name="", - title="", - description="", - is_private=True, -): - analysis_data = await _query_create_analysis_data(datasette, db, sql, request.actor) - return { - "database": db.name, - "database_color": db.color, - "sql": sql, - "name": name, - "title": title, - "description": description, - "is_private": is_private, - **analysis_data, - } - - -async def _inserted_row_url(datasette, db, analysis, cursor): - if cursor.rowcount != 1: - return None - lastrowid = getattr(cursor, "lastrowid", None) - if lastrowid is None: - return None - direct_inserts = [ - access - for access in analysis.table_accesses - if access.operation == "insert" - and access.source is None - and access.database == db.name - ] - if len(direct_inserts) != 1: - return None - table = direct_inserts[0].table - pks = await db.primary_keys(table) - use_rowid = not pks - select = ( - "rowid" - if use_rowid - else ", ".join(escape_sqlite(primary_key) for primary_key in pks) - ) - try: - result = await db.execute( - "select {} from {} where rowid = ?".format(select, escape_sqlite(table)), - [lastrowid], - ) - except sqlite3.DatabaseError: - return None - row = result.first() - if row is None: - return None - row_path = path_from_row_pks(row, pks, use_rowid) - return datasette.urls.row(db.name, table, row_path) - - -def _apply_query_data_types(data): - typed = dict(data) - for key in ("hide_sql", "is_private"): - if key in typed: - typed[key] = _as_bool(typed[key]) - return typed - - -async def _prepare_query_create(datasette, request, db, data): - invalid_keys = set(data) - _query_create_fields - if invalid_keys: - raise QueryValidationError( - "Invalid keys: {}".format(", ".join(sorted(invalid_keys))) - ) - - data = _apply_query_data_types(data) - name = data.get("name") - await _check_query_name(db, name) - if await datasette.get_query(db.name, name) is not None: - raise QueryValidationError("Query already exists") - - is_write, derived, analysis = await _analyze_user_query( - datasette, - db, - data.get("sql"), - actor=request.actor, - ) - if not is_write and any(data.get(field) for field in _query_write_fields): - raise QueryValidationError("Writable query fields require writable SQL") - - parameters = _coerce_query_parameters( - data.get("parameters", data.get("params")), - derived, - ) - return { - "name": name, - "sql": data["sql"], - "title": data.get("title"), - "description": data.get("description"), - "hide_sql": _as_bool(data.get("hide_sql")), - "fragment": data.get("fragment"), - "parameters": parameters, - "is_write": is_write, - "is_private": _as_bool(data.get("is_private", True)), - "is_trusted": False, - "source": "user", - "owner_id": _actor_id(request.actor), - "on_success_message": data.get("on_success_message"), - "on_success_redirect": data.get("on_success_redirect"), - "on_error_message": data.get("on_error_message"), - "on_error_redirect": data.get("on_error_redirect"), - "analysis": analysis, - } - - -async def _prepare_query_update(datasette, request, db, existing: StoredQuery, update): - invalid_keys = set(update) - _query_update_fields - if invalid_keys: - raise QueryValidationError( - "Invalid keys: {}".format(", ".join(sorted(invalid_keys))) - ) - - update = _apply_query_data_types(update) - sql = update.get("sql", existing.sql) - query_is_write = existing.is_write - derived = _derived_query_parameters(sql) - parameters = None - - if "sql" in update: - query_is_write, derived, _ = await _analyze_user_query( - datasette, - db, - sql, - actor=request.actor, - ) - - if "parameters" in update or "params" in update: - parameters = _coerce_query_parameters( - update.get("parameters", update.get("params")), - derived, - ) - elif "sql" in update: - parameters = derived - - if not query_is_write and any(update.get(field) for field in _query_write_fields): - raise QueryValidationError("Writable query fields require writable SQL") - - field_values = { - "sql": sql, - "title": update.get("title"), - "description": update.get("description"), - "hide_sql": update.get("hide_sql"), - "fragment": update.get("fragment"), - "parameters": parameters, - "is_write": query_is_write, - "is_private": update.get("is_private"), - "on_success_message": update.get("on_success_message"), - "on_success_redirect": update.get("on_success_redirect"), - "on_error_message": update.get("on_error_message"), - "on_error_redirect": update.get("on_error_redirect"), - } - update_kwargs = {} - for field_name, value in field_values.items(): - if field_name in update: - update_kwargs[field_name] = value - if parameters is not None: - update_kwargs["parameters"] = parameters - if "sql" in update: - update_kwargs["is_write"] = query_is_write - return update_kwargs - - -async def _table_columns(datasette, database_name): - internal_db = datasette.get_internal_database() - result = await internal_db.execute( - "select table_name, name from catalog_columns where database_name = ?", - [database_name], - ) - table_columns = {} - for row in result.rows: - table_columns.setdefault(row["table_name"], []).append(row["name"]) - # Add views - db = datasette.get_database(database_name) - for view_name in await db.view_names(): - table_columns[view_name] = [] - return table_columns diff --git a/datasette/views/row.py b/datasette/views/row.py deleted file mode 100644 index 4eacfe49..00000000 --- a/datasette/views/row.py +++ /dev/null @@ -1,413 +0,0 @@ -from datasette.utils.asgi import NotFound, Forbidden, Response -from datasette.database import QueryInterrupted -from datasette.events import UpdateRowEvent, DeleteRowEvent -from datasette.resources import TableResource -from .base import DataView, BaseView, _error -from datasette.utils import ( - await_me_maybe, - CustomRow, - make_slot_function, - to_css_class, - escape_sqlite, -) -from datasette.plugins import pm -import json -import markupsafe -import sqlite_utils -from .table import display_columns_and_rows, _get_extras - - -class RowView(DataView): - name = "row" - - async def data(self, request, default_labels=False): - resolved = await self.ds.resolve_row(request) - db = resolved.db - database = db.name - table = resolved.table - pk_values = resolved.pk_values - - # Ensure user has permission to view this row - visible, private = await self.ds.check_visibility( - request.actor, - action="view-table", - resource=TableResource(database=database, table=table), - ) - if not visible: - raise Forbidden("You do not have permission to view this table") - - results = await resolved.db.execute( - resolved.sql, resolved.params, truncate=True - ) - columns = [r[0] for r in results.description] - rows = list(results.rows) - if not rows: - raise NotFound(f"Record not found: {pk_values}") - - pks = resolved.pks - - async def template_data(): - # Reorder columns so primary keys come first - pk_set = set(pks) - pk_cols = [d for d in results.description if d[0] in pk_set] - non_pk_cols = [d for d in results.description if d[0] not in pk_set] - reordered_description = pk_cols + non_pk_cols - reordered_columns = [d[0] for d in reordered_description] - - # Reorder row data to match - reordered_rows = [] - for row in rows: - new_row = CustomRow(reordered_columns) - for col in reordered_columns: - new_row[col] = row[col] - reordered_rows.append(new_row) - - # Expand foreign key columns into dicts so display_columns_and_rows - # renders them as hyperlinks, matching the table view behavior - expanded_rows = reordered_rows - for fk in await db.foreign_keys_for_table(table): - column = fk["column"] - if column not in reordered_columns: - continue - column_index = reordered_columns.index(column) - values = [row[column_index] for row in expanded_rows] - expanded_labels = await self.ds.expand_foreign_keys( - request.actor, database, table, column, values - ) - if expanded_labels: - new_rows = [] - for row in expanded_rows: - new_row = CustomRow(reordered_columns) - for col in reordered_columns: - value = row[col] - if ( - col == column - and (col, value) in expanded_labels - and value is not None - ): - new_row[col] = { - "value": value, - "label": expanded_labels[(col, value)], - } - else: - new_row[col] = value - new_rows.append(new_row) - expanded_rows = new_rows - - display_columns, display_rows = await display_columns_and_rows( - self.ds, - database, - table, - reordered_description, - expanded_rows, - link_column=False, - truncate_cells=0, - request=request, - ) - for column in display_columns: - column["sortable"] = False - - # Bold primary key cell values - for row in display_rows: - for cell in row: - if cell["column"] in pk_set: - cell["value"] = markupsafe.Markup( - "{}".format(cell["value"]) - ) - - row_actions = [] - for hook in pm.hook.row_actions( - datasette=self.ds, - actor=request.actor, - request=request, - database=database, - table=table, - row=rows[0], - ): - extra_links = await await_me_maybe(hook) - if extra_links: - row_actions.extend(extra_links) - - return { - "private": private, - "columns": reordered_columns, - "foreign_key_tables": await self.foreign_key_tables( - database, table, pk_values - ), - "database_color": db.color, - "display_columns": display_columns, - "display_rows": display_rows, - "custom_table_templates": [ - f"_table-{to_css_class(database)}-{to_css_class(table)}.html", - f"_table-row-{to_css_class(database)}-{to_css_class(table)}.html", - "_table.html", - ], - "row_actions": row_actions, - "top_row": make_slot_function( - "top_row", - self.ds, - request, - database=resolved.db.name, - table=resolved.table, - row=rows[0], - ), - "metadata": {}, - } - - data = { - "ok": True, - "database": database, - "table": table, - "rows": rows, - "columns": columns, - "primary_keys": resolved.pks, - "primary_key_values": pk_values, - } - - # Handle _extra parameter (new style) - extras = _get_extras(request) - - # Also support legacy _extras parameter for backward compatibility - if "foreign_key_tables" in (request.args.get("_extras") or "").split(","): - extras.add("foreign_key_tables") - - # Process extras - if "foreign_key_tables" in extras: - data["foreign_key_tables"] = await self.foreign_key_tables( - database, table, pk_values - ) - - if "render_cell" in extras: - # Call render_cell plugin hook for each cell - ct_map = await self.ds.get_column_types(database, table) - rendered_rows = [] - for row in rows: - rendered_row = {} - for value, column in zip(row, columns): - ct = ct_map.get(column) - plugin_display_value = None - # Try column type render_cell first - if ct: - candidate = await ct.render_cell( - value=value, - column=column, - table=table, - database=database, - datasette=self.ds, - request=request, - ) - if candidate is not None: - plugin_display_value = candidate - if plugin_display_value is None: - for candidate in pm.hook.render_cell( - row=row, - value=value, - column=column, - table=table, - pks=resolved.pks, - database=database, - datasette=self.ds, - request=request, - column_type=ct, - ): - candidate = await await_me_maybe(candidate) - if candidate is not None: - plugin_display_value = candidate - break - if plugin_display_value: - rendered_row[column] = str(plugin_display_value) - rendered_rows.append(rendered_row) - data["render_cell"] = rendered_rows - - return ( - data, - template_data, - ( - f"row-{to_css_class(database)}-{to_css_class(table)}.html", - "row.html", - ), - ) - - async def foreign_key_tables(self, database, table, pk_values): - if len(pk_values) != 1: - return [] - db = self.ds.databases[database] - all_foreign_keys = await db.get_all_foreign_keys() - foreign_keys = all_foreign_keys[table]["incoming"] - if len(foreign_keys) == 0: - return [] - - sql = "select " + ", ".join( - [ - "(select count(*) from {table} where {column}=:id)".format( - table=escape_sqlite(fk["other_table"]), - column=escape_sqlite(fk["other_column"]), - ) - for fk in foreign_keys - ] - ) - try: - rows = list(await db.execute(sql, {"id": pk_values[0]})) - except QueryInterrupted: - # Almost certainly hit the timeout - return [] - - foreign_table_counts = dict( - zip( - [(fk["other_table"], fk["other_column"]) for fk in foreign_keys], - list(rows[0]), - ) - ) - foreign_key_tables = [] - for fk in foreign_keys: - count = ( - foreign_table_counts.get((fk["other_table"], fk["other_column"])) or 0 - ) - key = fk["other_column"] - if key.startswith("_"): - key += "__exact" - link = "{}?{}={}".format( - self.ds.urls.table(database, fk["other_table"]), - key, - ",".join(pk_values), - ) - foreign_key_tables.append({**fk, **{"count": count, "link": link}}) - return foreign_key_tables - - -class RowError(Exception): - def __init__(self, error): - self.error = error - - -async def _resolve_row_and_check_permission(datasette, request, permission): - from datasette.app import DatabaseNotFound, TableNotFound, RowNotFound - - try: - resolved = await datasette.resolve_row(request) - except DatabaseNotFound as e: - return False, _error(["Database not found: {}".format(e.database_name)], 404) - except TableNotFound as e: - return False, _error(["Table not found: {}".format(e.table)], 404) - except RowNotFound as e: - return False, _error(["Record not found: {}".format(e.pk_values)], 404) - - # Ensure user has permission to delete this row - if not await datasette.allowed( - action=permission, - resource=TableResource(database=resolved.db.name, table=resolved.table), - actor=request.actor, - ): - return False, _error(["Permission denied"], 403) - - return True, resolved - - -class RowDeleteView(BaseView): - name = "row-delete" - - def __init__(self, datasette): - self.ds = datasette - - async def post(self, request): - ok, resolved = await _resolve_row_and_check_permission( - self.ds, request, "delete-row" - ) - if not ok: - return resolved - - # Delete table - def delete_row(conn): - sqlite_utils.Database(conn)[resolved.table].delete(resolved.pk_values) - - try: - await resolved.db.execute_write_fn(delete_row, request=request) - except Exception as e: - return _error([str(e)], 500) - - await self.ds.track_event( - DeleteRowEvent( - actor=request.actor, - database=resolved.db.name, - table=resolved.table, - pks=resolved.pk_values, - ) - ) - - return Response.json({"ok": True}, status=200) - - -class RowUpdateView(BaseView): - name = "row-update" - - def __init__(self, datasette): - self.ds = datasette - - async def post(self, request): - ok, resolved = await _resolve_row_and_check_permission( - self.ds, request, "update-row" - ) - if not ok: - return resolved - - body = await request.post_body() - try: - data = json.loads(body) - except json.JSONDecodeError as e: - return _error(["Invalid JSON: {}".format(e)]) - - if not isinstance(data, dict): - return _error(["JSON must be a dictionary"]) - if "update" not in data or not isinstance(data["update"], dict): - return _error(["JSON must contain an update dictionary"]) - - invalid_keys = set(data.keys()) - {"update", "return", "alter"} - if invalid_keys: - return _error(["Invalid keys: {}".format(", ".join(invalid_keys))]) - - update = data["update"] - - # Validate column types - from datasette.views.table import _validate_column_types - - ct_errors = await _validate_column_types( - self.ds, resolved.db.name, resolved.table, [update] - ) - if ct_errors: - return _error(ct_errors, 400) - - alter = data.get("alter") - if alter and not await self.ds.allowed( - action="alter-table", - resource=TableResource(database=resolved.db.name, table=resolved.table), - actor=request.actor, - ): - return _error(["Permission denied for alter-table"], 403) - - def update_row(conn): - sqlite_utils.Database(conn)[resolved.table].update( - resolved.pk_values, update, alter=alter - ) - - try: - await resolved.db.execute_write_fn(update_row, request=request) - except Exception as e: - return _error([str(e)], 400) - - result = {"ok": True} - if data.get("return"): - results = await resolved.db.execute( - resolved.sql, resolved.params, truncate=True - ) - result["row"] = results.dicts()[0] - - await self.ds.track_event( - UpdateRowEvent( - actor=request.actor, - database=resolved.db.name, - table=resolved.table, - pks=resolved.pk_values, - ) - ) - - return Response.json(result, status=200) diff --git a/datasette/views/special.py b/datasette/views/special.py index 6c82983c..7fde5ee9 100644 --- a/datasette/views/special.py +++ b/datasette/views/special.py @@ -1,1271 +1,30 @@ import json -import logging -from datasette.jump import JumpSQL, namespace_sql_params -from datasette.plugins import pm -from datasette.events import LogoutEvent, LoginEvent, CreateTokenEvent -from datasette.resources import DatabaseResource, TableResource -from datasette.utils.asgi import Response, Forbidden -from datasette.utils import ( - actor_matches_allow, - add_cors_headers, - await_me_maybe, - tilde_encode, - tilde_decode, -) -from .base import BaseView, View -import secrets -import urllib - -logger = logging.getLogger(__name__) +from sanic import response +from .base import RenderMixin -def _resource_path(parent, child): - if parent is None: - return "/" - if child is None: - return f"/{parent}" - return f"/{parent}/{child}" - - -class JsonDataView(BaseView): - name = "json_data" - template = "show_json.html" # Can be overridden in subclasses - - def __init__( - self, - datasette, - filename, - data_callback, - needs_request=False, - permission="view-instance", - template=None, - ): +class JsonDataView(RenderMixin): + def __init__(self, datasette, filename, data_callback): self.ds = datasette + self.jinja_env = datasette.jinja_env self.filename = filename self.data_callback = data_callback - self.needs_request = needs_request - self.permission = permission - if template is not None: - self.template = template - async def get(self, request): - if self.permission: - await self.ds.ensure_permission(action=self.permission, actor=request.actor) - if self.needs_request: - data = self.data_callback(request) - else: - data = self.data_callback() - - # Return JSON or HTML depending on format parameter - as_format = request.url_vars.get("format") + async def get(self, request, as_format): + data = self.data_callback() if as_format: headers = {} if self.ds.cors: - add_cors_headers(headers) - return Response.json(data, headers=headers) + headers["Access-Control-Allow-Origin"] = "*" + return response.HTTPResponse( + json.dumps(data), + content_type="application/json", + headers=headers + ) + else: - context = { - "filename": self.filename, - "data": data, - "data_json": json.dumps(data, indent=2, default=repr), - } - # Add has_debug_permission if this view requires permissions-debug - if self.permission == "permissions-debug": - context["has_debug_permission"] = True - return await self.render( - [self.template], - request=request, - context=context, - ) - - -class PatternPortfolioView(View): - async def get(self, request, datasette): - await datasette.ensure_permission(action="view-instance", actor=request.actor) - return Response.html( - await datasette.render_template( - "patterns.html", - request=request, - view_name="patterns", - ) - ) - - -class AuthTokenView(BaseView): - name = "auth_token" - has_json_alternate = False - - async def get(self, request): - # If already signed in as root, redirect - if request.actor and request.actor.get("id") == "root": - return Response.redirect(self.ds.urls.instance()) - token = request.args.get("token") or "" - if not self.ds._root_token: - raise Forbidden("Root token has already been used") - if secrets.compare_digest(token, self.ds._root_token): - self.ds._root_token = None - response = Response.redirect(self.ds.urls.instance()) - root_actor = {"id": "root"} - self.ds.set_actor_cookie(response, root_actor) - await self.ds.track_event(LoginEvent(actor=root_actor)) - return response - else: - raise Forbidden("Invalid token") - - -class LogoutView(BaseView): - name = "logout" - has_json_alternate = False - - async def get(self, request): - if not request.actor: - return Response.redirect(self.ds.urls.instance()) - return await self.render( - ["logout.html"], - request, - {"actor": request.actor}, - ) - - async def post(self, request): - response = Response.redirect(self.ds.urls.instance()) - self.ds.delete_actor_cookie(response) - self.ds.add_message(request, "You are now logged out", self.ds.WARNING) - await self.ds.track_event(LogoutEvent(actor=request.actor)) - return response - - -class PermissionsDebugView(BaseView): - name = "permissions_debug" - has_json_alternate = False - - async def get(self, request): - await self.ds.ensure_permission(action="view-instance", actor=request.actor) - await self.ds.ensure_permission(action="permissions-debug", actor=request.actor) - filter_ = request.args.get("filter") or "all" - permission_checks = list(reversed(self.ds._permission_checks)) - if filter_ == "exclude-yours": - permission_checks = [ - check - for check in permission_checks - if (check.actor or {}).get("id") != request.actor["id"] - ] - elif filter_ == "only-yours": - permission_checks = [ - check - for check in permission_checks - if (check.actor or {}).get("id") == request.actor["id"] - ] - return await self.render( - ["debug_permissions_playground.html"], - request, - # list() avoids error if check is performed during template render: - { - "permission_checks": permission_checks, - "filter": filter_, - "has_debug_permission": True, - "permissions": [ - { - "name": p.name, - "abbr": p.abbr, - "description": p.description, - "takes_parent": p.takes_parent, - "takes_child": p.takes_child, - } - for p in self.ds.actions.values() - ], - }, - ) - - async def post(self, request): - await self.ds.ensure_permission(action="view-instance", actor=request.actor) - await self.ds.ensure_permission(action="permissions-debug", actor=request.actor) - form = await request.form() - actor = json.loads(form["actor"]) - permission = form["permission"] - parent = form.get("resource_1") or None - child = form.get("resource_2") or None - - response, status = await _check_permission_for_actor( - self.ds, permission, parent, child, actor - ) - return Response.json(response, status=status) - - -class AllowedResourcesView(BaseView): - name = "allowed" - has_json_alternate = False - - async def get(self, request): - await self.ds.refresh_schemas() - - # Check if user has permissions-debug (to show sensitive fields) - has_debug_permission = await self.ds.allowed( - action="permissions-debug", actor=request.actor - ) - - # Check if this is a request for JSON (has .json extension) - as_format = request.url_vars.get("format") - - if not as_format: - # Render the HTML form (even if query parameters are present) - # Put most common/interesting actions first - priority_actions = [ - "view-instance", - "view-database", - "view-table", - "view-query", - "execute-sql", - "insert-row", - "update-row", - "delete-row", - ] - actions = list(self.ds.actions.keys()) - # Priority actions first (in order), then remaining alphabetically - sorted_actions = [a for a in priority_actions if a in actions] - sorted_actions.extend( - sorted(a for a in actions if a not in priority_actions) - ) - - return await self.render( - ["debug_allowed.html"], - request, - { - "supported_actions": sorted_actions, - "has_debug_permission": has_debug_permission, - }, - ) - - payload, status = await self._allowed_payload(request, has_debug_permission) - headers = {} - if self.ds.cors: - add_cors_headers(headers) - return Response.json(payload, status=status, headers=headers) - - async def _allowed_payload(self, request, has_debug_permission): - action = request.args.get("action") - if not action: - return {"error": "action parameter is required"}, 400 - if action not in self.ds.actions: - return {"error": f"Unknown action: {action}"}, 404 - - actor = request.actor if isinstance(request.actor, dict) else None - actor_id = actor.get("id") if actor else None - parent_filter = request.args.get("parent") - child_filter = request.args.get("child") - if child_filter and not parent_filter: - return {"error": "parent must be provided when child is specified"}, 400 - - try: - page = int(request.args.get("page", "1")) - page_size = int(request.args.get("page_size", "50")) - except ValueError: - return {"error": "page and page_size must be integers"}, 400 - if page < 1: - return {"error": "page must be >= 1"}, 400 - if page_size < 1: - return {"error": "page_size must be >= 1"}, 400 - max_page_size = 200 - if page_size > max_page_size: - page_size = max_page_size - offset = (page - 1) * page_size - - # Use the simplified allowed_resources method - # Collect all resources with optional reasons for debugging - try: - allowed_rows = [] - result = await self.ds.allowed_resources( - action=action, - actor=actor, - parent=parent_filter, - include_reasons=has_debug_permission, - ) - async for resource in result.all(): - parent_val = resource.parent - child_val = resource.child - - # Build resource path - if parent_val is None: - resource_path = "/" - elif child_val is None: - resource_path = f"/{parent_val}" - else: - resource_path = f"/{parent_val}/{child_val}" - - row = { - "parent": parent_val, - "child": child_val, - "resource": resource_path, - } - - # Add reason if we have it (from include_reasons=True) - if has_debug_permission and hasattr(resource, "reasons"): - row["reason"] = resource.reasons - - allowed_rows.append(row) - except Exception: - # If catalog tables don't exist yet, return empty results - return ( - { - "action": action, - "actor_id": actor_id, - "page": page, - "page_size": page_size, - "total": 0, - "items": [], - }, - 200, - ) - - # Apply child filter if specified - if child_filter is not None: - allowed_rows = [row for row in allowed_rows if row["child"] == child_filter] - - # Pagination - total = len(allowed_rows) - paged_rows = allowed_rows[offset : offset + page_size] - - # Items are already in the right format - items = paged_rows - - def build_page_url(page_number): - pairs = [] - for key in request.args: - if key in {"page", "page_size"}: - continue - for value in request.args.getlist(key): - pairs.append((key, value)) - pairs.append(("page", str(page_number))) - pairs.append(("page_size", str(page_size))) - query = urllib.parse.urlencode(pairs) - return f"{request.path}?{query}" - - response = { - "action": action, - "actor_id": actor_id, - "page": page, - "page_size": page_size, - "total": total, - "items": items, - } - - if total > offset + page_size: - response["next_url"] = build_page_url(page + 1) - if page > 1: - response["previous_url"] = build_page_url(page - 1) - - return response, 200 - - -class PermissionRulesView(BaseView): - name = "permission_rules" - has_json_alternate = False - - async def get(self, request): - await self.ds.ensure_permission(action="view-instance", actor=request.actor) - await self.ds.ensure_permission(action="permissions-debug", actor=request.actor) - - # Check if this is a request for JSON (has .json extension) - as_format = request.url_vars.get("format") - - if not as_format: - # Render the HTML form (even if query parameters are present) - return await self.render( - ["debug_rules.html"], - request, - { - "sorted_actions": sorted(self.ds.actions.keys()), - "has_debug_permission": True, - }, - ) - - # JSON API - action parameter is required - action = request.args.get("action") - if not action: - return Response.json({"error": "action parameter is required"}, status=400) - if action not in self.ds.actions: - return Response.json({"error": f"Unknown action: {action}"}, status=404) - - actor = request.actor if isinstance(request.actor, dict) else None - - try: - page = int(request.args.get("page", "1")) - page_size = int(request.args.get("page_size", "50")) - except ValueError: - return Response.json( - {"error": "page and page_size must be integers"}, status=400 - ) - if page < 1: - return Response.json({"error": "page must be >= 1"}, status=400) - if page_size < 1: - return Response.json({"error": "page_size must be >= 1"}, status=400) - max_page_size = 200 - if page_size > max_page_size: - page_size = max_page_size - offset = (page - 1) * page_size - - from datasette.utils.actions_sql import build_permission_rules_sql - - union_sql, union_params, restriction_sqls = await build_permission_rules_sql( - self.ds, actor, action - ) - await self.ds.refresh_schemas() - db = self.ds.get_internal_database() - - count_query = f""" - WITH rules AS ( - {union_sql} - ) - SELECT COUNT(*) AS count - FROM rules - """ - count_row = (await db.execute(count_query, union_params)).first() - total = count_row["count"] if count_row else 0 - - data_query = f""" - WITH rules AS ( - {union_sql} - ) - SELECT parent, child, allow, reason, source_plugin - FROM rules - ORDER BY allow DESC, (parent IS NOT NULL), parent, child - LIMIT :limit OFFSET :offset - """ - params = {**union_params, "limit": page_size, "offset": offset} - rows = await db.execute(data_query, params) - - items = [] - for row in rows: - parent = row["parent"] - child = row["child"] - items.append( - { - "parent": parent, - "child": child, - "resource": _resource_path(parent, child), - "allow": row["allow"], - "reason": row["reason"], - "source_plugin": row["source_plugin"], - } - ) - - def build_page_url(page_number): - pairs = [] - for key in request.args: - if key in {"page", "page_size"}: - continue - for value in request.args.getlist(key): - pairs.append((key, value)) - pairs.append(("page", str(page_number))) - pairs.append(("page_size", str(page_size))) - query = urllib.parse.urlencode(pairs) - return f"{request.path}?{query}" - - response = { - "action": action, - "actor_id": (actor or {}).get("id") if actor else None, - "page": page, - "page_size": page_size, - "total": total, - "items": items, - } - - if total > offset + page_size: - response["next_url"] = build_page_url(page + 1) - if page > 1: - response["previous_url"] = build_page_url(page - 1) - - headers = {} - if self.ds.cors: - add_cors_headers(headers) - return Response.json(response, headers=headers) - - -async def _check_permission_for_actor(ds, action, parent, child, actor): - """Shared logic for checking permissions. Returns a dict with check results.""" - if action not in ds.actions: - return {"error": f"Unknown action: {action}"}, 404 - - if child and not parent: - return {"error": "parent is required when child is provided"}, 400 - - # Use the action's properties to create the appropriate resource object - action_obj = ds.actions.get(action) - if not action_obj: - return {"error": f"Unknown action: {action}"}, 400 - - # Global actions (no resource_class) don't have a resource - if action_obj.resource_class is None: - resource_obj = None - elif action_obj.takes_parent and action_obj.takes_child: - # Child-level resource (e.g., TableResource, QueryResource) - resource_obj = action_obj.resource_class(database=parent, table=child) - elif action_obj.takes_parent: - # Parent-level resource (e.g., DatabaseResource) - resource_obj = action_obj.resource_class(database=parent) - else: - # This shouldn't happen given validation in Action.__post_init__ - return {"error": f"Invalid action configuration: {action}"}, 500 - - allowed = await ds.allowed(action=action, resource=resource_obj, actor=actor) - - response = { - "action": action, - "allowed": bool(allowed), - "resource": { - "parent": parent, - "child": child, - "path": _resource_path(parent, child), - }, - } - - if actor and "id" in actor: - response["actor_id"] = actor["id"] - - return response, 200 - - -class PermissionCheckView(BaseView): - name = "permission_check" - has_json_alternate = False - - async def get(self, request): - await self.ds.ensure_permission(action="permissions-debug", actor=request.actor) - as_format = request.url_vars.get("format") - - if not as_format: - return await self.render( - ["debug_check.html"], - request, - { - "sorted_actions": sorted(self.ds.actions.keys()), - "has_debug_permission": True, - }, - ) - - # JSON API - action parameter is required - action = request.args.get("action") - if not action: - return Response.json({"error": "action parameter is required"}, status=400) - - parent = request.args.get("parent") - child = request.args.get("child") - - response, status = await _check_permission_for_actor( - self.ds, action, parent, child, request.actor - ) - return Response.json(response, status=status) - - -class AllowDebugView(BaseView): - name = "allow_debug" - has_json_alternate = False - - async def get(self, request): - errors = [] - actor_input = request.args.get("actor") or '{"id": "root"}' - try: - actor = json.loads(actor_input) - actor_input = json.dumps(actor, indent=4) - except json.decoder.JSONDecodeError as ex: - errors.append(f"Actor JSON error: {ex}") - allow_input = request.args.get("allow") or '{"id": "*"}' - try: - allow = json.loads(allow_input) - allow_input = json.dumps(allow, indent=4) - except json.decoder.JSONDecodeError as ex: - errors.append(f"Allow JSON error: {ex}") - - result = None - if not errors: - result = str(actor_matches_allow(actor, allow)) - - return await self.render( - ["allow_debug.html"], - request, - { - "result": result, - "error": "\n\n".join(errors) if errors else "", - "actor_input": actor_input, - "allow_input": allow_input, - "has_debug_permission": await self.ds.allowed( - action="permissions-debug", actor=request.actor - ), - }, - ) - - -class MessagesDebugView(BaseView): - name = "messages_debug" - has_json_alternate = False - - async def get(self, request): - await self.ds.ensure_permission(action="view-instance", actor=request.actor) - return await self.render(["messages_debug.html"], request) - - async def post(self, request): - await self.ds.ensure_permission(action="view-instance", actor=request.actor) - form = await request.form() - message = form.get("message", "") - message_type = form.get("message_type") or "INFO" - assert message_type in ("INFO", "WARNING", "ERROR", "all") - datasette = self.ds - if message_type == "all": - datasette.add_message(request, message, datasette.INFO) - datasette.add_message(request, message, datasette.WARNING) - datasette.add_message(request, message, datasette.ERROR) - else: - datasette.add_message(request, message, getattr(datasette, message_type)) - return Response.redirect(self.ds.urls.instance()) - - -class CreateTokenView(BaseView): - name = "create_token" - has_json_alternate = False - - def check_permission(self, request): - if not self.ds.setting("allow_signed_tokens"): - raise Forbidden("Signed tokens are not enabled for this Datasette instance") - if not request.actor: - raise Forbidden("You must be logged in to create a token") - if not request.actor.get("id"): - raise Forbidden( - "You must be logged in as an actor with an ID to create a token" - ) - if request.actor.get("token"): - raise Forbidden( - "Token authentication cannot be used to create additional tokens" - ) - - async def shared(self, request): - self.check_permission(request) - # Build list of databases and tables the user has permission to view - db_page = await self.ds.allowed_resources("view-database", request.actor) - allowed_databases = [r async for r in db_page.all()] - - table_page = await self.ds.allowed_resources("view-table", request.actor) - allowed_tables = [r async for r in table_page.all()] - - # Build database -> tables mapping - database_with_tables = [] - for db_resource in allowed_databases: - database_name = db_resource.parent - if database_name == "_memory": - continue - - # Find tables for this database - tables = [] - for table_resource in allowed_tables: - if table_resource.parent == database_name: - tables.append( - { - "name": table_resource.child, - "encoded": tilde_encode(table_resource.child), - } - ) - - database_with_tables.append( - { - "name": database_name, - "encoded": tilde_encode(database_name), - "tables": tables, - } - ) - return { - "actor": request.actor, - "all_actions": self.ds.actions.keys(), - "database_actions": [ - key for key, value in self.ds.actions.items() if value.takes_parent - ], - "child_actions": [ - key for key, value in self.ds.actions.items() if value.takes_child - ], - "database_with_tables": database_with_tables, - } - - async def get(self, request): - self.check_permission(request) - return await self.render( - ["create_token.html"], request, await self.shared(request) - ) - - async def post(self, request): - self.check_permission(request) - form = await request.form() - errors = [] - expires_after = None - if form.get("expire_type"): - duration_string = form.get("expire_duration") - if ( - not duration_string - or not duration_string.isdigit() - or not int(duration_string) > 0 - ): - errors.append("Invalid expire duration") - else: - unit = form["expire_type"] - if unit == "minutes": - expires_after = int(duration_string) * 60 - elif unit == "hours": - expires_after = int(duration_string) * 60 * 60 - elif unit == "days": - expires_after = int(duration_string) * 60 * 60 * 24 - else: - errors.append("Invalid expire duration unit") - - # Are there any restrictions? - from datasette.tokens import TokenRestrictions - - restrictions = TokenRestrictions() - - for key in form: - if key.startswith("all:") and key.count(":") == 1: - restrictions.allow_all(key.split(":")[1]) - elif key.startswith("database:") and key.count(":") == 2: - bits = key.split(":") - restrictions.allow_database(tilde_decode(bits[1]), bits[2]) - elif key.startswith("resource:") and key.count(":") == 3: - bits = key.split(":") - restrictions.allow_resource( - tilde_decode(bits[1]), tilde_decode(bits[2]), bits[3] - ) - - token = await self.ds.create_token( - request.actor["id"], - expires_after=expires_after, - restrictions=restrictions, - handler="signed", - ) - token_bits = self.ds.unsign(token[len("dstok_") :], namespace="token") - await self.ds.track_event( - CreateTokenEvent( - actor=request.actor, - expires_after=expires_after, - restrict_all=restrictions.all, - restrict_database=restrictions.database, - restrict_resource=restrictions.resource, - ) - ) - context = await self.shared(request) - context.update({"errors": errors, "token": token, "token_bits": token_bits}) - return await self.render(["create_token.html"], request, context) - - -class ApiExplorerView(BaseView): - name = "api_explorer" - has_json_alternate = False - - async def example_links(self, request): - databases = [] - for name, db in self.ds.databases.items(): - database_visible, _ = await self.ds.check_visibility( - request.actor, - action="view-database", - resource=DatabaseResource(database=name), - ) - if not database_visible: - continue - tables = [] - table_names = await db.table_names() - for table in table_names: - visible, _ = await self.ds.check_visibility( - request.actor, - action="view-table", - resource=TableResource(database=name, table=table), - ) - if not visible: - continue - table_links = [] - tables.append({"name": table, "links": table_links}) - table_links.append( - { - "label": "Get rows for {}".format(table), - "method": "GET", - "path": self.ds.urls.table(name, table, format="json"), - } - ) - # If not mutable don't show any write APIs - if not db.is_mutable: - continue - - if await self.ds.allowed( - action="insert-row", - resource=TableResource(database=name, table=table), - actor=request.actor, - ): - pks = await db.primary_keys(table) - table_links.extend( - [ - { - "path": self.ds.urls.table(name, table) + "/-/insert", - "method": "POST", - "label": "Insert rows into {}".format(table), - "json": { - "rows": [ - { - column: None - for column in await db.table_columns(table) - if column not in pks - } - ] - }, - }, - { - "path": self.ds.urls.table(name, table) + "/-/upsert", - "method": "POST", - "label": "Upsert rows into {}".format(table), - "json": { - "rows": [ - { - column: "<{}{}>".format( - column, - ( - " (primary key)" - if column in (pks or ["rowid"]) - else "" - ), - ) - for column in ( - (["rowid"] if not pks else []) - + await db.table_columns(table) - ) - } - ] - }, - }, - ] - ) - if await self.ds.allowed( - action="drop-table", - resource=TableResource(database=name, table=table), - actor=request.actor, - ): - table_links.append( - { - "path": self.ds.urls.table(name, table) + "/-/drop", - "label": "Drop table {}".format(table), - "json": {"confirm": False}, - "method": "POST", - } - ) - database_links = [] - if ( - await self.ds.allowed( - action="create-table", - resource=DatabaseResource(database=name), - actor=request.actor, - ) - and db.is_mutable - ): - database_links.append( - { - "path": self.ds.urls.database(name) + "/-/create", - "label": "Create table in {}".format(name), - "json": { - "table": "new_table", - "columns": [ - {"name": "id", "type": "integer"}, - {"name": "name", "type": "text"}, - ], - "pk": "id", - }, - "method": "POST", - } - ) - if database_links or tables: - databases.append( - { - "name": name, - "links": database_links, - "tables": tables, - } - ) - # Sort so that mutable databases are first - databases.sort(key=lambda d: not self.ds.databases[d["name"]].is_mutable) - return databases - - async def get(self, request): - visible, private = await self.ds.check_visibility( - request.actor, - action="view-instance", - ) - if not visible: - raise Forbidden("You do not have permission to view this instance") - - def api_path(link): - return "/-/api#{}".format( - urllib.parse.urlencode( - { - key: json.dumps(value, indent=2) if key == "json" else value - for key, value in link.items() - if key in ("path", "method", "json") - } - ) - ) - - return await self.render( - ["api_explorer.html"], - request, - { - "example_links": await self.example_links(request), - "api_path": api_path, - "private": private, - }, - ) - - -class JumpView(BaseView): - """ - Endpoint for the jump menu. Returns JSON navigation items the actor can use. - """ - - name = "jump" - has_json_alternate = False - - async def _fragments(self, request): - fragments = [] - for hook in pm.hook.jump_items_sql( - datasette=self.ds, - actor=request.actor, - request=request, - ): - value = await await_me_maybe(hook) - if value is None: - continue - if isinstance(value, JumpSQL): - fragments.append(value) - elif isinstance(value, (list, tuple)): - for fragment in value: - if fragment is not None: - assert isinstance( - fragment, JumpSQL - ), "jump_items_sql must return JumpSQL instances" - fragments.append(fragment) - else: - raise TypeError("jump_items_sql must return JumpSQL instances") - return fragments - - def _resolve_url(self, url): - if not url or url.startswith("/"): - return url - - descriptor = json.loads(url) - if not isinstance(descriptor, dict): - raise TypeError("jump item url JSON must be an object") - method_name = descriptor.get("method") - if not isinstance(method_name, str) or not method_name: - raise TypeError("jump item url JSON must include a method") - if method_name.startswith("_"): - raise AttributeError(f"datasette.urls has no method named {method_name!r}") - try: - method = getattr(self.ds.urls, method_name) - except AttributeError as ex: - raise AttributeError( - f"datasette.urls has no method named {method_name!r}" - ) from ex - if not callable(method): - raise TypeError(f"datasette.urls.{method_name} is not callable") - kwargs = {key: value for key, value in descriptor.items() if key != "method"} - try: - return method(**kwargs) - except TypeError as ex: - raise TypeError( - f"Invalid arguments for datasette.urls.{method_name}(): {ex}" - ) from ex - - def _sort_key(self, row, q): - display_label = row["display_name"] or row["label"] - display_label_lower = display_label.lower() - q_lower = q.lower() - if display_label_lower == q_lower: - relevance = 0 - elif display_label_lower.startswith(q_lower): - relevance = 1 - else: - relevance = 2 - type_sort = { - "database": 10, - "table": 20, - "view": 25, - "query": 30, - }.get(row["type"], 50) - return (relevance, type_sort, len(display_label), row["label"]) - - async def _rows_for_database(self, database_name, indexed_fragments, q, pattern): - params = {"q": q, "pattern": pattern} - union_parts = [] - for index, fragment in indexed_fragments: - fragment_sql, fragment_params = namespace_sql_params( - fragment.sql, - fragment.params or {}, - f"jump_{index}", - ) - union_parts.append(f""" - SELECT - type, - label, - description, - url, - search_text, - display_name - FROM ( - {fragment_sql} - ) - """) - params.update(fragment_params) - sql = f""" - WITH jump_items AS ( - {" UNION ALL ".join(union_parts)} - ) - SELECT - type, - label, - description, - url, - search_text, - display_name - FROM jump_items - WHERE :q = '' - OR search_text LIKE :pattern COLLATE NOCASE - ORDER BY - CASE - WHEN lower(COALESCE(display_name, label)) = lower(:q) THEN 0 - WHEN lower(COALESCE(display_name, label)) LIKE lower(:q || '%') THEN 1 - ELSE 2 - END, - CASE type - WHEN 'database' THEN 10 - WHEN 'table' THEN 20 - WHEN 'view' THEN 25 - WHEN 'query' THEN 30 - ELSE 50 - END, - length(COALESCE(display_name, label)), - label - LIMIT 101 - """ - db = ( - self.ds.get_internal_database() - if database_name is None - else self.ds.get_database(database_name) - ) - result = await db.execute(sql, params) - return list(result.rows) - - async def get(self, request): - q = request.args.get("q", "").strip() - terms = q.split() - pattern = "%" + "%".join(terms) + "%" if terms else "%" - fragments = await self._fragments(request) - - fragments_by_database = {} - for index, fragment in enumerate(fragments): - fragments_by_database.setdefault(fragment.database, []).append( - (index, fragment) - ) - - rows = [] - truncated = False - for database_name, indexed_fragments in fragments_by_database.items(): - database_rows = await self._rows_for_database( - database_name, indexed_fragments, q, pattern - ) - if len(database_rows) > 100: - truncated = True - database_rows = database_rows[:100] - rows.extend(database_rows) - rows.sort(key=lambda row: self._sort_key(row, q)) - - if len(rows) > 100: - truncated = True - rows = rows[:100] - - matches = [] - for row in rows: - match = { - "name": row["label"], - "url": self._resolve_url(row["url"]), - "type": row["type"], - "description": row["description"], - } - if row["display_name"]: - match["display_name"] = row["display_name"] - matches.append(match) - - return Response.json({"matches": matches, "truncated": truncated}) - - -class SchemaBaseView(BaseView): - """Base class for schema views with common response formatting.""" - - has_json_alternate = False - - async def get_database_schema(self, database_name): - """Get schema SQL for a database.""" - db = self.ds.databases[database_name] - result = await db.execute( - "select group_concat(sql, ';' || CHAR(10)) as schema from sqlite_master where sql is not null" - ) - row = result.first() - return row["schema"] if row and row["schema"] else "" - - def format_json_response(self, data): - """Format data as JSON response with CORS headers if needed.""" - headers = {} - if self.ds.cors: - add_cors_headers(headers) - return Response.json(data, headers=headers) - - def format_error_response(self, error_message, format_, status=404): - """Format error response based on requested format.""" - if format_ == "json": - headers = {} - if self.ds.cors: - add_cors_headers(headers) - return Response.json( - {"ok": False, "error": error_message}, status=status, headers=headers - ) - else: - return Response.text(error_message, status=status) - - def format_markdown_response(self, heading, schema): - """Format schema as Markdown response.""" - md_output = f"# {heading}\n\n```sql\n{schema}\n```\n" - return Response.text( - md_output, headers={"content-type": "text/markdown; charset=utf-8"} - ) - - async def format_html_response( - self, request, schemas, is_instance=False, table_name=None - ): - """Format schema as HTML response.""" - context = { - "schemas": schemas, - "is_instance": is_instance, - } - if table_name: - context["table_name"] = table_name - return await self.render(["schema.html"], request=request, context=context) - - -class InstanceSchemaView(SchemaBaseView): - """ - Displays schema for all databases in the instance. - Supports HTML, JSON, and Markdown formats. - """ - - name = "instance_schema" - - async def get(self, request): - format_ = request.url_vars.get("format") or "html" - - # Get all databases the actor can view - allowed_databases_page = await self.ds.allowed_resources( - "view-database", - request.actor, - ) - allowed_databases = [r.parent async for r in allowed_databases_page.all()] - - # Get schema for each database - schemas = [] - for database_name in allowed_databases: - schema = await self.get_database_schema(database_name) - schemas.append({"database": database_name, "schema": schema}) - - if format_ == "json": - return self.format_json_response({"schemas": schemas}) - elif format_ == "md": - md_parts = [ - f"# Schema for {item['database']}\n\n```sql\n{item['schema']}\n```" - for item in schemas - ] - return Response.text( - "\n\n".join(md_parts), - headers={"content-type": "text/markdown; charset=utf-8"}, - ) - else: - return await self.format_html_response(request, schemas, is_instance=True) - - -class DatabaseSchemaView(SchemaBaseView): - """ - Displays schema for a specific database. - Supports HTML, JSON, and Markdown formats. - """ - - name = "database_schema" - - async def get(self, request): - database_name = request.url_vars["database"] - format_ = request.url_vars.get("format") or "html" - - # Check if database exists - if database_name not in self.ds.databases: - return self.format_error_response("Database not found", format_) - - # Check view-database permission - await self.ds.ensure_permission( - action="view-database", - resource=DatabaseResource(database=database_name), - actor=request.actor, - ) - - schema = await self.get_database_schema(database_name) - - if format_ == "json": - return self.format_json_response( - {"database": database_name, "schema": schema} - ) - elif format_ == "md": - return self.format_markdown_response(f"Schema for {database_name}", schema) - else: - schemas = [{"database": database_name, "schema": schema}] - return await self.format_html_response(request, schemas) - - -class TableSchemaView(SchemaBaseView): - """ - Displays schema for a specific table. - Supports HTML, JSON, and Markdown formats. - """ - - name = "table_schema" - - async def get(self, request): - database_name = request.url_vars["database"] - table_name = request.url_vars["table"] - format_ = request.url_vars.get("format") or "html" - - # Check view-table permission - await self.ds.ensure_permission( - action="view-table", - resource=TableResource(database=database_name, table=table_name), - actor=request.actor, - ) - - # Get schema for the table - db = self.ds.databases[database_name] - result = await db.execute( - "select sql from sqlite_master where name = ? and sql is not null", - [table_name], - ) - row = result.first() - - # Return 404 if table doesn't exist - if not row or not row["sql"]: - return self.format_error_response("Table not found", format_) - - schema = row["sql"] - - if format_ == "json": - return self.format_json_response( - {"database": database_name, "table": table_name, "schema": schema} - ) - elif format_ == "md": - return self.format_markdown_response( - f"Schema for {database_name}.{table_name}", schema - ) - else: - schemas = [{"database": database_name, "schema": schema}] - return await self.format_html_response( - request, schemas, table_name=table_name + return self.render( + ["show_json.html"], + filename=self.filename, + data=data ) diff --git a/datasette/views/stored_queries.py b/datasette/views/stored_queries.py deleted file mode 100644 index 8c4e849e..00000000 --- a/datasette/views/stored_queries.py +++ /dev/null @@ -1,483 +0,0 @@ -from urllib.parse import parse_qsl, urlencode - -from datasette.resources import DatabaseResource, QueryResource -from datasette.stored_queries import stored_query_to_dict -from datasette.utils import sqlite3, tilde_decode -from datasette.utils.asgi import Response - -from .base import BaseView, _error -from .query_helpers import ( - QueryValidationError, - _as_bool, - _as_optional_bool, - _block_framing, - _derived_query_parameters, - _json_or_form_payload, - _prepare_query_create, - _prepare_query_update, - _query_create_analysis_data, - _query_create_form_context, - _query_create_form_error_message, - _query_list_limit, -) - - -class QueryParametersView(BaseView): - name = "query-parameters" - has_json_alternate = False - - async def get(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing(_error(["Permission denied: need execute-sql"], 403)) - - invalid_keys = set(request.args) - {"sql"} - if invalid_keys: - return _block_framing( - _error( - ["Invalid keys: {}".format(", ".join(sorted(invalid_keys)))], - 400, - ) - ) - try: - parameters = _derived_query_parameters(request.args.get("sql") or "") - except QueryValidationError as ex: - return _block_framing(_error([ex.message], ex.status)) - return _block_framing(Response.json({"ok": True, "parameters": parameters})) - - -def _query_list_url(path, query_string, *, set_args=None, remove_args=None): - set_args = set_args or {} - remove_args = set(remove_args or ()) - skip = set(set_args) | remove_args | {"_next"} - pairs = [ - (key, value) - for key, value in parse_qsl(query_string, keep_blank_values=True) - if key not in skip - ] - for key, value in set_args.items(): - if value not in (None, ""): - pairs.append((key, value)) - return path + (("?" + urlencode(pairs)) if pairs else "") - - -class QueryListView(BaseView): - name = "query-list" - - async def database_name(self, request): - return (await self.ds.resolve_database(request)).name - - def query_list_path(self, database): - return self.ds.urls.database(database) + "/-/queries" - - async def get(self, request): - database = await self.database_name(request) - format_ = request.url_vars.get("format") or "html" - try: - limit = _query_list_limit( - request.args.get("_size"), - default=20 if format_ == "html" else 50, - ) - is_write = _as_optional_bool(request.args.get("is_write"), "is_write") - is_private = _as_optional_bool(request.args.get("is_private"), "is_private") - except QueryValidationError as ex: - return _error([ex.message], ex.status) - - page = await self.ds.list_queries( - database, - actor=request.actor, - limit=limit, - cursor=request.args.get("_next"), - q=request.args.get("q") or None, - is_write=is_write, - is_private=is_private, - source=request.args.get("source") or None, - owner_id=request.args.get("owner_id") or None, - include_private=True, - ) - query_list_path = self.query_list_path(database) - next_url = None - if page.next: - pairs = [ - (key, value) - for key, value in parse_qsl( - request.query_string, keep_blank_values=True - ) - if key != "_next" - ] - pairs.append(("_next", page.next)) - next_url = "{}?{}".format( - query_list_path, - urlencode(pairs), - ) - - current_filters = { - "actor": request.actor, - "q": request.args.get("q") or None, - "is_write": is_write, - "is_private": is_private, - "source": request.args.get("source") or None, - "owner_id": request.args.get("owner_id") or None, - } - - async def facet_count(field, value): - if current_filters[field] is not None and current_filters[field] != value: - return 0 - filters = dict(current_filters) - filters[field] = value - return await self.ds.count_queries(database, **filters) - - def facet_href(field, value): - if current_filters[field] == value: - return _query_list_url( - query_list_path, - request.query_string, - remove_args=[field], - ) - if current_filters[field] is not None: - return None - return _query_list_url( - query_list_path, - request.query_string, - set_args={field: str(int(value))}, - ) - - async def facet_item(label, field, value): - count = await facet_count(field, value) - active = current_filters[field] == value - if not active and not count: - return None - return { - "label": label, - "count": count, - "href": facet_href(field, value) if active or count else None, - "active": active, - } - - async def facet_items(items): - return [ - item - for item in [ - await facet_item(label, field, value) - for label, field, value in items - ] - if item is not None - ] - - facets = [ - { - "title": "Mode", - "items": await facet_items( - [ - ("Read-only", "is_write", False), - ("Writable", "is_write", True), - ] - ), - }, - { - "title": "Visibility", - "items": await facet_items( - [ - ("Not private", "is_private", False), - ("Private", "is_private", True), - ] - ), - }, - ] - - data = { - "ok": True, - "database": database, - "database_color": ( - self.ds.get_database(database).color if database is not None else None - ), - "queries": page.queries, - "next": page.next, - "next_url": next_url, - "has_more": page.has_more, - "limit": page.limit, - "show_private_note": any(query.is_private for query in page.queries), - "show_trusted_note": any(query.is_trusted for query in page.queries), - "query_list_path": query_list_path, - "show_database": database is None, - "facets": facets, - "filters": { - "q": request.args.get("q") or "", - "is_write": request.args.get("is_write") or "", - "is_private": request.args.get("is_private") or "", - "source": request.args.get("source") or "", - "owner_id": request.args.get("owner_id") or "", - }, - } - if format_ == "json": - return Response.json( - { - **data, - "queries": [stored_query_to_dict(query) for query in page.queries], - } - ) - return await self.render( - ["query_list.html"], - request, - data, - ) - - -class GlobalQueryListView(QueryListView): - name = "global-query-list" - - async def database_name(self, request): - return None - - def query_list_path(self, database): - return self.ds.urls.path("/-/queries") - - -class QueryCreateView(BaseView): - name = "query-create" - has_json_alternate = False - - async def _render_form( - self, - request, - db, - *, - sql="", - name="", - title="", - description="", - is_private=True, - status=200, - ): - response = await self.render( - ["query_create.html"], - request, - await _query_create_form_context( - self.ds, - request, - db, - sql=sql, - name=name, - title=title, - description=description, - is_private=is_private, - ), - ) - response.status = status - return response - - async def get(self, request): - db = await self.ds.resolve_database(request) - await self.ds.ensure_permission( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ) - await self.ds.ensure_permission( - action="store-query", - resource=DatabaseResource(db.name), - actor=request.actor, - ) - - return await self._render_form(request, db, sql=request.args.get("sql") or "") - - -class QueryCreateAnalyzeView(BaseView): - name = "query-create-analyze" - has_json_alternate = False - - async def get(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing(_error(["Permission denied: need execute-sql"], 403)) - if not await self.ds.allowed( - action="store-query", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _block_framing(_error(["Permission denied: need store-query"], 403)) - - invalid_keys = set(request.args) - {"sql"} - if invalid_keys: - return _block_framing( - _error( - ["Invalid keys: {}".format(", ".join(sorted(invalid_keys)))], - 400, - ) - ) - sql = request.args.get("sql") or "" - return _block_framing( - Response.json( - await _query_create_analysis_data(self.ds, db, sql, request.actor) - ) - ) - - -class QueryStoreView(QueryCreateView): - name = "query-store" - - async def _error_response(self, request, db, query_data, message, status): - message = _query_create_form_error_message(message) - self.ds.add_message(request, message, self.ds.ERROR) - return await self._render_form( - request, - db, - sql=query_data.get("sql") or "", - name=query_data.get("name") or "", - title=query_data.get("title") or "", - description=query_data.get("description") or "", - is_private=_as_bool(query_data.get("is_private", True)), - status=status, - ) - - async def post(self, request): - db = await self.ds.resolve_database(request) - if not await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _error(["Permission denied: need execute-sql"], 403) - if not await self.ds.allowed( - action="store-query", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - return _error(["Permission denied: need store-query"], 403) - - is_json = False - query_data = {} - try: - data, is_json = await _json_or_form_payload(request) - if not isinstance(data, dict): - raise QueryValidationError("JSON must be a dictionary") - query_data = data.get("query") if is_json else data - if not isinstance(query_data, dict): - raise QueryValidationError("JSON must contain a query dictionary") - prepared = await _prepare_query_create(self.ds, request, db, query_data) - except QueryValidationError as ex: - if not is_json and isinstance(query_data, dict): - return await self._error_response( - request, db, query_data, ex.message, ex.status - ) - return _error([ex.message], ex.status) - - prepared.pop("analysis") - name = prepared.pop("name") - try: - await self.ds.add_query(db.name, name, replace=False, **prepared) - except sqlite3.IntegrityError as ex: - if not is_json and isinstance(query_data, dict): - return await self._error_response(request, db, query_data, str(ex), 400) - return _error([str(ex)], 400) - - query = await self.ds.get_query(db.name, name) - assert query is not None - if is_json: - return Response.json( - {"ok": True, "query": stored_query_to_dict(query)}, status=201 - ) - self.ds.add_message(request, "Query saved", self.ds.INFO) - return Response.redirect(self.ds.urls.path(self.ds.urls.table(db.name, name))) - - -class QueryDefinitionView(BaseView): - name = "query-definition" - - async def get(self, request): - db = await self.ds.resolve_database(request) - query_name = tilde_decode(request.url_vars["query"]) - query = await self.ds.get_query(db.name, query_name) - if query is None: - return _error(["Query not found: {}".format(query_name)], 404) - if not await self.ds.allowed( - action="view-query", - resource=QueryResource(db.name, query_name), - actor=request.actor, - ): - return _error(["Permission denied"], 403) - return Response.json({"ok": True, "query": stored_query_to_dict(query)}) - - -class QueryUpdateView(BaseView): - name = "query-update" - - async def post(self, request): - db = await self.ds.resolve_database(request) - query_name = tilde_decode(request.url_vars["query"]) - existing = await self.ds.get_query(db.name, query_name) - if existing is None: - return _error(["Query not found: {}".format(query_name)], 404) - if not await self.ds.allowed( - action="update-query", - resource=QueryResource(db.name, query_name), - actor=request.actor, - ): - return _error(["Permission denied: need update-query"], 403) - if existing.is_trusted: - return _error(["Trusted queries cannot be updated using the API"], 403) - - try: - data, _ = await _json_or_form_payload(request) - if not isinstance(data, dict): - raise QueryValidationError("JSON must be a dictionary") - invalid_keys = set(data) - {"update", "return"} - if invalid_keys: - raise QueryValidationError( - "Invalid keys: {}".format(", ".join(invalid_keys)) - ) - update = data.get("update") - if not isinstance(update, dict): - raise QueryValidationError("JSON must contain an update dictionary") - if "sql" in update and not await self.ds.allowed( - action="execute-sql", - resource=DatabaseResource(db.name), - actor=request.actor, - ): - raise QueryValidationError( - "Permission denied: need execute-sql", status=403 - ) - update_kwargs = await _prepare_query_update( - self.ds, request, db, existing, update - ) - except QueryValidationError as ex: - return _error([ex.message], ex.status) - - await self.ds.update_query(db.name, query_name, **update_kwargs) - if data.get("return"): - query = await self.ds.get_query(db.name, query_name) - assert query is not None - return Response.json( - { - "ok": True, - "query": stored_query_to_dict(query), - } - ) - return Response.json({"ok": True}) - - -class QueryDeleteView(BaseView): - name = "query-delete" - - async def post(self, request): - db = await self.ds.resolve_database(request) - query_name = tilde_decode(request.url_vars["query"]) - existing = await self.ds.get_query(db.name, query_name) - if existing is None: - return _error(["Query not found: {}".format(query_name)], 404) - if not await self.ds.allowed( - action="delete-query", - resource=QueryResource(db.name, query_name), - actor=request.actor, - ): - return _error(["Permission denied: need delete-query"], 403) - await self.ds.remove_query(db.name, query_name) - return Response.json({"ok": True}) diff --git a/datasette/views/table.py b/datasette/views/table.py index da69c6b5..c139cf30 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -1,2134 +1,904 @@ -import asyncio -import itertools -import json +import sqlite3 import urllib -from asyncinject import Registry -import markupsafe +import jinja2 +from sanic.exceptions import NotFound +from sanic.request import RequestParameters -from datasette.plugins import pm -from datasette.database import QueryInterrupted -from datasette.events import ( - AlterTableEvent, - DropTableEvent, - InsertRowsEvent, - UpsertRowsEvent, -) -from datasette import tracer -from datasette.resources import DatabaseResource, TableResource from datasette.utils import ( - add_cors_headers, - await_me_maybe, - call_with_supported_arguments, CustomRow, - append_querystring, + Filters, + InterruptedError, compound_keys_after_sql, - format_bytes, - make_slot_function, - tilde_encode, escape_sqlite, filters_should_redirect, is_url, path_from_row_pks, path_with_added_args, - path_with_format, path_with_removed_args, path_with_replaced_args, to_css_class, - truncate_url, urlsafe_components, value_as_boolean, - InvalidSql, - sqlite3, ) -from datasette.utils.asgi import BadRequest, Forbidden, NotFound, Response -from datasette.filters import Filters -import sqlite_utils -from .base import BaseView, DatasetteError, _error, stream_csv -from .database import QueryView -LINK_WITH_LABEL = ( - '{label} {id}' -) -LINK_WITH_VALUE = '{id}' +from .base import BaseView, DatasetteError, ureg + +LINK_WITH_LABEL = '{label} {id}' +LINK_WITH_VALUE = '{id}' -class Row: - def __init__(self, cells): - self.cells = cells +class RowTableShared(BaseView): - def __iter__(self): - return iter(self.cells) - - def __getitem__(self, key): - for cell in self.cells: - if cell["column"] == key: - return cell["raw"] - raise KeyError - - def display(self, key): - for cell in self.cells: - if cell["column"] == key: - return cell["value"] - return None - - def __str__(self): - d = { - key: self[key] - for key in [ - c["column"] for c in self.cells if not c.get("is_special_link_column") - ] - } - return json.dumps(d, default=repr, indent=2) - - -async def run_sequential(*args): - # This used to be swappable for asyncio.gather() to run things in - # parallel, but this lead to hard-to-debug locking issues with - # in-memory databases: https://github.com/simonw/datasette/issues/2189 - results = [] - for fn in args: - results.append(await fn) - return results - - -def _redirect(datasette, request, path, forward_querystring=True, remove_args=None): - if request.query_string and "?" not in path and forward_querystring: - path = f"{path}?{request.query_string}" - if remove_args: - path = path_with_removed_args(request, remove_args, path=path) - r = Response.redirect(path) - r.headers["Link"] = f"<{path}>; rel=preload" - if datasette.cors: - add_cors_headers(r.headers) - return r - - -async def _redirect_if_needed(datasette, request, resolved): - # Handle ?_filter_column - redirect_params = filters_should_redirect(request.args) - if redirect_params: - return _redirect( - datasette, - request, - datasette.urls.path(path_with_added_args(request, redirect_params)), - forward_querystring=False, - ) - - # If ?_sort_by_desc=on (from checkbox) redirect to _sort_desc=(_sort) - if "_sort_by_desc" in request.args: - return _redirect( - datasette, - request, - datasette.urls.path( - path_with_added_args( - request, - { - "_sort_desc": request.args.get("_sort"), - "_sort_by_desc": None, - "_sort": None, - }, - ) - ), - forward_querystring=False, - ) - - -async def _validate_column_types(datasette, database_name, table_name, rows): - """Validate row values against assigned column types. Returns list of error strings.""" - ct_map = await datasette.get_column_types(database_name, table_name) - if not ct_map: - return [] - errors = [] - for row in rows: - for col_name, ct in ct_map.items(): - if col_name not in row: - continue - error = await ct.validate(row[col_name], datasette) - if error: - errors.append(f"{col_name}: {error}") - return errors - - -async def display_columns_and_rows( - datasette, - database_name, - table_name, - description, - rows, - link_column=False, - truncate_cells=0, - sortable_columns=None, - request=None, -): - """Returns columns, rows for specified table - including fancy foreign key treatment""" - sortable_columns = sortable_columns or set() - db = datasette.databases[database_name] - column_descriptions = dict( - await datasette.get_internal_database().execute( - """ - SELECT - column_name, - value - FROM metadata_columns - WHERE database_name = ? - AND resource_name = ? - AND key = 'description' - """, - [database_name, table_name], - ) - ) - - # Look up column types for this table - column_types_map = await datasette.get_column_types(database_name, table_name) - - column_details = { - col.name: col for col in await db.table_column_details(table_name) - } - pks = await db.primary_keys(table_name) - pks_for_display = pks - if not pks_for_display: - pks_for_display = ["rowid"] - - columns = [] - for r in description: - if r[0] == "rowid" and "rowid" not in column_details: - type_ = "integer" - notnull = 0 + def sortable_columns_for_table(self, database, table, use_rowid): + table_metadata = self.table_metadata(database, table) + if "sortable_columns" in table_metadata: + sortable_columns = set(table_metadata["sortable_columns"]) else: - type_ = column_details[r[0]].type - notnull = column_details[r[0]].notnull - col_dict = { - "name": r[0], - "sortable": r[0] in sortable_columns, - "is_pk": r[0] in pks_for_display, - "type": type_, - "notnull": notnull, - "description": column_descriptions.get(r[0]), - "column_type": None, - "column_type_config": None, - } - ct = column_types_map.get(r[0]) - if ct: - col_dict["column_type"] = ct.name - col_dict["column_type_config"] = ct.config - columns.append(col_dict) - - column_to_foreign_key_table = { - fk["column"]: fk["other_table"] - for fk in await db.foreign_keys_for_table(table_name) - } - - cell_rows = [] - base_url = datasette.setting("base_url") - for row in rows: - cells = [] - # Unless we are a view, the first column is a link - either to the rowid - # or to the simple or compound primary key - if link_column: - is_special_link_column = len(pks) != 1 - pk_path = path_from_row_pks(row, pks, not pks, False) - cells.append( - { - "column": pks[0] if len(pks) == 1 else "Link", - "value_type": "pk", - "is_special_link_column": is_special_link_column, - "raw": pk_path, - "value": markupsafe.Markup( - '{flat_pks}'.format( - table_path=datasette.urls.table(database_name, table_name), - flat_pks=str(markupsafe.escape(pk_path)), - flat_pks_quoted=path_from_row_pks(row, pks, not pks), - ) - ), - } - ) - - for value, column_dict in zip(row, columns): - column = column_dict["name"] - if link_column and len(pks) == 1 and column == pks[0]: - # If there's a simple primary key, don't repeat the value as it's - # already shown in the link column. - continue - - # First try column type render_cell, then plugins - # pylint: disable=no-member - plugin_display_value = None - ct = column_types_map.get(column) - if ct: - candidate = await ct.render_cell( - value=value, - column=column, - table=table_name, - database=database_name, - datasette=datasette, - request=request, - ) - if candidate is not None: - plugin_display_value = candidate - if plugin_display_value is None: - for candidate in pm.hook.render_cell( - row=row, - value=value, - column=column, - table=table_name, - pks=pks_for_display, - database=database_name, - datasette=datasette, - request=request, - column_type=ct, - ): - candidate = await await_me_maybe(candidate) - if candidate is not None: - plugin_display_value = candidate - break - if plugin_display_value: - display_value = plugin_display_value - elif isinstance(value, bytes): - formatted = format_bytes(len(value)) - display_value = markupsafe.Markup( - '<Binary: {:,} byte{}>'.format( - datasette.urls.row_blob( - database_name, - table_name, - path_from_row_pks(row, pks, not pks), - column, - ), - ( - ' title="{}"'.format(formatted) - if "bytes" not in formatted - else "" - ), - len(value), - "" if len(value) == 1 else "s", - ) - ) - elif isinstance(value, dict): - # It's an expanded foreign key - display link to other row - label = value["label"] - value = value["value"] - # The table we link to depends on the column - other_table = column_to_foreign_key_table[column] - link_template = LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE - display_value = markupsafe.Markup( - link_template.format( - database=tilde_encode(database_name), - base_url=base_url, - table=tilde_encode(other_table), - link_id=tilde_encode(str(value)), - id=str(markupsafe.escape(value)), - label=str(markupsafe.escape(label)) or "-", - ) - ) - elif value in ("", None): - display_value = markupsafe.Markup(" ") - elif is_url(str(value).strip()): - display_value = markupsafe.Markup( - '{truncated_url}'.format( - url=markupsafe.escape(value.strip()), - truncated_url=markupsafe.escape( - truncate_url(value.strip(), truncate_cells) - ), - ) - ) - else: - display_value = str(value) - if truncate_cells and len(display_value) > truncate_cells: - display_value = display_value[:truncate_cells] + "\u2026" - - cells.append( - { - "column": column, - "value": display_value, - "raw": value, - "value_type": ( - "none" if value is None else str(type(value).__name__) - ), - } - ) - cell_rows.append(Row(cells)) - - if link_column: - # Add the link column header. - # If it's a simple primary key, we have to remove and re-add that column name at - # the beginning of the header row. - first_column = None - if len(pks) == 1: - columns = [col for col in columns if col["name"] != pks[0]] - first_column = { - "name": pks[0], - "sortable": len(pks) == 1, - "is_pk": True, - "type": column_details[pks[0]].type, - "notnull": column_details[pks[0]].notnull, - } - else: - first_column = { - "name": "Link", - "sortable": False, - "is_pk": False, - "type": "", - "notnull": 0, - "is_special_link_column": True, - } - columns = [first_column] + columns - return columns, cell_rows - - -class TableInsertView(BaseView): - name = "table-insert" - - def __init__(self, datasette): - self.ds = datasette - - async def _validate_data(self, request, db, table_name, pks, upsert): - errors = [] - - pks_list = [] - if isinstance(pks, str): - pks_list = [pks] - else: - pks_list = list(pks) - - if not pks_list: - pks_list = ["rowid"] - - def _errors(errors): - return None, errors, {} - - if not request.headers.get("content-type").startswith("application/json"): - # TODO: handle form-encoded data - return _errors(["Invalid content-type, must be application/json"]) - body = await request.post_body() - try: - data = json.loads(body) - except json.JSONDecodeError as e: - return _errors(["Invalid JSON: {}".format(e)]) - if not isinstance(data, dict): - return _errors(["JSON must be a dictionary"]) - keys = data.keys() - - # keys must contain "row" or "rows" - if "row" not in keys and "rows" not in keys: - return _errors(['JSON must have one or other of "row" or "rows"']) - rows = [] - if "row" in keys: - if "rows" in keys: - return _errors(['Cannot use "row" and "rows" at the same time']) - row = data["row"] - if not isinstance(row, dict): - return _errors(['"row" must be a dictionary']) - rows = [row] - data["return"] = True - else: - rows = data["rows"] - if not isinstance(rows, list): - return _errors(['"rows" must be a list']) - for row in rows: - if not isinstance(row, dict): - return _errors(['"rows" must be a list of dictionaries']) - - # Does this exceed max_insert_rows? - max_insert_rows = self.ds.setting("max_insert_rows") - if len(rows) > max_insert_rows: - return _errors( - ["Too many rows, maximum allowed is {}".format(max_insert_rows)] - ) - - # Validate other parameters - extras = { - key: value for key, value in data.items() if key not in ("row", "rows") - } - valid_extras = {"return", "ignore", "replace", "alter"} - invalid_extras = extras.keys() - valid_extras - if invalid_extras: - return _errors( - ['Invalid parameter: "{}"'.format('", "'.join(sorted(invalid_extras)))] - ) - if extras.get("ignore") and extras.get("replace"): - return _errors(['Cannot use "ignore" and "replace" at the same time']) - - columns = set(await db.table_columns(table_name)) - columns.update(pks_list) - - for i, row in enumerate(rows): - if upsert: - # It MUST have the primary key - missing_pks = [pk for pk in pks_list if pk not in row] - if missing_pks: - errors.append( - 'Row {} is missing primary key column(s): "{}"'.format( - i, '", "'.join(missing_pks) - ) - ) - null_pks = [pk for pk in pks_list if pk in row and row[pk] is None] - if null_pks: - errors.append( - 'Row {} has null primary key column(s): "{}"'.format( - i, '", "'.join(null_pks) - ) - ) - invalid_columns = set(row.keys()) - columns - if invalid_columns and not extras.get("alter"): - errors.append( - "Row {} has invalid columns: {}".format( - i, ", ".join(sorted(invalid_columns)) - ) - ) - if errors: - return _errors(errors) - return rows, errors, extras - - async def post(self, request, upsert=False): - try: - resolved = await self.ds.resolve_table(request) - except NotFound as e: - return _error([e.args[0]], 404) - db = resolved.db - database_name = db.name - table_name = resolved.table - - # Table must exist (may handle table creation in the future) - db = self.ds.get_database(database_name) - if not await db.table_exists(table_name): - return _error(["Table not found: {}".format(table_name)], 404) - - if upsert: - # Must have insert-row AND upsert-row permissions - if not ( - await self.ds.allowed( - action="insert-row", - resource=TableResource(database=database_name, table=table_name), - actor=request.actor, - ) - and await self.ds.allowed( - action="update-row", - resource=TableResource(database=database_name, table=table_name), - actor=request.actor, - ) - ): - return _error( - ["Permission denied: need both insert-row and update-row"], 403 - ) - else: - # Must have insert-row permission - if not await self.ds.allowed( - action="insert-row", - resource=TableResource(database=database_name, table=table_name), - actor=request.actor, - ): - return _error(["Permission denied"], 403) - - if not db.is_mutable: - return _error(["Database is immutable"], 403) - - pks = await db.primary_keys(table_name) - - rows, errors, extras = await self._validate_data( - request, db, table_name, pks, upsert - ) - if errors: - return _error(errors, 400) - - # Validate column types - ct_errors = await _validate_column_types( - self.ds, database_name, table_name, rows - ) - if ct_errors: - return _error(ct_errors, 400) - - num_rows = len(rows) - - # No that we've passed pks to _validate_data it's safe to - # fix the rowids case: - if not pks: - pks = ["rowid"] - - ignore = extras.get("ignore") - replace = extras.get("replace") - alter = extras.get("alter") - - if upsert and (ignore or replace): - return _error(["Upsert does not support ignore or replace"], 400) - - if replace and not await self.ds.allowed( - action="update-row", - resource=TableResource(database=database_name, table=table_name), - actor=request.actor, - ): - return _error(['Permission denied: need update-row to use "replace"'], 403) - - initial_schema = None - if alter: - # Must have alter-table permission - if not await self.ds.allowed( - action="alter-table", - resource=TableResource(database=database_name, table=table_name), - actor=request.actor, - ): - return _error(["Permission denied for alter-table"], 403) - # Track initial schema to check if it changed later - initial_schema = await db.execute_fn( - lambda conn: sqlite_utils.Database(conn)[table_name].schema - ) - - should_return = bool(extras.get("return", False)) - row_pk_values_for_later = [] - if should_return and upsert: - row_pk_values_for_later = [tuple(row[pk] for pk in pks) for row in rows] - - def insert_or_upsert_rows(conn): - table = sqlite_utils.Database(conn)[table_name] - kwargs = {} - if upsert: - kwargs = { - "pk": pks[0] if len(pks) == 1 else pks, - "alter": alter, - } - else: - # Insert - kwargs = {"ignore": ignore, "replace": replace, "alter": alter} - if should_return and not upsert: - rowids = [] - method = table.upsert if upsert else table.insert - for row in rows: - rowids.append(method(row, **kwargs).last_rowid) - return list( - table.rows_where( - "rowid in ({})".format(",".join("?" for _ in rowids)), - rowids, - ) - ) - else: - method_all = table.upsert_all if upsert else table.insert_all - method_all(rows, **kwargs) - - try: - rows = await db.execute_write_fn(insert_or_upsert_rows, request=request) - except Exception as e: - return _error([str(e)]) - result = {"ok": True} - if should_return: - if upsert: - # Fetch based on initial input IDs - where_clause = " OR ".join( - ["({})".format(" AND ".join("{} = ?".format(pk) for pk in pks))] - * len(row_pk_values_for_later) - ) - args = list(itertools.chain.from_iterable(row_pk_values_for_later)) - fetched_rows = await db.execute( - "select {}* from [{}] where {}".format( - "rowid, " if pks == ["rowid"] else "", table_name, where_clause - ), - args, - ) - result["rows"] = fetched_rows.dicts() - else: - result["rows"] = rows - # We track the number of rows requested, but do not attempt to show which were actually - # inserted or upserted v.s. ignored - if upsert: - await self.ds.track_event( - UpsertRowsEvent( - actor=request.actor, - database=database_name, - table=table_name, - num_rows=num_rows, - ) - ) - else: - await self.ds.track_event( - InsertRowsEvent( - actor=request.actor, - database=database_name, - table=table_name, - num_rows=num_rows, - ignore=bool(ignore), - replace=bool(replace), - ) - ) - - if initial_schema is not None: - after_schema = await db.execute_fn( - lambda conn: sqlite_utils.Database(conn)[table_name].schema - ) - if initial_schema != after_schema: - await self.ds.track_event( - AlterTableEvent( - request.actor, - database=database_name, - table=table_name, - before_schema=initial_schema, - after_schema=after_schema, - ) - ) - - return Response.json(result, status=200 if upsert else 201) - - -class TableUpsertView(TableInsertView): - name = "table-upsert" - - async def post(self, request): - return await super().post(request, upsert=True) - - -class TableSetColumnTypeView(BaseView): - name = "table-set-column-type" - - def __init__(self, datasette): - self.ds = datasette - - async def post(self, request): - try: - resolved = await self.ds.resolve_table(request) - except NotFound as e: - return _error([e.args[0]], 404) - - database_name = resolved.db.name - table_name = resolved.table - - if not await self.ds.allowed( - action="set-column-type", - resource=TableResource(database=database_name, table=table_name), - actor=request.actor, - ): - return _error(["Permission denied"], 403) - - content_type = request.headers.get("content-type") or "" - if not content_type.startswith("application/json"): - return _error(["Invalid content-type, must be application/json"], 400) - - try: - data = json.loads(await request.post_body()) - except json.JSONDecodeError as e: - return _error(["Invalid JSON: {}".format(e)], 400) - - if not isinstance(data, dict): - return _error(["JSON must be a dictionary"], 400) - - invalid_keys = set(data.keys()) - {"column", "column_type"} - if invalid_keys: - return _error( - ['Invalid parameter: "{}"'.format('", "'.join(sorted(invalid_keys)))], - 400, - ) - - if "column" not in data: - return _error(['"column" is required'], 400) - column = data["column"] - if not isinstance(column, str): - return _error(['"column" must be a string'], 400) - - if "column_type" not in data: - return _error(['"column_type" is required'], 400) - - column_details = await self.ds._get_resource_column_details( - database_name, table_name - ) - if column not in column_details: - return _error(["Column not found: {}".format(column)], 400) - - column_type_data = data["column_type"] - if column_type_data is None: - await self.ds.remove_column_type(database_name, table_name, column) - return Response.json( - { - "ok": True, - "database": database_name, - "table": table_name, - "column": column, - "column_type": None, - }, - status=200, - ) - - if not isinstance(column_type_data, dict): - return _error(['"column_type" must be an object or null'], 400) - - invalid_column_type_keys = set(column_type_data.keys()) - {"type", "config"} - if invalid_column_type_keys: - return _error( - [ - 'Invalid column_type parameter: "{}"'.format( - '", "'.join(sorted(invalid_column_type_keys)) - ) - ], - 400, - ) - - if "type" not in column_type_data: - return _error(['"column_type.type" is required'], 400) - column_type = column_type_data["type"] - if not isinstance(column_type, str): - return _error(['"column_type.type" must be a string'], 400) - - config = column_type_data.get("config") - if config is not None and not isinstance(config, dict): - return _error(['"column_type.config" must be a dictionary'], 400) - - if column_type not in self.ds._column_types: - return _error(["Unknown column type: {}".format(column_type)], 400) - - try: - await self.ds.set_column_type( - database_name, table_name, column, column_type, config - ) - except ValueError as e: - return _error([str(e)], 400) - - return Response.json( - { - "ok": True, - "database": database_name, - "table": table_name, - "column": column, - "column_type": {"type": column_type, "config": config}, - }, - status=200, - ) - - -class TableDropView(BaseView): - name = "table-drop" - - def __init__(self, datasette): - self.ds = datasette - - async def post(self, request): - try: - resolved = await self.ds.resolve_table(request) - except NotFound as e: - return _error([e.args[0]], 404) - db = resolved.db - database_name = db.name - table_name = resolved.table - # Table must exist - db = self.ds.get_database(database_name) - if not await db.table_exists(table_name): - return _error(["Table not found: {}".format(table_name)], 404) - if not await self.ds.allowed( - action="drop-table", - resource=TableResource(database=database_name, table=table_name), - actor=request.actor, - ): - return _error(["Permission denied"], 403) - if not db.is_mutable: - return _error(["Database is immutable"], 403) - confirm = False - try: - data = json.loads(await request.post_body()) - confirm = data.get("confirm") - except json.JSONDecodeError: - pass - - if not confirm: - return Response.json( - { - "ok": True, - "database": database_name, - "table": table_name, - "row_count": ( - await db.execute("select count(*) from [{}]".format(table_name)) - ).single_value(), - "message": 'Pass "confirm": true to confirm', - }, - status=200, - ) - - # Drop table - def drop_table(conn): - sqlite_utils.Database(conn)[table_name].drop() - - await db.execute_write_fn(drop_table, request=request) - await self.ds.track_event( - DropTableEvent( - actor=request.actor, database=database_name, table=table_name - ) - ) - return Response.json({"ok": True}, status=200) - - -def _get_extras(request): - extra_bits = request.args.getlist("_extra") - extras = set() - for bit in extra_bits: - extras.update(bit.split(",")) - return extras - - -async def _columns_to_select(table_columns, pks, request): - columns = list(table_columns) - if "_col" in request.args: - columns = list(pks) - _cols = request.args.getlist("_col") - bad_columns = [column for column in _cols if column not in table_columns] - if bad_columns: - raise DatasetteError( - "_col={} - invalid columns".format(", ".join(bad_columns)), - status=400, - ) - # De-duplicate maintaining order: - columns.extend(dict.fromkeys(_cols)) - if "_nocol" in request.args: - # Return all columns EXCEPT these - bad_columns = [ - column - for column in request.args.getlist("_nocol") - if (column not in table_columns) or (column in pks) - ] - if bad_columns: - raise DatasetteError( - "_nocol={} - invalid columns".format(", ".join(bad_columns)), - status=400, - ) - tmp_columns = [ - column for column in columns if column not in request.args.getlist("_nocol") - ] - columns = tmp_columns - return columns - - -async def _sortable_columns_for_table(datasette, database_name, table_name, use_rowid): - db = datasette.databases[database_name] - table_metadata = await datasette.table_config(database_name, table_name) - if "sortable_columns" in table_metadata: - sortable_columns = set(table_metadata["sortable_columns"]) - else: - sortable_columns = set(await db.table_columns(table_name)) - if use_rowid: - sortable_columns.add("rowid") - return sortable_columns - - -async def _sort_order(table_metadata, sortable_columns, request, order_by): - sort = request.args.get("_sort") - sort_desc = request.args.get("_sort_desc") - - if not sort and not sort_desc: - sort = table_metadata.get("sort") - sort_desc = table_metadata.get("sort_desc") - - if sort and sort_desc: - raise DatasetteError( - "Cannot use _sort and _sort_desc at the same time", status=400 - ) - - if sort: - if sort not in sortable_columns: - raise DatasetteError(f"Cannot sort table by {sort}", status=400) - - order_by = escape_sqlite(sort) - - if sort_desc: - if sort_desc not in sortable_columns: - raise DatasetteError(f"Cannot sort table by {sort_desc}", status=400) - - order_by = f"{escape_sqlite(sort_desc)} desc" - - return sort, sort_desc, order_by - - -async def table_view(datasette, request): - await datasette.refresh_schemas() - with tracer.trace_child_tasks(): - response = await table_view_traced(datasette, request) - - # CORS - if datasette.cors: - add_cors_headers(response.headers) - - # Cache TTL header - ttl = request.args.get("_ttl", None) - if ttl is None or not ttl.isdigit(): - ttl = datasette.setting("default_cache_ttl") - - if datasette.cache_headers and response.status == 200: - ttl = int(ttl) - if ttl == 0: - ttl_header = "no-cache" - else: - ttl_header = f"max-age={ttl}" - response.headers["Cache-Control"] = ttl_header - - # Referrer policy - response.headers["Referrer-Policy"] = "no-referrer" - - return response - - -async def table_view_traced(datasette, request): - from datasette.app import TableNotFound - - try: - resolved = await datasette.resolve_table(request) - except TableNotFound as not_found: - # Was this actually a stored query? - stored_query = await datasette.get_query( - not_found.database_name, not_found.table - ) - # If this is a stored query, not a table, then dispatch to QueryView instead - if stored_query: - return await QueryView()(request, datasette) - else: - raise - - if request.method == "POST": - return Response.text("Method not allowed", status=405) - - format_ = request.url_vars.get("format") or "html" - extra_extras = None - context_for_html_hack = False - default_labels = False - if format_ == "html": - extra_extras = {"_html"} - context_for_html_hack = True - default_labels = True - - view_data = await table_view_data( - datasette, - request, - resolved, - extra_extras=extra_extras, - context_for_html_hack=context_for_html_hack, - default_labels=default_labels, - ) - if isinstance(view_data, Response): - return view_data - data, rows, columns, expanded_columns, sql, next_url = view_data - - # Handle formats from plugins - if format_ == "csv": - - async def fetch_data(request, _next=None): - ( - data, - rows, - columns, - expanded_columns, - sql, - next_url, - ) = await table_view_data( - datasette, - request, - resolved, - extra_extras=extra_extras, - context_for_html_hack=context_for_html_hack, - default_labels=default_labels, - _next=_next, - ) - data["rows"] = rows - data["table"] = resolved.table - data["columns"] = columns - data["expanded_columns"] = expanded_columns - return data, None, None - - return await stream_csv(datasette, fetch_data, request, resolved.db.name) - elif format_ in datasette.renderers.keys(): - # Dispatch request to the correct output format renderer - # (CSV is not handled here due to streaming) - result = call_with_supported_arguments( - datasette.renderers[format_][0], - datasette=datasette, - columns=columns, - rows=rows, - sql=sql, - query_name=None, - database=resolved.db.name, - table=resolved.table, - request=request, - view_name="table", - truncated=False, - error=None, - # These will be deprecated in Datasette 1.0: - args=request.args, - data=data, - ) - if asyncio.iscoroutine(result): - result = await result - if result is None: - raise NotFound("No data") - if isinstance(result, dict): - r = Response( - body=result.get("body"), - status=result.get("status_code") or 200, - content_type=result.get("content_type", "text/plain"), - headers=result.get("headers"), - ) - elif isinstance(result, Response): - r = result - # if status_code is not None: - # # Over-ride the status code - # r.status = status_code - else: - assert False, f"{result} should be dict or Response" - elif format_ == "html": - headers = {} - templates = [ - f"table-{to_css_class(resolved.db.name)}-{to_css_class(resolved.table)}.html", - "table.html", - ] - environment = datasette.get_jinja_environment(request) - template = environment.select_template(templates) - alternate_url_json = datasette.absolute_url( - request, - datasette.urls.path(path_with_format(request=request, format="json")), - ) - headers.update( - { - "Link": '<{}>; rel="alternate"; type="application/json+datasette"'.format( - alternate_url_json - ) - } - ) - r = Response.html( - await datasette.render_template( - template, - dict( - data, - append_querystring=append_querystring, - path_with_replaced_args=path_with_replaced_args, - fix_path=datasette.urls.path, - settings=datasette.settings_dict(), - # TODO: review up all of these hacks: - alternate_url_json=alternate_url_json, - datasette_allow_facet=( - "true" if datasette.setting("allow_facet") else "false" - ), - is_sortable=any(c["sortable"] for c in data["display_columns"]), - allow_execute_sql=await datasette.allowed( - action="execute-sql", - resource=DatabaseResource(database=resolved.db.name), - actor=request.actor, - ), - query_ms=1.2, - select_templates=[ - f"{'*' if template_name == template.name else ''}{template_name}" - for template_name in templates - ], - top_table=make_slot_function( - "top_table", - datasette, - request, - database=resolved.db.name, - table=resolved.table, - ), - count_limit=resolved.db.count_limit, - ), - request=request, - view_name="table", - ), - headers=headers, - ) - else: - assert False, "Invalid format: {}".format(format_) - if next_url: - r.headers["link"] = f'<{next_url}>; rel="next"' - return r - - -async def table_view_data( - datasette, - request, - resolved, - extra_extras=None, - context_for_html_hack=False, - default_labels=False, - _next=None, -): - extra_extras = extra_extras or set() - # We have a table or view - db = resolved.db - database_name = resolved.db.name - table_name = resolved.table - is_view = resolved.is_view - - # Can this user view it? - visible, private = await datasette.check_visibility( - request.actor, - action="view-table", - resource=TableResource(database=database_name, table=table_name), - ) - if not visible: - raise Forbidden("You do not have permission to view this table") - - # Redirect based on request.args, if necessary - redirect_response = await _redirect_if_needed(datasette, request, resolved) - if redirect_response: - return redirect_response - - # Introspect columns and primary keys for table - pks = await db.primary_keys(table_name) - table_columns = await db.table_columns(table_name) - - # Take ?_col= and ?_nocol= into account - specified_columns = await _columns_to_select(table_columns, pks, request) - select_specified_columns = ", ".join(escape_sqlite(t) for t in specified_columns) - select_all_columns = ", ".join(escape_sqlite(t) for t in table_columns) - - # rowid tables (no specified primary key) need a different SELECT - use_rowid = not pks and not is_view - order_by = "" - if use_rowid: - select_specified_columns = f"rowid, {select_specified_columns}" - select_all_columns = f"rowid, {select_all_columns}" - order_by = "rowid" - order_by_pks = "rowid" - else: - order_by_pks = ", ".join([escape_sqlite(pk) for pk in pks]) - order_by = order_by_pks - - if is_view: - order_by = "" - - # TODO: This logic should turn into logic about which ?_extras get - # executed instead: - nocount = request.args.get("_nocount") - nofacet = request.args.get("_nofacet") - nosuggest = request.args.get("_nosuggest") - if request.args.get("_shape") in ("array", "object"): - nocount = True - nofacet = True - - table_metadata = await datasette.table_config(database_name, table_name) - - # Arguments that start with _ and don't contain a __ are - # special - things like ?_search= - and should not be - # treated as filters. - filter_args = [] - for key in request.args: - if not (key.startswith("_") and "__" not in key): - for v in request.args.getlist(key): - filter_args.append((key, v)) - - # Build where clauses from query string arguments - filters = Filters(sorted(filter_args)) - where_clauses, params = filters.build_where_clauses(table_name) - - # Execute filters_from_request plugin hooks - including the default - # ones that live in datasette/filters.py - extra_context_from_filters = {} - extra_human_descriptions = [] - - for hook in pm.hook.filters_from_request( - request=request, - table=table_name, - database=database_name, - datasette=datasette, - ): - filter_arguments = await await_me_maybe(hook) - if filter_arguments: - where_clauses.extend(filter_arguments.where_clauses) - params.update(filter_arguments.params) - extra_human_descriptions.extend(filter_arguments.human_descriptions) - extra_context_from_filters.update(filter_arguments.extra_context) - - # Deal with custom sort orders - sortable_columns = await _sortable_columns_for_table( - datasette, database_name, table_name, use_rowid - ) - - sort, sort_desc, order_by = await _sort_order( - table_metadata, sortable_columns, request, order_by - ) - - from_sql = "from {table_name} {where}".format( - table_name=escape_sqlite(table_name), - where=( - ("where {} ".format(" and ".join(where_clauses))) if where_clauses else "" - ), - ) - # Copy of params so we can mutate them later: - from_sql_params = dict(**params) - - count_sql = f"select count(*) {from_sql}" - - # Handle pagination driven by ?_next= - _next = _next or request.args.get("_next") - - offset = "" - if _next: - sort_value = None - if is_view: - # _next is an offset - offset = f" offset {int(_next)}" - else: - components = urlsafe_components(_next) - # If a sort order is applied and there are multiple components, - # the first of these is the sort value - if (sort or sort_desc) and (len(components) > 1): - sort_value = components[0] - # Special case for if non-urlencoded first token was $null - if _next.split(",")[0] == "$null": - sort_value = None - components = components[1:] - - # Figure out the SQL for next-based-on-primary-key first - next_by_pk_clauses = [] - if use_rowid: - next_by_pk_clauses.append(f"rowid > :p{len(params)}") - params[f"p{len(params)}"] = components[0] - else: - # Apply the tie-breaker based on primary keys - if len(components) == len(pks): - param_len = len(params) - next_by_pk_clauses.append(compound_keys_after_sql(pks, param_len)) - for i, pk_value in enumerate(components): - params[f"p{param_len + i}"] = pk_value - - # Now add the sort SQL, which may incorporate next_by_pk_clauses - if sort or sort_desc: - if sort_value is None: - if sort_desc: - # Just items where column is null ordered by pk - where_clauses.append( - "({column} is null and {next_clauses})".format( - column=escape_sqlite(sort_desc), - next_clauses=" and ".join(next_by_pk_clauses), - ) - ) - else: - where_clauses.append( - "({column} is not null or ({column} is null and {next_clauses}))".format( - column=escape_sqlite(sort), - next_clauses=" and ".join(next_by_pk_clauses), - ) - ) - else: - where_clauses.append( - "({column} {op} :p{p}{extra_desc_only} or ({column} = :p{p} and {next_clauses}))".format( - column=escape_sqlite(sort or sort_desc), - op=">" if sort else "<", - p=len(params), - extra_desc_only=( - "" - if sort - else " or {column2} is null".format( - column2=escape_sqlite(sort or sort_desc) - ) - ), - next_clauses=" and ".join(next_by_pk_clauses), - ) - ) - params[f"p{len(params)}"] = sort_value - order_by = f"{order_by}, {order_by_pks}" - else: - where_clauses.extend(next_by_pk_clauses) - - where_clause = "" - if where_clauses: - where_clause = f"where {' and '.join(where_clauses)} " - - if order_by: - order_by = f"order by {order_by}" - - extra_args = {} - # Handle ?_size=500 - # TODO: This was: - # page_size = _size or request.args.get("_size") or table_metadata.get("size") - page_size = request.args.get("_size") or table_metadata.get("size") - if page_size: - if page_size == "max": - page_size = datasette.max_returned_rows - try: - page_size = int(page_size) - if page_size < 0: - raise ValueError - - except ValueError: - raise BadRequest("_size must be a positive integer") - - if page_size > datasette.max_returned_rows: - raise BadRequest(f"_size must be <= {datasette.max_returned_rows}") - - extra_args["page_size"] = page_size - else: - page_size = datasette.page_size - - # Facets are calculated against SQL without order by or limit - sql_no_order_no_limit = ( - "select {select_all_columns} from {table_name} {where}".format( - select_all_columns=select_all_columns, - table_name=escape_sqlite(table_name), - where=where_clause, - ) - ) - - # This is the SQL that populates the main table on the page - sql = "select {select_specified_columns} from {table_name} {where}{order_by} limit {page_size}{offset}".format( - select_specified_columns=select_specified_columns, - table_name=escape_sqlite(table_name), - where=where_clause, - order_by=order_by, - page_size=page_size + 1, - offset=offset, - ) - - if request.args.get("_timelimit"): - extra_args["custom_time_limit"] = int(request.args.get("_timelimit")) - - # Execute the main query! - try: - results = await db.execute(sql, params, truncate=True, **extra_args) - except (sqlite3.OperationalError, InvalidSql) as e: - raise DatasetteError(str(e), title="Invalid SQL", status=400) - - except sqlite3.OperationalError as e: - raise DatasetteError(str(e)) - - columns = [r[0] for r in results.description] - rows = list(results.rows) - - # Expand labeled columns if requested - expanded_columns = [] - # List of (fk_dict, label_column-or-None) pairs for that table - expandable_columns = [] - for fk in await db.foreign_keys_for_table(table_name): - label_column = await db.label_column_for_table(fk["other_table"]) - expandable_columns.append((fk, label_column)) - - columns_to_expand = None - try: - all_labels = value_as_boolean(request.args.get("_labels", "")) - except ValueError: - all_labels = default_labels - # Check for explicit _label= - if "_label" in request.args: - columns_to_expand = request.args.getlist("_label") - if columns_to_expand is None and all_labels: - # expand all columns with foreign keys - columns_to_expand = [fk["column"] for fk, _ in expandable_columns] - - if columns_to_expand: - expanded_labels = {} - for fk, _ in expandable_columns: - column = fk["column"] - if column not in columns_to_expand: - continue - if column not in columns: - continue - expanded_columns.append(column) - # Gather the values - column_index = columns.index(column) - values = [row[column_index] for row in rows] - # Expand them - expanded_labels.update( - await datasette.expand_foreign_keys( - request.actor, database_name, table_name, column, values - ) - ) - if expanded_labels: - # Rewrite the rows - new_rows = [] - for row in rows: - new_row = CustomRow(columns) - for column in row.keys(): - value = row[column] - if (column, value) in expanded_labels and value is not None: - new_row[column] = { - "value": value, - "label": expanded_labels[(column, value)], - } - else: - new_row[column] = value - new_rows.append(new_row) - rows = new_rows - - _next = request.args.get("_next") - - # Pagination next link - next_value, next_url = await _next_value_and_url( - datasette, - db, - request, - table_name, - _next, - rows, - pks, - use_rowid, - sort, - sort_desc, - page_size, - is_view, - ) - rows = rows[:page_size] - - # Resolve extras - extras = _get_extras(request) - if any(k for k in request.args.keys() if k == "_facet" or k.startswith("_facet_")): - extras.add("facet_results") - if request.args.get("_shape") == "object": - extras.add("primary_keys") - if extra_extras: - extras.update(extra_extras) - - async def extra_count_sql(): - return count_sql - - async def extra_count(): - "Total count of rows matching these filters" - # Calculate the total count for this query - count = None - if ( - not db.is_mutable - and datasette.inspect_data - and count_sql == f"select count(*) from {table_name} " - ): - # We can use a previously cached table row count - try: - count = datasette.inspect_data[database_name]["tables"][table_name][ - "count" - ] - except KeyError: - pass - - # Otherwise run a select count(*) ... - if count_sql and count is None and not nocount: - count_sql_limited = ( - f"select count(*) from (select * {from_sql} limit 10001)" - ) - try: - count_rows = list(await db.execute(count_sql_limited, from_sql_params)) - count = count_rows[0][0] - except QueryInterrupted: - pass - return count - - async def facet_instances(extra_count): - facet_instances = [] - facet_classes = list( - itertools.chain.from_iterable(pm.hook.register_facet_classes()) - ) - for facet_class in facet_classes: - facet_instances.append( - facet_class( - datasette, - request, - database_name, - sql=sql_no_order_no_limit, - params=params, - table=table_name, - table_config=table_metadata, - row_count=extra_count, - ) - ) - return facet_instances - - async def extra_facet_results(facet_instances): - "Results of facets calculated against this data" - facet_results = {} - facets_timed_out = [] - - if not nofacet: - # Run them in parallel - facet_awaitables = [facet.facet_results() for facet in facet_instances] - facet_awaitable_results = await run_sequential(*facet_awaitables) - for ( - instance_facet_results, - instance_facets_timed_out, - ) in facet_awaitable_results: - for facet_info in instance_facet_results: - base_key = facet_info["name"] - key = base_key - i = 1 - while key in facet_results: - i += 1 - key = f"{base_key}_{i}" - facet_results[key] = facet_info - facets_timed_out.extend(instance_facets_timed_out) - - return { - "results": facet_results, - "timed_out": facets_timed_out, - } - - async def extra_suggested_facets(facet_instances): - "Suggestions for facets that might return interesting results" - suggested_facets = [] - # Calculate suggested facets - if ( - datasette.setting("suggest_facets") - and datasette.setting("allow_facet") - and not _next - and not nofacet - and not nosuggest - ): - # Run them in parallel - facet_suggest_awaitables = [facet.suggest() for facet in facet_instances] - for suggest_result in await run_sequential(*facet_suggest_awaitables): - suggested_facets.extend(suggest_result) - return suggested_facets - - # Faceting - if not datasette.setting("allow_facet") and any( - arg.startswith("_facet") for arg in request.args - ): - raise BadRequest("_facet= is not allowed") - - # human_description_en combines filters AND search, if provided - async def extra_human_description_en(): - "Human-readable description of the filters" - human_description_en = filters.human_description_en( - extra=extra_human_descriptions - ) - if sort or sort_desc: - human_description_en = " ".join( - [b for b in [human_description_en, sorted_by] if b] - ) - return human_description_en - - if sort or sort_desc: - sorted_by = "sorted by {}{}".format( - (sort or sort_desc), " descending" if sort_desc else "" - ) - - async def extra_next_url(): - "Full URL for the next page of results" - return next_url - - async def extra_columns(): - "Column names returned by this query" - return columns - - async def extra_all_columns(): - "All columns in the table, regardless of _col/_nocol filtering" - return list(table_columns) - - async def extra_primary_keys(): - "Primary keys for this table" - return pks - - async def extra_actions(): - async def actions(): - links = [] - kwargs = { - "datasette": datasette, - "database": database_name, - "actor": request.actor, - "request": request, - } - if is_view: - kwargs["view"] = table_name - method = pm.hook.view_actions - else: - kwargs["table"] = table_name - method = pm.hook.table_actions - for hook in method(**kwargs): - extra_links = await await_me_maybe(hook) - if extra_links: - links.extend(extra_links) - return links - - return actions - - async def extra_is_view(): - return is_view - - async def extra_debug(): - "Extra debug information" - return { - "resolved": repr(resolved), - "url_vars": request.url_vars, - "nofacet": nofacet, - "nosuggest": nosuggest, - } - - async def extra_request(): - "Full information about the request" - return { - "url": request.url, - "path": request.path, - "full_path": request.full_path, - "host": request.host, - "args": request.args._data, - } - - async def run_display_columns_and_rows(): - display_columns, display_rows = await display_columns_and_rows( - datasette, - database_name, - table_name, - results.description, - rows, - link_column=not is_view, - truncate_cells=datasette.setting("truncate_cells_html"), - sortable_columns=sortable_columns, - request=request, - ) - return { - "columns": display_columns, - "rows": display_rows, - } - - async def extra_display_columns(run_display_columns_and_rows): - return run_display_columns_and_rows["columns"] - - async def extra_display_rows(run_display_columns_and_rows): - return run_display_columns_and_rows["rows"] - - async def extra_render_cell(): - "Rendered HTML for each cell using the render_cell plugin hook" - pks_for_display = pks if pks else (["rowid"] if not is_view else []) - col_names = [col[0] for col in results.description] - ct_map = await datasette.get_column_types(database_name, table_name) - rendered_rows = [] - for row in rows: - rendered_row = {} - for value, column in zip(row, col_names): - ct = ct_map.get(column) - plugin_display_value = None - # Try column type render_cell first - if ct: - candidate = await ct.render_cell( - value=value, - column=column, - table=table_name, - database=database_name, - datasette=datasette, - request=request, - ) - if candidate is not None: - plugin_display_value = candidate - if plugin_display_value is None: - for candidate in pm.hook.render_cell( - row=row, - value=value, - column=column, - table=table_name, - pks=pks_for_display, - database=database_name, - datasette=datasette, - request=request, - column_type=ct, - ): - candidate = await await_me_maybe(candidate) - if candidate is not None: - plugin_display_value = candidate - break - if plugin_display_value: - rendered_row[column] = str(plugin_display_value) - rendered_rows.append(rendered_row) - return rendered_rows - - async def extra_query(): - "Details of the underlying SQL query" - return { - "sql": sql, - "params": params, - } - - async def extra_column_types(): - "Column type assignments for this table" - ct_map = await datasette.get_column_types(database_name, table_name) - return { - col_name: { - "type": ct.name, - "config": ct.config, - } - for col_name, ct in ct_map.items() - } - - async def extra_set_column_type_ui(): - "Column type UI metadata for this table" - if is_view: - return None - - if not await datasette.allowed( - action="set-column-type", - resource=TableResource(database=database_name, table=table_name), - actor=request.actor, - ): - return None - - column_details = await datasette._get_resource_column_details( - database_name, table_name - ) - ct_map = await datasette.get_column_types(database_name, table_name) - columns = {} - for column_name, column_detail in column_details.items(): - current = ct_map.get(column_name) - columns[column_name] = { - "current": ( - {"type": current.name, "config": current.config} - if current is not None - else None - ), - "options": [ - { - "name": name, - "description": ct_cls.description, - } - for name, ct_cls in sorted(datasette._column_types.items()) - if datasette._column_type_is_applicable(ct_cls, column_detail) - ], - } - return { - "path": "{}/-/set-column-type".format( - datasette.urls.table(database_name, table_name) - ), - "columns": columns, - } - - async def extra_metadata(): - "Metadata about the table and database" - tablemetadata = await datasette.get_resource_metadata(database_name, table_name) - - rows = await datasette.get_internal_database().execute( - """ - SELECT - column_name, - value - FROM metadata_columns - WHERE database_name = ? - AND resource_name = ? - AND key = 'description' - """, - [database_name, table_name], - ) - tablemetadata["columns"] = dict(rows) - return tablemetadata - - async def extra_database(): - return database_name - - async def extra_table(): - return table_name - - async def extra_database_color(): - return db.color - - async def extra_form_hidden_args(): - form_hidden_args = [] - for key in request.args: - if ( - key.startswith("_") - and key not in ("_sort", "_sort_desc", "_search", "_next") - and "__" not in key - ): - for value in request.args.getlist(key): - form_hidden_args.append((key, value)) - return form_hidden_args - - async def extra_filters(): - return filters - - async def extra_custom_table_templates(): - return [ - f"_table-{to_css_class(database_name)}-{to_css_class(table_name)}.html", - f"_table-table-{to_css_class(database_name)}-{to_css_class(table_name)}.html", - "_table.html", - ] - - async def extra_sorted_facet_results(extra_facet_results): - facet_configs = table_metadata.get("facets", []) - if facet_configs: - # Build ordered list of facet names from metadata config - metadata_facet_names = [] - for fc in facet_configs: - if isinstance(fc, str): - metadata_facet_names.append(fc) - elif isinstance(fc, dict): - metadata_facet_names.append(list(fc.values())[0]) - metadata_order = {name: i for i, name in enumerate(metadata_facet_names)} - metadata_facets = [] - request_facets = [] - for f in extra_facet_results["results"].values(): - if f["name"] in metadata_order: - metadata_facets.append(f) - else: - request_facets.append(f) - metadata_facets.sort(key=lambda f: metadata_order[f["name"]]) - request_facets.sort( - key=lambda f: (len(f["results"]), f["name"]), - reverse=True, - ) - return metadata_facets + request_facets - else: - return sorted( - extra_facet_results["results"].values(), - key=lambda f: (len(f["results"]), f["name"]), - reverse=True, - ) - - async def extra_table_definition(): - return await db.get_table_definition(table_name) - - async def extra_view_definition(): - return await db.get_view_definition(table_name) - - async def extra_renderers(extra_expandable_columns, extra_query): - renderers = {} - url_labels_extra = {} - if extra_expandable_columns: - url_labels_extra = {"_labels": "on"} - for key, (_, can_render) in datasette.renderers.items(): - it_can_render = call_with_supported_arguments( - can_render, - datasette=datasette, - columns=columns or [], - rows=rows or [], - sql=extra_query.get("sql", None), - query_name=None, - database=database_name, - table=table_name, - request=request, - view_name="table", - ) - it_can_render = await await_me_maybe(it_can_render) - if it_can_render: - renderers[key] = datasette.urls.path( - path_with_format( - request=request, format=key, extra_qs={**url_labels_extra} - ) - ) - return renderers - - async def extra_private(): - return private - - async def extra_expandable_columns(): + table_info = self.ds.inspect()[database]["tables"].get(table) or {} + sortable_columns = set(table_info.get("columns", [])) + if use_rowid: + sortable_columns.add("rowid") + return sortable_columns + + def expandable_columns(self, database, table): + # Returns list of (fk_dict, label_column-or-None) pairs for that table + tables = self.ds.inspect()[database].get("tables", {}) + table_info = tables.get(table) + if not table_info: + return [] expandables = [] - db = datasette.databases[database_name] - for fk in await db.foreign_keys_for_table(table_name): - label_column = await db.label_column_for_table(fk["other_table"]) + for fk in table_info["foreign_keys"]["outgoing"]: + label_column = ( + self.table_metadata( + database, fk["other_table"] + ).get("label_column") + or tables.get(fk["other_table"], {}).get("label_column") + ) or None expandables.append((fk, label_column)) return expandables - async def extra_extras(): - "Available ?_extra= blocks" - all_extras = [ - (key[len("extra_") :], fn.__doc__) - for key, fn in registry._registry.items() - if key.startswith("extra_") - ] - return [ - { - "name": name, - "description": doc, - "toggle_url": datasette.absolute_url( - request, - datasette.urls.path( - path_with_added_args(request, {"_extra": name}) - if name not in extras - else path_with_removed_args(request, {"_extra": name}) - ), - ), - "selected": name in extras, + async def expand_foreign_keys(self, database, table, column, values): + "Returns dict mapping (column, value) -> label" + labeled_fks = {} + tables_info = self.ds.inspect()[database]["tables"] + table_info = tables_info.get(table) or {} + if not table_info: + return {} + foreign_keys = table_info["foreign_keys"]["outgoing"] + # Find the foreign_key for this column + try: + fk = [ + foreign_key for foreign_key in foreign_keys + if foreign_key["column"] == column + ][0] + except IndexError: + return {} + label_column = ( + # First look in metadata.json for this foreign key table: + self.table_metadata( + database, fk["other_table"] + ).get("label_column") + or tables_info.get(fk["other_table"], {}).get("label_column") + ) + if not label_column: + return { + (fk["column"], value): str(value) + for value in values } - for name, doc in all_extras + labeled_fks = {} + sql = ''' + select {other_column}, {label_column} + from {other_table} + where {other_column} in ({placeholders}) + '''.format( + other_column=escape_sqlite(fk["other_column"]), + label_column=escape_sqlite(label_column), + other_table=escape_sqlite(fk["other_table"]), + placeholders=", ".join(["?"] * len(set(values))), + ) + try: + results = await self.ds.execute( + database, sql, list(set(values)) + ) + except InterruptedError: + pass + else: + for id, value in results: + labeled_fks[(fk["column"], id)] = value + return labeled_fks + + async def display_columns_and_rows( + self, + database, + table, + description, + rows, + link_column=False, + ): + "Returns columns, rows for specified table - including fancy foreign key treatment" + table_metadata = self.table_metadata(database, table) + info = self.ds.inspect()[database] + sortable_columns = self.sortable_columns_for_table(database, table, True) + columns = [ + {"name": r[0], "sortable": r[0] in sortable_columns} for r in description ] - - async def extra_facets_timed_out(extra_facet_results): - return extra_facet_results["timed_out"] - - bundles = { - "html": [ - "suggested_facets", - "facet_results", - "facets_timed_out", - "count", - "count_sql", - "human_description_en", - "next_url", - "metadata", - "query", - "columns", - "display_columns", - "display_rows", - "database", - "table", - "database_color", - "actions", - "filters", - "renderers", - "custom_table_templates", - "sorted_facet_results", - "table_definition", - "view_definition", - "is_view", - "private", - "primary_keys", - "all_columns", - "expandable_columns", - "form_hidden_args", - "set_column_type_ui", - ] - } - - for key, values in bundles.items(): - if f"_{key}" in extras: - extras.update(values) - extras.discard(f"_{key}") - - registry = Registry( - extra_count, - extra_count_sql, - extra_facet_results, - extra_facets_timed_out, - extra_suggested_facets, - facet_instances, - extra_human_description_en, - extra_next_url, - extra_columns, - extra_all_columns, - extra_primary_keys, - run_display_columns_and_rows, - extra_display_columns, - extra_display_rows, - extra_render_cell, - extra_debug, - extra_request, - extra_query, - extra_column_types, - extra_set_column_type_ui, - extra_metadata, - extra_extras, - extra_database, - extra_table, - extra_database_color, - extra_actions, - extra_filters, - extra_renderers, - extra_custom_table_templates, - extra_sorted_facet_results, - extra_table_definition, - extra_view_definition, - extra_is_view, - extra_private, - extra_expandable_columns, - extra_form_hidden_args, - ) - - results = await registry.resolve_multi( - ["extra_{}".format(extra) for extra in extras] - ) - data = { - "ok": True, - "next": next_value and str(next_value) or None, - } - data.update( - { - key.replace("extra_", ""): value - for key, value in results.items() - if key.startswith("extra_") and key.replace("extra_", "") in extras + tables = info["tables"] + table_info = tables.get(table) or {} + pks = table_info.get("primary_keys") or [] + column_to_foreign_key_table = { + fk["column"]: fk["other_table"] + for fk in table_info.get( + "foreign_keys", {} + ).get("outgoing", None) or [] } - ) - raw_sqlite_rows = rows[:page_size] - # Apply transform_value for columns with assigned types - ct_map = await datasette.get_column_types(database_name, table_name) - transformed_rows = [] - for r in raw_sqlite_rows: - row_dict = dict(r) - for col_name, ct in ct_map.items(): - if col_name in row_dict: - row_dict[col_name] = await ct.transform_value( - row_dict[col_name], datasette + + cell_rows = [] + for row in rows: + cells = [] + # Unless we are a view, the first column is a link - either to the rowid + # or to the simple or compound primary key + if link_column: + cells.append( + { + "column": pks[0] if len(pks) == 1 else "Link", + "value": jinja2.Markup( + '{flat_pks}'.format( + database=database, + table=urllib.parse.quote_plus(table), + flat_pks=str( + jinja2.escape( + path_from_row_pks(row, pks, not pks, False) + ) + ), + flat_pks_quoted=path_from_row_pks(row, pks, not pks), + ) + ), + } ) - transformed_rows.append(row_dict) - data["rows"] = transformed_rows - if context_for_html_hack: - data.update(extra_context_from_filters) - # filter_columns combine the columns we know are available - # in the table with any additional columns (such as rowid) - # which are available in the query - data["filter_columns"] = list(columns) + [ - table_column - for table_column in table_columns - if table_column not in columns - ] - url_labels_extra = {} - if data.get("expandable_columns"): - url_labels_extra = {"_labels": "on"} - url_csv_args = {"_size": "max", **url_labels_extra} - url_csv = datasette.urls.path( - path_with_format(request=request, format="csv", extra_qs=url_csv_args) - ) - url_csv_path = url_csv.split("?")[0] - data.update( - { - "url_csv": url_csv, - "url_csv_path": url_csv_path, - "url_csv_hidden_args": [ - (key, value) - for key, value in urllib.parse.parse_qsl(request.query_string) - if key not in ("_labels", "_facet", "_size") - ] - + [("_size", "max")], - } - ) - # if no sort specified AND table has a single primary key, - # set sort to that so arrow is displayed - if not sort and not sort_desc: - if 1 == len(pks): - sort = pks[0] - elif use_rowid: - sort = "rowid" - data["sort"] = sort - data["sort_desc"] = sort_desc + for value, column_dict in zip(row, columns): + column = column_dict["name"] + if link_column and len(pks) == 1 and column == pks[0]: + # If there's a simple primary key, don't repeat the value as it's + # already shown in the link column. + continue - return data, rows[:page_size], columns, expanded_columns, sql, next_url - - -async def _next_value_and_url( - datasette, - db, - request, - table_name, - _next, - rows, - pks, - use_rowid, - sort, - sort_desc, - page_size, - is_view, -): - next_value = None - next_url = None - if 0 < page_size < len(rows): - if is_view: - next_value = int(_next or 0) + page_size - else: - next_value = path_from_row_pks(rows[-2], pks, use_rowid) - # If there's a sort or sort_desc, add that value as a prefix - if (sort or sort_desc) and not is_view: - try: - prefix = rows[-2][sort or sort_desc] - except IndexError: - # sort/sort_desc column missing from SELECT - look up value by PK instead - prefix_where_clause = " and ".join( - "[{}] = :pk{}".format(pk, i) for i, pk in enumerate(pks) - ) - prefix_lookup_sql = "select [{}] from [{}] where {}".format( - sort or sort_desc, table_name, prefix_where_clause - ) - prefix = ( - await db.execute( - prefix_lookup_sql, - { - **{ - "pk{}".format(i): rows[-2][pk] - for i, pk in enumerate(pks) - } - }, + if isinstance(value, dict): + # It's an expanded foreign key - display link to other row + label = value["label"] + value = value["value"] + # The table we link to depends on the column + other_table = column_to_foreign_key_table[column] + link_template = ( + LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE ) - ).single_value() - if isinstance(prefix, dict) and "value" in prefix: - prefix = prefix["value"] - if prefix is None: - prefix = "$null" - else: - prefix = tilde_encode(str(prefix)) - next_value = f"{prefix},{next_value}" - added_args = {"_next": next_value} - if sort: - added_args["_sort"] = sort - else: - added_args["_sort_desc"] = sort_desc + display_value = jinja2.Markup(link_template.format( + database=database, + table=urllib.parse.quote_plus(other_table), + link_id=urllib.parse.quote_plus(str(value)), + id=str(jinja2.escape(value)), + label=str(jinja2.escape(label)), + )) + elif value is None: + display_value = jinja2.Markup(" ") + elif is_url(str(value).strip()): + display_value = jinja2.Markup( + '{url}'.format( + url=jinja2.escape(value.strip()) + ) + ) + elif column in table_metadata.get("units", {}) and value != "": + # Interpret units using pint + value = value * ureg(table_metadata["units"][column]) + # Pint uses floating point which sometimes introduces errors in the compact + # representation, which we have to round off to avoid ugliness. In the vast + # majority of cases this rounding will be inconsequential. I hope. + value = round(value.to_compact(), 6) + display_value = jinja2.Markup( + "{:~P}".format(value).replace(" ", " ") + ) + else: + display_value = str(value) + + cells.append({"column": column, "value": display_value}) + cell_rows.append(cells) + + if link_column: + # Add the link column header. + # If it's a simple primary key, we have to remove and re-add that column name at + # the beginning of the header row. + if len(pks) == 1: + columns = [col for col in columns if col["name"] != pks[0]] + + columns = [ + {"name": pks[0] if len(pks) == 1 else "Link", "sortable": len(pks) == 1} + ] + columns + return columns, cell_rows + + +class TableView(RowTableShared): + + async def data(self, qs, name, hash, table, default_labels=False): + canned_query = self.ds.get_canned_query(name, table) + if canned_query is not None: + return await self.custom_sql( + qs, + name, + hash, + canned_query["sql"], + editable=False, + canned_query=table, + ) + + is_view = bool(await self.ds.get_view_definition(name, table)) + info = self.ds.inspect() + table_info = info[name]["tables"].get(table) or {} + if not is_view and not table_info: + raise NotFound("Table not found: {}".format(table)) + + pks = table_info.get("primary_keys") or [] + use_rowid = not pks and not is_view + if use_rowid: + select = "rowid, *" + order_by = "rowid" + order_by_pks = "rowid" else: - added_args = {"_next": next_value} - next_url = datasette.absolute_url( - request, datasette.urls.path(path_with_replaced_args(request, added_args)) + select = "*" + order_by_pks = ", ".join([escape_sqlite(pk) for pk in pks]) + order_by = order_by_pks + + if is_view: + order_by = "" + + # We roll our own query_string decoder because by default Sanic + # drops anything with an empty value e.g. ?name__exact= + args = RequestParameters( + urllib.parse.parse_qs(str(qs), keep_blank_values=True) ) - return next_value, next_url + + # Special args start with _ and do not contain a __ + # That's so if there is a column that starts with _ + # it can still be queried using ?_col__exact=blah + special_args = {} + special_args_lists = {} + other_args = {} + for key, value in args.items(): + if key.startswith("_") and "__" not in key: + special_args[key] = value[0] + special_args_lists[key] = value + else: + other_args[key] = value[0] + + # Handle ?_filter_column and redirect, if present + redirect_params = filters_should_redirect(special_args) + if redirect_params: + return self.redirect( + qs, + path_with_added_args(qs, redirect_params), + forward_querystring=False, + ) + + # Spot ?_sort_by_desc and redirect to _sort_desc=(_sort) + if "_sort_by_desc" in special_args: + return self.redirect( + qs, + path_with_added_args( + qs, + { + "_sort_desc": special_args.get("_sort"), + "_sort_by_desc": None, + "_sort": None, + }, + ), + forward_querystring=False, + ) + + table_metadata = self.table_metadata(name, table) + units = table_metadata.get("units", {}) + filters = Filters(sorted(other_args.items()), units, ureg) + where_clauses, params = filters.build_where_clauses() + + # _search support: + fts_table = info[name]["tables"].get(table, {}).get("fts_table") + search_args = dict( + pair for pair in special_args.items() if pair[0].startswith("_search") + ) + search_descriptions = [] + search = "" + if fts_table and search_args: + if "_search" in search_args: + # Simple ?_search=xxx + search = search_args["_search"] + where_clauses.append( + "rowid in (select rowid from {fts_table} where {fts_table} match :search)".format( + fts_table=escape_sqlite(fts_table), + ) + ) + search_descriptions.append('search matches "{}"'.format(search)) + params["search"] = search + else: + # More complex: search against specific columns + valid_columns = set(info[name]["tables"][fts_table]["columns"]) + for i, (key, search_text) in enumerate(search_args.items()): + search_col = key.split("_search_", 1)[1] + if search_col not in valid_columns: + raise DatasetteError("Cannot search by that column", status=400) + + where_clauses.append( + "rowid in (select rowid from {fts_table} where {search_col} match :search_{i})".format( + fts_table=escape_sqlite(fts_table), + search_col=escape_sqlite(search_col), + i=i + ) + ) + search_descriptions.append( + 'search column "{}" matches "{}"'.format( + search_col, search_text + ) + ) + params["search_{}".format(i)] = search_text + + table_rows_count = None + sortable_columns = set() + if not is_view: + table_rows_count = table_info["count"] + sortable_columns = self.sortable_columns_for_table(name, table, use_rowid) + + # Allow for custom sort order + sort = special_args.get("_sort") + if sort: + if sort not in sortable_columns: + raise DatasetteError("Cannot sort table by {}".format(sort)) + + order_by = escape_sqlite(sort) + sort_desc = special_args.get("_sort_desc") + if sort_desc: + if sort_desc not in sortable_columns: + raise DatasetteError("Cannot sort table by {}".format(sort_desc)) + + if sort: + raise DatasetteError("Cannot use _sort and _sort_desc at the same time") + + order_by = "{} desc".format(escape_sqlite(sort_desc)) + + from_sql = "from {table_name} {where}".format( + table_name=escape_sqlite(table), + where=( + "where {} ".format(" and ".join(where_clauses)) + ) if where_clauses else "", + ) + # Store current params and where_clauses for later: + from_sql_params = dict(**params) + from_sql_where_clauses = where_clauses[:] + + count_sql = "select count(*) {}".format(from_sql) + + _next = special_args.get("_next") + offset = "" + if _next: + if is_view: + # _next is an offset + offset = " offset {}".format(int(_next)) + else: + components = urlsafe_components(_next) + # If a sort order is applied, the first of these is the sort value + if sort or sort_desc: + sort_value = components[0] + # Special case for if non-urlencoded first token was $null + if _next.split(",")[0] == "$null": + sort_value = None + components = components[1:] + + # Figure out the SQL for next-based-on-primary-key first + next_by_pk_clauses = [] + if use_rowid: + next_by_pk_clauses.append("rowid > :p{}".format(len(params))) + params["p{}".format(len(params))] = components[0] + else: + # Apply the tie-breaker based on primary keys + if len(components) == len(pks): + param_len = len(params) + next_by_pk_clauses.append( + compound_keys_after_sql(pks, param_len) + ) + for i, pk_value in enumerate(components): + params["p{}".format(param_len + i)] = pk_value + + # Now add the sort SQL, which may incorporate next_by_pk_clauses + if sort or sort_desc: + if sort_value is None: + if sort_desc: + # Just items where column is null ordered by pk + where_clauses.append( + "({column} is null and {next_clauses})".format( + column=escape_sqlite(sort_desc), + next_clauses=" and ".join(next_by_pk_clauses), + ) + ) + else: + where_clauses.append( + "({column} is not null or ({column} is null and {next_clauses}))".format( + column=escape_sqlite(sort), + next_clauses=" and ".join(next_by_pk_clauses), + ) + ) + else: + where_clauses.append( + "({column} {op} :p{p}{extra_desc_only} or ({column} = :p{p} and {next_clauses}))".format( + column=escape_sqlite(sort or sort_desc), + op=">" if sort else "<", + p=len(params), + extra_desc_only="" if sort else " or {column2} is null".format( + column2=escape_sqlite(sort or sort_desc) + ), + next_clauses=" and ".join(next_by_pk_clauses), + ) + ) + params["p{}".format(len(params))] = sort_value + order_by = "{}, {}".format(order_by, order_by_pks) + else: + where_clauses.extend(next_by_pk_clauses) + + where_clause = "" + if where_clauses: + where_clause = "where {} ".format(" and ".join(where_clauses)) + + if order_by: + order_by = "order by {} ".format(order_by) + + # _group_count=col1&_group_count=col2 + group_count = special_args_lists.get("_group_count") or [] + if group_count: + sql = 'select {group_cols}, count(*) as "count" from {table_name} {where} group by {group_cols} order by "count" desc limit 100'.format( + group_cols=", ".join( + '"{}"'.format(group_count_col) for group_count_col in group_count + ), + table_name=escape_sqlite(table), + where=where_clause, + ) + return await self.custom_sql(qs, name, hash, sql, editable=True) + + extra_args = {} + # Handle ?_size=500 + page_size = qs.first_or_none("_size") + if page_size: + if page_size == "max": + page_size = self.max_returned_rows + try: + page_size = int(page_size) + if page_size < 0: + raise ValueError + + except ValueError: + raise DatasetteError("_size must be a positive integer", status=400) + + if page_size > self.max_returned_rows: + raise DatasetteError( + "_size must be <= {}".format(self.max_returned_rows), status=400 + ) + + extra_args["page_size"] = page_size + else: + page_size = self.page_size + + sql = "select {select} from {table_name} {where}{order_by}limit {limit}{offset}".format( + select=select, + table_name=escape_sqlite(table), + where=where_clause, + order_by=order_by, + limit=page_size + 1, + offset=offset, + ) + + if qs.first_or_none("_timelimit"): + extra_args["custom_time_limit"] = int(qs.first("_timelimit")) + + results = await self.ds.execute( + name, sql, params, truncate=True, **extra_args + ) + + # facets support + facet_size = self.ds.config["default_facet_size"] + metadata_facets = table_metadata.get("facets", []) + facets = metadata_facets[:] + if qs.first_or_none("_facet") and not self.ds.config["allow_facet"]: + raise DatasetteError("_facet= is not allowed", status=400) + try: + facets.extend(qs.getlist("_facet")) + except KeyError: + pass + facet_results = {} + facets_timed_out = [] + for column in facets: + facet_sql = """ + select {col} as value, count(*) as count + {from_sql} {and_or_where} {col} is not null + group by {col} order by count desc limit {limit} + """.format( + col=escape_sqlite(column), + from_sql=from_sql, + and_or_where='and' if from_sql_where_clauses else 'where', + limit=facet_size+1, + ) + try: + facet_rows_results = await self.ds.execute( + name, facet_sql, params, + truncate=False, + custom_time_limit=self.ds.config["facet_time_limit_ms"], + ) + facet_results_values = [] + facet_results[column] = { + "name": column, + "results": facet_results_values, + "truncated": len(facet_rows_results) > facet_size, + } + facet_rows = facet_rows_results.rows[:facet_size] + # Attempt to expand foreign keys into labels + values = [row["value"] for row in facet_rows] + expanded = (await self.expand_foreign_keys( + name, table, column, values + )) + for row in facet_rows: + selected = str(other_args.get(column)) == str(row["value"]) + if selected: + toggle_path = path_with_removed_args( + qs, {column: str(row["value"])} + ) + else: + toggle_path = path_with_added_args( + qs, {column: row["value"]} + ) + facet_results_values.append({ + "value": row["value"], + "label": expanded.get( + (column, row["value"]), + row["value"] + ), + "count": row["count"], + "toggle_url": urllib.parse.urljoin( + qs.path, toggle_path + ), + "selected": selected, + }) + except InterruptedError: + facets_timed_out.append(column) + + columns = [r[0] for r in results.description] + rows = list(results.rows) + + filter_columns = columns[:] + if use_rowid and filter_columns[0] == "rowid": + filter_columns = filter_columns[1:] + + # Expand labeled columns if requested + columns_expanded = [] + expandable_columns = self.expandable_columns(name, table) + columns_to_expand = None + try: + all_labels = value_as_boolean(special_args.get("_labels", "")) + except ValueError: + all_labels = default_labels + # Check for explicit _label= + if qs.first_or_none("_label"): + columns_to_expand = qs.first("_label") + if columns_to_expand is None and all_labels: + # expand all columns with foreign keys + columns_to_expand = [ + fk["column"] for fk, _ in expandable_columns + ] + + if columns_to_expand: + expanded_labels = {} + for fk, label_column in expandable_columns: + column = fk["column"] + if column not in columns_to_expand: + continue + columns_expanded.append(column) + # Gather the values + column_index = columns.index(column) + values = [row[column_index] for row in rows] + # Expand them + expanded_labels.update(await self.expand_foreign_keys( + name, table, column, values + )) + if expanded_labels: + # Rewrite the rows + new_rows = [] + for row in rows: + new_row = CustomRow(columns) + for column in row.keys(): + value = row[column] + if (column, value) in expanded_labels: + new_row[column] = { + 'value': value, + 'label': expanded_labels[(column, value)] + } + else: + new_row[column] = value + new_rows.append(new_row) + rows = new_rows + + # Pagination next link + next_value = None + next_url = None + if len(rows) > page_size and page_size > 0: + if is_view: + next_value = int(_next or 0) + page_size + else: + next_value = path_from_row_pks(rows[-2], pks, use_rowid) + # If there's a sort or sort_desc, add that value as a prefix + if (sort or sort_desc) and not is_view: + prefix = rows[-2][sort or sort_desc] + if prefix is None: + prefix = "$null" + else: + prefix = urllib.parse.quote_plus(str(prefix)) + next_value = "{},{}".format(prefix, next_value) + added_args = {"_next": next_value} + if sort: + added_args["_sort"] = sort + else: + added_args["_sort_desc"] = sort_desc + else: + added_args = {"_next": next_value} + next_url = urllib.parse.urljoin( + qs.path, path_with_replaced_args(qs, added_args) + ) + rows = rows[:page_size] + + # Number of filtered rows in whole set: + filtered_table_rows_count = None + if count_sql: + try: + count_rows = list(await self.ds.execute( + name, count_sql, from_sql_params + )) + filtered_table_rows_count = count_rows[0][0] + except InterruptedError: + pass + + # Detect suggested facets + suggested_facets = [] + if self.ds.config["suggest_facets"] and self.ds.config["allow_facet"]: + for facet_column in columns: + if facet_column in facets: + continue + if not self.ds.config["suggest_facets"]: + continue + suggested_facet_sql = ''' + select distinct {column} {from_sql} + {and_or_where} {column} is not null + limit {limit} + '''.format( + column=escape_sqlite(facet_column), + from_sql=from_sql, + and_or_where='and' if from_sql_where_clauses else 'where', + limit=facet_size+1 + ) + distinct_values = None + try: + distinct_values = await self.ds.execute( + name, suggested_facet_sql, from_sql_params, + truncate=False, + custom_time_limit=self.ds.config["facet_suggest_time_limit_ms"], + ) + num_distinct_values = len(distinct_values) + if ( + num_distinct_values and + num_distinct_values > 1 and + num_distinct_values <= facet_size and + num_distinct_values < filtered_table_rows_count + ): + suggested_facets.append({ + 'name': facet_column, + 'toggle_url': path_with_added_args( + qs, {'_facet': facet_column} + ), + }) + except InterruptedError: + pass + + # human_description_en combines filters AND search, if provided + human_description_en = filters.human_description_en(extra=search_descriptions) + + if sort or sort_desc: + sorted_by = "sorted by {}{}".format( + (sort or sort_desc), " descending" if sort_desc else "" + ) + human_description_en = " ".join( + [b for b in [human_description_en, sorted_by] if b] + ) + + async def extra_template(): + display_columns, display_rows = await self.display_columns_and_rows( + name, + table, + results.description, + rows, + link_column=not is_view, + ) + metadata = self.ds.metadata.get("databases", {}).get(name, {}).get( + "tables", {} + ).get( + table, {} + ) + self.ds.update_with_inherited_metadata(metadata) + return { + "database_hash": hash, + "supports_search": bool(fts_table), + "search": search or "", + "use_rowid": use_rowid, + "filters": filters, + "display_columns": display_columns, + "filter_columns": filter_columns, + "display_rows": display_rows, + "facets_timed_out": facets_timed_out, + "sorted_facet_results": sorted( + facet_results.values(), + key=lambda f: (len(f["results"]), f["name"]), + reverse=True + ), + "facet_hideable": lambda facet: facet not in metadata_facets, + "is_sortable": any(c["sortable"] for c in display_columns), + "path_with_replaced_args": path_with_replaced_args, + "path_with_removed_args": path_with_removed_args, + "qs": qs, + "sort": sort, + "sort_desc": sort_desc, + "disable_sort": is_view, + "custom_rows_and_columns_templates": [ + "_rows_and_columns-{}-{}.html".format( + to_css_class(name), to_css_class(table) + ), + "_rows_and_columns-table-{}-{}.html".format( + to_css_class(name), to_css_class(table) + ), + "_rows_and_columns.html", + ], + "metadata": metadata, + "view_definition": await self.ds.get_view_definition(name, table), + "table_definition": await self.ds.get_table_definition(name, table), + } + + return { + "database": name, + "table": table, + "is_view": is_view, + "human_description_en": human_description_en, + "rows": rows[:page_size], + "truncated": results.truncated, + "table_rows_count": table_rows_count, + "filtered_table_rows_count": filtered_table_rows_count, + "columns_expanded": columns_expanded, + "columns": columns, + "primary_keys": pks, + "units": units, + "query": {"sql": sql, "params": params}, + "facet_results": facet_results, + "suggested_facets": suggested_facets, + "next": next_value and str(next_value) or None, + "next_url": next_url, + }, extra_template, ( + "table-{}-{}.html".format(to_css_class(name), to_css_class(table)), + "table.html", + ) + + +class RowView(RowTableShared): + + async def data(self, qs, name, hash, table, pk_path, default_labels=False): + pk_values = urlsafe_components(pk_path) + info = self.ds.inspect()[name] + table_info = info["tables"].get(table) or {} + pks = table_info.get("primary_keys") or [] + use_rowid = not pks + select = "*" + if use_rowid: + select = "rowid, *" + pks = ["rowid"] + wheres = ['"{}"=:p{}'.format(pk, i) for i, pk in enumerate(pks)] + sql = 'select {} from "{}" where {}'.format(select, table, " AND ".join(wheres)) + params = {} + for i, pk_value in enumerate(pk_values): + params["p{}".format(i)] = pk_value + results = await self.ds.execute( + name, sql, params, truncate=True + ) + columns = [r[0] for r in results.description] + rows = list(results.rows) + if not rows: + raise NotFound("Record not found: {}".format(pk_values)) + + async def template_data(): + display_columns, display_rows = await self.display_columns_and_rows( + name, + table, + results.description, + rows, + link_column=False, + ) + for column in display_columns: + column["sortable"] = False + return { + "database_hash": hash, + "foreign_key_tables": await self.foreign_key_tables( + name, table, pk_values + ), + "display_columns": display_columns, + "display_rows": display_rows, + "custom_rows_and_columns_templates": [ + "_rows_and_columns-{}-{}.html".format( + to_css_class(name), to_css_class(table) + ), + "_rows_and_columns-row-{}-{}.html".format( + to_css_class(name), to_css_class(table) + ), + "_rows_and_columns.html", + ], + "metadata": self.ds.metadata.get("databases", {}).get(name, {}).get( + "tables", {} + ).get( + table, {} + ), + } + + data = { + "database": name, + "table": table, + "rows": rows, + "columns": columns, + "primary_keys": pks, + "primary_key_values": pk_values, + "units": self.table_metadata(name, table).get("units", {}), + } + + if "foreign_key_tables" in (qs.first_or_none("_extras") or "").split(","): + data["foreign_key_tables"] = await self.foreign_key_tables( + name, table, pk_values + ) + + return data, template_data, ( + "row-{}-{}.html".format(to_css_class(name), to_css_class(table)), "row.html" + ) + + async def foreign_key_tables(self, name, table, pk_values): + if len(pk_values) != 1: + return [] + + table_info = self.ds.inspect()[name]["tables"].get(table) + if not table_info: + return [] + + foreign_keys = table_info["foreign_keys"]["incoming"] + if len(foreign_keys) == 0: + return [] + + sql = "select " + ", ".join( + [ + '(select count(*) from {table} where {column}=:id)'.format( + table=escape_sqlite(fk["other_table"]), + column=escape_sqlite(fk["other_column"]), + ) + for fk in foreign_keys + ] + ) + try: + rows = list(await self.ds.execute(name, sql, {"id": pk_values[0]})) + except sqlite3.OperationalError: + # Almost certainly hit the timeout + return [] + + foreign_table_counts = dict( + zip( + [(fk["other_table"], fk["other_column"]) for fk in foreign_keys], + list(rows[0]), + ) + ) + foreign_key_tables = [] + for fk in foreign_keys: + count = foreign_table_counts.get( + (fk["other_table"], fk["other_column"]) + ) or 0 + foreign_key_tables.append({**fk, **{"count": count}}) + return foreign_key_tables diff --git a/demos/apache-proxy/000-default.conf b/demos/apache-proxy/000-default.conf deleted file mode 100644 index 5b6607a3..00000000 --- a/demos/apache-proxy/000-default.conf +++ /dev/null @@ -1,13 +0,0 @@ - - Options Indexes FollowSymLinks - AllowOverride None - Require all granted - - - - ServerName localhost - DocumentRoot /app/html - ProxyPreserveHost On - ProxyPass /prefix/ http://127.0.0.1:8001/ - Header add X-Proxied-By "Apache2 Debian" - diff --git a/demos/apache-proxy/Dockerfile b/demos/apache-proxy/Dockerfile deleted file mode 100644 index 9a8448da..00000000 --- a/demos/apache-proxy/Dockerfile +++ /dev/null @@ -1,56 +0,0 @@ -FROM python:3.11.0-slim-bullseye - -RUN apt-get update && \ - apt-get install -y apache2 supervisor && \ - apt clean && \ - rm -rf /var/lib/apt && \ - rm -rf /var/lib/dpkg/info/* - -# Apache environment, copied from -# https://github.com/ijklim/laravel-benfords-law-app/blob/e9bf385dcaddb62ea466a7b245ab6e4ef708c313/docker/os/Dockerfile -ENV APACHE_DOCUMENT_ROOT=/var/www/html/public -ENV APACHE_RUN_USER www-data -ENV APACHE_RUN_GROUP www-data -ENV APACHE_PID_FILE /var/run/apache2.pid -ENV APACHE_RUN_DIR /var/run/apache2 -ENV APACHE_LOCK_DIR /var/lock/apache2 -ENV APACHE_LOG_DIR /var/log -RUN ln -sf /dev/stdout /var/log/apache2-access.log -RUN ln -sf /dev/stderr /var/log/apache2-error.log -RUN mkdir -p $APACHE_RUN_DIR $APACHE_LOCK_DIR - -RUN a2enmod proxy -RUN a2enmod proxy_http -RUN a2enmod headers - -ARG DATASETTE_REF - -RUN pip install \ - https://github.com/simonw/datasette/archive/${DATASETTE_REF}.zip \ - datasette-redirect-to-https datasette-debug-asgi - -ADD 000-default.conf /etc/apache2/sites-enabled/000-default.conf - -WORKDIR /app -RUN mkdir -p /app/html -RUN echo '

Demo is at /prefix/

' > /app/html/index.html - -ADD https://latest.datasette.io/fixtures.db /app/fixtures.db - -EXPOSE 80 - -# Dynamically build supervisord config since it includes $DATASETTE_REF: -RUN echo "[supervisord]" >> /app/supervisord.conf -RUN echo "nodaemon=true" >> /app/supervisord.conf -RUN echo "" >> /app/supervisord.conf -RUN echo "[program:apache2]" >> /app/supervisord.conf -RUN echo "command=apache2 -D FOREGROUND" >> /app/supervisord.conf -RUN echo "stdout_logfile=/dev/stdout" >> /app/supervisord.conf -RUN echo "stdout_logfile_maxbytes=0" >> /app/supervisord.conf -RUN echo "" >> /app/supervisord.conf -RUN echo "[program:datasette]" >> /app/supervisord.conf -RUN echo "command=datasette /app/fixtures.db --setting base_url '/prefix/' --version-note '${DATASETTE_REF}' -h 0.0.0.0 -p 8001" >> /app/supervisord.conf -RUN echo "stdout_logfile=/dev/stdout" >> /app/supervisord.conf -RUN echo "stdout_logfile_maxbytes=0" >> /app/supervisord.conf - -CMD ["/usr/bin/supervisord", "-c", "/app/supervisord.conf"] diff --git a/demos/apache-proxy/README.md b/demos/apache-proxy/README.md deleted file mode 100644 index c76e440d..00000000 --- a/demos/apache-proxy/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# Datasette running behind an Apache proxy - -See also [Running Datasette behind a proxy](https://docs.datasette.io/en/latest/deploying.html#running-datasette-behind-a-proxy) - -This live demo is running at https://datasette-apache-proxy-demo.fly.dev/prefix/ - -To build locally, passing in a Datasette commit hash (or `main` for the main branch): - - docker build -t datasette-apache-proxy-demo . \ - --build-arg DATASETTE_REF=c617e1769ea27e045b0f2907ef49a9a1244e577d - -Then run it like this: - - docker run -p 5000:80 datasette-apache-proxy-demo - -And visit `http://localhost:5000/` or `http://localhost:5000/prefix/` - -## Deployment to Fly - -To deploy to [Fly](https://fly.io/) first create an application there by running: - - flyctl apps create --name datasette-apache-proxy-demo - -You will need a different name, since I have already taken that one. - -Then run this command to deploy: - - flyctl deploy --build-arg DATASETTE_REF=main - -This uses `fly.toml` in this directory, which hard-codes the `datasette-apache-proxy-demo` name - so you would need to edit that file to match your application name before running this. - -## Deployment to Cloud Run - -Deployments to Cloud Run currently result in intermittent 503 errors and I'm not sure why, see [issue #1522](https://github.com/simonw/datasette/issues/1522). - -You can deploy like this: - - DATASETTE_REF=main ./deploy-to-cloud-run.sh diff --git a/demos/apache-proxy/deploy-to-cloud-run.sh b/demos/apache-proxy/deploy-to-cloud-run.sh deleted file mode 100755 index 2846590a..00000000 --- a/demos/apache-proxy/deploy-to-cloud-run.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -# https://til.simonwillison.net/cloudrun/using-build-args-with-cloud-run - -if [[ -z "$DATASETTE_REF" ]]; then - echo "Must provide DATASETTE_REF environment variable" 1>&2 - exit 1 -fi - -NAME="datasette-apache-proxy-demo" -PROJECT=$(gcloud config get-value project) -IMAGE="gcr.io/$PROJECT/$NAME" - -# Need YAML so we can set --build-arg -echo " -steps: -- name: 'gcr.io/cloud-builders/docker' - args: ['build', '-t', '$IMAGE', '.', '--build-arg', 'DATASETTE_REF=$DATASETTE_REF'] -- name: 'gcr.io/cloud-builders/docker' - args: ['push', '$IMAGE'] -" > /tmp/cloudbuild.yml - -gcloud builds submit --config /tmp/cloudbuild.yml - -rm /tmp/cloudbuild.yml - -gcloud run deploy $NAME \ - --allow-unauthenticated \ - --platform=managed \ - --image $IMAGE \ - --port 80 diff --git a/demos/apache-proxy/fly.toml b/demos/apache-proxy/fly.toml deleted file mode 100644 index 52e6af5d..00000000 --- a/demos/apache-proxy/fly.toml +++ /dev/null @@ -1,37 +0,0 @@ -app = "datasette-apache-proxy-demo" - -kill_signal = "SIGINT" -kill_timeout = 5 -processes = [] - -[env] - -[experimental] - allowed_public_ports = [] - auto_rollback = true - -[[services]] - http_checks = [] - internal_port = 80 - processes = ["app"] - protocol = "tcp" - script_checks = [] - - [services.concurrency] - hard_limit = 25 - soft_limit = 20 - type = "connections" - - [[services.ports]] - handlers = ["http"] - port = 80 - - [[services.ports]] - handlers = ["tls", "http"] - port = 443 - - [[services.tcp_checks]] - grace_period = "1s" - interval = "15s" - restart_limit = 0 - timeout = "2s" diff --git a/demos/plugins/example_js_manager_plugins.py b/demos/plugins/example_js_manager_plugins.py deleted file mode 100644 index 2705f2c5..00000000 --- a/demos/plugins/example_js_manager_plugins.py +++ /dev/null @@ -1,21 +0,0 @@ -from datasette import hookimpl - -# Test command: -# datasette fixtures.db \ --plugins-dir=demos/plugins/ -# \ --static static:demos/plugins/static - -# Create a set with view names that qualify for this JS, since plugins won't do anything on other pages -# Same pattern as in Nteract data explorer -# https://github.com/hydrosquall/datasette-nteract-data-explorer/blob/main/datasette_nteract_data_explorer/__init__.py#L77 -PERMITTED_VIEWS = {"table", "query", "database"} - - -@hookimpl -def extra_js_urls(view_name): - print(view_name) - if view_name in PERMITTED_VIEWS: - return [ - { - "url": "/static/table-example-plugins.js", - } - ] diff --git a/demos/plugins/static/table-example-plugins.js b/demos/plugins/static/table-example-plugins.js deleted file mode 100644 index 8c19d9a6..00000000 --- a/demos/plugins/static/table-example-plugins.js +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Example usage of Datasette JS Manager API - */ - -document.addEventListener("datasette_init", function (evt) { - const { detail: manager } = evt; - // === Demo plugins: remove before merge=== - addPlugins(manager); -}); - -/** - * Examples for to test datasette JS api - */ -const addPlugins = (manager) => { - - manager.registerPlugin("column-name-plugin", { - version: 0.1, - makeColumnActions: (columnMeta) => { - const { column } = columnMeta; - - return [ - { - label: "Copy name to clipboard", - onClick: (evt) => copyToClipboard(column), - }, - { - label: "Log column metadata to console", - onClick: (evt) => console.log(column), - }, - ]; - }, - }); - - manager.registerPlugin("panel-plugin-graphs", { - version: 0.1, - makeAboveTablePanelConfigs: () => { - return [ - { - id: 'first-panel', - label: "First", - render: node => { - const description = document.createElement('p'); - description.innerText = 'Hello world'; - node.appendChild(description); - } - }, - { - id: 'second-panel', - label: "Second", - render: node => { - const iframe = document.createElement('iframe'); - iframe.src = "https://observablehq.com/embed/@d3/sortable-bar-chart?cell=viewof+order&cell=chart"; - iframe.width = 800; - iframe.height = 635; - iframe.frameborder = '0'; - node.appendChild(iframe); - } - }, - ]; - }, - }); - - manager.registerPlugin("panel-plugin-maps", { - version: 0.1, - makeAboveTablePanelConfigs: () => { - return [ - { - // ID only has to be unique within a plugin, manager namespaces for you - id: 'first-map-panel', - label: "Map plugin", - // datasette-vega, leafleft can provide a "render" function - render: node => node.innerHTML = "Here sits a map", - }, - { - id: 'second-panel', - label: "Image plugin", - render: node => { - const img = document.createElement('img'); - img.src = 'https://datasette.io/static/datasette-logo.svg' - node.appendChild(img); - }, - } - ]; - }, - }); - - // Future: dispatch message to some other part of the page with CustomEvent API - // Could use to drive filter/sort query builder actions without page refresh. -} - - - -async function copyToClipboard(str) { - try { - await navigator.clipboard.writeText(str); - } catch (err) { - /** Rejected - text failed to copy to the clipboard. Browsers didn't give permission */ - console.error('Failed to copy: ', err); - } -} diff --git a/docs/Makefile b/docs/Makefile index 2b092179..dbb89483 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -17,7 +17,4 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -livehtml: - sphinx-autobuild -b html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(0) + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css deleted file mode 100644 index 0a6f8799..00000000 --- a/docs/_static/css/custom.css +++ /dev/null @@ -1,8 +0,0 @@ -a.external { - overflow-wrap: anywhere; -} -body[data-theme="dark"] .sidebar-logo-container { - background-color: white; - padding: 5px; - opacity: 0.6; -} diff --git a/docs/_static/datasette-favicon.png b/docs/_static/datasette-favicon.png deleted file mode 100644 index 4993163f..00000000 Binary files a/docs/_static/datasette-favicon.png and /dev/null differ diff --git a/docs/_static/js/custom.js b/docs/_static/js/custom.js deleted file mode 100644 index 91c3e306..00000000 --- a/docs/_static/js/custom.js +++ /dev/null @@ -1,23 +0,0 @@ -jQuery(function ($) { - // Show banner linking to /stable/ if this is a /latest/ page - if (!/\/latest\//.test(location.pathname)) { - return; - } - var stableUrl = location.pathname.replace("/latest/", "/stable/"); - // Check it's not a 404 - fetch(stableUrl, { method: "HEAD" }).then((response) => { - if (response.status == 200) { - var warning = $( - `
-

Note

-

- This documentation covers the development version of Datasette.

-

See this page for the current stable release. -

-
` - ); - warning.find("a").attr("href", stableUrl); - $("article[role=main]").prepend(warning); - } - }); -}); diff --git a/docs/_templates/base.html b/docs/_templates/base.html deleted file mode 100644 index 9dea86eb..00000000 --- a/docs/_templates/base.html +++ /dev/null @@ -1,37 +0,0 @@ -{%- extends "!base.html" %} - -{% block site_meta %} -{{ super() }} - -{% endblock %} - -{% block scripts %} -{{ super() }} - -{% endblock %} diff --git a/docs/_templates/sidebar/brand.html b/docs/_templates/sidebar/brand.html deleted file mode 100644 index 8be9e8ee..00000000 --- a/docs/_templates/sidebar/brand.html +++ /dev/null @@ -1,16 +0,0 @@ - diff --git a/docs/_templates/sidebar/navigation.html b/docs/_templates/sidebar/navigation.html deleted file mode 100644 index c460a17e..00000000 --- a/docs/_templates/sidebar/navigation.html +++ /dev/null @@ -1,11 +0,0 @@ - \ No newline at end of file diff --git a/docs/authentication.rst b/docs/authentication.rst deleted file mode 100644 index f720c12f..00000000 --- a/docs/authentication.rst +++ /dev/null @@ -1,1445 +0,0 @@ -.. _authentication: - -================================ - Authentication and permissions -================================ - -Datasette doesn't require authentication by default. Any visitor to a Datasette instance can explore the full data and execute read-only SQL queries. - -Datasette can be configured to only allow authenticated users, or to control which databases, tables, and queries can be accessed by the public or by specific users. Datasette's plugin system can be used to add many different styles of authentication, such as user accounts, single sign-on or API keys. - -.. _authentication_actor: - -Actors -====== - -Through plugins, Datasette can support both authenticated users (with cookies) and authenticated API clients (via authentication tokens). The word "actor" is used to cover both of these cases. - -Every request to Datasette has an associated actor value, available in the code as ``request.actor``. This can be ``None`` for unauthenticated requests, or a JSON compatible Python dictionary for authenticated users or API clients. - -The actor dictionary can be any shape - the design of that data structure is left up to the plugins. Actors should always include a unique ``"id"`` string, as demonstrated by the "root" actor below. - -Plugins can use the :ref:`plugin_hook_actor_from_request` hook to implement custom logic for authenticating an actor based on the incoming HTTP request. - -.. _authentication_root: - -Using the "root" actor ----------------------- - -Datasette currently leaves almost all forms of authentication to plugins - `datasette-auth-github `__ for example. - -The one exception is the "root" account, which you can sign into while using Datasette on your local machine. The root user has **all permissions** - they can perform any action regardless of other permission rules. - -The ``--root`` flag is designed for local development and testing. When you start Datasette with ``--root``, the root user automatically receives every permission, including: - -* All view permissions (``view-instance``, ``view-database``, ``view-table``, etc.) -* All write permissions (``insert-row``, ``update-row``, ``delete-row``, ``create-table``, ``alter-table``, ``set-column-type``, ``drop-table``) -* Debug permissions (``permissions-debug``, ``debug-menu``) -* Any custom permissions defined by plugins - -If you add explicit deny rules in ``datasette.yaml`` those can still block the -root actor from specific databases or tables. - -The ``--root`` flag sets an internal ``root_enabled`` switch—without it, a signed-in user with ``{"id": "root"}`` is treated like any other actor. - -To sign in as root, start Datasette using the ``--root`` command-line option, like this:: - - datasette --root - -Datasette will output a single-use-only login URL on startup:: - - http://127.0.0.1:8001/-/auth-token?token=786fc524e0199d70dc9a581d851f466244e114ca92f33aa3b42a139e9388daa7 - INFO: Started server process [25801] - INFO: Waiting for application startup. - INFO: Application startup complete. - INFO: Uvicorn running on http://127.0.0.1:8001 (Press CTRL+C to quit) - -Click on that link and then visit ``http://127.0.0.1:8001/-/actor`` to confirm that you are authenticated as an actor that looks like this: - -.. code-block:: json - - { - "id": "root" - } - -.. _authentication_permissions: - -Permissions -=========== - -Datasette's permissions system is built around SQL queries. Datasette and its plugins construct SQL queries to resolve the list of resources that an actor cas access. - -The key question the permissions system answers is this: - - Is this **actor** allowed to perform this **action**, optionally against this particular **resource**? - -**Actors** are :ref:`described above `. - -An **action** is a string describing the action the actor would like to perform. A full list is :ref:`provided below ` - examples include ``view-table`` and ``execute-sql``. - -A **resource** is the item the actor wishes to interact with - for example a specific database or table. Some actions, such as ``permissions-debug``, are not associated with a particular resource. - -Datasette's built-in view actions (``view-database``, ``view-table`` etc) are allowed by Datasette's default configuration: unless you :ref:`configure additional permission rules ` unauthenticated users will be allowed to access content. - -Other actions, including those introduced by plugins, will default to *deny*. - -.. _authentication_default_deny: - -Denying all permissions by default ----------------------------------- - -By default, Datasette allows unauthenticated access to view databases, tables, and execute SQL queries. - -You may want to run Datasette in a mode where **all** access is denied by default, and you explicitly grant permissions only to authenticated users, either using the :ref:`--root mechanism ` or through :ref:`configuration file rules ` or plugins. - -Use the ``--default-deny`` command-line option to run Datasette in this mode:: - - datasette --default-deny data.db --root - -With ``--default-deny`` enabled: - -* Anonymous users are denied access to view the instance, databases, tables, and queries -* Authenticated users are also denied access unless they're explicitly granted permissions -* The root user (when using ``--root``) still has access to everything -* You can grant permissions using :ref:`configuration file rules ` or plugins - -For example, to allow only a specific user to access your instance:: - - datasette --default-deny data.db --config datasette.yaml - -Where ``datasette.yaml`` contains: - -.. code-block:: yaml - - allow: - id: alice - -This configuration will deny access to everyone except the user with ``id`` of ``alice``. - -.. _authentication_permissions_explained: - -How permissions are resolved ----------------------------- - -Datasette performs permission checks using the internal :ref:`datasette_allowed`, method which accepts keyword arguments for ``action``, ``resource`` and an optional ``actor``. - -``resource`` should be an instance of the appropriate ``Resource`` subclass from :mod:`datasette.resources`—for example ``InstanceResource()``, ``DatabaseResource(database="...``)`` or ``TableResource(database="...", table="...")``. This defaults to ``InstanceResource()`` if not specified. - -When a check runs Datasette gathers allow/deny rules from multiple sources and -compiles them into a SQL query. The resulting query describes all of the -resources an actor may access for that action, together with the reasons those -resources were allowed or denied. The combined sources are: - -* ``allow`` blocks configured in :ref:`datasette.yaml `. -* :ref:`Actor restrictions ` encoded into the actor dictionary or API token. -* The "root" user shortcut when ``--root`` (or :attr:`Datasette.root_enabled `) is active, replying ``True`` to all permission chucks unless configuration rules deny them at a more specific level. -* Any additional SQL provided by plugins implementing :ref:`plugin_hook_permission_resources_sql`. - -Datasette evaluates the SQL to determine if the requested ``resource`` is -included. Explicit deny rules returned by configuration or plugins will block -access even if other rules allowed it. - -.. _authentication_permissions_allow: - -Defining permissions with "allow" blocks ----------------------------------------- - -One way to define permissions in Datasette is to use an ``"allow"`` block :ref:`in the datasette.yaml file `. This is a JSON document describing which actors are allowed to perform an action against a specific resource. - -Each ``allow`` block is compiled into SQL and combined with any -:ref:`plugin-provided rules ` to produce -the cascading allow/deny decisions that power :ref:`datasette_allowed`. - -The most basic form of allow block is this (`allow demo `__, `deny demo `__): - -.. [[[cog - from metadata_doc import config_example - import textwrap - config_example(cog, textwrap.dedent( - """ - allow: - id: root - """).strip(), - "YAML", "JSON" - ) -.. ]]] - -.. tab:: YAML - - .. code-block:: yaml - - allow: - id: root - -.. tab:: JSON - - .. code-block:: json - - { - "allow": { - "id": "root" - } - } -.. [[[end]]] - -This will match any actors with an ``"id"`` property of ``"root"`` - for example, an actor that looks like this: - -.. code-block:: json - - { - "id": "root", - "name": "Root User" - } - -An allow block can specify "deny all" using ``false`` (`demo `__): - -.. [[[cog - from metadata_doc import config_example - import textwrap - config_example(cog, textwrap.dedent( - """ - allow: false - """).strip(), - "YAML", "JSON" - ) -.. ]]] - -.. tab:: YAML - - .. code-block:: yaml - - allow: false - -.. tab:: JSON - - .. code-block:: json - - { - "allow": false - } -.. [[[end]]] - -An ``"allow"`` of ``true`` allows all access (`demo `__): - -.. [[[cog - from metadata_doc import config_example - import textwrap - config_example(cog, textwrap.dedent( - """ - allow: true - """).strip(), - "YAML", "JSON" - ) -.. ]]] - -.. tab:: YAML - - .. code-block:: yaml - - allow: true - -.. tab:: JSON - - .. code-block:: json - - { - "allow": true - } -.. [[[end]]] - -Allow keys can provide a list of values. These will match any actor that has any of those values (`allow demo `__, `deny demo `__): - -.. [[[cog - from metadata_doc import config_example - import textwrap - config_example(cog, textwrap.dedent( - """ - allow: - id: - - simon - - cleopaws - """).strip(), - "YAML", "JSON" - ) -.. ]]] - -.. tab:: YAML - - .. code-block:: yaml - - allow: - id: - - simon - - cleopaws - -.. tab:: JSON - - .. code-block:: json - - { - "allow": { - "id": [ - "simon", - "cleopaws" - ] - } - } -.. [[[end]]] - -This will match any actor with an ``"id"`` of either ``"simon"`` or ``"cleopaws"``. - -Actors can have properties that feature a list of values. These will be matched against the list of values in an allow block. Consider the following actor: - -.. code-block:: json - - { - "id": "simon", - "roles": ["staff", "developer"] - } - -This allow block will provide access to any actor that has ``"developer"`` as one of their roles (`allow demo `__, `deny demo `__): - -.. [[[cog - from metadata_doc import config_example - import textwrap - config_example(cog, textwrap.dedent( - """ - allow: - roles: - - developer - """).strip(), - "YAML", "JSON" - ) -.. ]]] - -.. tab:: YAML - - .. code-block:: yaml - - allow: - roles: - - developer - -.. tab:: JSON - - .. code-block:: json - - { - "allow": { - "roles": [ - "developer" - ] - } - } -.. [[[end]]] - -Note that "roles" is not a concept that is baked into Datasette - it's a convention that plugins can choose to implement and act on. - -If you want to provide access to any actor with a value for a specific key, use ``"*"``. For example, to match any logged-in user specify the following (`allow demo `__, `deny demo `__): - -.. [[[cog - from metadata_doc import config_example - import textwrap - config_example(cog, textwrap.dedent( - """ - allow: - id: "*" - """).strip(), - "YAML", "JSON" - ) -.. ]]] - -.. tab:: YAML - - .. code-block:: yaml - - allow: - id: "*" - -.. tab:: JSON - - .. code-block:: json - - { - "allow": { - "id": "*" - } - } -.. [[[end]]] - -You can specify that only unauthenticated actors (from anonymous HTTP requests) should be allowed access using the special ``"unauthenticated": true`` key in an allow block (`allow demo `__, `deny demo `__): - -.. [[[cog - from metadata_doc import config_example - import textwrap - config_example(cog, textwrap.dedent( - """ - allow: - unauthenticated: true - """).strip(), - "YAML", "JSON" - ) -.. ]]] - -.. tab:: YAML - - .. code-block:: yaml - - allow: - unauthenticated: true - -.. tab:: JSON - - .. code-block:: json - - { - "allow": { - "unauthenticated": true - } - } -.. [[[end]]] - -Allow keys act as an "or" mechanism. An actor will be able to execute the query if any of their JSON properties match any of the values in the corresponding lists in the ``allow`` block. The following block will allow users with either a ``role`` of ``"ops"`` OR users who have an ``id`` of ``"simon"`` or ``"cleopaws"``: - -.. [[[cog - from metadata_doc import config_example - import textwrap - config_example(cog, textwrap.dedent( - """ - allow: - id: - - simon - - cleopaws - role: ops - """).strip(), - "YAML", "JSON" - ) -.. ]]] - -.. tab:: YAML - - .. code-block:: yaml - - allow: - id: - - simon - - cleopaws - role: ops - -.. tab:: JSON - - .. code-block:: json - - { - "allow": { - "id": [ - "simon", - "cleopaws" - ], - "role": "ops" - } - } -.. [[[end]]] - -`Demo for cleopaws `__, `demo for ops role `__, `demo for an actor matching neither rule `__. - -.. _AllowDebugView: - -The /-/allow-debug tool ------------------------ - -The ``/-/allow-debug`` tool lets you try out different ``"action"`` blocks against different ``"actor"`` JSON objects. You can try that out here: https://latest.datasette.io/-/allow-debug - -.. _authentication_permissions_config: - -Access permissions in ``datasette.yaml`` -======================================== - -There are two ways to configure permissions using ``datasette.yaml`` (or ``datasette.json``). - -For simple visibility permissions you can use ``"allow"`` blocks in the root, database, table and query sections. - -For other permissions you can use a ``"permissions"`` block, described :ref:`in the next section `. - -You can limit who is allowed to view different parts of your Datasette instance using ``"allow"`` keys in your :ref:`configuration`. - -You can control the following: - -* Access to the entire Datasette instance -* Access to specific databases -* Access to specific tables and views -* Access to specific :ref:`queries ` - -If a user has permission to view a table they will be able to view that table, independent of if they have permission to view the database or instance that the table exists within. - -.. _authentication_permissions_instance: - -Access to an instance ---------------------- - -Here's how to restrict access to your entire Datasette instance to just the ``"id": "root"`` user: - -.. [[[cog - from metadata_doc import config_example - config_example(cog, """ - title: My private Datasette instance - allow: - id: root - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - title: My private Datasette instance - allow: - id: root - - -.. tab:: datasette.json - - .. code-block:: json - - { - "title": "My private Datasette instance", - "allow": { - "id": "root" - } - } -.. [[[end]]] - -To deny access to all users, you can use ``"allow": false``: - -.. [[[cog - config_example(cog, """ - title: My entirely inaccessible instance - allow: false - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - title: My entirely inaccessible instance - allow: false - - -.. tab:: datasette.json - - .. code-block:: json - - { - "title": "My entirely inaccessible instance", - "allow": false - } -.. [[[end]]] - -One reason to do this is if you are using a Datasette plugin - such as `datasette-permissions-sql `__ - to control permissions instead. - -.. _authentication_permissions_database: - -Access to specific databases ----------------------------- - -To limit access to a specific ``private.db`` database to just authenticated users, use the ``"allow"`` block like this: - -.. [[[cog - config_example(cog, """ - databases: - private: - allow: - id: "*" - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - databases: - private: - allow: - id: "*" - - -.. tab:: datasette.json - - .. code-block:: json - - { - "databases": { - "private": { - "allow": { - "id": "*" - } - } - } - } -.. [[[end]]] - -.. _authentication_permissions_table: - -Access to specific tables and views ------------------------------------ - -To limit access to the ``users`` table in your ``bakery.db`` database: - -.. [[[cog - config_example(cog, """ - databases: - bakery: - tables: - users: - allow: - id: '*' - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - databases: - bakery: - tables: - users: - allow: - id: '*' - - -.. tab:: datasette.json - - .. code-block:: json - - { - "databases": { - "bakery": { - "tables": { - "users": { - "allow": { - "id": "*" - } - } - } - } - } - } -.. [[[end]]] - -This works for SQL views as well - you can list their names in the ``"tables"`` block above in the same way as regular tables. - -.. warning:: - Restricting access to tables and views in this way will NOT prevent users from querying them using arbitrary SQL queries, `like this `__ for example. - - If you are restricting access to specific tables you should also use the ``"allow_sql"`` block to prevent users from bypassing the limit with their own SQL queries - see :ref:`authentication_permissions_execute_sql`. - -.. _authentication_permissions_query: - -Access to specific queries --------------------------- - -:ref:`Queries ` allow you to configure named SQL queries in your ``datasette.yaml`` that can be executed by users. These queries can be set up to both read and write to the database, so controlling who can execute them can be important. - -To limit access to the ``add_name`` query in your ``dogs.db`` database to just the :ref:`root user`: - -.. [[[cog - config_example(cog, """ - databases: - dogs: - queries: - add_name: - sql: INSERT INTO names (name) VALUES (:name) - write: true - allow: - id: - - root - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - databases: - dogs: - queries: - add_name: - sql: INSERT INTO names (name) VALUES (:name) - write: true - allow: - id: - - root - - -.. tab:: datasette.json - - .. code-block:: json - - { - "databases": { - "dogs": { - "queries": { - "add_name": { - "sql": "INSERT INTO names (name) VALUES (:name)", - "write": true, - "allow": { - "id": [ - "root" - ] - } - } - } - } - } - } -.. [[[end]]] - -.. _authentication_permissions_execute_sql: - -Controlling the ability to execute arbitrary SQL ------------------------------------------------- - -Datasette defaults to allowing any site visitor to execute their own custom SQL queries, for example using the form on `the database page `__ or by appending a ``?_where=`` parameter to the table page `like this `__. - -Access to this ability is controlled by the :ref:`actions_execute_sql` permission. - -The easiest way to disable arbitrary SQL queries is using the :ref:`default_allow_sql setting ` when you first start Datasette running. - -You can alternatively use an ``"allow_sql"`` block to control who is allowed to execute arbitrary SQL queries. - -To prevent any user from executing arbitrary SQL queries, use this: - -.. [[[cog - config_example(cog, """ - allow_sql: false - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - allow_sql: false - - -.. tab:: datasette.json - - .. code-block:: json - - { - "allow_sql": false - } -.. [[[end]]] - -To enable just the :ref:`root user` to execute SQL for all databases in your instance, use the following: - -.. [[[cog - config_example(cog, """ - allow_sql: - id: root - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - allow_sql: - id: root - - -.. tab:: datasette.json - - .. code-block:: json - - { - "allow_sql": { - "id": "root" - } - } -.. [[[end]]] - -To limit this ability for just one specific database, use this: - -.. [[[cog - config_example(cog, """ - databases: - mydatabase: - allow_sql: - id: root - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - databases: - mydatabase: - allow_sql: - id: root - - -.. tab:: datasette.json - - .. code-block:: json - - { - "databases": { - "mydatabase": { - "allow_sql": { - "id": "root" - } - } - } - } -.. [[[end]]] - -.. _authentication_permissions_other: - -Other permissions in ``datasette.yaml`` -======================================= - -For all other permissions, you can use one or more ``"permissions"`` blocks in your ``datasette.yaml`` configuration file. - -To grant access to the :ref:`permissions debug tool ` to all signed in users, you can grant ``permissions-debug`` to any actor with an ``id`` matching the wildcard ``*`` by adding this a the root of your configuration: - -.. [[[cog - config_example(cog, """ - permissions: - debug-menu: - id: '*' - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - permissions: - debug-menu: - id: '*' - - -.. tab:: datasette.json - - .. code-block:: json - - { - "permissions": { - "debug-menu": { - "id": "*" - } - } - } -.. [[[end]]] - -To grant ``create-table`` to the user with ``id`` of ``editor`` for the ``docs`` database: - -.. [[[cog - config_example(cog, """ - databases: - docs: - permissions: - create-table: - id: editor - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - databases: - docs: - permissions: - create-table: - id: editor - - -.. tab:: datasette.json - - .. code-block:: json - - { - "databases": { - "docs": { - "permissions": { - "create-table": { - "id": "editor" - } - } - } - } - } -.. [[[end]]] - -Other table-scoped write permissions, including ``set-column-type``, can be configured in the same place. - -And for ``insert-row`` against the ``reports`` table in that ``docs`` database: - -.. [[[cog - config_example(cog, """ - databases: - docs: - tables: - reports: - permissions: - insert-row: - id: editor - """) -.. ]]] - -.. tab:: datasette.yaml - - .. code-block:: yaml - - - databases: - docs: - tables: - reports: - permissions: - insert-row: - id: editor - - -.. tab:: datasette.json - - .. code-block:: json - - { - "databases": { - "docs": { - "tables": { - "reports": { - "permissions": { - "insert-row": { - "id": "editor" - } - } - } - } - } - } - } -.. [[[end]]] - -The :ref:`permissions debug tool ` can be useful for helping test permissions that you have configured in this way. - -.. _CreateTokenView: - -API Tokens -========== - -Datasette includes a default mechanism for generating API tokens that can be used to authenticate requests. - -Authenticated users can create new API tokens using a form on the ``/-/create-token`` page. - -Tokens created in this way can be further restricted to only allow access to specific actions, or to limit those actions to specific databases, tables or queries. - -Created tokens can then be passed in the ``Authorization: Bearer $token`` header of HTTP requests to Datasette. - -A token created by a user will include that user's ``"id"`` in the token payload, so any permissions granted to that user based on their ID can be made available to the token as well. - -When one of these a token accompanies a request, the actor for that request will have the following shape: - -.. code-block:: json - - { - "id": "user_id", - "token": "dstok", - "token_expires": 1667717426 - } - -The ``"id"`` field duplicates the ID of the actor who first created the token. - -The ``"token"`` field identifies that this actor was authenticated using a Datasette signed token (``dstok``). - -The ``"token_expires"`` field, if present, indicates that the token will expire after that integer timestamp. - -The ``/-/create-token`` page cannot be accessed by actors that are authenticated with a ``"token": "some-value"`` property. This is to prevent API tokens from being used to create more tokens. - -Datasette plugins that implement their own form of API token authentication should follow this convention. - -You can disable the signed token feature entirely using the :ref:`allow_signed_tokens ` setting. - -.. _authentication_cli_create_token: - -datasette create-token ----------------------- - -You can also create tokens on the command line using the ``datasette create-token`` command. - -This command takes one required argument - the ID of the actor to be associated with the created token. - -You can specify a ``-e/--expires-after`` option in seconds. If omitted, the token will never expire. - -The command will sign the token using the ``DATASETTE_SECRET`` environment variable, if available. You can also pass the secret using the ``--secret`` option. - -This means you can run the command locally to create tokens for use with a deployed Datasette instance, provided you know that instance's secret. - -To create a token for the ``root`` actor that will expire in one hour:: - - datasette create-token root --expires-after 3600 - -To create a token that never expires using a specific secret:: - - datasette create-token root --secret my-secret-goes-here - -.. _authentication_cli_create_token_restrict: - -Restricting the actions that a token can perform -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Tokens created using ``datasette create-token ACTOR_ID`` will inherit all of the permissions of the actor that they are associated with. - -You can pass additional options to create tokens that are restricted to a subset of that actor's permissions. - -To restrict the token to just specific permissions against all available databases, use the ``--all`` option:: - - datasette create-token root --all insert-row --all update-row - -This option can be passed as many times as you like. In the above example the token will only be allowed to insert and update rows. - -You can also restrict permissions such that they can only be used within specific databases:: - - datasette create-token root --database mydatabase insert-row - -The resulting token will only be able to insert rows, and only to tables in the ``mydatabase`` database. - -Finally, you can restrict permissions to individual resources - tables, SQL views and :ref:`named queries ` - within a specific database:: - - datasette create-token root --resource mydatabase mytable insert-row - -These options have short versions: ``-a`` for ``--all``, ``-d`` for ``--database`` and ``-r`` for ``--resource``. - -You can add ``--debug`` to see a JSON representation of the token that has been created. Here's a full example:: - - datasette create-token root \ - --secret mysecret \ - --all view-instance \ - --all view-table \ - --database docs view-query \ - --resource docs documents insert-row \ - --resource docs documents update-row \ - --debug - -This example outputs the following:: - - dstok_.eJxFizEKgDAMRe_y5w4qYrFXERGxDkVsMI0uxbubdjFL8l_ez1jhwEQCA6Fjjxp90qtkuHawzdjYrh8MFobLxZ_wBH0_gtnAF-hpS5VfmF8D_lnd97lHqUJgLd6sls4H1qwlhA.nH_7RecYHj5qSzvjhMU95iy0Xlc - - Decoded: - - { - "a": "root", - "token": "dstok", - "t": 1670907246, - "_r": { - "a": [ - "vi", - "vt" - ], - "d": { - "docs": [ - "vq" - ] - }, - "r": { - "docs": { - "documents": [ - "ir", - "ur" - ] - } - } - } - } - -Restrictions act as an allowlist layered on top of the actor's existing -permissions. They can only remove access the actor would otherwise have—they -cannot grant new access. If the underlying actor is denied by ``allow`` rules in -``datasette.yaml`` or by a plugin, a token that lists that resource in its -``"_r"`` section will still be denied. - -To create tokens with restrictions in Python code, use the :ref:`TokenRestrictions ` builder and pass it to :ref:`datasette.create_token() `. - -.. _permissions_plugins: - -Checking permissions in plugins -=============================== - -Datasette plugins can check if an actor has permission to perform an action using :ref:`datasette_allowed`—for example:: - - from datasette.resources import TableResource - - can_edit = await datasette.allowed( - action="update-row", - resource=TableResource(database="fixtures", table="facetable"), - actor=request.actor, - ) - -Use :ref:`datasette_ensure_permission` when you need to enforce a permission and -raise a ``Forbidden`` error automatically. - -Plugins that define new operations should return :class:`~datasette.permissions.Action` -objects from :ref:`plugin_register_actions` and can supply additional allow/deny -rules by returning :class:`~datasette.permissions.PermissionSQL` objects from the -:ref:`plugin_hook_permission_resources_sql` hook. Those rules are merged with -configuration ``allow`` blocks and actor restrictions to determine the final -result for each check. - -.. _authentication_actor_matches_allow: - -actor_matches_allow() -===================== - -Plugins that wish to implement this same ``"allow"`` block permissions scheme can take advantage of the ``datasette.utils.actor_matches_allow(actor, allow)`` function: - -.. code-block:: python - - from datasette.utils import actor_matches_allow - - actor_matches_allow({"id": "root"}, {"id": "*"}) - # returns True - -The currently authenticated actor is made available to plugins as ``request.actor``. - -.. _PermissionsDebugView: - -Permissions debug tools -======================= - -The debug tool at ``/-/permissions`` is available to any actor with the ``permissions-debug`` permission. By default this is just the :ref:`authenticated root user ` but you can open it up to all users by starting Datasette like this:: - - datasette -s permissions.permissions-debug true data.db - -The page shows the permission checks that have been carried out by the Datasette instance. - -It also provides an interface for running hypothetical permission checks against a hypothetical actor. This is a useful way of confirming that your configured permissions work in the way you expect. - -This is designed to help administrators and plugin authors understand exactly how permission checks are being carried out, in order to effectively configure Datasette's permission system. - -.. _AllowedResourcesView: - -Allowed resources view ----------------------- - -The ``/-/allowed`` endpoint displays resources that the current actor can access for a specified ``action``. - -This endpoint provides an interactive HTML form interface. Add ``.json`` to the URL path (e.g. ``/-/allowed.json``) to get the raw JSON response instead. - -Pass ``?action=view-table`` (or another action) to select the action. Optional ``parent=`` and ``child=`` query parameters can narrow the results to a specific database/table pair. - -This endpoint is publicly accessible to help users understand their own permissions. The potentially sensitive ``reason`` field is only shown to users with the ``permissions-debug`` permission - it shows the plugins and explanatory reasons that were responsible for each decision. - -.. _PermissionRulesView: - -Permission rules view ---------------------- - -The ``/-/rules`` endpoint displays all permission rules (both allow and deny) for each candidate resource for the requested action. - -This endpoint provides an interactive HTML form interface. Add ``.json`` to the URL path (e.g. ``/-/rules.json?action=view-table``) to get the raw JSON response instead. - -Pass ``?action=`` as a query parameter to specify which action to check. - -This endpoint requires the ``permissions-debug`` permission. - -.. _PermissionCheckView: - -Permission check view ---------------------- - -The ``/-/check`` endpoint evaluates a single action/resource pair and returns information indicating whether the access was allowed along with diagnostic information. - -This endpoint provides an interactive HTML form interface. Add ``.json`` to the URL path (e.g. ``/-/check.json?action=view-instance``) to get the raw JSON response instead. - -Pass ``?action=`` to specify the action to check, and optional ``?parent=`` and ``?child=`` parameters to specify the resource. - -.. _authentication_ds_actor: - -The ds_actor cookie -=================== - -Datasette includes a default authentication plugin which looks for a signed ``ds_actor`` cookie containing a JSON actor dictionary. This is how the :ref:`root actor ` mechanism works. - -Authentication plugins can set signed ``ds_actor`` cookies themselves like so: - -.. code-block:: python - - response = Response.redirect("/") - datasette.set_actor_cookie(response, {"id": "cleopaws"}) - -The shape of data encoded in the cookie is as follows: - -.. code-block:: json - - { - "a": { - "id": "cleopaws" - } - } - -To implement logout in a plugin, use the ``delete_actor_cookie()`` method: - -.. code-block:: python - - response = Response.redirect("/") - datasette.delete_actor_cookie(response) - -.. _authentication_ds_actor_expiry: - -Including an expiry time ------------------------- - -``ds_actor`` cookies can optionally include a signed expiry timestamp, after which the cookies will no longer be valid. Authentication plugins may chose to use this mechanism to limit the lifetime of the cookie. For example, if a plugin implements single-sign-on against another source it may decide to set short-lived cookies so that if the user is removed from the SSO system their existing Datasette cookies will stop working shortly afterwards. - -To include an expiry pass ``expire_after=`` to ``datasette.set_actor_cookie()`` with a number of seconds. For example, to expire in 24 hours: - -.. code-block:: python - - response = Response.redirect("/") - datasette.set_actor_cookie( - response, {"id": "cleopaws"}, expire_after=60 * 60 * 24 - ) - -The resulting cookie will encode data that looks something like this: - -.. code-block:: json - - { - "a": { - "id": "cleopaws" - }, - "e": "1jjSji" - } - -.. _LogoutView: - -The /-/logout page ------------------- - -The page at ``/-/logout`` provides the ability to log out of a ``ds_actor`` cookie authentication session. - -.. _actions: - -Built-in actions -================ - -This section lists all of the permission checks that are carried out by Datasette core, along with the ``resource`` if it was passed. - -.. _actions_view_instance: - -view-instance -------------- - -Top level permission - Actor is allowed to view any pages within this instance, starting at https://latest.datasette.io/ - -.. _actions_view_database: - -view-database -------------- - -Actor is allowed to view a database page, e.g. https://latest.datasette.io/fixtures - -``resource`` - ``datasette.permissions.DatabaseResource(database)`` - ``database`` is the name of the database (string) - -.. _actions_view_database_download: - -view-database-download ----------------------- - -Actor is allowed to download a database, e.g. https://latest.datasette.io/fixtures.db - -``resource`` - ``datasette.resources.DatabaseResource(database)`` - ``database`` is the name of the database (string) - -.. _actions_view_table: - -view-table ----------- - -Actor is allowed to view a table (or view) page, e.g. https://latest.datasette.io/fixtures/complex_foreign_keys - -``resource`` - ``datasette.resources.TableResource(database, table)`` - ``database`` is the name of the database (string) - - ``table`` is the name of the table (string) - -.. _actions_view_query: - -view-query ----------- - -Actor is allowed to view a stored query page, e.g. https://latest.datasette.io/fixtures/pragma_cache_size. Executing an untrusted stored query also requires ``execute-sql`` or the relevant write permissions; :ref:`trusted stored queries ` can execute with ``view-query`` alone. - -``resource`` - ``datasette.resources.QueryResource(database, query)`` - ``database`` is the name of the database (string) - - ``query`` is the name of the query (string) - -.. _actions_store_query: - -store-query ------------ - -Actor is allowed to create stored queries against a database. - -``resource`` - ``datasette.resources.DatabaseResource(database)`` - ``database`` is the name of the database (string) - -.. _actions_update_query: - -update-query ------------- - -Actor is allowed to update a stored query. - -``resource`` - ``datasette.resources.QueryResource(database, query)`` - ``database`` is the name of the database (string) - - ``query`` is the name of the query (string) - -.. _actions_delete_query: - -delete-query ------------- - -Actor is allowed to delete a stored query. - -``resource`` - ``datasette.resources.QueryResource(database, query)`` - ``database`` is the name of the database (string) - - ``query`` is the name of the query (string) - -.. _actions_insert_row: - -insert-row ----------- - -Actor is allowed to insert rows into a table. - -``resource`` - ``datasette.resources.TableResource(database, table)`` - ``database`` is the name of the database (string) - - ``table`` is the name of the table (string) - -.. _actions_delete_row: - -delete-row ----------- - -Actor is allowed to delete rows from a table. - -``resource`` - ``datasette.resources.TableResource(database, table)`` - ``database`` is the name of the database (string) - - ``table`` is the name of the table (string) - -.. _actions_update_row: - -update-row ----------- - -Actor is allowed to update rows in a table. - -``resource`` - ``datasette.resources.TableResource(database, table)`` - ``database`` is the name of the database (string) - - ``table`` is the name of the table (string) - -.. _actions_create_table: - -create-table ------------- - -Actor is allowed to create a database table. - -``resource`` - ``datasette.resources.DatabaseResource(database)`` - ``database`` is the name of the database (string) - -.. _actions_alter_table: - -alter-table ------------ - -Actor is allowed to alter a database table. - -``resource`` - ``datasette.resources.TableResource(database, table)`` - ``database`` is the name of the database (string) - - ``table`` is the name of the table (string) - -.. _actions_set_column_type: - -set-column-type ---------------- - -Actor is allowed to set assigned :ref:`column types ` for columns in a table. - -``resource`` - ``datasette.resources.TableResource(database, table)`` - ``database`` is the name of the database (string) - - ``table`` is the name of the table (string) - -.. _actions_drop_table: - -drop-table ----------- - -Actor is allowed to drop a database table. - -``resource`` - ``datasette.resources.TableResource(database, table)`` - ``database`` is the name of the database (string) - - ``table`` is the name of the table (string) - -.. _actions_execute_sql: - -execute-sql ------------ - -Actor is allowed to run arbitrary read-only SQL queries against a specific database, e.g. https://latest.datasette.io/fixtures/-/query?sql=select+100 - -``resource`` - ``datasette.resources.DatabaseResource(database)`` - ``database`` is the name of the database (string) - -See also :ref:`the default_allow_sql setting `. - -.. _actions_execute_write_sql: - -execute-write-sql ------------------ - -Actor is allowed to run arbitrary writable SQL queries against a specific database, subject to table-level write permissions such as ``insert-row``, ``update-row`` and ``delete-row``. - -``resource`` - ``datasette.resources.DatabaseResource(database)`` - ``database`` is the name of the database (string) - -.. _actions_permissions_debug: - -permissions-debug ------------------ - -Actor is allowed to view the ``/-/permissions`` debug tools. - -.. _actions_debug_menu: - -debug-menu ----------- - -Controls if the various debug pages are displayed in the jump menu. diff --git a/docs/binary_data.rst b/docs/binary_data.rst deleted file mode 100644 index 0c890fe5..00000000 --- a/docs/binary_data.rst +++ /dev/null @@ -1,68 +0,0 @@ -.. _binary: - -============= - Binary data -============= - -SQLite tables can contain binary data in ``BLOB`` columns. - -Datasette includes special handling for these binary values. The Datasette interface detects binary values and provides a link to download their content, for example on https://latest.datasette.io/fixtures/binary_data - -.. image:: https://raw.githubusercontent.com/simonw/datasette-screenshots/0.62/binary-data.png - :width: 311px - :alt: Screenshot showing download links next to binary data in the table view - -Binary data is represented in ``.json`` exports using Base64 encoding. - -https://latest.datasette.io/fixtures/binary_data.json?_shape=array - -.. code-block:: json - - [ - { - "rowid": 1, - "data": { - "$base64": true, - "encoded": "FRwCx60F/g==" - } - }, - { - "rowid": 2, - "data": { - "$base64": true, - "encoded": "FRwDx60F/g==" - } - }, - { - "rowid": 3, - "data": null - } - ] - -.. _binary_linking: - -Linking to binary downloads ---------------------------- - -The ``.blob`` output format is used to return binary data. It requires a ``_blob_column=`` query string argument specifying which BLOB column should be downloaded, for example: - -https://latest.datasette.io/fixtures/binary_data/1.blob?_blob_column=data - -This output format can also be used to return binary data from an arbitrary SQL query. Since such queries do not specify an exact row, an additional ``?_blob_hash=`` parameter can be used to specify the SHA-256 hash of the value that is being linked to. - -Consider the query ``select data from binary_data`` - `demonstrated here `__. - -That page links to the binary value downloads. Those links look like this: - -https://latest.datasette.io/fixtures.blob?sql=select+data+from+binary_data&_blob_column=data&_blob_hash=f3088978da8f9aea479ffc7f631370b968d2e855eeb172bea7f6c7a04262bb6d - -These ``.blob`` links are also returned in the ``.csv`` exports Datasette provides for binary tables and queries, since the CSV format does not have a mechanism for representing binary data. - -Binary plugins --------------- - -Several Datasette plugins are available that change the way Datasette treats binary data. - -- `datasette-render-binary `__ modifies Datasette's default interface to show an automatic guess at what type of binary data is being stored, along with a visual representation of the binary value that displays ASCII strings directly in the interface. -- `datasette-render-images `__ detects common image formats and renders them as images directly in the Datasette interface. -- `datasette-media `__ allows Datasette interfaces to be configured to serve binary files from configured SQL queries, and includes the ability to resize images directly before serving them. diff --git a/docs/changelog.rst b/docs/changelog.rst index 2ba713ee..aa64a84e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,2429 +1,6 @@ -.. _changelog: - -========= Changelog ========= -.. _v1_0_unreleased: - -Unreleased ----------- - -Stored queries -~~~~~~~~~~~~~~ - -- The previous "canned queries" feature has been renamed and expanded into :ref:`stored queries `. Queries configured in ``datasette.yaml`` are now loaded into a new ``queries`` table in Datasette's :ref:`internal database `, alongside user-created stored queries. (:issue:`2735`) -- New stored query management APIs: ``datasette.add_query()``, ``datasette.update_query()``, ``datasette.remove_query()``, ``datasette.get_query()``, ``datasette.list_queries()`` and ``datasette.count_queries()``. These replace the removed ``datasette.get_canned_query()`` and ``datasette.get_canned_queries()`` methods. (:issue:`2735`) -- Users with :ref:`store-query ` and :ref:`execute-sql ` permission can create stored queries from the SQL query page or the new ``GET //-/queries/store`` form. (:issue:`2735`) -- The database page now shows a count and preview of stored queries, capped at five, and links to new paginated query browsers at ``/-/queries`` and ``//-/queries``. Those browsers support search. (:issue:`2735`) -- Stored queries created by users default to private and untrusted. Private stored queries can only be viewed, updated or deleted by their owner, even if another actor has broad ``view-query``, ``update-query`` or ``delete-query`` permission. Untrusted stored queries execute using the permissions of the actor running them. See :ref:`stored_queries` and :ref:`trusted_stored_queries` for details. (:issue:`2735`) -- New ``store-query``, ``update-query`` and ``delete-query`` permissions, plus updated semantics for :ref:`view-query `. Trusted stored queries can still execute with ``view-query`` alone; untrusted read queries also require :ref:`execute-sql ` and untrusted writable queries require :ref:`execute-write-sql ` plus the relevant table-level write permissions. (:issue:`2735`) - -Write SQL UI -~~~~~~~~~~~~ - -- New "Write to this database" interface at ``//-/execute-write`` for running arbitrary writable SQL against mutable databases. The form extracts named parameters, analyzes the SQL, shows the table operations that will be attempted and links to a newly inserted row when a single-row insert succeeds. (:issue:`2742`) -- Added the new :ref:`execute-write-sql ` permission for running arbitrary writable SQL. Execution is also gated by table-level permissions such as :ref:`insert-row `, :ref:`update-row ` and :ref:`delete-row `, and writes to attached databases are rejected. (:issue:`2742`) - -Plugin API changes -~~~~~~~~~~~~~~~~~~ - -- The ``top_canned_query()`` plugin hook has been renamed to :ref:`top_stored_query() `. (:issue:`2747`) -- The ``canned_queries()`` plugin hook has been removed. Plugins can use the new :ref:`stored query management methods ` together with :ref:`startup() ` to register queries. (:issue:`2735`) - -Bug fixes -~~~~~~~~~ - -- Fixed a bug where visiting ``//-/query`` without a ``?sql=`` parameter returned a 500 error. (:issue:`2743`) - -.. _v1_0_a30: - -1.0a30 (2026-05-24) -------------------- - -The "Jump to" menu, activated by hitting ``/`` or through the application menu, can now be extended by plugins. - -- New "Jump to..." menu item, always visible, for triggering the previously undocumented ``/`` menu. (:issue:`2725`) -- The ``/`` jump-to search interface now covers databases, views, canned queries and plugin-provided items in addition to tables. The endpoint backing it has been renamed from ``/-/tables`` to ``/-/jump``. -- New :ref:`plugin_hook_jump_items_sql` plugin hook, allowing plugins to contribute additional items to the jump-to menu by returning SQL. ``JumpSQL`` queries run against Datasette's internal database by default, or can target another database using the optional ``database=`` argument. (:issue:`2731`) -- ``datasette.jump.JumpSQL.menu_item()`` is a shortcut for adding individual jump menu items that are not backed by resources in the internal catalog. -- New :ref:`javascript_plugins_makeJumpSections` JavaScript plugin hook, allowing plugins to add custom blank-state sections to the jump-to menu before the user has typed a query. -- Debug menu links now appear in the jump-to menu instead of the top-right app menu, with descriptions for each debug item. -- Dropped Janus as a dependency, previously used to manage the write queue. This should not have any impact on plugin developers or end-users. (:issue:`1752`) -- Fixed a bug where stale tables and other related resources were not removed from ``catalog_*`` tables when a database was removed. (:issue:`2723`) -- New documented :ref:`datasette.fixtures.populate_fixture_database(conn) ` helper for creating the fixture database tables used by Datasette's own tests, intended for plugin test suites. -- Keyboard accessibility and ARIA roles for actions menus, thanks `pintaste `__. (:pr:`2727`) - -.. _v1_0_a29: - -1.0a29 (2026-05-12) -------------------- - -- New ``TokenRestrictions.abbreviated(datasette)`` :ref:`utility method ` for creating ``"_r"`` dictionaries. (:issue:`2695`) -- Table headers and column options are now visible even if a table contains zero rows. (:issue:`2701`) -- Fixed bug with display of column actions dialog on Mobile Safari. (:issue:`2708`) -- Fixed bug where tests could crash with a segfault due to a race condition between ``Datasette.close()`` and ``Datasette.close()``. (:issue:`2709`) - -.. _v1_0_a28: - -1.0a28 (2026-04-16) -------------------- - -- Fixed a compatibility bug introduced in 1.0a27 where ``execute_write_fn()`` callbacks with a parameter name other than ``conn`` were seeing errors. (:issue:`2691`) -- The :ref:`database.close() ` method now also shuts down the write connection for that database. -- New :ref:`datasette.close() ` method for closing down all databases and resources associated with a Datasette instance. This is called automatically when the server shuts down. (:pr:`2693`) -- Datasette now includes a pytest plugin which automatically calls ``datasette.close()`` on temporary instances created in function-scoped fixtures and during tests. See :ref:`testing_plugins_autoclose` for details. This helps avoid running out of file descriptors in plugin test suites that were written before the ``Database(is_temp_disk=True)`` feature introduced in Datasette 1.0a27. (:issue:`2692`) - -.. _v1_0_a27: - -1.0a27 (2026-04-15) -------------------- - -CSRF protection no longer uses CSRF tokens -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Datasette's token-based CSRF protection has been replaced with a mechanism based on the ``Sec-Fetch-Site`` and ``Origin`` request headers, which are `supported by all modern browsers `__. See `this article by Filippo Valsorda `__ for more details of this approach. This removes the need for CSRF tokens in forms and AJAX requests. (:pr:`2689`) - -``RenameTableEvent`` when a table is renamed -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Renaming a table within Datasette will now fire a new :class:`~datasette.events.RenameTableEvent`, which plugins can use to react by updating ACL records or re-assigning comments or other associated records to the new table name. (:issue:`2681`) - -This event will not be fired if the table is renamed by SQL running in some other process. - -The ``datasette.track_event()`` method can now be called from within a write operation (using :ref:`database.execute_write() ` and related methods) and the event will be fired after the write transaction has successfully committed. (:pr:`2682`) - -Other changes -~~~~~~~~~~~~~ - -- New :ref:`actor= parameter ` for ``datasette.client`` methods, allowing internal requests to be made as a specific actor. This is particularly useful for writing automated tests. (:pr:`2688`) -- New ``Database(is_temp_disk=True)`` option, used internally for the internal database. This helps resolve intermittent database locked errors caused by the internal database being in-memory as opposed to on-disk. (:issue:`2683`) (:pr:`2684`) -- The ``///-/upsert`` API (:ref:`docs `) now rejects rows with ``null`` primary key values. (:issue:`1936`) -- Improved example in the API explorer for the ``/-/upsert`` endpoint (:ref:`docs `). (:issue:`1936`) -- The ``/.json`` endpoint now includes an ``"ok": true`` key, for consistency with other JSON API responses. -- :ref:`call_with_supported_arguments() ` is now documented as a supported public API. (:pr:`2678`) - -.. _v1_0_a26: - -1.0a26 (2026-03-18) -------------------- - -New ``column_types`` system -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Table columns can now have custom column types assigned to them, using the new ``column_types`` table configuration option or at runtime using a new UI and ``POST //
/-/set-column-type`` JSON API. - -Built-in column types include ``url``, ``email``, and ``json``, and plugins can register additional types using the new :ref:`register_column_types() ` plugin hook. (:issue:`2664`, :issue:`2671`) - -Column types can customize HTML rendering, validate values written through the insert, update, and upsert APIs, and transform values returned by the JSON API. They can optionally restrict themselves to specific SQLite column types using ``sqlite_types``. This feature also introduces a new :ref:`set-column-type ` permission for assigning column types to a table. (:issue:`2672`) - -The :ref:`render_cell() ` plugin hook now receives a ``column_type`` argument containing the assigned type instance, and a column type's own ``render_cell()`` method takes priority over the plugin hook chain. - -The `datasette-files `__ plugin will be the first to use this new feature. - -UI for selecting columns and their order -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Table and view pages now include a dialog for selecting and re-ordering visible columns. (:issue:`2661`) - -Other changes -~~~~~~~~~~~~~ - -- Fixed ``allowed_resources("view-query", actor)`` so actor-specific canned queries are returned correctly. Any plugin that defines a ``resources_sql()`` method on a ``Resource`` subclass needs to update to the new signature, see :ref:`the resources_sql() method` documentation for details. -- Column actions can now be accessed in mobile view via a new "Column actions" button. Previously they were not available on mobile because table headers are not displayed there. (:issue:`2669`, :issue:`2670`) -- Row pages now render foreign key values as links to the referenced row. (:issue:`1592`) -- The ``startup()`` plugin hook now fires after metadata and internal schema tables have been populated, so plugins can reliably inspect that state during startup. (:issue:`2666`) - -.. _v1_0_a25: - -1.0a25 (2026-02-25) -------------------- - -``write_wrapper()`` plugin hook for intercepting write operations -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -A new :ref:`write_wrapper() ` plugin hook allows plugins to intercept and wrap database write operations. (:pr:`2636`) - -Plugins implement the hook as a generator-based context manager: - -.. code-block:: python - - @hookimpl - def write_wrapper(datasette, database, request): - def wrapper(conn): - # Setup code runs before the write - yield - # Cleanup code runs after the write - - return wrapper - -``register_token_handler()`` plugin hook for custom API token backends -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -A new :ref:`register_token_handler() ` plugin hook allows plugins to provide custom token backends for API authentication. (:pr:`2650`) - -This includes a **backwards incompatible change**: the ``datasette.create_token()`` internal method is now an ``async`` method. Consult the :ref:`upgrade guide ` for details on how to update your code. - -``render_cell()`` now receives a ``pks`` parameter -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The :ref:`render_cell() ` plugin hook now receives a ``pks`` parameter containing the list of primary key column names for the table being rendered. This avoids plugins needing to make redundant async calls to look up primary keys. (:pr:`2641`) - -Other changes -~~~~~~~~~~~~~ - -- Facets defined in metadata now preserve their configured order, instead of being sorted by result count. Request-based facets added via the ``_facet`` parameter are still sorted by result count and appear after metadata-defined facets. (:issue:`2647`) -- Fixed ``--reload`` incorrectly interpreting the ``serve`` command as a file argument. Thanks, `Daniel Bates `__. (:pr:`2646`) - -.. _v1_0_a24: - -1.0a24 (2026-01-29) -------------------- - -``request.form()`` method for POST data and file uploads -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Datasette now includes a ``request.form()`` method for parsing form submissions, including handling file uploads. (:pr:`2626`) - -This supports both ``application/x-www-form-urlencoded`` and ``multipart/form-data`` content types, and uses a new streaming multipart parser that processes uploads without buffering entire request bodies in memory. - -.. code-block:: python - - # Parse form fields (files are discarded by default) - form = await request.form() - username = form["username"] - - # Parse form fields AND file uploads - form = await request.form(files=True) - uploaded = form["avatar"] - content = await uploaded.read() - -The returned :ref:`FormData ` object provides dictionary-style access with support for multiple values per key via ``form.getlist("key")``. Uploaded files are represented as :ref:`UploadedFile ` objects with ``filename``, ``content_type``, ``size`` properties and async ``read()`` and ``seek()`` methods. - -Files smaller than 1MB are held in memory; larger files automatically spill to temporary files on disk. Configurable limits control maximum file size, request size, field counts and more. - -Several internal views (permissions debug, messages debug, create token) now use ``request.form()`` instead of ``request.post_vars()``. - -``request.post_vars()`` remains available for backwards compatibility but is no longer the recommended API for handling POST data. - -``render_cell`` and ``foreign_key_tables`` extras for the JSON API -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The table JSON API now supports ``?_extra=render_cell``, which returns the rendered HTML for each cell as produced by the :ref:`render_cell plugin hook `. Only columns whose rendered output differs from the default are included. (:issue:`2619`) - -The row JSON API also gains ``?_extra=render_cell`` and ``?_extra=foreign_key_tables`` extras, bringing it closer to parity with the table API. - -The row JSON API now returns ``"ok": true`` in its response, for consistency with the table API. - -``uv run pytest`` with a ``dev=`` dependency group -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The recommended development environment for Datasette now uses `uv `__. You can now set up a development environment and run the test suite with just ``uv run pytest`` — no manual virtualenv or ``pip install`` step required. (:issue:`2611`) - -Other changes -~~~~~~~~~~~~~ - -- Plugins that raise ``datasette.utils.StartupError()`` during startup now display a clean error message instead of a full traceback. (:issue:`2624`) -- Schema refreshes are now throttled to at most once per second, providing a small performance increase. (:issue:`2629`) -- Minor performance improvement to ``remove_infinites`` — rows without infinity values now skip the list/dict reconstruction step. (:issue:`2629`) -- Filter inputs and the search input no longer trigger unwanted zoom on iOS Safari. Thanks, `Daniel Olasubomi Sobowale `__. (:issue:`2346`) -- ``table_names()`` and ``get_all_foreign_keys()`` now return results in deterministic sorted order. (:issue:`2628`) -- Switched linting to `ruff `__ and fixed all lint errors. (:issue:`2630`) - -.. _v1_0_a23: - -1.0a23 (2025-12-02) -------------------- - -- Fix for bug where a stale database entry in ``internal.db`` could cause a 500 error on the homepage. (:issue:`2605`) -- Cosmetic improvement to ``/-/actions`` page. (:issue:`2599`) - -.. _v1_0_a22: - -1.0a22 (2025-11-13) -------------------- - -- ``datasette serve --default-deny`` option for running Datasette configured to :ref:`deny all permissions by default `. (:issue:`2592`) -- ``datasette.is_client()`` method for detecting if code is :ref:`executing inside a datasette.client request `. (:issue:`2594`) -- ``datasette.pm`` property can now be used to :ref:`register and unregister plugins in tests `. (:issue:`2595`) - -.. _v1_0_a21: - -1.0a21 (2025-11-05) -------------------- - -- Fixes an **open redirect** security issue: Datasette instances would redirect to ``example.com/foo/bar`` if you accessed the path ``//example.com/foo/bar``. Thanks to `James Jefferies `__ for the fix. (:issue:`2429`) -- Fixed ``datasette publish cloudrun`` to work with changes to the underlying Cloud Run architecture. (:issue:`2511`) -- New ``datasette --get /path --headers`` option for inspecting the headers returned by a path. (:issue:`2578`) -- New ``datasette.client.get(..., skip_permission_checks=True)`` parameter to bypass permission checks when making requests using the internal client. (:issue:`2583`) - -.. _v0_65_2: - -0.65.2 (2025-11-05) -------------------- - -- Fixes an **open redirect** security issue: Datasette instances would redirect to ``example.com/foo/bar`` if you accessed the path ``//example.com/foo/bar``. Thanks to `James Jefferies `__ for the fix. (:issue:`2429`) -- Upgraded for compatibility with Python 3.14. -- Fixed ``datasette publish cloudrun`` to work with changes to the underlying Cloud Run architecture. (:issue:`2511`) -- Minor upgrades to fix warnings, including ``pkg_resources`` deprecation. - -.. _v1_0_a20: - -1.0a20 (2025-11-03) -------------------- - -This alpha introduces a major breaking change prior to the 1.0 release of Datasette concerning how Datasette's permission system works. - -Permission system redesign -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Previously the permission system worked using ``datasette.permission_allowed()`` checks which consulted all available plugins in turn to determine whether a given actor was allowed to perform a given action on a given resource. - -This approach could become prohibitively expensive for large lists of items - for example to determine the list of tables that a user could view in a large Datasette instance each plugin implementation of that hook would be fired for every table. - -The new design uses SQL queries against Datasette's internal :ref:`catalog tables ` to derive the list of resources for which an actor has permission for a given action. This turns an N x M problem (N resources, M plugins) into a single SQL query. - -Plugins can use the new :ref:`plugin_hook_permission_resources_sql` hook to return SQL fragments which will be used as part of that query. - -Plugins that use any of the following features will need to be updated to work with this and following alphas (and Datasette 1.0 stable itself): - -- Checking permissions with ``datasette.permission_allowed()`` - this method has been replaced with :ref:`datasette.allowed() `. -- Implementing the ``permission_allowed()`` plugin hook - this hook has been removed in favor of :ref:`permission_resources_sql() `. -- Using ``register_permissions()`` to register permissions - this hook has been removed in favor of :ref:`register_actions() `. - -Consult the :ref:`v1.0a20 upgrade guide ` for further details on how to upgrade affected plugins. - -Plugins can now make use of two new internal methods to help resolve permission checks: - -- :ref:`datasette.allowed_resources() ` returns a ``PaginatedResources`` object with a ``.resources`` list of ``Resource`` instances that an actor is allowed to access for a given action (and a ``.next`` token for pagination). -- :ref:`datasette.allowed_resources_sql() ` returns the SQL and parameters that can be executed against the internal catalog tables to determine which resources an actor is allowed to access for a given action. This can be combined with further SQL to perform advanced custom filtering. - -Related changes: - -- The way ``datasette --root`` works has changed. Running Datasette with this flag now causes the root actor to pass *all* permission checks. (:issue:`2521`) - -- Permission debugging improvements: - - - The ``/-/allowed`` endpoint shows resources the user is allowed to interact with for different actions. - - ``/-/rules`` shows the raw allow/deny rules that apply to different permission checks. - - ``/-/actions`` lists every available action. - - ``/-/check`` can be used to try out different permission checks for the current actor. - -Other changes -~~~~~~~~~~~~~ - -- The internal ``catalog_views`` table now tracks SQLite views alongside tables in the introspection database. (:issue:`2495`) -- Hitting the ``/`` brings up a search interface for navigating to tables that the current user can view. A new ``/-/tables`` endpoint supports this functionality. (:issue:`2523`) -- Datasette attempts to detect some configuration errors on startup. -- Datasette now supports Python 3.14 and no longer tests against Python 3.9. - -.. _v1_0_a19: - -1.0a19 (2025-04-21) -------------------- - -- Tiny cosmetic bug fix for mobile display of table rows. (:issue:`2479`) - -.. _v1_0_a18: - -1.0a18 (2025-04-16) -------------------- - -- Fix for incorrect foreign key references in the internal database schema. (:issue:`2466`) -- The ``prepare_connection()`` hook no longer runs for the internal database. (:issue:`2468`) -- Fixed bug where ``link:`` HTTP headers used invalid syntax. (:issue:`2470`) -- No longer tested against Python 3.8. Now tests against Python 3.13. -- FTS tables are now hidden by default if they correspond to a content table. (:issue:`2477`) -- Fixed bug with foreign key links to rows in databases with filenames containing a special character. Thanks, `Jack Stratton `__. (:pr:`2476`) - -.. _v1_0_a17: - -1.0a17 (2025-02-06) -------------------- - -- ``DATASETTE_SSL_KEYFILE`` and ``DATASETTE_SSL_CERTFILE`` environment variables as alternatives to ``--ssl-keyfile`` and ``--ssl-certfile``. Thanks, Alex Garcia. (:issue:`2422`) -- ``SQLITE_EXTENSIONS`` environment variable has been renamed to ``DATASETTE_LOAD_EXTENSION``. (:issue:`2424`) -- ``datasette serve`` environment variables are now :ref:`documented here `. -- The :ref:`plugin_hook_register_magic_parameters` plugin hook can now register async functions. (:issue:`2441`) -- Datasette is now tested against Python 3.13. -- Breadcrumbs on database and table pages now include a consistent self-link for resetting query string parameters. (:issue:`2454`) -- Fixed issue where Datasette could crash on ``metadata.json`` with nested values. (:issue:`2455`) -- New internal methods ``datasette.set_actor_cookie()`` and ``datasette.delete_actor_cookie()``, :ref:`described here `. (:issue:`1690`) -- ``/-/permissions`` page now shows a list of all permissions registered by plugins. (:issue:`1943`) -- If a table has a single unique text column Datasette now detects that as the foreign key label for that table. (:issue:`2458`) -- The ``/-/permissions`` page now includes options for filtering or exclude permission checks recorded against the current user. (:issue:`2460`) -- Fixed a bug where replacing a database with a new one with the same name did not pick up the new database correctly. (:issue:`2465`) - -.. _v0_65_1: - -0.65.1 (2024-11-28) -------------------- - -- Fixed bug with upgraded HTTPX 0.28.0 dependency. (:issue:`2443`) - -.. _v0_65: - -0.65 (2024-10-07) ------------------ - -- Upgrade for compatibility with Python 3.13 (by vendoring Pint dependency). (:issue:`2434`) -- Dropped support for Python 3.8. - -.. _v1_0_a16: - -1.0a16 (2024-09-05) -------------------- - -This release focuses on performance, in particular against large tables, and introduces some minor breaking changes for CSS styling in Datasette plugins. - -- Removed the unit conversions feature and its dependency, Pint. This means Datasette is now compatible with the upcoming Python 3.13. (:issue:`2400`, :issue:`2320`) -- The ``datasette --pdb`` option now uses the `ipdb `__ debugger if it is installed. You can install it using ``datasette install ipdb``. Thanks, `Tiago Ilieve `__. (:pr:`2342`) -- Fixed a confusing error that occurred if ``metadata.json`` contained nested objects. (:issue:`2403`) -- Fixed a bug with ``?_trace=1`` where it returned a blank page if the response was larger than 256KB. (:issue:`2404`) -- Tracing mechanism now also displays SQL queries that returned errors or ran out of time. `datasette-pretty-traces 0.5 `__ includes support for displaying this new type of trace. (:issue:`2405`) -- Fixed a text spacing with table descriptions on the homepage. (:issue:`2399`) -- Performance improvements for large tables: - - Suggested facets now only consider the first 1000 rows. (:issue:`2406`) - - Improved performance of date facet suggestion against large tables. (:issue:`2407`) - - Row counts stop at 10,000 rows when listing tables. (:issue:`2398`) - - On table page the count stops at 10,000 rows too, with a "count all" button to execute the full count. (:issue:`2408`) -- New ``.dicts()`` internal method on :ref:`database_results` that returns a list of dictionaries representing the results from a SQL query: (:issue:`2414`) - - .. code-block:: bash - - rows = (await db.execute("select * from t")).dicts() - -- Default Datasette core CSS that styles inputs and buttons now requires a class of ``"core"`` on the element or a containing element, for example ``
``. (:issue:`2415`) -- Similarly, default table styles now only apply to ``
``. (:issue:`2420`) - -.. _v1_0_a15: - -1.0a15 (2024-08-15) -------------------- - -- Datasette now defaults to hiding SQLite "shadow" tables, as seen in extensions such as SQLite FTS and `sqlite-vec `__. Virtual tables that it makes sense to display, such as FTS core tables, are no longer hidden. Thanks, `Alex Garcia `__. (:issue:`2296`) -- Fixed bug where running Datasette with one or more ``-s/--setting`` options could over-ride settings that were present in ``datasette.yml``. (:issue:`2389`) -- The Datasette homepage is now duplicated at ``/-/``, using the default ``index.html`` template. This ensures that the information on that page is still accessible even if the Datasette homepage has been customized using a custom ``index.html`` template, for example on sites like `datasette.io `__. (:issue:`2393`) -- Failed CSRF checks now display a more user-friendly error page. (:issue:`2390`) -- Fixed a bug where the ``json1`` extension was not correctly detected on the ``/-/versions`` page. Thanks, `Seb Bacon `__. (:issue:`2326`) -- Fixed a bug where the Datasette write API did not correctly accept ``Content-Type: application/json; charset=utf-8``. (:issue:`2384`) -- Fixed a bug where Datasette would fail to start if ``metadata.yml`` contained a ``queries`` block. (:pr:`2386`) - -.. _v1_0_a14: - -1.0a14 (2024-08-05) -------------------- - -This alpha introduces significant changes to Datasette's :ref:`metadata` system, some of which represent breaking changes in advance of the full 1.0 release. The new :ref:`upgrade_guide` document provides detailed coverage of those breaking changes and how they affect plugin authors and Datasette API consumers. - -- The ``/databasename?sql=`` interface and JSON API for executing arbitrary SQL queries can now be found at ``/databasename/-/query?sql=``. Requests with a ``?sql=`` parameter to the old endpoints will be redirected. Thanks, `Alex Garcia `__. (:issue:`2360`) -- Metadata about tables, databases, instances and columns is now stored in :ref:`internals_internal`. Thanks, Alex Garcia. (:issue:`2341`) -- Database write connections now execute using the ``IMMEDIATE`` isolation level for SQLite. This should help avoid a rare ``SQLITE_BUSY`` error that could occur when a transaction upgraded to a write mid-flight. (:issue:`2358`) -- Fix for a bug where canned queries with named parameters could fail against SQLite 3.46. (:issue:`2353`) -- Datasette now serves ``E-Tag`` headers for static files. Thanks, `Agustin Bacigalup `__. (:pr:`2306`) -- Dropdown menus now use a ``z-index`` that should avoid them being hidden by plugins. (:issue:`2311`) -- Incorrect table and row names are no longer reflected back on the resulting 404 page. (:issue:`2359`) -- Improved documentation for async usage of the :ref:`plugin_hook_track_event` hook. (:issue:`2319`) -- Fixed some HTTPX deprecation warnings. (:issue:`2307`) -- Datasette now serves a ```` attribute. Thanks, `Charles Nepote `__. (:issue:`2348`) -- Datasette's automated tests now run against the maximum and minimum supported versions of SQLite: 3.25 (from September 2018) and 3.46 (from May 2024). Thanks, Alex Garcia. (:pr:`2352`) -- Fixed an issue where clicking twice on the URL output by ``datasette --root`` produced a confusing error. (:issue:`2375`) - -.. _v0_64_8: - -0.64.8 (2024-06-21) -------------------- - -- Security improvement: 404 pages used to reflect content from the URL path, which could be used to display misleading information to Datasette users. 404 errors no longer display additional information from the URL. (:issue:`2359`) -- Backported a better fix for correctly extracting named parameters from canned query SQL against SQLite 3.46.0. (:issue:`2353`) - -.. _v0_64_7: - -0.64.7 (2024-06-12) -------------------- - -- Fixed a bug where canned queries with named parameters threw an error when run against SQLite 3.46.0. (:issue:`2353`) - -.. _v1_0_a13: - -1.0a13 (2024-03-12) -------------------- - -Each of the key concepts in Datasette now has an :ref:`actions menu `, which plugins can use to add additional functionality targeting that entity. - -- Plugin hook: :ref:`view_actions() ` for actions that can be applied to a SQL view. (:issue:`2297`) -- Plugin hook: :ref:`homepage_actions() ` for actions that apply to the instance homepage. (:issue:`2298`) -- Plugin hook: :ref:`row_actions() ` for actions that apply to the row page. (:issue:`2299`) -- Action menu items for all of the ``*_actions()`` plugin hooks can now return an optional ``"description"`` key, which will be displayed in the menu below the action label. (:issue:`2294`) -- :ref:`Plugin hooks ` documentation page is now organized with additional headings. (:issue:`2300`) -- Improved the display of action buttons on pages that also display metadata. (:issue:`2286`) -- The header and footer of the page now uses a subtle gradient effect, and options in the navigation menu are better visually defined. (:issue:`2302`) -- Table names that start with an underscore now default to hidden. (:issue:`2104`) -- ``pragma_table_list`` has been added to the allow-list of SQLite pragma functions supported by Datasette. ``select * from pragma_table_list()`` is no longer blocked. (`#2104 `__) - -.. _v1_0_a12: - -1.0a12 (2024-02-29) -------------------- - -- New :ref:`query_actions() ` plugin hook, similar to :ref:`table_actions() ` and :ref:`database_actions() `. Can be used to add a menu of actions to the canned query or arbitrary SQL query page. (:issue:`2283`) -- New design for the button that opens the query, table and database actions menu. (:issue:`2281`) -- "does not contain" table filter for finding rows that do not contain a string. (:issue:`2287`) -- Fixed a bug in the :ref:`javascript_plugins_makeColumnActions` JavaScript plugin mechanism where the column action menu was not fully reset in between each interaction. (:issue:`2289`) - -.. _v1_0_a11: - -1.0a11 (2024-02-19) -------------------- - -- The ``"replace": true`` argument to the ``/db/table/-/insert`` API now requires the actor to have the ``update-row`` permission. (:issue:`2279`) -- Fixed some UI bugs in the interactive permissions debugging tool. (:issue:`2278`) -- The column action menu now aligns better with the cog icon, and positions itself taking into account the width of the browser window. (:issue:`2263`) - -.. _v1_0_a10: - -1.0a10 (2024-02-17) -------------------- - -The only changes in this alpha correspond to the way Datasette handles database transactions. (:issue:`2277`) - -- The :ref:`database.execute_write_fn() ` method has a new ``transaction=True`` parameter. This defaults to ``True`` which means all functions executed using this method are now automatically wrapped in a transaction - previously the functions needed to roll transaction handling on their own, and many did not. -- Pass ``transaction=False`` to ``execute_write_fn()`` if you want to manually handle transactions in your function. -- Several internal Datasette features, including parts of the :ref:`JSON write API `, had been failing to wrap their operations in a transaction. This has been fixed by the new ``transaction=True`` default. - -.. _v1_0_a9: - -1.0a9 (2024-02-16) ------------------- - -This alpha release adds basic alter table support to the Datasette Write API and fixes a permissions bug relating to the ``/upsert`` API endpoint. - -Alter table support for create, insert, upsert and update -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The :ref:`JSON write API ` can now be used to apply simple alter table schema changes, provided the acting actor has the new :ref:`actions_alter_table` permission. (:issue:`2101`) - -The only alter operation supported so far is adding new columns to an existing table. - -* The :ref:`/db/-/create ` API now adds new columns during large operations to create a table based on incoming example ``"rows"``, in the case where one of the later rows includes columns that were not present in the earlier batches. This requires the ``create-table`` but not the ``alter-table`` permission. -* When ``/db/-/create`` is called with rows in a situation where the table may have been already created, an ``"alter": true`` key can be included to indicate that any missing columns from the new rows should be added to the table. This requires the ``alter-table`` permission. -* :ref:`/db/table/-/insert ` and :ref:`/db/table/-/upsert ` and :ref:`/db/table/row-pks/-/update ` all now also accept ``"alter": true``, depending on the ``alter-table`` permission. - -Operations that alter a table now fire the new :ref:`alter-table event `. - -Permissions fix for the upsert API -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The :ref:`/database/table/-/upsert API ` had a minor permissions bug, only affecting Datasette instances that had configured the ``insert-row`` and ``update-row`` permissions to apply to a specific table rather than the database or instance as a whole. Full details in issue :issue:`2262`. - -To avoid similar mistakes in the future the ``datasette.permission_allowed()`` method now specifies ``default=`` as a keyword-only argument. - -Permission checks now consider opinions from every plugin -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The ``datasette.permission_allowed()`` method previously consulted every plugin that implemented the ``permission_allowed()`` plugin hook and obeyed the opinion of the last plugin to return a value. (:issue:`2275`) - -Datasette now consults every plugin and checks to see if any of them returned ``False`` (the veto rule), and if none of them did, it then checks to see if any of them returned ``True``. - -This is explained at length in the new documentation covering :ref:`authentication_permissions_explained`. - -Other changes -~~~~~~~~~~~~~ - -- The new :ref:`DATASETTE_TRACE_PLUGINS=1 environment variable ` turns on detailed trace output for every executed plugin hook, useful for debugging and understanding how the plugin system works at a low level. (:issue:`2274`) -- Datasette on Python 3.9 or above marks its non-cryptographic uses of the MD5 hash function as ``usedforsecurity=False``, for compatibility with FIPS systems. (:issue:`2270`) -- SQL relating to :ref:`internals_internal` now executes inside a transaction, avoiding a potential database locked error. (:issue:`2273`) -- The ``/-/threads`` debug page now identifies the database in the name associated with each dedicated write thread. (:issue:`2265`) -- The ``/db/-/create`` API now fires a ``insert-rows`` event if rows were inserted after the table was created. (:issue:`2260`) - -.. _v1_0_a8: - -1.0a8 (2024-02-07) ------------------- - -This alpha release continues the migration of Datasette's configuration from ``metadata.yaml`` to the new ``datasette.yaml`` configuration file, introduces a new system for JavaScript plugins and adds several new plugin hooks. - -See `Datasette 1.0a8: JavaScript plugins, new plugin hooks and plugin configuration in datasette.yaml `__ for an annotated version of these release notes. - -Configuration -~~~~~~~~~~~~~ - -- Plugin configuration now lives in the :ref:`datasette.yaml configuration file `, passed to Datasette using the ``-c/--config`` option. Thanks, Alex Garcia. (:issue:`2093`) - - .. code-block:: bash - - datasette -c datasette.yaml - - Where ``datasette.yaml`` contains configuration that looks like this: - - .. code-block:: yaml - - plugins: - datasette-cluster-map: - latitude_column: xlat - longitude_column: xlon - - Previously plugins were configured in ``metadata.yaml``, which was confusing as plugin settings were unrelated to database and table metadata. -- The ``-s/--setting`` option can now be used to set plugin configuration as well. See :ref:`configuration_cli` for details. (:issue:`2252`) - - The above YAML configuration example using ``-s/--setting`` looks like this: - - .. code-block:: bash - - datasette mydatabase.db \ - -s plugins.datasette-cluster-map.latitude_column xlat \ - -s plugins.datasette-cluster-map.longitude_column xlon - -- The new ``/-/config`` page shows the current instance configuration, after redacting keys that could contain sensitive data such as API keys or passwords. (:issue:`2254`) - -- Existing Datasette installations may already have configuration set in ``metadata.yaml`` that should be migrated to ``datasette.yaml``. To avoid breaking these installations, Datasette will silently treat table configuration, plugin configuration and allow blocks in metadata as if they had been specified in configuration instead. (:issue:`2247`) (:issue:`2248`) (:issue:`2249`) - -Note that the ``datasette publish`` command has not yet been updated to accept a ``datasette.yaml`` configuration file. This will be addressed in :issue:`2195` but for the moment you can include those settings in ``metadata.yaml`` instead. - -JavaScript plugins -~~~~~~~~~~~~~~~~~~ - -Datasette now includes a :ref:`JavaScript plugins mechanism `, allowing JavaScript to customize Datasette in a way that can collaborate with other plugins. - -This provides two initial hooks, with more to come in the future: - -- :ref:`makeAboveTablePanelConfigs() ` can add additional panels to the top of the table page. -- :ref:`makeColumnActions() ` can add additional actions to the column menu. - -Thanks `Cameron Yick `__ for contributing this feature. (:pr:`2052`) - -Plugin hooks -~~~~~~~~~~~~ - -- New :ref:`plugin_hook_jinja2_environment_from_request` plugin hook, which can be used to customize the current Jinja environment based on the incoming request. This can be used to modify the template lookup path based on the incoming request hostname, among other things. (:issue:`2225`) -- New :ref:`family of template slot plugin hooks `: ``top_homepage``, ``top_database``, ``top_table``, ``top_row``, ``top_query``, ``top_canned_query``. Plugins can use these to provide additional HTML to be injected at the top of the corresponding pages. (:issue:`1191`) -- New :ref:`track_event() mechanism ` for plugins to emit and receive events when certain events occur within Datasette. (:issue:`2240`) - - Plugins can register additional event classes using :ref:`plugin_hook_register_events`. - - They can then trigger those events with the :ref:`datasette.track_event(event) ` internal method. - - Plugins can subscribe to notifications of events using the :ref:`plugin_hook_track_event` plugin hook. - - Datasette core now emits ``login``, ``logout``, ``create-token``, ``create-table``, ``drop-table``, ``insert-rows``, ``upsert-rows``, ``update-row``, ``delete-row`` events, :ref:`documented here `. -- New internal function for plugin authors: :ref:`database_execute_isolated_fn`, for creating a new SQLite connection, executing code and then closing that connection, all while preventing other code from writing to that particular database. This connection will not have the :ref:`prepare_connection() ` plugin hook executed against it, allowing plugins to perform actions that might otherwise be blocked by existing connection configuration. (:issue:`2218`) - -Documentation -~~~~~~~~~~~~~ - -- Documentation describing :ref:`how to write tests that use signed actor cookies ` using ``datasette.client.actor_cookie()``. (:issue:`1830`) -- Documentation on how to :ref:`register a plugin for the duration of a test `. (:issue:`2234`) -- The :ref:`configuration documentation ` now shows examples of both YAML and JSON for each setting. - -Minor fixes -~~~~~~~~~~~ - -- Datasette no longer attempts to run SQL queries in parallel when rendering a table page, as this was leading to some rare crashing bugs. (:issue:`2189`) -- Fixed warning: ``DeprecationWarning: pkg_resources is deprecated as an API`` (:issue:`2057`) -- Fixed bug where ``?_extra=columns`` parameter returned an incorrectly shaped response. (:issue:`2230`) - -.. _v0_64_6: - -0.64.6 (2023-12-22) -------------------- - -- Fixed a bug where CSV export with expanded labels could fail if a foreign key reference did not correctly resolve. (:issue:`2214`) - -.. _v0_64_5: - -0.64.5 (2023-10-08) -------------------- - -- Dropped dependency on ``click-default-group-wheel``, which could cause a dependency conflict. (:issue:`2197`) - -.. _v1_0_a7: - -1.0a7 (2023-09-21) ------------------- - -- Fix for a crashing bug caused by viewing the table page for a named in-memory database. (:issue:`2189`) - -.. _v0_64_4: - -0.64.4 (2023-09-21) -------------------- - -- Fix for a crashing bug caused by viewing the table page for a named in-memory database. (:issue:`2189`) - -.. _v1_0_a6: - -1.0a6 (2023-09-07) ------------------- - -- New plugin hook: :ref:`plugin_hook_actors_from_ids` and an internal method to accompany it, :ref:`datasette_actors_from_ids`. This mechanism is intended to be used by plugins that may need to display the actor who was responsible for something managed by that plugin: they can now resolve the recorded IDs of actors into the full actor objects. (:issue:`2181`) -- ``DATASETTE_LOAD_PLUGINS`` environment variable for :ref:`controlling which plugins ` are loaded by Datasette. (:issue:`2164`) -- Datasette now checks if the user has permission to view a table linked to by a foreign key before turning that foreign key into a clickable link. (:issue:`2178`) -- The ``execute-sql`` permission now implies that the actor can also view the database and instance. (:issue:`2169`) -- Documentation describing a pattern for building plugins that themselves :ref:`define further hooks ` for other plugins. (:issue:`1765`) -- Datasette is now tested against the Python 3.12 preview. (:pr:`2175`) - -.. _v1_0_a5: - -1.0a5 (2023-08-29) ------------------- - -- When restrictions are applied to :ref:`API tokens `, those restrictions now behave slightly differently: applying the ``view-table`` restriction will imply the ability to ``view-database`` for the database containing that table, and both ``view-table`` and ``view-database`` will imply ``view-instance``. Previously you needed to create a token with restrictions that explicitly listed ``view-instance`` and ``view-database`` and ``view-table`` in order to view a table without getting a permission denied error. (:issue:`2102`) -- New ``datasette.yaml`` (or ``.json``) configuration file, which can be specified using ``datasette -c path-to-file``. The goal here to consolidate settings, plugin configuration, permissions, canned queries, and other Datasette configuration into a single single file, separate from ``metadata.yaml``. The legacy ``settings.json`` config file used for :ref:`config_dir` has been removed, and ``datasette.yaml`` has a ``"settings"`` section where the same settings key/value pairs can be included. In the next future alpha release, more configuration such as plugins/permissions/canned queries will be moved to the ``datasette.yaml`` file. See :issue:`2093` for more details. Thanks, Alex Garcia. -- The ``-s/--setting`` option can now take dotted paths to nested settings. These will then be used to set or over-ride the same options as are present in the new configuration file. (:issue:`2156`) -- New ``--actor '{"id": "json-goes-here"}'`` option for use with ``datasette --get`` to treat the simulated request as being made by a specific actor, see :ref:`cli_datasette_get`. (:issue:`2153`) -- The Datasette ``_internal`` database has had some changes. It no longer shows up in the ``datasette.databases`` list by default, and is now instead available to plugins using the ``datasette.get_internal_database()``. Plugins are invited to use this as a private database to store configuration and settings and secrets that should not be made visible through the default Datasette interface. Users can pass the new ``--internal internal.db`` option to persist that internal database to disk. Thanks, Alex Garcia. (:issue:`2157`). - -.. _v1_0_a4: - -1.0a4 (2023-08-21) ------------------- - -This alpha fixes a security issue with the ``/-/api`` API explorer. On authenticated Datasette instances (instances protected using plugins such as `datasette-auth-passwords `__) the API explorer interface could reveal the names of databases and tables within the protected instance. The data stored in those tables was not revealed. - -For more information and workarounds, read `the security advisory `__. The issue has been present in every previous alpha version of Datasette 1.0: versions 1.0a0, 1.0a1, 1.0a2 and 1.0a3. - -Also in this alpha: - -- The new ``datasette plugins --requirements`` option outputs a list of currently installed plugins in Python ``requirements.txt`` format, useful for duplicating that installation elsewhere. (:issue:`2133`) -- :ref:`queries_writable` can now define a ``on_success_message_sql`` field in their configuration, containing a SQL query that should be executed upon successful completion of the write operation in order to generate a message to be shown to the user. (:issue:`2138`) -- The automatically generated border color for a database is now shown in more places around the application. (:issue:`2119`) -- Every instance of example shell script code in the documentation should now include a working copy button, free from additional syntax. (:issue:`2140`) - -.. _v1_0_a3: - -1.0a3 (2023-08-09) ------------------- - -This alpha release previews the updated design for Datasette's default JSON API. (:issue:`782`) - -The new :ref:`default JSON representation ` for both table pages (``/dbname/table.json``) and arbitrary SQL queries (``/dbname.json?sql=...``) is now shaped like this: - -.. code-block:: json - - { - "ok": true, - "rows": [ - { - "id": 3, - "name": "Detroit" - }, - { - "id": 2, - "name": "Los Angeles" - }, - { - "id": 4, - "name": "Memnonia" - }, - { - "id": 1, - "name": "San Francisco" - } - ], - "truncated": false - } - -Tables will include an additional ``"next"`` key for pagination, which can be passed to ``?_next=`` to fetch the next page of results. - -The various ``?_shape=`` options continue to work as before - see :ref:`json_api_shapes` for details. - -A new ``?_extra=`` mechanism is available for tables, but has not yet been stabilized or documented. Details on that are available in :issue:`262`. - -Smaller changes -~~~~~~~~~~~~~~~ - -- Datasette documentation now shows YAML examples for :ref:`metadata` by default, with a tab interface for switching to JSON. (:issue:`1153`) -- :ref:`plugin_register_output_renderer` plugins now have access to ``error`` and ``truncated`` arguments, allowing them to display error messages and take into account truncated results. (:issue:`2130`) -- ``render_cell()`` plugin hook now also supports an optional ``request`` argument. (:issue:`2007`) -- New ``Justfile`` to support development workflows for Datasette using `Just `__. -- ``datasette.render_template()`` can now accepts a ``datasette.views.Context`` subclass as an alternative to a dictionary. (:issue:`2127`) -- ``datasette install -e path`` option for editable installations, useful while developing plugins. (:issue:`2106`) -- When started with the ``--cors`` option Datasette now serves an ``Access-Control-Max-Age: 3600`` header, ensuring CORS OPTIONS requests are repeated no more than once an hour. (:issue:`2079`) -- Fixed a bug where the ``_internal`` database could display ``None`` instead of ``null`` for in-memory databases. (:issue:`1970`) - -.. _v0_64_2: - -0.64.2 (2023-03-08) -------------------- - -- Fixed a bug with ``datasette publish cloudrun`` where deploys all used the same Docker image tag. This was mostly inconsequential as the service is deployed as soon as the image has been pushed to the registry, but could result in the incorrect image being deployed if two different deploys for two separate services ran at exactly the same time. (:issue:`2036`) - -.. _v0_64_1: - -0.64.1 (2023-01-11) -------------------- - -- Documentation now links to a current source of information for installing Python 3. (:issue:`1987`) -- Incorrectly calling the Datasette constructor using ``Datasette("path/to/data.db")`` instead of ``Datasette(["path/to/data.db"])`` now returns a useful error message. (:issue:`1985`) - -.. _v0_64: - -0.64 (2023-01-09) ------------------ - -- Datasette now **strongly recommends against allowing arbitrary SQL queries if you are using SpatiaLite**. SpatiaLite includes SQL functions that could cause the Datasette server to crash. See :ref:`spatialite` for more details. -- New :ref:`setting_default_allow_sql` setting, providing an easier way to disable all arbitrary SQL execution by end users: ``datasette --setting default_allow_sql off``. See also :ref:`authentication_permissions_execute_sql`. (:issue:`1409`) -- `Building a location to time zone API with SpatiaLite `__ is a new Datasette tutorial showing how to safely use SpatiaLite to create a location to time zone API. -- New documentation about :ref:`how to debug problems loading SQLite extensions `. The error message shown when an extension cannot be loaded has also been improved. (:issue:`1979`) -- Fixed an accessibility issue: the ``\n' - "" + assert 'SQL Interrupted' == response.json['title'] + + +def test_custom_sql_time_limit(app_client): + response = app_client.get( + '/test_tables.json?sql=select+sleep(0.01)' + ) + assert 200 == response.status + response = app_client.get( + '/test_tables.json?sql=select+sleep(0.01)&_timelimit=5' + ) + assert 400 == response.status + assert 'SQL Interrupted' == response.json['title'] + + +def test_invalid_custom_sql(app_client): + response = app_client.get( + '/test_tables.json?sql=.schema' + ) + assert response.status == 400 + assert response.json['ok'] is False + assert 'Statement must be a SELECT' == response.json['error'] + + +def test_allow_sql_off(): + for client in app_client(config={ + 'allow_sql': False, + }): + assert 400 == client.get( + "/test_tables.json?sql=select+sleep(0.01)" + ).status + + +def test_table_json(app_client): + response = app_client.get('/test_tables/simple_primary_key.json?_shape=objects') + assert response.status == 200 + data = response.json + assert data['query']['sql'] == 'select * from simple_primary_key order by id limit 51' + assert data['query']['params'] == {} + assert data['rows'] == [{ + 'id': '1', + 'content': 'hello', + }, { + 'id': '2', + 'content': 'world', + }, { + 'id': '3', + 'content': '', + }] + + +def test_table_not_exists_json(app_client): + assert { + 'ok': False, + 'error': 'Table not found: blah', + 'status': 404, + 'title': None, + } == app_client.get('/test_tables/blah.json').json + + +def test_jsono_redirects_to_shape_objects(app_client): + response_1 = app_client.get( + '/test_tables/simple_primary_key.jsono', + allow_redirects=False + ) + response = app_client.get( + response_1.headers['Location'], + allow_redirects=False + ) + assert response.status == 302 + assert response.headers['Location'].endswith('?_shape=objects') + + +def test_table_shape_arrays(app_client): + response = app_client.get( + '/test_tables/simple_primary_key.json?_shape=arrays' + ) + assert [ + ['1', 'hello'], + ['2', 'world'], + ['3', ''], + ] == response.json['rows'] + + +def test_table_shape_arrayfirst(app_client): + response = app_client.get( + '/test_tables.json?' + urllib.parse.urlencode({ + 'sql': 'select content from simple_primary_key order by id', + '_shape': 'arrayfirst' + }) + ) + assert ['hello', 'world', ''] == response.json + + +def test_table_shape_objects(app_client): + response = app_client.get( + '/test_tables/simple_primary_key.json?_shape=objects' + ) + assert [{ + 'id': '1', + 'content': 'hello', + }, { + 'id': '2', + 'content': 'world', + }, { + 'id': '3', + 'content': '', + }] == response.json['rows'] + + +def test_table_shape_array(app_client): + response = app_client.get( + '/test_tables/simple_primary_key.json?_shape=array' + ) + assert [{ + 'id': '1', + 'content': 'hello', + }, { + 'id': '2', + 'content': 'world', + }, { + 'id': '3', + 'content': '', + }] == response.json + + +def test_table_shape_invalid(app_client): + response = app_client.get( + '/test_tables/simple_primary_key.json?_shape=invalid' + ) + assert { + 'ok': False, + 'error': 'Invalid _shape: invalid', + 'status': 400, + 'title': None, + } == response.json + + +def test_table_shape_object(app_client): + response = app_client.get( + '/test_tables/simple_primary_key.json?_shape=object' + ) + assert { + '1': { + 'id': '1', + 'content': 'hello', + }, + '2': { + 'id': '2', + 'content': 'world', + }, + '3': { + 'id': '3', + 'content': '', + } + } == response.json + + +def test_table_shape_object_compound_primary_Key(app_client): + response = app_client.get( + '/test_tables/compound_primary_key.json?_shape=object' + ) + assert { + 'a,b': { + 'pk1': 'a', + 'pk2': 'b', + 'content': 'c', + } + } == response.json + + +def test_table_with_slashes_in_name(app_client): + response = app_client.get('/test_tables/table%2Fwith%2Fslashes.csv?_shape=objects&_format=json') + assert response.status == 200 + data = response.json + assert data['rows'] == [{ + 'pk': '3', + 'content': 'hey', + }] + + +def test_table_with_reserved_word_name(app_client): + response = app_client.get('/test_tables/select.json?_shape=objects') + assert response.status == 200 + data = response.json + assert data['rows'] == [{ + 'rowid': 1, + 'group': 'group', + 'having': 'having', + 'and': 'and', + }] + + +@pytest.mark.parametrize('path,expected_rows,expected_pages', [ + ('/test_tables/no_primary_key.json', 201, 5), + ('/test_tables/paginated_view.json', 201, 5), + ('/test_tables/no_primary_key.json?_size=25', 201, 9), + ('/test_tables/paginated_view.json?_size=25', 201, 9), + ('/test_tables/paginated_view.json?_size=max', 201, 3), + ('/test_tables/123_starts_with_digits.json', 0, 1), + # Ensure faceting doesn't break pagination: + ('/test_tables/compound_three_primary_keys.json?_facet=pk1', 1001, 21), +]) +def test_paginate_tables_and_views(app_client, path, expected_rows, expected_pages): + fetched = [] + count = 0 + while path: + response = app_client.get(path) + assert 200 == response.status + count += 1 + fetched.extend(response.json['rows']) + path = response.json['next_url'] + if path: + assert response.json['next'] + assert urllib.parse.urlencode({ + '_next': response.json['next'] + }) in path + assert count < 30, 'Possible infinite loop detected' + + assert expected_rows == len(fetched) + assert expected_pages == count + + +@pytest.mark.parametrize('path,expected_error', [ + ('/test_tables/no_primary_key.json?_size=-4', '_size must be a positive integer'), + ('/test_tables/no_primary_key.json?_size=dog', '_size must be a positive integer'), + ('/test_tables/no_primary_key.json?_size=1001', '_size must be <= 100'), +]) +def test_validate_page_size(app_client, path, expected_error): + response = app_client.get(path) + assert expected_error == response.json['error'] + assert 400 == response.status + + +def test_page_size_zero(app_client): + "For _size=0 we return the counts, empty rows and no continuation token" + response = app_client.get('/test_tables/no_primary_key.json?_size=0') + assert 200 == response.status + assert [] == response.json['rows'] + assert 201 == response.json['table_rows_count'] + assert 201 == response.json['filtered_table_rows_count'] + assert None is response.json['next'] + assert None is response.json['next_url'] + + +def test_paginate_compound_keys(app_client): + fetched = [] + path = '/test_tables/compound_three_primary_keys.json?_shape=objects' + page = 0 + while path: + page += 1 + response = app_client.get(path) + fetched.extend(response.json['rows']) + path = response.json['next_url'] + assert page < 100 + assert 1001 == len(fetched) + assert 21 == page + # Should be correctly ordered + contents = [f['content'] for f in fetched] + expected = [r[3] for r in generate_compound_rows(1001)] + assert expected == contents + + +def test_paginate_compound_keys_with_extra_filters(app_client): + fetched = [] + path = '/test_tables/compound_three_primary_keys.json?content__contains=d&_shape=objects' + page = 0 + while path: + page += 1 + assert page < 100 + response = app_client.get(path) + fetched.extend(response.json['rows']) + path = response.json['next_url'] + assert 2 == page + expected = [ + r[3] for r in generate_compound_rows(1001) + if 'd' in r[3] + ] + assert expected == [f['content'] for f in fetched] + + +@pytest.mark.parametrize('query_string,sort_key,human_description_en', [ + ('_sort=sortable', lambda row: row['sortable'], 'sorted by sortable'), + ('_sort_desc=sortable', lambda row: -row['sortable'], 'sorted by sortable descending'), + ( + '_sort=sortable_with_nulls', + lambda row: ( + 1 if row['sortable_with_nulls'] is not None else 0, + row['sortable_with_nulls'] ), - "status": 400, - "title": "SQL Interrupted", - } - - -@pytest.mark.asyncio -async def test_custom_sql_time_limit(ds_client): - response = await ds_client.get( - "/fixtures/-/query.json?sql=select+sleep(0.01)", - ) - assert response.status_code == 200 - response = await ds_client.get( - "/fixtures/-/query.json?sql=select+sleep(0.01)&_timelimit=5", - ) - assert response.status_code == 400 - assert response.json()["title"] == "SQL Interrupted" - - -@pytest.mark.asyncio -async def test_invalid_custom_sql(ds_client): - response = await ds_client.get( - "/fixtures/-/query.json?sql=.schema", - ) - assert response.status_code == 400 - assert response.json()["ok"] is False - assert "Statement must be a SELECT" == response.json()["error"] - - -@pytest.mark.asyncio -async def test_row(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key/1.json?_shape=objects") - assert response.status_code == 200 - assert response.json()["ok"] is True - assert response.json()["rows"] == [{"id": 1, "content": "hello"}] - - -@pytest.mark.asyncio -async def test_row_strange_table_name(ds_client): - response = await ds_client.get( - "/fixtures/table~2Fwith~2Fslashes~2Ecsv/3.json?_shape=objects" - ) - assert response.status_code == 200 - assert response.json()["rows"] == [{"pk": "3", "content": "hey"}] - - -@pytest.mark.asyncio -async def test_row_foreign_key_tables(ds_client): - response = await ds_client.get( - "/fixtures/simple_primary_key/1.json?_extras=foreign_key_tables" - ) - assert response.status_code == 200 - # Foreign keys are sorted by (other_table, column, other_column) - assert response.json()["foreign_key_tables"] == [ - { - "other_table": "complex_foreign_keys", - "column": "id", - "other_column": "f1", - "count": 1, - "link": "/fixtures/complex_foreign_keys?f1=1", - }, - { - "other_table": "complex_foreign_keys", - "column": "id", - "other_column": "f2", - "count": 0, - "link": "/fixtures/complex_foreign_keys?f2=1", - }, - { - "other_table": "complex_foreign_keys", - "column": "id", - "other_column": "f3", - "count": 1, - "link": "/fixtures/complex_foreign_keys?f3=1", - }, - { - "other_table": "foreign_key_references", - "column": "id", - "other_column": "foreign_key_with_blank_label", - "count": 0, - "link": "/fixtures/foreign_key_references?foreign_key_with_blank_label=1", - }, - { - "other_table": "foreign_key_references", - "column": "id", - "other_column": "foreign_key_with_label", - "count": 1, - "link": "/fixtures/foreign_key_references?foreign_key_with_label=1", - }, + 'sorted by sortable_with_nulls' + ), + ( + '_sort_desc=sortable_with_nulls', + lambda row: ( + 1 if row['sortable_with_nulls'] is None else 0, + -row['sortable_with_nulls'] if row['sortable_with_nulls'] is not None else 0, + row['content'] + ), + 'sorted by sortable_with_nulls descending' + ), + # text column contains '$null' - ensure it doesn't confuse pagination: + ('_sort=text', lambda row: row['text'], 'sorted by text'), +]) +def test_sortable(app_client, query_string, sort_key, human_description_en): + path = '/test_tables/sortable.json?_shape=objects&{}'.format(query_string) + fetched = [] + page = 0 + while path: + page += 1 + assert page < 100 + response = app_client.get(path) + assert human_description_en == response.json['human_description_en'] + fetched.extend(response.json['rows']) + path = response.json['next_url'] + assert 5 == page + expected = list(generate_sortable_rows(201)) + expected.sort(key=sort_key) + assert [ + r['content'] for r in expected + ] == [ + r['content'] for r in fetched ] -@pytest.mark.asyncio -async def test_row_extra_render_cell(): - """Test that _extra=render_cell returns rendered HTML from render_cell plugin hook on row pages""" - from datasette import hookimpl - from datasette.app import Datasette - - class TestRenderCellPlugin: - __name__ = "TestRenderCellPlugin" - - @hookimpl - def render_cell(self, value, column, table, database): - # Only modify cells in our test table - if table == "test_render" and column == "name": - return f"{value}" - return None - - ds = Datasette(memory=True) - await ds.invoke_startup() - db = ds.add_memory_database("test_row_render") - await db.execute_write( - "create table test_render (id integer primary key, name text)" +def test_sortable_and_filtered(app_client): + path = ( + '/test_tables/sortable.json' + '?content__contains=d&_sort_desc=sortable&_shape=objects' ) - await db.execute_write("insert into test_render values (1, 'Alice')") + response = app_client.get(path) + fetched = response.json['rows'] + assert 'where content contains "d" sorted by sortable descending' \ + == response.json['human_description_en'] + expected = [ + row for row in generate_sortable_rows(201) + if 'd' in row['content'] + ] + assert len(expected) == response.json['filtered_table_rows_count'] + assert 201 == response.json['table_rows_count'] + expected.sort(key=lambda row: -row['sortable']) + assert [ + r['content'] for r in expected + ] == [ + r['content'] for r in fetched + ] - # Register our test plugin - ds.pm.register(TestRenderCellPlugin(), name="TestRenderCellPlugin") - try: - # Request row with _extra=render_cell - response = await ds.client.get( - "/test_row_render/test_render/1.json?_extra=render_cell" +def test_sortable_argument_errors(app_client): + response = app_client.get( + '/test_tables/sortable.json?_sort=badcolumn' + ) + assert 'Cannot sort table by badcolumn' == response.json['error'] + response = app_client.get( + '/test_tables/sortable.json?_sort_desc=badcolumn2' + ) + assert 'Cannot sort table by badcolumn2' == response.json['error'] + response = app_client.get( + '/test_tables/sortable.json?_sort=sortable_with_nulls&_sort_desc=sortable' + ) + assert 'Cannot use _sort and _sort_desc at the same time' == response.json['error'] + + +def test_sortable_columns_metadata(app_client): + response = app_client.get( + '/test_tables/sortable.json?_sort=content' + ) + assert 'Cannot sort table by content' == response.json['error'] + # no_primary_key has ALL sort options disabled + for column in ('content', 'a', 'b', 'c'): + response = app_client.get( + '/test_tables/sortable.json?_sort={}'.format(column) ) - assert response.status_code == 200 - data = response.json() - - # Verify the response structure - assert "render_cell" in data - assert "rows" in data - - # render_cell should be a list with one row (since this is a row page) - # Only columns modified by plugins are included (sparse output) - render_cell = data["render_cell"] - assert len(render_cell) == 1 - - # The row: id=1, name='Alice' - # The 'name' column should be rendered by our plugin as Alice - assert render_cell[0]["name"] == "Alice" - # The 'id' column is not included since no plugin modified it - assert "id" not in render_cell[0] - - # The regular rows should still contain raw values - assert data["rows"] == [{"id": 1, "name": "Alice"}] - - finally: - ds.pm.unregister(name="TestRenderCellPlugin") + assert 'Cannot sort table by {}'.format(column) == response.json['error'] -def test_databases_json(app_client_two_attached_databases_one_immutable): - response = app_client_two_attached_databases_one_immutable.get("/-/databases.json") - databases = response.json - assert 2 == len(databases) - extra_database, fixtures_database = databases - assert "extra database" == extra_database["name"] - assert extra_database["hash"] is None - assert extra_database["is_mutable"] is True - assert extra_database["is_memory"] is False - - assert "fixtures" == fixtures_database["name"] - assert fixtures_database["hash"] is not None - assert fixtures_database["is_mutable"] is False - assert fixtures_database["is_memory"] is False +@pytest.mark.parametrize('path,expected_rows', [ + ('/test_tables/searchable.json?_search=dog', [ + [1, 'barry cat', 'terry dog', 'panther'], + [2, 'terry dog', 'sara weasel', 'puma'], + ]), + ('/test_tables/searchable.json?_search=weasel', [ + [2, 'terry dog', 'sara weasel', 'puma'], + ]), + ('/test_tables/searchable.json?_search_text2=dog', [ + [1, 'barry cat', 'terry dog', 'panther'], + ]), + ('/test_tables/searchable.json?_search_name%20with%20.%20and%20spaces=panther', [ + [1, 'barry cat', 'terry dog', 'panther'], + ]), +]) +def test_searchable(app_client, path, expected_rows): + response = app_client.get(path) + assert expected_rows == response.json['rows'] -@pytest.mark.asyncio -async def test_threads_json(ds_client): - response = await ds_client.get("/-/threads.json") - expected_keys = {"threads", "num_threads"} - if sys.version_info >= (3, 7, 0): - expected_keys.update({"tasks", "num_tasks"}) - data = response.json() - assert set(data.keys()) == expected_keys - # Should be at least one _execute_writes thread for __INTERNAL__ - thread_names = [thread["name"] for thread in data["threads"]] - assert "_execute_writes for database __INTERNAL__" in thread_names - - -@pytest.mark.asyncio -async def test_plugins_json(ds_client): - response = await ds_client.get("/-/plugins.json") - # Filter out TrackEventPlugin - actual_plugins = sorted( - [p for p in response.json() if p["name"] != "TrackEventPlugin"], - key=lambda p: p["name"], +def test_searchable_invalid_column(app_client): + response = app_client.get( + '/test_tables/searchable.json?_search_invalid=x' ) - assert EXPECTED_PLUGINS == actual_plugins - # Try with ?all=1 - response = await ds_client.get("/-/plugins.json?all=1") - names = {p["name"] for p in response.json()} - assert names.issuperset(p["name"] for p in EXPECTED_PLUGINS) - assert names.issuperset(DEFAULT_PLUGINS) + assert 400 == response.status + assert { + 'ok': False, + 'error': 'Cannot search by that column', + 'status': 400, + 'title': None + } == response.json -@pytest.mark.asyncio -async def test_versions_json(ds_client): - response = await ds_client.get("/-/versions.json") - data = response.json() - assert "python" in data - assert "3.0" == data.get("asgi") - assert "version" in data["python"] - assert "full" in data["python"] - assert "datasette" in data - assert "version" in data["datasette"] - assert data["datasette"]["version"] == __version__ - assert "sqlite" in data - assert "version" in data["sqlite"] - assert "fts_versions" in data["sqlite"] - assert "compile_options" in data["sqlite"] - # By default, the json1 extension is enabled in the SQLite - # provided by the `ubuntu-latest` github actions runner, and - # all versions of SQLite from 3.38.0 onwards - assert data["sqlite"]["extensions"]["json1"] +@pytest.mark.parametrize('path,expected_rows', [ + ('/test_tables/simple_primary_key.json?content=hello', [ + ['1', 'hello'], + ]), + ('/test_tables/simple_primary_key.json?content__contains=o', [ + ['1', 'hello'], + ['2', 'world'], + ]), + ('/test_tables/simple_primary_key.json?content__exact=', [ + ['3', ''], + ]), + ('/test_tables/simple_primary_key.json?content__not=world', [ + ['1', 'hello'], + ['3', ''], + ]), +]) +def test_table_filter_queries(app_client, path, expected_rows): + response = app_client.get(path) + assert expected_rows == response.json['rows'] -@pytest.mark.asyncio -async def test_actions_json(ds_client): - original_root_enabled = ds_client.ds.root_enabled - try: - ds_client.ds.root_enabled = True - response = await ds_client.get("/-/actions.json", actor={"id": "root"}) - data = response.json() - finally: - ds_client.ds.root_enabled = original_root_enabled - assert isinstance(data, list) - assert len(data) > 0 - # Check structure of first action - action = data[0] - for key in ( - "name", - "abbr", - "description", - "takes_parent", - "takes_child", - "resource_class", - "also_requires", - ): - assert key in action - # Check that some expected actions exist - action_names = {a["name"] for a in data} - for expected_action in ( - "view-instance", - "view-database", - "view-table", - "execute-sql", - ): - assert expected_action in action_names +def test_max_returned_rows(app_client): + response = app_client.get( + '/test_tables.json?sql=select+content+from+no_primary_key' + ) + data = response.json + assert { + 'sql': 'select content from no_primary_key', + 'params': {} + } == data['query'] + assert data['truncated'] + assert 100 == len(data['rows']) -@pytest.mark.asyncio -async def test_settings_json(ds_client): - response = await ds_client.get("/-/settings.json") - assert response.json() == { +def test_view(app_client): + response = app_client.get('/test_tables/simple_view.json?_shape=objects') + assert response.status == 200 + data = response.json + assert data['rows'] == [{ + 'upper_content': 'HELLO', + 'content': 'hello', + }, { + 'upper_content': 'WORLD', + 'content': 'world', + }, { + 'upper_content': '', + 'content': '', + }] + + +def test_row(app_client): + response = app_client.get('/test_tables/simple_primary_key/1.json?_shape=objects') + assert response.status == 200 + assert [{'id': '1', 'content': 'hello'}] == response.json['rows'] + + +def test_row_foreign_key_tables(app_client): + response = app_client.get('/test_tables/simple_primary_key/1.json?_extras=foreign_key_tables') + assert response.status == 200 + assert [{ + 'column': 'id', + 'count': 1, + 'other_column': 'foreign_key_with_label', + 'other_table': 'foreign_key_references' + }, { + 'column': 'id', + 'count': 1, + 'other_column': 'f3', + 'other_table': 'complex_foreign_keys' + }, { + 'column': 'id', + 'count': 0, + 'other_column': 'f2', + 'other_table': 'complex_foreign_keys' + }, { + 'column': 'id', + 'count': 1, + 'other_column': 'f1', + 'other_table': 'complex_foreign_keys' + }] == response.json['foreign_key_tables'] + + +def test_unit_filters(app_client): + response = app_client.get( + '/test_tables/units.json?distance__lt=75km&frequency__gt=1kHz' + ) + assert response.status == 200 + data = response.json + + assert data['units']['distance'] == 'm' + assert data['units']['frequency'] == 'Hz' + + assert len(data['rows']) == 1 + assert data['rows'][0][0] == 2 + + +def test_metadata_json(app_client): + response = app_client.get( + "/-/metadata.json" + ) + assert METADATA == response.json + + +def test_inspect_json(app_client): + response = app_client.get( + "/-/inspect.json" + ) + assert app_client.ds.inspect() == response.json + + +def test_plugins_json(app_client): + response = app_client.get( + "/-/plugins.json" + ) + # This will include any plugins that have been installed into the + # current virtual environment, so we only check for the presence of + # the one we know will definitely be There + assert { + 'name': 'my_plugin.py', + 'static': False, + 'templates': False, + 'version': None, + } in response.json + + +def test_versions_json(app_client): + response = app_client.get( + "/-/versions.json" + ) + assert 'python' in response.json + assert 'version' in response.json['python'] + assert 'full' in response.json['python'] + assert 'datasette' in response.json + assert 'version' in response.json['datasette'] + assert 'sqlite' in response.json + assert 'version' in response.json['sqlite'] + assert 'fts_versions' in response.json['sqlite'] + + +def test_config_json(app_client): + response = app_client.get( + "/-/config.json" + ) + assert { "default_page_size": 50, "default_facet_size": 30, - "default_allow_sql": True, - "facet_suggest_time_limit_ms": 200, + "facet_suggest_time_limit_ms": 50, "facet_time_limit_ms": 200, "max_returned_rows": 100, - "max_insert_rows": 100, "sql_time_limit_ms": 200, "allow_download": True, - "allow_signed_tokens": True, - "max_signed_tokens_ttl": 0, "allow_facet": True, "suggest_facets": True, - "default_cache_ttl": 5, - "num_sql_threads": 1, + "allow_sql": True, + "default_cache_ttl": 365 * 24 * 60 * 60, + "num_sql_threads": 3, "cache_size_kb": 0, - "allow_csv_stream": True, - "max_csv_mb": 100, - "truncate_cells_html": 2048, - "force_https_urls": False, - "template_debug": False, - "trace_debug": False, - "base_url": "/", - } + } == response.json -test_json_columns_default_expected = [ - {"intval": 1, "strval": "s", "floatval": 0.5, "jsonval": '{"foo": "bar"}'} -] +def test_page_size_matching_max_returned_rows(app_client_returned_rows_matches_page_size): + fetched = [] + path = '/test_tables/no_primary_key.json' + while path: + response = app_client_returned_rows_matches_page_size.get(path) + fetched.extend(response.json['rows']) + assert len(response.json['rows']) in (1, 50) + path = response.json['next_url'] + assert 201 == len(fetched) -@pytest.mark.asyncio -@pytest.mark.parametrize( - "extra_args,expected", - [ - ("", test_json_columns_default_expected), - ("&_json=intval", test_json_columns_default_expected), - ("&_json=strval", test_json_columns_default_expected), - ("&_json=floatval", test_json_columns_default_expected), - ( - "&_json=jsonval", - [{"intval": 1, "strval": "s", "floatval": 0.5, "jsonval": {"foo": "bar"}}], - ), - ], -) -async def test_json_columns(ds_client, extra_args, expected): - sql = """ +@pytest.mark.parametrize('path,expected_facet_results', [ + ( + "/test_tables/facetable.json?_facet=state&_facet=city_id", + { + "state": { + "name": "state", + "results": [ + { + "value": "CA", + "label": "CA", + "count": 10, + "toggle_url": "_facet=state&_facet=city_id&state=CA", + "selected": False, + }, + { + "value": "MI", + "label": "MI", + "count": 4, + "toggle_url": "_facet=state&_facet=city_id&state=MI", + "selected": False, + }, + { + "value": "MC", + "label": "MC", + "count": 1, + "toggle_url": "_facet=state&_facet=city_id&state=MC", + "selected": False, + } + ], + "truncated": False, + }, + "city_id": { + "name": "city_id", + "results": [ + { + "value": 1, + "label": "San Francisco", + "count": 6, + "toggle_url": "_facet=state&_facet=city_id&city_id=1", + "selected": False, + }, + { + "value": 2, + "label": "Los Angeles", + "count": 4, + "toggle_url": "_facet=state&_facet=city_id&city_id=2", + "selected": False, + }, + { + "value": 3, + "label": "Detroit", + "count": 4, + "toggle_url": "_facet=state&_facet=city_id&city_id=3", + "selected": False, + }, + { + "value": 4, + "label": "Memnonia", + "count": 1, + "toggle_url": "_facet=state&_facet=city_id&city_id=4", + "selected": False, + } + ], + "truncated": False, + } + } + ), ( + "/test_tables/facetable.json?_facet=state&_facet=city_id&state=MI", + { + "state": { + "name": "state", + "results": [ + { + "value": "MI", + "label": "MI", + "count": 4, + "selected": True, + "toggle_url": "_facet=state&_facet=city_id", + }, + ], + "truncated": False, + }, + "city_id": { + "name": "city_id", + "results": [ + { + "value": 3, + "label": "Detroit", + "count": 4, + "selected": False, + "toggle_url": "_facet=state&_facet=city_id&state=MI&city_id=3", + }, + ], + "truncated": False, + }, + }, + ), ( + "/test_tables/facetable.json?_facet=planet_int", + { + "planet_int": { + "name": "planet_int", + "results": [ + { + "value": 1, + "label": 1, + "count": 14, + "selected": False, + "toggle_url": "_facet=planet_int&planet_int=1", + }, + { + "value": 2, + "label": 2, + "count": 1, + "selected": False, + "toggle_url": "_facet=planet_int&planet_int=2", + }, + ], + "truncated": False, + } + }, + ), ( + # planet_int is an integer field: + "/test_tables/facetable.json?_facet=planet_int&planet_int=1", + { + "planet_int": { + "name": "planet_int", + "results": [ + { + "value": 1, + "label": 1, + "count": 14, + "selected": True, + "toggle_url": "_facet=planet_int", + } + ], + "truncated": False, + }, + }, + ) +]) +def test_facets(app_client, path, expected_facet_results): + response = app_client.get(path) + facet_results = response.json['facet_results'] + # We only compare the querystring portion of the taggle_url + for facet_name, facet_info in facet_results.items(): + assert facet_name == facet_info["name"] + assert False is facet_info["truncated"] + for facet_value in facet_info["results"]: + facet_value['toggle_url'] = facet_value['toggle_url'].split('?')[1] + assert expected_facet_results == facet_results + + +def test_suggested_facets(app_client): + assert len(app_client.get( + "/test_tables/facetable.json" + ).json["suggested_facets"]) > 0 + + +def test_allow_facet_off(): + for client in app_client(config={ + 'allow_facet': False, + }): + assert 400 == client.get( + "/test_tables/facetable.json?_facet=planet_int" + ).status + # Should not suggest any facets either: + assert [] == client.get( + "/test_tables/facetable.json" + ).json["suggested_facets"] + + +def test_suggest_facets_off(): + for client in app_client(config={ + 'suggest_facets': False, + }): + # Now suggested_facets should be [] + assert [] == client.get( + "/test_tables/facetable.json" + ).json["suggested_facets"] + + +def test_expand_labels(app_client): + response = app_client.get( + "/test_tables/facetable.json?_shape=object&_labels=1&_size=2" + "&neighborhood__contains=c" + ) + assert { + "2": { + "pk": 2, + "planet_int": 1, + "state": "CA", + "city_id": { + "value": 1, + "label": "San Francisco" + }, + "neighborhood": "Dogpatch" + }, + "13": { + "pk": 13, + "planet_int": 1, + "state": "MI", + "city_id": { + "value": 3, + "label": "Detroit" + }, + "neighborhood": "Corktown" + } + } == response.json + + +def test_expand_label(app_client): + response = app_client.get( + "/test_tables/foreign_key_references.json?_shape=object" + "&_label=foreign_key_with_label" + ) + assert { + "1": { + "pk": "1", + "foreign_key_with_label": { + "value": "1", + "label": "hello" + }, + "foreign_key_with_no_label": "1" + } + } == response.json + + +@pytest.mark.parametrize('path,expected_cache_control', [ + ("/test_tables/facetable.json", "max-age=31536000"), + ("/test_tables/facetable.json?_ttl=invalid", "max-age=31536000"), + ("/test_tables/facetable.json?_ttl=10", "max-age=10"), + ("/test_tables/facetable.json?_ttl=0", "no-cache"), +]) +def test_ttl_parameter(app_client, path, expected_cache_control): + response = app_client.get(path) + assert expected_cache_control == response.headers['Cache-Control'] + + +test_json_columns_default_expected = [{ + "intval": 1, + "strval": "s", + "floatval": 0.5, + "jsonval": "{\"foo\": \"bar\"}" +}] + + +@pytest.mark.parametrize("extra_args,expected", [ + ("", test_json_columns_default_expected), + ("&_json=intval", test_json_columns_default_expected), + ("&_json=strval", test_json_columns_default_expected), + ("&_json=floatval", test_json_columns_default_expected), + ("&_json=jsonval", [{ + "intval": 1, + "strval": "s", + "floatval": 0.5, + "jsonval": { + "foo": "bar" + } + }]) +]) +def test_json_columns(app_client, extra_args, expected): + sql = ''' select 1 as intval, "s" as strval, 0.5 as floatval, '{"foo": "bar"}' as jsonval - """ - path = "/fixtures/-/query.json?" + urllib.parse.urlencode( - {"sql": sql, "_shape": "array"} - ) + ''' + path = "/test_tables.json?" + urllib.parse.urlencode({ + "sql": sql, + "_shape": "array" + }) path += extra_args - response = await ds_client.get( - path, - ) - assert response.json() == expected + response = app_client.get(path) + assert expected == response.json def test_config_cache_size(app_client_larger_cache_size): - response = app_client_larger_cache_size.get("/fixtures/pragma_cache_size.json") - assert response.json["rows"] == [{"cache_size": -2500}] - - -def test_config_force_https_urls(): - with make_app_client(settings={"force_https_urls": True}) as client: - response = client.get( - "/fixtures/facetable.json?_size=3&_facet=state&_extra=next_url,suggested_facets" - ) - assert response.json["next_url"].startswith("https://") - assert response.json["facet_results"]["results"]["state"]["results"][0][ - "toggle_url" - ].startswith("https://") - assert response.json["suggested_facets"][0]["toggle_url"].startswith("https://") - # Also confirm that request.url and request.scheme are set correctly - response = client.get("/") - assert client.ds._last_request.url.startswith("https://") - assert client.ds._last_request.scheme == "https" - - -@pytest.mark.parametrize( - "path,status_code", - [ - ("/fixtures.db", 200), - ("/fixtures.json", 200), - ("/fixtures/no_primary_key.json", 200), - # A 400 invalid SQL query should still have the header: - ("/fixtures/-/query.json?sql=select+blah", 400), - # Write APIs - ("/fixtures/-/create", 405), - ("/fixtures/facetable/-/insert", 405), - ("/fixtures/facetable/-/drop", 405), - ], -) -def test_cors( - app_client_with_cors, - app_client_two_attached_databases_one_immutable, - path, - status_code, -): - response = app_client_with_cors.get( - path, + response = app_client_larger_cache_size.get( + '/test_tables/pragma_cache_size.json' ) - assert response.status == status_code - assert response.headers["Access-Control-Allow-Origin"] == "*" - assert ( - response.headers["Access-Control-Allow-Headers"] - == "Authorization, Content-Type" - ) - assert response.headers["Access-Control-Expose-Headers"] == "Link" - assert ( - response.headers["Access-Control-Allow-Methods"] == "GET, POST, HEAD, OPTIONS" - ) - assert response.headers["Access-Control-Max-Age"] == "3600" - # Same request to app_client_two_attached_databases_one_immutable - # should not have those headers - I'm using that fixture because - # regular app_client doesn't have immutable fixtures.db which means - # the test for /fixtures.db returns a 403 error - response = app_client_two_attached_databases_one_immutable.get( - path, - ) - assert response.status == status_code - assert "Access-Control-Allow-Origin" not in response.headers - assert "Access-Control-Allow-Headers" not in response.headers - assert "Access-Control-Expose-Headers" not in response.headers - assert "Access-Control-Allow-Methods" not in response.headers - assert "Access-Control-Max-Age" not in response.headers - - -def test_cors_query_redirect(app_client_with_cors): - # /db?sql= redirects to /db/-/query - the redirect itself needs CORS - # headers, otherwise browsers refuse to follow it cross-origin - response = app_client_with_cors.get( - "/fixtures?sql=select+1", follow_redirects=False - ) - assert response.status == 302 - assert response.headers["Location"] == "/fixtures/-/query?sql=select+1" - assert response.headers["Access-Control-Allow-Origin"] == "*" - - -@pytest.mark.parametrize( - "path", - ( - "/", - ".json", - "/searchable", - "/searchable.json", - "/searchable_view", - "/searchable_view.json", - ), -) -def test_database_with_space_in_name(app_client_two_attached_databases, path): - response = app_client_two_attached_databases.get( - "/extra~20database" + path, follow_redirects=True - ) - assert response.status == 200 - - -def test_common_prefix_database_names(app_client_conflicting_database_names): - # https://github.com/simonw/datasette/issues/597 - assert ["foo-bar", "foo", "fixtures"] == [ - d["name"] - for d in app_client_conflicting_database_names.get("/-/databases.json").json - ] - for db_name, path in (("foo", "/foo.json"), ("foo-bar", "/foo-bar.json")): - data = app_client_conflicting_database_names.get(path).json - assert db_name == data["database"] - - -def test_inspect_file_used_for_count(app_client_immutable_and_inspect_file): - response = app_client_immutable_and_inspect_file.get( - "/fixtures/sortable.json?_extra=count" - ) - assert response.json["count"] == 100 - - -@pytest.mark.asyncio -async def test_http_options_request(ds_client): - response = await ds_client.options("/fixtures") - assert response.status_code == 200 - assert response.text == "ok" - - -@pytest.mark.asyncio -async def test_db_path(app_client): - # Needs app_client because needs file based database - db = app_client.ds.get_database() - path = pathlib.Path(db.path) - - assert path.exists() - - datasette = Datasette([path]) - - # Previously this broke if path was a pathlib.Path: - await datasette.refresh_schemas() - - -@pytest.mark.asyncio -async def test_hidden_sqlite_stat1_table(): - ds = Datasette() - db = ds.add_memory_database("db") - await db.execute_write("create table normal (id integer primary key, name text)") - await db.execute_write("create index idx on normal (name)") - await db.execute_write("analyze") - data = (await ds.client.get("/db.json?_show_hidden=1")).json() - tables = [(t["name"], t["hidden"]) for t in data["tables"]] - assert tables in ( - [("normal", False), ("sqlite_stat1", True)], - [("normal", False), ("sqlite_stat1", True), ("sqlite_stat4", True)], - ) - - -@pytest.mark.asyncio -async def test_hide_tables_starting_with_underscore(): - ds = Datasette() - db = ds.add_memory_database("test_hide_tables_starting_with_underscore") - await db.execute_write("create table normal (id integer primary key, name text)") - await db.execute_write("create table _hidden (id integer primary key, name text)") - data = ( - await ds.client.get( - "/test_hide_tables_starting_with_underscore.json?_show_hidden=1" - ) - ).json() - tables = [(t["name"], t["hidden"]) for t in data["tables"]] - assert tables == [("normal", False), ("_hidden", True)] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("db_name", ("foo", r"fo%o", "f~/c.d")) -async def test_tilde_encoded_database_names(db_name): - ds = Datasette() - ds.add_memory_database(db_name) - response = await ds.client.get("/.json") - assert db_name in response.json()["databases"].keys() - path = response.json()["databases"][db_name]["path"] - # And the JSON for that database - response2 = await ds.client.get(path + ".json") - assert response2.status_code == 200 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "config,expected", - ( - ({}, {}), - ({"plugins": {"datasette-foo": "bar"}}, {"plugins": {"datasette-foo": "bar"}}), - # Test redaction - ( - { - "plugins": { - "datasette-auth": {"secret_key": "key"}, - "datasette-foo": "bar", - "datasette-auth2": {"password": "password"}, - "datasette-sentry": { - "dsn": "sentry:///foo", - }, - } - }, - { - "plugins": { - "datasette-auth": {"secret_key": "***"}, - "datasette-foo": "bar", - "datasette-auth2": {"password": "***"}, - "datasette-sentry": {"dsn": "***"}, - } - }, - ), - ), -) -async def test_config_json(config, expected): - "/-/config.json should return redacted configuration" - ds = Datasette(config=config) - response = await ds.client.get("/-/config.json") - assert response.json() == expected - - -@pytest.mark.asyncio -@pytest.mark.skip(reason="rm?") -@pytest.mark.parametrize( - "metadata,expected_config,expected_metadata", - ( - ({}, {}, {}), - ( - # Metadata input - { - "title": "Datasette Fixtures", - "databases": { - "fixtures": { - "tables": { - "sortable": { - "sortable_columns": [ - "sortable", - "sortable_with_nulls", - "sortable_with_nulls_2", - "text", - ], - }, - "no_primary_key": {"sortable_columns": [], "hidden": True}, - "primary_key_multiple_columns_explicit_label": { - "label_column": "content2" - }, - "simple_view": {"sortable_columns": ["content"]}, - "searchable_view_configured_by_metadata": { - "fts_table": "searchable_fts", - "fts_pk": "pk", - }, - "roadside_attractions": { - "columns": { - "name": "The name of the attraction", - "address": "The street address for the attraction", - } - }, - "attraction_characteristic": {"sort_desc": "pk"}, - "facet_cities": {"sort": "name"}, - "paginated_view": {"size": 25}, - }, - } - }, - }, - # Should produce a config with just the table configuration keys - { - "databases": { - "fixtures": { - "tables": { - "sortable": { - "sortable_columns": [ - "sortable", - "sortable_with_nulls", - "sortable_with_nulls_2", - "text", - ] - }, - # These one get redacted: - "no_primary_key": "***", - "primary_key_multiple_columns_explicit_label": "***", - "simple_view": {"sortable_columns": ["content"]}, - "searchable_view_configured_by_metadata": { - "fts_table": "searchable_fts", - "fts_pk": "pk", - }, - "attraction_characteristic": {"sort_desc": "pk"}, - "facet_cities": {"sort": "name"}, - "paginated_view": {"size": 25}, - } - } - } - }, - # And metadata with everything else - { - "title": "Datasette Fixtures", - "databases": { - "fixtures": { - "tables": { - "roadside_attractions": { - "columns": { - "name": "The name of the attraction", - "address": "The street address for the attraction", - } - }, - } - } - }, - }, - ), - ), -) -async def test_upgrade_metadata(metadata, expected_config, expected_metadata): - ds = Datasette(metadata=metadata) - response = await ds.client.get("/-/config.json") - assert response.json() == expected_config - response2 = await ds.client.get("/-/metadata.json") - assert response2.json() == expected_metadata - - -class Either: - def __init__(self, a, b): - self.a = a - self.b = b - - def __eq__(self, other): - return other == self.a or other == self.b + assert [[-2500]] == response.json['rows'] diff --git a/tests/test_api_write.py b/tests/test_api_write.py deleted file mode 100644 index 64f91701..00000000 --- a/tests/test_api_write.py +++ /dev/null @@ -1,1659 +0,0 @@ -from datasette.app import Datasette -from datasette.utils import sqlite3 -from .utils import last_event -import pytest -import time - - -@pytest.fixture -def ds_write(tmp_path_factory): - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db_path_immutable = str(db_directory / "immutable.db") - db1 = sqlite3.connect(str(db_path)) - db2 = sqlite3.connect(str(db_path_immutable)) - for db in (db1, db2): - db.execute("vacuum") - db.execute( - "create table docs (id integer primary key, title text, score float, age integer)" - ) - db1.close() - db2.close() - ds = Datasette([db_path], immutables=[db_path_immutable]) - ds.root_enabled = True - yield ds - ds.close() - - -def write_token(ds, actor_id="root", permissions=None): - to_sign = {"a": actor_id, "token": "dstok", "t": int(time.time())} - if permissions: - to_sign["_r"] = {"a": permissions} - return "dstok_{}".format(ds.sign(to_sign, namespace="token")) - - -def _headers(token): - return { - "Authorization": "Bearer {}".format(token), - "Content-Type": "application/json", - } - - -@pytest.mark.asyncio -async def test_api_explorer_upsert_example_json(ds_write): - response = await ds_write.client.get("/-/api", actor={"id": "root"}) - assert response.status_code == 200 - import urllib.parse - - text = urllib.parse.unquote_plus(response.text) - upsert_idx = text.index("/data/docs/-/upsert") - upsert_chunk = text[upsert_idx : upsert_idx + 500] - assert '"id": ""' in upsert_chunk - assert '"title": ""' in upsert_chunk - assert '"score": "<score>"' in upsert_chunk - assert '"age": "<age>"' in upsert_chunk - - -@pytest.mark.asyncio -async def test_api_explorer_upsert_example_json_rowid_table(tmp_path_factory): - db_path = str(tmp_path_factory.mktemp("dbs") / "data.db") - conn = sqlite3.connect(db_path) - conn.execute("create table things (title text, score float)") - conn.close() - ds = Datasette([db_path]) - ds.root_enabled = True - response = await ds.client.get("/-/api", actor={"id": "root"}) - assert response.status_code == 200 - import urllib.parse - - text = urllib.parse.unquote_plus(response.text) - upsert_idx = text.index("/data/things/-/upsert") - upsert_chunk = text[upsert_idx : upsert_idx + 500] - assert '"rowid": "<rowid (primary key)>"' in upsert_chunk - assert '"title": "<title>"' in upsert_chunk - assert '"score": "<score>"' in upsert_chunk - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "content_type", - ( - "application/json", - "application/json; charset=utf-8", - ), -) -async def test_insert_row(ds_write, content_type): - token = write_token(ds_write) - response = await ds_write.client.post( - "/data/docs/-/insert", - json={"row": {"title": "Test", "score": 1.2, "age": 5}}, - headers={ - "Authorization": "Bearer {}".format(token), - "Content-Type": content_type, - }, - ) - expected_row = {"id": 1, "title": "Test", "score": 1.2, "age": 5} - assert response.status_code == 201 - assert response.json()["ok"] is True - assert response.json()["rows"] == [expected_row] - rows = (await ds_write.get_database("data").execute("select * from docs")).dicts() - assert rows[0] == expected_row - # Analytics event - event = last_event(ds_write) - assert event.name == "insert-rows" - assert event.num_rows == 1 - assert event.database == "data" - assert event.table == "docs" - assert not event.ignore - assert not event.replace - - -@pytest.mark.asyncio -async def test_insert_row_alter(ds_write): - token = write_token(ds_write) - response = await ds_write.client.post( - "/data/docs/-/insert", - json={ - "row": {"title": "Test", "score": 1.2, "age": 5, "extra": "extra"}, - "alter": True, - }, - headers=_headers(token), - ) - assert response.status_code == 201 - assert response.json()["ok"] is True - assert response.json()["rows"][0]["extra"] == "extra" - # Analytics event - event = last_event(ds_write) - assert event.name == "alter-table" - assert "extra" not in event.before_schema - assert "extra" in event.after_schema - - -@pytest.mark.asyncio -@pytest.mark.parametrize("return_rows", (True, False)) -async def test_insert_rows(ds_write, return_rows): - token = write_token(ds_write) - data = { - "rows": [ - {"title": "Test {}".format(i), "score": 1.0, "age": 5} for i in range(20) - ] - } - if return_rows: - data["return"] = True - response = await ds_write.client.post( - "/data/docs/-/insert", - json=data, - headers=_headers(token), - ) - assert response.status_code == 201 - - # Analytics event - event = last_event(ds_write) - assert event.name == "insert-rows" - assert event.num_rows == 20 - assert event.database == "data" - assert event.table == "docs" - assert not event.ignore - assert not event.replace - - actual_rows = ( - await ds_write.get_database("data").execute("select * from docs") - ).dicts() - assert len(actual_rows) == 20 - assert actual_rows == [ - {"id": i + 1, "title": "Test {}".format(i), "score": 1.0, "age": 5} - for i in range(20) - ] - assert response.json()["ok"] is True - if return_rows: - assert response.json()["rows"] == actual_rows - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,input,special_case,expected_status,expected_errors", - ( - ( - "/data2/docs/-/insert", - {}, - None, - 404, - ["Database not found"], - ), - ( - "/data/docs2/-/insert", - {}, - None, - 404, - ["Table not found"], - ), - ( - "/data/docs/-/insert", - {"rows": [{"title": "Test"} for i in range(10)]}, - "bad_token", - 403, - ["Permission denied"], - ), - ( - "/data/docs/-/insert", - {}, - "invalid_json", - 400, - [ - "Invalid JSON: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)" - ], - ), - ( - "/data/docs/-/insert", - {}, - "invalid_content_type", - 400, - ["Invalid content-type, must be application/json"], - ), - ( - "/data/docs/-/insert", - [], - None, - 400, - ["JSON must be a dictionary"], - ), - ( - "/data/docs/-/insert", - {"row": "blah"}, - None, - 400, - ['"row" must be a dictionary'], - ), - ( - "/data/docs/-/insert", - {"blah": "blah"}, - None, - 400, - ['JSON must have one or other of "row" or "rows"'], - ), - ( - "/data/docs/-/insert", - {"rows": "blah"}, - None, - 400, - ['"rows" must be a list'], - ), - ( - "/data/docs/-/insert", - {"rows": ["blah"]}, - None, - 400, - ['"rows" must be a list of dictionaries'], - ), - ( - "/data/docs/-/insert", - {"rows": [{"title": "Test"} for i in range(101)]}, - None, - 400, - ["Too many rows, maximum allowed is 100"], - ), - ( - "/data/docs/-/insert", - {"rows": [{"id": 1, "title": "Test"}, {"id": 2, "title": "Test"}]}, - "duplicate_id", - 400, - ["UNIQUE constraint failed: docs.id"], - ), - ( - "/data/docs/-/insert", - {"rows": [{"title": "Test"}], "ignore": True, "replace": True}, - None, - 400, - ['Cannot use "ignore" and "replace" at the same time'], - ), - ( - # Replace is not allowed if you don't have update-row - "/data/docs/-/insert", - {"rows": [{"title": "Test"}], "replace": True}, - "insert-but-not-update", - 403, - ['Permission denied: need update-row to use "replace"'], - ), - ( - "/data/docs/-/insert", - {"rows": [{"title": "Test"}], "invalid_param": True}, - None, - 400, - ['Invalid parameter: "invalid_param"'], - ), - ( - "/data/docs/-/insert", - {"rows": [{"title": "Test"}], "one": True, "two": True}, - None, - 400, - ['Invalid parameter: "one", "two"'], - ), - ( - "/immutable/docs/-/insert", - {"rows": [{"title": "Test"}]}, - None, - 403, - ["Database is immutable"], - ), - # Validate columns of each row - ( - "/data/docs/-/insert", - {"rows": [{"title": "Test", "bad": 1, "worse": 2} for i in range(2)]}, - None, - 400, - [ - "Row 0 has invalid columns: bad, worse", - "Row 1 has invalid columns: bad, worse", - ], - ), - ## UPSERT ERRORS: - ( - "/immutable/docs/-/upsert", - {"rows": [{"title": "Test"}]}, - None, - 403, - ["Database is immutable"], - ), - ( - "/data/badtable/-/upsert", - {"rows": [{"title": "Test"}]}, - None, - 404, - ["Table not found"], - ), - # missing primary key - ( - "/data/docs/-/upsert", - {"rows": [{"title": "Missing PK"}]}, - None, - 400, - ['Row 0 is missing primary key column(s): "id"'], - ), - # null primary key - ( - "/data/docs/-/upsert", - {"rows": [{"id": None, "title": "Null PK"}]}, - None, - 400, - ['Row 0 has null primary key column(s): "id"'], - ), - # Upsert does not support ignore or replace - ( - "/data/docs/-/upsert", - {"rows": [{"id": 1, "title": "Bad"}], "ignore": True}, - None, - 400, - ["Upsert does not support ignore or replace"], - ), - # Upsert permissions - ( - "/data/docs/-/upsert", - {"rows": [{"id": 1, "title": "Disallowed"}]}, - "insert-but-not-update", - 403, - ["Permission denied: need both insert-row and update-row"], - ), - ( - "/data/docs/-/upsert", - {"rows": [{"id": 1, "title": "Disallowed"}]}, - "update-but-not-insert", - 403, - ["Permission denied: need both insert-row and update-row"], - ), - # Alter table forbidden without alter permission - ( - "/data/docs/-/upsert", - {"rows": [{"id": 1, "title": "One", "extra": "extra"}], "alter": True}, - "update-and-insert-but-no-alter", - 403, - ["Permission denied for alter-table"], - ), - ), -) -async def test_insert_or_upsert_row_errors( - ds_write, path, input, special_case, expected_status, expected_errors -): - token_permissions = [] - if special_case == "insert-but-not-update": - token_permissions = ["ir", "vi"] - if special_case == "update-but-not-insert": - token_permissions = ["ur", "vi"] - if special_case == "update-and-insert-but-no-alter": - token_permissions = ["ur", "ir"] - token = write_token(ds_write, permissions=token_permissions) - if special_case == "duplicate_id": - await ds_write.get_database("data").execute_write( - "insert into docs (id) values (1)" - ) - if special_case == "bad_token": - token += "bad" - kwargs = dict( - json=input, - headers={ - "Authorization": "Bearer {}".format(token), - "Content-Type": ( - "text/plain" - if special_case == "invalid_content_type" - else "application/json" - ), - }, - ) - - actor_response = ( - await ds_write.client.get("/-/actor.json", headers=kwargs["headers"]) - ).json() - assert set((actor_response["actor"] or {}).get("_r", {}).get("a") or []) == set( - token_permissions - ) - - if special_case == "invalid_json": - del kwargs["json"] - kwargs["content"] = "{bad json" - before_count = ( - await ds_write.get_database("data").execute("select count(*) from docs") - ).rows[0][0] == 0 - response = await ds_write.client.post( - path, - **kwargs, - ) - assert response.status_code == expected_status - assert response.json()["ok"] is False - assert response.json()["errors"] == expected_errors - # Check that no rows were inserted - after_count = ( - await ds_write.get_database("data").execute("select count(*) from docs") - ).rows[0][0] == 0 - assert before_count == after_count - - -@pytest.mark.asyncio -@pytest.mark.parametrize("allowed", (True, False)) -async def test_upsert_permissions_per_table(ds_write, allowed): - # https://github.com/simonw/datasette/issues/2262 - token = "dstok_{}".format( - ds_write.sign( - { - "a": "root", - "token": "dstok", - "t": int(time.time()), - "_r": { - "r": { - "data": { - "docs" if allowed else "other": ["ir", "ur"], - } - } - }, - }, - namespace="token", - ) - ) - response = await ds_write.client.post( - "/data/docs/-/upsert", - json={"rows": [{"id": 1, "title": "One"}]}, - headers={ - "Authorization": "Bearer {}".format(token), - }, - ) - if allowed: - assert response.status_code == 200 - assert response.json()["ok"] is True - else: - assert response.status_code == 403 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "ignore,replace,expected_rows", - ( - ( - True, - False, - [ - {"id": 1, "title": "Exists", "score": None, "age": None}, - ], - ), - ( - False, - True, - [ - {"id": 1, "title": "One", "score": None, "age": None}, - ], - ), - ), -) -@pytest.mark.parametrize("should_return", (True, False)) -async def test_insert_ignore_replace( - ds_write, ignore, replace, expected_rows, should_return -): - await ds_write.get_database("data").execute_write( - "insert into docs (id, title) values (1, 'Exists')" - ) - token = write_token(ds_write) - data = {"rows": [{"id": 1, "title": "One"}]} - if ignore: - data["ignore"] = True - if replace: - data["replace"] = True - if should_return: - data["return"] = True - response = await ds_write.client.post( - "/data/docs/-/insert", - json=data, - headers=_headers(token), - ) - assert response.status_code == 201 - - # Analytics event - event = last_event(ds_write) - assert event.name == "insert-rows" - assert event.num_rows == 1 - assert event.database == "data" - assert event.table == "docs" - assert event.ignore == ignore - assert event.replace == replace - - actual_rows = ( - await ds_write.get_database("data").execute("select * from docs") - ).dicts() - - assert actual_rows == expected_rows - assert response.json()["ok"] is True - if should_return: - assert response.json()["rows"] == expected_rows - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "initial,input,expected_rows", - ( - ( - # Simple primary key update - {"rows": [{"id": 1, "title": "One"}], "pk": "id"}, - {"rows": [{"id": 1, "title": "Two"}]}, - [ - {"id": 1, "title": "Two"}, - ], - ), - ( - # Multiple rows update one of them - { - "rows": [{"id": 1, "title": "One"}, {"id": 2, "title": "Two"}], - "pk": "id", - }, - {"rows": [{"id": 1, "title": "Three"}]}, - [ - {"id": 1, "title": "Three"}, - {"id": 2, "title": "Two"}, - ], - ), - ( - # rowid update - {"rows": [{"title": "One"}]}, - {"rows": [{"rowid": 1, "title": "Two"}]}, - [ - {"rowid": 1, "title": "Two"}, - ], - ), - ( - # Compound primary key update - {"rows": [{"id": 1, "title": "One", "score": 1}], "pks": ["id", "score"]}, - {"rows": [{"id": 1, "title": "Two", "score": 1}]}, - [ - {"id": 1, "title": "Two", "score": 1}, - ], - ), - ( - # Upsert with an alter - {"rows": [{"id": 1, "title": "One"}], "pk": "id"}, - {"rows": [{"id": 1, "title": "Two", "extra": "extra"}], "alter": True}, - [{"id": 1, "title": "Two", "extra": "extra"}], - ), - ), -) -@pytest.mark.parametrize("should_return", (False, True)) -async def test_upsert(ds_write, initial, input, expected_rows, should_return): - token = write_token(ds_write) - # Insert initial data - initial["table"] = "upsert_test" - create_response = await ds_write.client.post( - "/data/-/create", - json=initial, - headers=_headers(token), - ) - assert create_response.status_code == 201 - if should_return: - input["return"] = True - response = await ds_write.client.post( - "/data/upsert_test/-/upsert", - json=input, - headers=_headers(token), - ) - assert response.status_code == 200, response.text - assert response.json()["ok"] is True - - # Analytics event - event = last_event(ds_write) - assert event.database == "data" - assert event.table == "upsert_test" - if input.get("alter"): - assert event.name == "alter-table" - assert "extra" in event.after_schema - else: - assert event.name == "upsert-rows" - assert event.num_rows == 1 - - if should_return: - # We only expect it to return rows corresponding to those we sent - expected_returned_rows = expected_rows[: len(input["rows"])] - assert response.json()["rows"] == expected_returned_rows - # Check the database too - actual_rows = ( - await ds_write.client.get("/data/upsert_test.json?_shape=array") - ).json() - assert actual_rows == expected_rows - # Drop the upsert_test table - await ds_write.get_database("data").execute_write("drop table upsert_test") - - -async def _insert_row(ds): - insert_response = await ds.client.post( - "/data/docs/-/insert", - json={"row": {"title": "Row one", "score": 1.2, "age": 5}, "return": True}, - headers=_headers(write_token(ds)), - ) - assert insert_response.status_code == 201 - return insert_response.json()["rows"][0]["id"] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("scenario", ("no_token", "no_perm", "bad_table")) -async def test_delete_row_errors(ds_write, scenario): - if scenario == "no_token": - token = "bad_token" - elif scenario == "no_perm": - token = write_token(ds_write, actor_id="not-root") - else: - token = write_token(ds_write) - - pk = await _insert_row(ds_write) - - path = "/data/{}/{}/-/delete".format( - "docs" if scenario != "bad_table" else "bad_table", pk - ) - response = await ds_write.client.post( - path, - headers=_headers(token), - ) - assert response.status_code == 403 if scenario in ("no_token", "bad_token") else 404 - assert response.json()["ok"] is False - assert ( - response.json()["errors"] == ["Permission denied"] - if scenario == "no_token" - else ["Table not found"] - ) - assert len((await ds_write.client.get("/data/docs.json?_shape=array")).json()) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "table,row_for_create,pks,delete_path", - ( - ("rowid_table", {"name": "rowid row"}, None, None), - ("pk_table", {"id": 1, "name": "ID table"}, "id", "1"), - ( - "compound_pk_table", - {"type": "article", "key": "k"}, - ["type", "key"], - "article,k", - ), - ), -) -async def test_delete_row(ds_write, table, row_for_create, pks, delete_path): - # First create the table with that example row - create_data = { - "table": table, - "row": row_for_create, - } - if pks: - if isinstance(pks, str): - create_data["pk"] = pks - else: - create_data["pks"] = pks - create_response = await ds_write.client.post( - "/data/-/create", - json=create_data, - headers=_headers(write_token(ds_write)), - ) - assert create_response.status_code == 201, create_response.json() - # Should be a single row - assert ( - await ds_write.client.get( - "/data/-/query.json?_shape=arrayfirst&sql=select+count(*)+from+{}".format( - table - ) - ) - ).json() == [1] - # Now delete the row - if delete_path is None: - # Special case for that rowid table - delete_path = ( - await ds_write.client.get( - "/data/-/query.json?_shape=arrayfirst&sql=select+rowid+from+{}".format( - table - ) - ) - ).json()[0] - - delete_response = await ds_write.client.post( - "/data/{}/{}/-/delete".format(table, delete_path), - headers=_headers(write_token(ds_write)), - ) - assert delete_response.status_code == 200 - - # Analytics event - event = last_event(ds_write) - assert event.name == "delete-row" - assert event.database == "data" - assert event.table == table - assert event.pks == str(delete_path).split(",") - assert ( - await ds_write.client.get( - "/data/-/query.json?_shape=arrayfirst&sql=select+count(*)+from+{}".format( - table - ) - ) - ).json() == [0] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "scenario", ("no_token", "no_perm", "bad_table", "cannot_alter") -) -async def test_update_row_check_permission(ds_write, scenario): - if scenario == "no_token": - token = "bad_token" - elif scenario == "no_perm": - token = write_token(ds_write, actor_id="not-root") - elif scenario == "cannot_alter": - # update-row but no alter-table: - token = write_token(ds_write, permissions=["ur"]) - else: - token = write_token(ds_write) - - pk = await _insert_row(ds_write) - - path = "/data/{}/{}/-/update".format( - "docs" if scenario != "bad_table" else "bad_table", pk - ) - - json_body = {"update": {"title": "New title"}} - if scenario == "cannot_alter": - json_body["alter"] = True - - response = await ds_write.client.post( - path, - json=json_body, - headers=_headers(token), - ) - assert response.status_code == 403 if scenario in ("no_token", "bad_token") else 404 - assert response.json()["ok"] is False - assert ( - response.json()["errors"] == ["Permission denied"] - if scenario == "no_token" - else ["Table not found"] - ) - - -@pytest.mark.asyncio -async def test_update_row_invalid_key(ds_write): - token = write_token(ds_write) - - pk = await _insert_row(ds_write) - - path = "/data/docs/{}/-/update".format(pk) - response = await ds_write.client.post( - path, - json={"update": {"title": "New title"}, "bad_key": 1}, - headers=_headers(token), - ) - assert response.status_code == 400 - assert response.json() == {"ok": False, "errors": ["Invalid keys: bad_key"]} - - -@pytest.mark.asyncio -async def test_update_row_alter(ds_write): - token = write_token(ds_write, permissions=["ur", "at"]) - pk = await _insert_row(ds_write) - path = "/data/docs/{}/-/update".format(pk) - response = await ds_write.client.post( - path, - json={"update": {"title": "New title", "extra": "extra"}, "alter": True}, - headers=_headers(token), - ) - assert response.status_code == 200 - assert response.json() == {"ok": True} - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input,expected_errors", - ( - ({"title": "New title"}, None), - ({"title": None}, None), - ({"score": 1.6}, None), - ({"age": 10}, None), - ({"title": "New title", "score": 1.6}, None), - ({"title2": "New title"}, ["no such column: title2"]), - ), -) -@pytest.mark.parametrize("use_return", (True, False)) -async def test_update_row(ds_write, input, expected_errors, use_return): - token = write_token(ds_write) - pk = await _insert_row(ds_write) - - path = "/data/docs/{}/-/update".format(pk) - - data = {"update": input} - if use_return: - data["return"] = True - - response = await ds_write.client.post( - path, - json=data, - headers=_headers(token), - ) - if expected_errors: - assert response.status_code == 400 - assert response.json()["ok"] is False - assert response.json()["errors"] == expected_errors - return - - assert response.json()["ok"] is True - if not use_return: - assert "row" not in response.json() - else: - returned_row = response.json()["row"] - assert returned_row["id"] == pk - for k, v in input.items(): - assert returned_row[k] == v - - # Analytics event - event = last_event(ds_write) - assert event.actor == {"id": "root", "token": "dstok"} - assert event.database == "data" - assert event.table == "docs" - assert event.pks == [str(pk)] - - # And fetch the row to check it's updated - response = await ds_write.client.get( - "/data/docs/{}.json?_shape=array".format(pk), - ) - assert response.status_code == 200 - row = response.json()[0] - assert row["id"] == pk - for k, v in input.items(): - assert row[k] == v - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "scenario", ("no_token", "no_perm", "bad_table", "has_perm", "immutable") -) -async def test_drop_table(ds_write, scenario): - if scenario == "no_token": - token = "bad_token" - elif scenario == "no_perm": - token = write_token(ds_write, actor_id="not-root") - else: - token = write_token(ds_write) - should_work = scenario == "has_perm" - await ds_write.get_database("data").execute_write( - "insert into docs (id, title) values (1, 'Row 1')" - ) - path = "/{database}/{table}/-/drop".format( - database="immutable" if scenario == "immutable" else "data", - table="docs" if scenario != "bad_table" else "bad_table", - ) - response = await ds_write.client.post( - path, - headers=_headers(token), - ) - if not should_work: - assert ( - response.status_code == 403 - if scenario in ("no_token", "bad_token") - else 404 - ) - assert response.json()["ok"] is False - expected_error = "Permission denied" - if scenario == "bad_table": - expected_error = "Table not found" - elif scenario == "immutable": - expected_error = "Database is immutable" - assert response.json()["errors"] == [expected_error] - assert (await ds_write.client.get("/data/docs")).status_code == 200 - else: - # It should show a confirmation page - assert response.status_code == 200 - assert response.json() == { - "ok": True, - "database": "data", - "table": "docs", - "row_count": 1, - "message": 'Pass "confirm": true to confirm', - } - assert (await ds_write.client.get("/data/docs")).status_code == 200 - # Now send confirm: true - response2 = await ds_write.client.post( - path, - json={"confirm": True}, - headers=_headers(token), - ) - assert response2.json() == {"ok": True} - # Check event - event = last_event(ds_write) - assert event.name == "drop-table" - assert event.actor == {"id": "root", "token": "dstok"} - assert event.table == "docs" - assert event.database == "data" - # Table should 404 - assert (await ds_write.client.get("/data/docs")).status_code == 404 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "input,expected_status,expected_response,expected_events", - ( - # Permission error with a bad token - ( - {"table": "bad", "row": {"id": 1}}, - 403, - {"ok": False, "errors": ["Permission denied"]}, - [], - ), - # Successful creation with columns: - ( - { - "table": "one", - "columns": [ - { - "name": "id", - "type": "integer", - }, - { - "name": "title", - "type": "text", - }, - { - "name": "score", - "type": "integer", - }, - { - "name": "weight", - "type": "float", - }, - { - "name": "thumbnail", - "type": "blob", - }, - ], - "pk": "id", - }, - 201, - { - "ok": True, - "database": "data", - "table": "one", - "table_url": "http://localhost/data/one", - "table_api_url": "http://localhost/data/one.json", - "schema": ( - "CREATE TABLE [one] (\n" - " [id] INTEGER PRIMARY KEY,\n" - " [title] TEXT,\n" - " [score] INTEGER,\n" - " [weight] FLOAT,\n" - " [thumbnail] BLOB\n" - ")" - ), - }, - ["create-table"], - ), - # Successful creation with rows: - ( - { - "table": "two", - "rows": [ - { - "id": 1, - "title": "Row 1", - "score": 1.5, - }, - { - "id": 2, - "title": "Row 2", - "score": 1.5, - }, - ], - "pk": "id", - }, - 201, - { - "ok": True, - "database": "data", - "table": "two", - "table_url": "http://localhost/data/two", - "table_api_url": "http://localhost/data/two.json", - "schema": ( - "CREATE TABLE [two] (\n" - " [id] INTEGER PRIMARY KEY,\n" - " [title] TEXT,\n" - " [score] FLOAT\n" - ")" - ), - "row_count": 2, - }, - ["create-table", "insert-rows"], - ), - # Successful creation with row: - ( - { - "table": "three", - "row": { - "id": 1, - "title": "Row 1", - "score": 1.5, - }, - "pk": "id", - }, - 201, - { - "ok": True, - "database": "data", - "table": "three", - "table_url": "http://localhost/data/three", - "table_api_url": "http://localhost/data/three.json", - "schema": ( - "CREATE TABLE [three] (\n" - " [id] INTEGER PRIMARY KEY,\n" - " [title] TEXT,\n" - " [score] FLOAT\n" - ")" - ), - "row_count": 1, - }, - ["create-table", "insert-rows"], - ), - # Create with row and no primary key - ( - { - "table": "four", - "row": { - "name": "Row 1", - }, - }, - 201, - { - "ok": True, - "database": "data", - "table": "four", - "table_url": "http://localhost/data/four", - "table_api_url": "http://localhost/data/four.json", - "schema": ("CREATE TABLE [four] (\n" " [name] TEXT\n" ")"), - "row_count": 1, - }, - ["create-table", "insert-rows"], - ), - # Create table with compound primary key - ( - { - "table": "five", - "row": {"type": "article", "key": 123, "title": "Article 1"}, - "pks": ["type", "key"], - }, - 201, - { - "ok": True, - "database": "data", - "table": "five", - "table_url": "http://localhost/data/five", - "table_api_url": "http://localhost/data/five.json", - "schema": ( - "CREATE TABLE [five] (\n [type] TEXT,\n [key] INTEGER,\n" - " [title] TEXT,\n PRIMARY KEY ([type], [key])\n)" - ), - "row_count": 1, - }, - ["create-table", "insert-rows"], - ), - # Error: Table is required - ( - { - "row": {"id": 1}, - }, - 400, - { - "ok": False, - "errors": ["Table is required"], - }, - [], - ), - # Error: Invalid table name - ( - { - "table": "sqlite_bad_name", - "row": {"id": 1}, - }, - 400, - { - "ok": False, - "errors": ["Invalid table name"], - }, - [], - ), - # Error: JSON must be an object - ( - [], - 400, - { - "ok": False, - "errors": ["JSON must be an object"], - }, - [], - ), - # Error: Cannot specify columns with rows or row - ( - { - "table": "bad", - "columns": [{"name": "id", "type": "integer"}], - "rows": [{"id": 1}], - }, - 400, - { - "ok": False, - "errors": ["Cannot specify columns with rows or row"], - }, - [], - ), - # Error: columns, rows or row is required - ( - { - "table": "bad", - }, - 400, - { - "ok": False, - "errors": ["columns, rows or row is required"], - }, - [], - ), - # Error: columns must be a list - ( - { - "table": "bad", - "columns": {"name": "id", "type": "integer"}, - }, - 400, - { - "ok": False, - "errors": ["columns must be a list"], - }, - [], - ), - # Error: columns must be a list of objects - ( - { - "table": "bad", - "columns": ["id"], - }, - 400, - { - "ok": False, - "errors": ["columns must be a list of objects"], - }, - [], - ), - # Error: Column name is required - ( - { - "table": "bad", - "columns": [{"type": "integer"}], - }, - 400, - { - "ok": False, - "errors": ["Column name is required"], - }, - [], - ), - # Error: Unsupported column type - ( - { - "table": "bad", - "columns": [{"name": "id", "type": "bad"}], - }, - 400, - { - "ok": False, - "errors": ["Unsupported column type: bad"], - }, - [], - ), - # Error: Duplicate column name - ( - { - "table": "bad", - "columns": [ - {"name": "id", "type": "integer"}, - {"name": "id", "type": "integer"}, - ], - }, - 400, - { - "ok": False, - "errors": ["Duplicate column name: id"], - }, - [], - ), - # Error: rows must be a list - ( - { - "table": "bad", - "rows": {"id": 1}, - }, - 400, - { - "ok": False, - "errors": ["rows must be a list"], - }, - [], - ), - # Error: rows must be a list of objects - ( - { - "table": "bad", - "rows": ["id"], - }, - 400, - { - "ok": False, - "errors": ["rows must be a list of objects"], - }, - [], - ), - # Error: pk must be a string - ( - { - "table": "bad", - "row": {"id": 1}, - "pk": 1, - }, - 400, - { - "ok": False, - "errors": ["pk must be a string"], - }, - [], - ), - # Error: Cannot specify both pk and pks - ( - { - "table": "bad", - "row": {"id": 1, "name": "Row 1"}, - "pk": "id", - "pks": ["id", "name"], - }, - 400, - { - "ok": False, - "errors": ["Cannot specify both pk and pks"], - }, - [], - ), - # Error: pks must be a list - ( - { - "table": "bad", - "row": {"id": 1, "name": "Row 1"}, - "pks": "id", - }, - 400, - { - "ok": False, - "errors": ["pks must be a list"], - }, - [], - ), - # Error: pks must be a list of strings - ( - {"table": "bad", "row": {"id": 1, "name": "Row 1"}, "pks": [1, 2]}, - 400, - {"ok": False, "errors": ["pks must be a list of strings"]}, - [], - ), - # Error: ignore and replace are mutually exclusive - ( - { - "table": "bad", - "row": {"id": 1, "name": "Row 1"}, - "pk": "id", - "ignore": True, - "replace": True, - }, - 400, - { - "ok": False, - "errors": ["ignore and replace are mutually exclusive"], - }, - [], - ), - # ignore and replace require row or rows - ( - { - "table": "bad", - "columns": [{"name": "id", "type": "integer"}], - "ignore": True, - }, - 400, - { - "ok": False, - "errors": ["ignore and replace require row or rows"], - }, - [], - ), - # ignore and replace require pk or pks - ( - { - "table": "bad", - "row": {"id": 1}, - "ignore": True, - }, - 400, - { - "ok": False, - "errors": ["ignore and replace require pk or pks"], - }, - [], - ), - ( - { - "table": "bad", - "row": {"id": 1}, - "replace": True, - }, - 400, - { - "ok": False, - "errors": ["ignore and replace require pk or pks"], - }, - [], - ), - ), -) -async def test_create_table( - ds_write, input, expected_status, expected_response, expected_events -): - ds_write._tracked_events = [] - # Special case for expected status of 403 - if expected_status == 403: - token = "bad_token" - else: - token = write_token(ds_write) - response = await ds_write.client.post( - "/data/-/create", - json=input, - headers=_headers(token), - ) - assert response.status_code == expected_status - data = response.json() - assert data == expected_response - # Should have tracked the expected events - events = ds_write._tracked_events - assert [e.name for e in events] == expected_events - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "permissions,body,expected_status,expected_errors", - ( - (["create-table"], {"table": "t", "columns": [{"name": "c"}]}, 201, None), - # Need insert-row too if you use "rows": - ( - ["create-table"], - {"table": "t", "rows": [{"name": "c"}]}, - 403, - ["Permission denied: need insert-row"], - ), - # This should work: - ( - ["create-table", "insert-row"], - {"table": "t", "rows": [{"name": "c"}]}, - 201, - None, - ), - # If you use replace: true you need update-row too: - ( - ["create-table", "insert-row"], - {"table": "t", "rows": [{"id": 1}], "pk": "id", "replace": True}, - 403, - ["Permission denied: need update-row"], - ), - ), -) -async def test_create_table_permissions( - ds_write, permissions, body, expected_status, expected_errors -): - from datasette.tokens import TokenRestrictions - - restrictions = TokenRestrictions() - for action in ["view-instance"] + permissions: - restrictions.allow_all(action) - token = await ds_write.create_token( - "root", handler="signed", restrictions=restrictions - ) - response = await ds_write.client.post( - "/data/-/create", - json=body, - headers=_headers(token), - ) - assert response.status_code == expected_status - if expected_errors: - data = response.json() - assert data["ok"] is False - assert data["errors"] == expected_errors - - -@pytest.mark.asyncio -@pytest.mark.xfail(reason="Flaky, see https://github.com/simonw/datasette/issues/2356") -@pytest.mark.parametrize( - "input,expected_rows_after", - ( - ( - { - "table": "test_insert_replace", - "rows": [ - {"id": 1, "name": "Row 1 new"}, - {"id": 3, "name": "Row 3 new"}, - ], - "pk": "id", - "ignore": True, - }, - [ - {"id": 1, "name": "Row 1"}, - {"id": 2, "name": "Row 2"}, - {"id": 3, "name": "Row 3 new"}, - ], - ), - ( - { - "table": "test_insert_replace", - "rows": [ - {"id": 1, "name": "Row 1 new"}, - {"id": 3, "name": "Row 3 new"}, - ], - "pk": "id", - "replace": True, - }, - [ - {"id": 1, "name": "Row 1 new"}, - {"id": 2, "name": "Row 2"}, - {"id": 3, "name": "Row 3 new"}, - ], - ), - ), -) -async def test_create_table_ignore_replace(ds_write, input, expected_rows_after): - # Create table with two rows - token = write_token(ds_write) - first_response = await ds_write.client.post( - "/data/-/create", - json={ - "rows": [{"id": 1, "name": "Row 1"}, {"id": 2, "name": "Row 2"}], - "table": "test_insert_replace", - "pk": "id", - }, - headers=_headers(token), - ) - assert first_response.status_code == 201 - - ds_write._tracked_events = [] - - # Try a second time - second_response = await ds_write.client.post( - "/data/-/create", - json=input, - headers=_headers(token), - ) - assert second_response.status_code == 201 - # Check that the rows are as expected - rows = await ds_write.client.get("/data/test_insert_replace.json?_shape=array") - assert rows.json() == expected_rows_after - - # Check it fired the right events - event_names = [e.name for e in ds_write._tracked_events] - assert event_names == ["insert-rows"] - - -@pytest.mark.asyncio -async def test_create_table_error_if_pk_changed(ds_write): - token = write_token(ds_write) - first_response = await ds_write.client.post( - "/data/-/create", - json={ - "rows": [{"id": 1, "name": "Row 1"}, {"id": 2, "name": "Row 2"}], - "table": "test_insert_replace", - "pk": "id", - }, - headers=_headers(token), - ) - assert first_response.status_code == 201 - # Try a second time with a different pk - second_response = await ds_write.client.post( - "/data/-/create", - json={ - "rows": [{"id": 1, "name": "Row 1"}, {"id": 2, "name": "Row 2"}], - "table": "test_insert_replace", - "pk": "name", - "replace": True, - }, - headers=_headers(token), - ) - assert second_response.status_code == 400 - assert second_response.json() == { - "ok": False, - "errors": ["pk cannot be changed for existing table"], - } - - -@pytest.mark.asyncio -async def test_create_table_error_rows_twice_with_duplicates(ds_write): - # Error if you don't send ignore: True or replace: True - token = write_token(ds_write) - input = { - "rows": [{"id": 1, "name": "Row 1"}, {"id": 2, "name": "Row 2"}], - "table": "test_create_twice", - "pk": "id", - } - first_response = await ds_write.client.post( - "/data/-/create", - json=input, - headers=_headers(token), - ) - assert first_response.status_code == 201 - second_response = await ds_write.client.post( - "/data/-/create", - json=input, - headers=_headers(token), - ) - assert second_response.status_code == 400 - assert second_response.json() == { - "ok": False, - "errors": ["UNIQUE constraint failed: test_create_twice.id"], - } - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path", - ( - "/data/-/create", - "/data/docs/-/drop", - "/data/docs/-/insert", - ), -) -async def test_method_not_allowed(ds_write, path): - response = await ds_write.client.get( - path, - headers={ - "Content-Type": "application/json", - }, - ) - assert response.status_code == 405 - assert response.json() == { - "ok": False, - "error": "Method not allowed", - } - - -@pytest.mark.asyncio -async def test_create_uses_alter_by_default_for_new_table(ds_write): - ds_write._tracked_events = [] - token = write_token(ds_write) - response = await ds_write.client.post( - "/data/-/create", - json={ - "table": "new_table", - "rows": [ - { - "name": "Row 1", - } - ] - * 100 - + [ - {"name": "Row 2", "extra": "Extra"}, - ], - "pk": "id", - }, - headers=_headers(token), - ) - assert response.status_code == 201 - event_names = [e.name for e in ds_write._tracked_events] - assert event_names == ["create-table", "insert-rows"] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("has_alter_permission", (True, False)) -async def test_create_using_alter_against_existing_table( - ds_write, has_alter_permission -): - token = write_token( - ds_write, permissions=["ir", "ct"] + (["at"] if has_alter_permission else []) - ) - # First create the table - response = await ds_write.client.post( - "/data/-/create", - json={ - "table": "new_table", - "rows": [ - { - "name": "Row 1", - } - ], - "pk": "id", - }, - headers=_headers(token), - ) - assert response.status_code == 201 - - ds_write._tracked_events = [] - # Now try to insert more rows using /-/create with alter=True - response2 = await ds_write.client.post( - "/data/-/create", - json={ - "table": "new_table", - "rows": [{"name": "Row 2", "extra": "extra"}], - "pk": "id", - "alter": True, - }, - headers=_headers(token), - ) - if not has_alter_permission: - assert response2.status_code == 403 - assert response2.json() == { - "ok": False, - "errors": ["Permission denied: need alter-table"], - } - else: - assert response2.status_code == 201 - - event_names = [e.name for e in ds_write._tracked_events] - assert event_names == ["alter-table", "insert-rows"] - - # It should have altered the table - alter_event = ds_write._tracked_events[0] - assert alter_event.name == "alter-table" - assert "extra" not in alter_event.before_schema - assert "extra" in alter_event.after_schema - - insert_rows_event = ds_write._tracked_events[1] - assert insert_rows_event.name == "insert-rows" - assert insert_rows_event.num_rows == 1 diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index 5868a21c..00000000 --- a/tests/test_auth.py +++ /dev/null @@ -1,512 +0,0 @@ -from bs4 import BeautifulSoup as Soup -from .utils import cookie_was_deleted, last_event -from click.testing import CliRunner -from datasette.utils import baseconv -from datasette.cli import cli -from datasette.resources import ( - DatabaseResource, - TableResource, -) -import pytest -import time - - -@pytest.mark.asyncio -async def test_auth_token(ds_client): - """The /-/auth-token endpoint sets the correct cookie""" - assert ds_client.ds._root_token is not None - path = f"/-/auth-token?token={ds_client.ds._root_token}" - response = await ds_client.get(path) - assert response.status_code == 302 - assert "/" == response.headers["Location"] - assert {"a": {"id": "root"}} == ds_client.ds.unsign( - response.cookies["ds_actor"], "actor" - ) - # Should have recorded a login event - event = last_event(ds_client.ds) - assert event.name == "login" - assert event.actor == {"id": "root"} - # Check that a second with same token fails - assert ds_client.ds._root_token is None - assert (await ds_client.get(path)).status_code == 403 - # But attempting with same token while logged in as root should redirect to / - response = await ds_client.get( - path, cookies={"ds_actor": ds_client.actor_cookie({"id": "root"})} - ) - assert response.status_code == 302 - assert response.headers["Location"] == "/" - - -@pytest.mark.asyncio -async def test_actor_cookie(ds_client): - """A valid actor cookie sets request.scope['actor']""" - cookie = ds_client.actor_cookie({"id": "test"}) - await ds_client.get("/", cookies={"ds_actor": cookie}) - assert ds_client.ds._last_request.scope["actor"] == {"id": "test"} - - -@pytest.mark.asyncio -async def test_actor_cookie_invalid(ds_client): - cookie = ds_client.actor_cookie({"id": "test"}) - # Break the signature - await ds_client.get("/", cookies={"ds_actor": cookie[:-1] + "."}) - assert ds_client.ds._last_request.scope["actor"] is None - # Break the cookie format - cookie = ds_client.ds.sign({"b": {"id": "test"}}, "actor") - await ds_client.get("/", cookies={"ds_actor": cookie}) - assert ds_client.ds._last_request.scope["actor"] is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "offset,expected", - [ - ((24 * 60 * 60), {"id": "test"}), - (-(24 * 60 * 60), None), - ], -) -async def test_actor_cookie_that_expires(ds_client, offset, expected): - expires_at = int(time.time()) + offset - cookie = ds_client.ds.sign( - {"a": {"id": "test"}, "e": baseconv.base62.encode(expires_at)}, "actor" - ) - await ds_client.get("/", cookies={"ds_actor": cookie}) - assert ds_client.ds._last_request.scope["actor"] == expected - - -def test_logout(app_client): - # Keeping app_client for the moment because of csrftoken_from - response = app_client.get( - "/-/logout", cookies={"ds_actor": app_client.actor_cookie({"id": "test"})} - ) - assert 200 == response.status - assert "<p>You are logged in as <strong>test</strong></p>" in response.text - # Actors without an id get full serialization - response2 = app_client.get( - "/-/logout", cookies={"ds_actor": app_client.actor_cookie({"name2": "bob"})} - ) - assert 200 == response2.status - assert ( - "<p>You are logged in as <strong>{'name2': 'bob'}</strong></p>" - in response2.text - ) - # If logged out you get a redirect to / - response3 = app_client.get("/-/logout") - assert 302 == response3.status - # A POST to that page should log the user out - response4 = app_client.post( - "/-/logout", - csrftoken_from=True, - cookies={"ds_actor": app_client.actor_cookie({"id": "test"})}, - ) - # Should have recorded a logout event - event = last_event(app_client.ds) - assert event.name == "logout" - assert event.actor == {"id": "test"} - # The ds_actor cookie should have been unset - assert cookie_was_deleted(response4, "ds_actor") - # Should also have set a message - messages = app_client.ds.unsign(response4.cookies["ds_messages"], "messages") - assert [["You are now logged out", 2]] == messages - - -@pytest.mark.asyncio -@pytest.mark.parametrize("path", ["/", "/fixtures", "/fixtures/facetable"]) -async def test_logout_button_in_navigation(ds_client, path): - response = await ds_client.get( - path, cookies={"ds_actor": ds_client.actor_cookie({"id": "test"})} - ) - anon_response = await ds_client.get(path) - for fragment in ( - "<strong>test</strong>", - '<form class="nav-menu-logout" action="/-/logout" method="post">', - ): - assert fragment in response.text - assert fragment not in anon_response.text - - -@pytest.mark.asyncio -@pytest.mark.parametrize("path", ["/", "/fixtures", "/fixtures/facetable"]) -async def test_no_logout_button_in_navigation_if_no_ds_actor_cookie(ds_client, path): - response = await ds_client.get(path + "?_bot=1") - assert "<strong>bot</strong>" in response.text - assert ( - '<form class="nav-menu-logout" action="/-/logout" method="post">' - not in response.text - ) - - -@pytest.mark.parametrize( - "post_data,errors,expected_duration,expected_r", - ( - ({"expire_type": ""}, [], None, None), - ({"expire_type": "x"}, ["Invalid expire duration"], None, None), - ({"expire_type": "minutes"}, ["Invalid expire duration"], None, None), - ( - {"expire_type": "minutes", "expire_duration": "x"}, - ["Invalid expire duration"], - None, - None, - ), - ( - {"expire_type": "minutes", "expire_duration": "-1"}, - ["Invalid expire duration"], - None, - None, - ), - ( - {"expire_type": "minutes", "expire_duration": "0"}, - ["Invalid expire duration"], - None, - None, - ), - ({"expire_type": "minutes", "expire_duration": "10"}, [], 600, None), - ({"expire_type": "hours", "expire_duration": "10"}, [], 10 * 60 * 60, None), - ({"expire_type": "days", "expire_duration": "3"}, [], 60 * 60 * 24 * 3, None), - # Token restrictions - ({"all:view-instance": "on"}, [], None, {"a": ["vi"]}), - ({"database:fixtures:view-query": "on"}, [], None, {"d": {"fixtures": ["vq"]}}), - ( - {"resource:fixtures:facetable:insert-row": "on"}, - [], - None, - {"r": {"fixtures": {"facetable": ["ir"]}}}, - ), - ), -) -def test_auth_create_token( - app_client, post_data, errors, expected_duration, expected_r -): - assert app_client.get("/-/create-token").status == 403 - ds_actor = app_client.actor_cookie({"id": "test"}) - response = app_client.get("/-/create-token", cookies={"ds_actor": ds_actor}) - assert response.status == 200 - assert ">Create an API token<" in response.text - # Confirm some aspects of expected set of checkboxes - soup = Soup(response.text, "html.parser") - checkbox_names = {el["name"] for el in soup.select('input[type="checkbox"]')} - assert checkbox_names.issuperset( - { - "all:view-instance", - "all:view-query", - "database:fixtures:drop-table", - "resource:fixtures:foreign_key_references:insert-row", - "resource:fixtures:facetable:set-column-type", - } - ) - # Now try actually creating one - response2 = app_client.post( - "/-/create-token", - post_data, - csrftoken_from=True, - cookies={"ds_actor": ds_actor}, - ) - assert response2.status == 200 - if errors: - for error in errors: - assert '<p class="message-error">{}</p>'.format(error) in response2.text - else: - # Check create-token event - event = last_event(app_client.ds) - assert event.name == "create-token" - assert event.expires_after == expected_duration - assert isinstance(event.restrict_all, list) - assert isinstance(event.restrict_database, dict) - assert isinstance(event.restrict_resource, dict) - # Extract token from page - token = response2.text.split('value="dstok_')[1].split('"')[0] - details = app_client.ds.unsign(token, "token") - if expected_r: - r = details.pop("_r") - assert r == expected_r - assert details.keys() == {"a", "t", "d"} or details.keys() == {"a", "t"} - assert details["a"] == "test" - if expected_duration is None: - assert "d" not in details - else: - assert details["d"] == expected_duration - # And test that token - response3 = app_client.get( - "/-/actor.json", - headers={"Authorization": "Bearer {}".format("dstok_{}".format(token))}, - ) - assert response3.status == 200 - assert response3.json["actor"]["id"] == "test" - - -@pytest.mark.asyncio -async def test_auth_create_token_not_allowed_for_tokens(ds_client): - ds_tok = ds_client.ds.sign({"a": "test", "token": "dstok"}, "token") - response = await ds_client.get( - "/-/create-token", - headers={"Authorization": "Bearer dstok_{}".format(ds_tok)}, - ) - assert response.status_code == 403 - - -@pytest.mark.asyncio -async def test_auth_create_token_not_allowed_if_allow_signed_tokens_off(ds_client): - ds_client.ds._settings["allow_signed_tokens"] = False - try: - ds_actor = ds_client.actor_cookie({"id": "test"}) - response = await ds_client.get( - "/-/create-token", cookies={"ds_actor": ds_actor} - ) - assert response.status_code == 403 - finally: - ds_client.ds._settings["allow_signed_tokens"] = True - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "scenario,should_work", - ( - ("allow_signed_tokens_off", False), - ("no_token", False), - ("no_timestamp", False), - ("invalid_token", False), - ("expired_token", False), - ("valid_unlimited_token", True), - ("valid_expiring_token", True), - ), -) -async def test_auth_with_dstok_token(ds_client, scenario, should_work): - token = None - _time = int(time.time()) - if scenario in ("valid_unlimited_token", "allow_signed_tokens_off"): - token = ds_client.ds.sign({"a": "test", "t": _time}, "token") - elif scenario == "valid_expiring_token": - token = ds_client.ds.sign({"a": "test", "t": _time - 50, "d": 1000}, "token") - elif scenario == "expired_token": - token = ds_client.ds.sign({"a": "test", "t": _time - 2000, "d": 1000}, "token") - elif scenario == "no_timestamp": - token = ds_client.ds.sign({"a": "test"}, "token") - elif scenario == "invalid_token": - token = "invalid" - if token: - token = "dstok_{}".format(token) - if scenario == "allow_signed_tokens_off": - ds_client.ds._settings["allow_signed_tokens"] = False - headers = {} - if token: - headers["Authorization"] = "Bearer {}".format(token) - response = await ds_client.get("/-/actor.json", headers=headers) - try: - if should_work: - data = response.json() - assert data.keys() == {"actor"} - actor = data["actor"] - expected_keys = {"id", "token"} - if scenario != "valid_unlimited_token": - expected_keys.add("token_expires") - assert actor.keys() == expected_keys - assert actor["id"] == "test" - assert actor["token"] == "dstok" - if scenario != "valid_unlimited_token": - assert isinstance(actor["token_expires"], int) - else: - assert response.json() == {"actor": None} - finally: - ds_client.ds._settings["allow_signed_tokens"] = True - - -@pytest.mark.parametrize("expires", (None, 1000, -1000)) -def test_cli_create_token(app_client, expires): - secret = app_client.ds._secret - runner = CliRunner() - args = ["create-token", "--secret", secret, "test"] - if expires: - args += ["--expires-after", str(expires)] - result = runner.invoke(cli, args) - assert result.exit_code == 0 - token = result.output.strip() - assert token.startswith("dstok_") - details = app_client.ds.unsign(token[len("dstok_") :], "token") - expected_keys = {"a", "t"} - if expires: - expected_keys.add("d") - assert details.keys() == expected_keys - assert details["a"] == "test" - response = app_client.get( - "/-/actor.json", headers={"Authorization": "Bearer {}".format(token)} - ) - if expires is None or expires > 0: - expected_actor = { - "id": "test", - "token": "dstok", - } - if expires and expires > 0: - expected_actor["token_expires"] = details["t"] + expires - assert response.json == {"actor": expected_actor} - else: - expected_actor = None - assert response.json == {"actor": expected_actor} - - -@pytest.mark.asyncio -async def test_root_with_root_enabled_gets_all_permissions(ds_client): - """Root user with root_enabled=True gets all permissions""" - # Ensure catalog tables are populated - await ds_client.ds.invoke_startup() - await ds_client.ds._refresh_schemas() - - # Set root_enabled to simulate --root flag - ds_client.ds.root_enabled = True - - root_actor = {"id": "root"} - - # Test instance-level permissions (no resource) - assert ( - await ds_client.ds.allowed(action="permissions-debug", actor=root_actor) is True - ) - assert await ds_client.ds.allowed(action="debug-menu", actor=root_actor) is True - - # Test view permissions using the new ds.allowed() method - assert await ds_client.ds.allowed(action="view-instance", actor=root_actor) is True - - assert ( - await ds_client.ds.allowed( - action="view-database", - resource=DatabaseResource("fixtures"), - actor=root_actor, - ) - is True - ) - - assert ( - await ds_client.ds.allowed( - action="view-table", - resource=TableResource("fixtures", "facetable"), - actor=root_actor, - ) - is True - ) - - # Test write permissions using ds.allowed() - assert ( - await ds_client.ds.allowed( - action="insert-row", - resource=TableResource("fixtures", "facetable"), - actor=root_actor, - ) - is True - ) - - assert ( - await ds_client.ds.allowed( - action="delete-row", - resource=TableResource("fixtures", "facetable"), - actor=root_actor, - ) - is True - ) - - assert ( - await ds_client.ds.allowed( - action="update-row", - resource=TableResource("fixtures", "facetable"), - actor=root_actor, - ) - is True - ) - - assert ( - await ds_client.ds.allowed( - action="create-table", - resource=DatabaseResource("fixtures"), - actor=root_actor, - ) - is True - ) - - assert ( - await ds_client.ds.allowed( - action="alter-table", - resource=TableResource("fixtures", "facetable"), - actor=root_actor, - ) - is True - ) - - assert ( - await ds_client.ds.allowed( - action="set-column-type", - resource=TableResource("fixtures", "facetable"), - actor=root_actor, - ) - is True - ) - - assert ( - await ds_client.ds.allowed( - action="drop-table", - resource=TableResource("fixtures", "facetable"), - actor=root_actor, - ) - is True - ) - - -@pytest.mark.asyncio -async def test_root_without_root_enabled_no_special_permissions(ds_client): - """Root user without root_enabled doesn't get automatic permissions""" - # Ensure catalog tables are populated - await ds_client.ds.invoke_startup() - await ds_client.ds._refresh_schemas() - - # Ensure root_enabled is NOT set (or is False) - ds_client.ds.root_enabled = False - - root_actor = {"id": "root"} - - # Test permissions that normally require special access - # Without root_enabled, root should follow normal permission rules - - # View permissions should still work (default=True) - assert ( - await ds_client.ds.allowed(action="view-instance", actor=root_actor) is True - ) # Default permission - - assert ( - await ds_client.ds.allowed( - action="view-database", - resource=DatabaseResource("fixtures"), - actor=root_actor, - ) - is True - ) # Default permission - - # But restricted permissions should NOT automatically be granted - # Test with instance-level permission (no resource class) - result = await ds_client.ds.allowed(action="permissions-debug", actor=root_actor) - assert ( - result is not True - ), "Root without root_enabled should not automatically get permissions-debug" - - # Test with resource-based permissions using ds.allowed() - assert ( - await ds_client.ds.allowed( - action="create-table", - resource=DatabaseResource("fixtures"), - actor=root_actor, - ) - is not True - ), "Root without root_enabled should not automatically get create-table" - - assert ( - await ds_client.ds.allowed( - action="drop-table", - resource=TableResource("fixtures", "facetable"), - actor=root_actor, - ) - is not True - ), "Root without root_enabled should not automatically get drop-table" - - assert ( - await ds_client.ds.allowed( - action="set-column-type", - resource=TableResource("fixtures", "facetable"), - actor=root_actor, - ) - is not True - ), "Root without root_enabled should not automatically get set-column-type" diff --git a/tests/test_base_view.py b/tests/test_base_view.py deleted file mode 100644 index 2cd4d601..00000000 --- a/tests/test_base_view.py +++ /dev/null @@ -1,84 +0,0 @@ -from datasette.views.base import View -from datasette import Request, Response -from datasette.app import Datasette -import json -import pytest - - -class GetView(View): - async def get(self, request, datasette): - return Response.json( - { - "absolute_url": datasette.absolute_url(request, "/"), - "request_path": request.path, - } - ) - - -class GetAndPostView(GetView): - async def post(self, request, datasette): - return Response.json( - { - "method": request.method, - "absolute_url": datasette.absolute_url(request, "/"), - "request_path": request.path, - } - ) - - -@pytest.mark.asyncio -async def test_get_view(): - v = GetView() - datasette = Datasette() - response = await v(Request.fake("/foo"), datasette) - assert json.loads(response.body) == { - "absolute_url": "http://localhost/", - "request_path": "/foo", - } - # Try a HEAD request - head_response = await v(Request.fake("/foo", method="HEAD"), datasette) - assert head_response.body == "" - assert head_response.status == 200 - # And OPTIONS - options_response = await v(Request.fake("/foo", method="OPTIONS"), datasette) - assert options_response.body == "ok" - assert options_response.status == 200 - assert options_response.headers["allow"] == "HEAD, GET" - # And POST - post_response = await v(Request.fake("/foo", method="POST"), datasette) - assert post_response.body == "Method not allowed" - assert post_response.status == 405 - # And POST with .json extension - post_json_response = await v(Request.fake("/foo.json", method="POST"), datasette) - assert json.loads(post_json_response.body) == { - "ok": False, - "error": "Method not allowed", - } - assert post_json_response.status == 405 - - -@pytest.mark.asyncio -async def test_post_view(): - v = GetAndPostView() - datasette = Datasette() - response = await v(Request.fake("/foo"), datasette) - assert json.loads(response.body) == { - "absolute_url": "http://localhost/", - "request_path": "/foo", - } - # Try a HEAD request - head_response = await v(Request.fake("/foo", method="HEAD"), datasette) - assert head_response.body == "" - assert head_response.status == 200 - # And OPTIONS - options_response = await v(Request.fake("/foo", method="OPTIONS"), datasette) - assert options_response.body == "ok" - assert options_response.status == 200 - assert options_response.headers["allow"] == "HEAD, GET, POST" - # And POST - post_response = await v(Request.fake("/foo", method="POST"), datasette) - assert json.loads(post_response.body) == { - "method": "POST", - "absolute_url": "http://localhost/", - "request_path": "/foo", - } diff --git a/tests/test_cli.py b/tests/test_cli.py deleted file mode 100644 index 1d3a2b28..00000000 --- a/tests/test_cli.py +++ /dev/null @@ -1,583 +0,0 @@ -from .fixtures import ( - make_app_client, - TestClient as _TestClient, - EXPECTED_PLUGINS, -) -from datasette.app import SETTINGS -from datasette.plugins import DEFAULT_PLUGINS, pm -from datasette.cli import cli, serve -from datasette.version import __version__ -from datasette.utils import tilde_encode -from datasette.utils.sqlite import sqlite3 -from click.testing import CliRunner -import io -import json -import pathlib -import pytest -import sys -import textwrap -from unittest import mock - - -def test_inspect_cli(app_client): - runner = CliRunner() - result = runner.invoke(cli, ["inspect", "fixtures.db"]) - data = json.loads(result.output) - assert ["fixtures"] == list(data.keys()) - database = data["fixtures"] - assert "fixtures.db" == database["file"] - assert isinstance(database["hash"], str) - assert 64 == len(database["hash"]) - for table_name, expected_count in { - "Table With Space In Name": 0, - "facetable": 15, - }.items(): - assert expected_count == database["tables"][table_name]["count"] - - -def test_inspect_cli_writes_to_file(app_client): - runner = CliRunner() - result = runner.invoke( - cli, ["inspect", "fixtures.db", "--inspect-file", "foo.json"] - ) - assert 0 == result.exit_code, result.output - with open("foo.json") as fp: - data = json.load(fp) - assert ["fixtures"] == list(data.keys()) - - -def test_serve_with_inspect_file_prepopulates_table_counts_cache(): - inspect_data = {"fixtures": {"tables": {"hithere": {"count": 44}}}} - with make_app_client(inspect_data=inspect_data, is_immutable=True) as client: - assert inspect_data == client.ds.inspect_data - db = client.ds.databases["fixtures"] - assert {"hithere": 44} == db.cached_table_counts - - -@pytest.mark.parametrize( - "spatialite_paths,should_suggest_load_extension", - ( - ([], False), - (["/tmp"], True), - ), -) -def test_spatialite_error_if_attempt_to_open_spatialite( - spatialite_paths, should_suggest_load_extension -): - with mock.patch("datasette.utils.SPATIALITE_PATHS", spatialite_paths): - runner = CliRunner() - result = runner.invoke( - cli, ["serve", str(pathlib.Path(__file__).parent / "spatialite.db")] - ) - assert result.exit_code != 0 - assert "It looks like you're trying to load a SpatiaLite" in result.output - suggestion = "--load-extension=spatialite" - if should_suggest_load_extension: - assert suggestion in result.output - else: - assert suggestion not in result.output - - -@mock.patch("datasette.utils.SPATIALITE_PATHS", ["/does/not/exist"]) -def test_spatialite_error_if_cannot_find_load_extension_spatialite(): - runner = CliRunner() - result = runner.invoke( - cli, - [ - "serve", - str(pathlib.Path(__file__).parent / "spatialite.db"), - "--load-extension", - "spatialite", - ], - ) - assert result.exit_code != 0 - assert "Could not find SpatiaLite extension" in result.output - - -def test_plugins_cli(app_client): - runner = CliRunner() - result1 = runner.invoke(cli, ["plugins"]) - actual_plugins = sorted( - [p for p in json.loads(result1.output) if p["name"] != "TrackEventPlugin"], - key=lambda p: p["name"], - ) - assert actual_plugins == EXPECTED_PLUGINS - # Try with --all - result2 = runner.invoke(cli, ["plugins", "--all"]) - names = [p["name"] for p in json.loads(result2.output)] - # Should have all the EXPECTED_PLUGINS - assert set(names).issuperset({p["name"] for p in EXPECTED_PLUGINS}) - # And the following too: - assert set(names).issuperset(DEFAULT_PLUGINS) - # --requirements should be empty because there are no installed non-plugins-dir plugins - result3 = runner.invoke(cli, ["plugins", "--requirements"]) - assert result3.output == "" - - -def test_metadata_yaml(): - yaml_file = io.StringIO(textwrap.dedent(""" - title: Hello from YAML - """)) - # Annoyingly we have to provide all default arguments here: - ds = serve.callback( - [], - metadata=yaml_file, - immutable=[], - host="127.0.0.1", - port=8001, - uds=None, - reload=False, - cors=False, - sqlite_extensions=[], - inspect_file=None, - template_dir=None, - plugins_dir=None, - static=[], - memory=False, - config=[], - settings=[], - secret=None, - root=False, - default_deny=False, - token=None, - actor=None, - version_note=None, - get=None, - headers=False, - help_settings=False, - pdb=False, - crossdb=False, - nolock=False, - open_browser=False, - create=False, - ssl_keyfile=None, - ssl_certfile=None, - return_instance=True, - internal=None, - ) - client = _TestClient(ds) - response = client.get("/.json") - assert {"title": "Hello from YAML"} == response.json["metadata"] - - -@mock.patch("datasette.cli.run_module") -def test_install(run_module): - runner = CliRunner() - runner.invoke(cli, ["install", "datasette-mock-plugin", "datasette-mock-plugin2"]) - run_module.assert_called_once_with("pip", run_name="__main__") - assert sys.argv == [ - "pip", - "install", - "datasette-mock-plugin", - "datasette-mock-plugin2", - ] - - -@pytest.mark.parametrize("flag", ["-U", "--upgrade"]) -@mock.patch("datasette.cli.run_module") -def test_install_upgrade(run_module, flag): - runner = CliRunner() - runner.invoke(cli, ["install", flag, "datasette"]) - run_module.assert_called_once_with("pip", run_name="__main__") - assert sys.argv == ["pip", "install", "--upgrade", "datasette"] - - -@mock.patch("datasette.cli.run_module") -def test_install_requirements(run_module, tmpdir): - path = tmpdir.join("requirements.txt") - path.write("datasette-mock-plugin\ndatasette-plugin-2") - runner = CliRunner() - runner.invoke(cli, ["install", "-r", str(path)]) - run_module.assert_called_once_with("pip", run_name="__main__") - assert sys.argv == ["pip", "install", "-r", str(path)] - - -def test_install_error_if_no_packages(): - runner = CliRunner() - result = runner.invoke(cli, ["install"]) - assert result.exit_code == 2 - assert "Error: Please specify at least one package to install" in result.output - - -@mock.patch("datasette.cli.run_module") -def test_uninstall(run_module): - runner = CliRunner() - runner.invoke(cli, ["uninstall", "datasette-mock-plugin", "-y"]) - run_module.assert_called_once_with("pip", run_name="__main__") - assert sys.argv == ["pip", "uninstall", "datasette-mock-plugin", "-y"] - - -def test_version(): - runner = CliRunner() - result = runner.invoke(cli, ["--version"]) - assert result.output == f"cli, version {__version__}\n" - - -@pytest.mark.parametrize("invalid_port", ["-1", "0.5", "dog", "65536"]) -def test_serve_invalid_ports(invalid_port): - runner = CliRunner() - result = runner.invoke(cli, ["--port", invalid_port]) - assert result.exit_code == 2 - assert "Invalid value for '-p'" in result.stderr - - -@pytest.mark.parametrize( - "args", - ( - ["--setting", "default_page_size", "5"], - ["--setting", "settings.default_page_size", "5"], - ["-s", "settings.default_page_size", "5"], - ), -) -def test_setting(args): - runner = CliRunner() - result = runner.invoke(cli, ["--get", "/-/settings.json"] + args) - assert result.exit_code == 0, result.output - settings = json.loads(result.output) - assert settings["default_page_size"] == 5 - - -def test_setting_compatible_with_config(tmp_path): - # https://github.com/simonw/datasette/issues/2389 - runner = CliRunner() - config_path = tmp_path / "config.json" - config_path.write_text( - '{"settings": {"default_page_size": 5, "sql_time_limit_ms": 50}}', "utf-8" - ) - result = runner.invoke( - cli, - [ - "--get", - "/-/settings.json", - "--config", - str(config_path), - "--setting", - "default_page_size", - "10", - ], - ) - assert result.exit_code == 0, result.output - settings = json.loads(result.output) - assert settings["default_page_size"] == 10 - assert settings["sql_time_limit_ms"] == 50 - - -def test_plugin_s_overwrite(): - runner = CliRunner() - plugins_dir = str(pathlib.Path(__file__).parent / "plugins") - - result = runner.invoke( - cli, - [ - "--plugins-dir", - plugins_dir, - "--get", - "/_memory/-/query.json?sql=select+prepare_connection_args()", - ], - ) - assert result.exit_code == 0, result.output - assert ( - json.loads(result.output).get("rows")[0].get("prepare_connection_args()") - == 'database=_memory, datasette.plugin_config("name-of-plugin")=None' - ) - - result = runner.invoke( - cli, - [ - "--plugins-dir", - plugins_dir, - "--get", - "/_memory/-/query.json?sql=select+prepare_connection_args()", - "-s", - "plugins.name-of-plugin", - "OVERRIDE", - ], - ) - assert result.exit_code == 0, result.output - assert ( - json.loads(result.output).get("rows")[0].get("prepare_connection_args()") - == 'database=_memory, datasette.plugin_config("name-of-plugin")=OVERRIDE' - ) - - -def test_startup_error_from_plugin_is_click_exception(tmp_path): - plugins_dir = tmp_path / "plugins" - plugins_dir.mkdir() - (plugins_dir / "startup_error.py").write_text( - "from datasette import hookimpl\n" - "from datasette.utils import StartupError\n" - "\n" - "@hookimpl\n" - "def startup(datasette):\n" - ' raise StartupError("boom")\n', - "utf-8", - ) - runner = CliRunner() - result = runner.invoke( - cli, - [ - "--plugins-dir", - str(plugins_dir), - "--get", - "/", - ], - ) - try: - assert result.exit_code == 1 - assert "Error: boom" in result.output - finally: - # Cleanup: Unregister the plugin to avoid test isolation issues - to_unregister = [ - p for p in pm.get_plugins() if p.__name__ == "startup_error.py" - ] - if to_unregister: - pm.unregister(to_unregister[0]) - - -def test_setting_type_validation(): - runner = CliRunner() - result = runner.invoke(cli, ["--setting", "default_page_size", "dog"]) - assert result.exit_code == 2 - assert '"settings.default_page_size" should be an integer' in result.output - - -def test_setting_boolean_validation_invalid(): - """Test that invalid boolean values are rejected""" - runner = CliRunner() - result = runner.invoke( - cli, ["--setting", "default_allow_sql", "invalid", "--get", "/-/settings.json"] - ) - assert result.exit_code == 2 - assert ( - '"settings.default_allow_sql" should be on/off/true/false/1/0' in result.output - ) - - -@pytest.mark.parametrize("value", ("off", "false", "0")) -def test_setting_boolean_validation_false_values(value): - """Test that 'off', 'false', '0' work for boolean settings""" - runner = CliRunner() - result = runner.invoke( - cli, - [ - "--setting", - "default_allow_sql", - value, - "--get", - "/_memory/-/query.json?sql=select+1", - ], - ) - # Should be forbidden (setting is false) - assert result.exit_code == 1, result.output - assert "Forbidden" in result.output - - -@pytest.mark.parametrize("value", ("on", "true", "1")) -def test_setting_boolean_validation_true_values(value): - """Test that 'on', 'true', '1' work for boolean settings""" - runner = CliRunner() - result = runner.invoke( - cli, - [ - "--setting", - "default_allow_sql", - value, - "--get", - "/_memory/-/query.json?sql=select+1&_shape=objects", - ], - ) - # Should succeed (setting is true) - assert result.exit_code == 0, result.output - assert json.loads(result.output)["rows"][0] == {"1": 1} - - -@pytest.mark.parametrize("default_allow_sql", (True, False)) -def test_setting_default_allow_sql(default_allow_sql): - runner = CliRunner() - result = runner.invoke( - cli, - [ - "--setting", - "default_allow_sql", - "on" if default_allow_sql else "off", - "--get", - "/_memory/-/query.json?sql=select+21&_shape=objects", - ], - ) - if default_allow_sql: - assert result.exit_code == 0, result.output - assert json.loads(result.output)["rows"][0] == {"21": 21} - else: - assert result.exit_code == 1, result.output - # This isn't JSON at the moment, maybe it should be though - assert "Forbidden" in result.output - - -def test_sql_errors_logged_to_stderr(): - runner = CliRunner() - result = runner.invoke(cli, ["--get", "/_memory/-/query.json?sql=select+blah"]) - assert result.exit_code == 1 - assert "sql = 'select blah', params = {}: no such column: blah\n" in result.stderr - - -def test_serve_create(tmpdir): - runner = CliRunner() - db_path = tmpdir / "does_not_exist_yet.db" - assert not db_path.exists() - result = runner.invoke( - cli, [str(db_path), "--create", "--get", "/-/databases.json"] - ) - assert result.exit_code == 0, result.output - databases = json.loads(result.output) - assert { - "name": "does_not_exist_yet", - "is_mutable": True, - "is_memory": False, - "hash": None, - }.items() <= databases[0].items() - assert db_path.exists() - - -@pytest.mark.parametrize("argument", ("-c", "--config")) -@pytest.mark.parametrize("format_", ("json", "yaml")) -def test_serve_config(tmpdir, argument, format_): - config_path = tmpdir / "datasette.{}".format(format_) - config_path.write_text( - ( - "settings:\n default_page_size: 5\n" - if format_ == "yaml" - else '{"settings": {"default_page_size": 5}}' - ), - "utf-8", - ) - runner = CliRunner() - result = runner.invoke( - cli, - [ - argument, - str(config_path), - "--get", - "/-/settings.json", - ], - ) - assert result.exit_code == 0, result.output - assert json.loads(result.output)["default_page_size"] == 5 - - -def test_serve_duplicate_database_names(tmpdir): - "'datasette db.db nested/db.db' should attach two databases, /db and /db_2" - runner = CliRunner() - db_1_path = str(tmpdir / "db.db") - nested = tmpdir / "nested" - nested.mkdir() - db_2_path = str(tmpdir / "nested" / "db.db") - for path in (db_1_path, db_2_path): - conn = sqlite3.connect(path) - conn.execute("vacuum") - conn.close() - result = runner.invoke(cli, [db_1_path, db_2_path, "--get", "/-/databases.json"]) - assert result.exit_code == 0, result.output - databases = json.loads(result.output) - assert {db["name"] for db in databases} == {"db", "db_2"} - - -@pytest.mark.parametrize( - "filename", ["test-database (1).sqlite", "database (1).sqlite"] -) -def test_weird_database_names(tmpdir, filename): - # https://github.com/simonw/datasette/issues/1181 - runner = CliRunner() - db_path = str(tmpdir / filename) - conn = sqlite3.connect(db_path) - conn.execute("vacuum") - conn.close() - result1 = runner.invoke(cli, [db_path, "--get", "/"]) - assert result1.exit_code == 0, result1.output - filename_no_stem = filename.rsplit(".", 1)[0] - expected_link = '<a href="/{}">{}</a>'.format( - tilde_encode(filename_no_stem), filename_no_stem - ) - assert expected_link in result1.output - # Now try hitting that database page - result2 = runner.invoke( - cli, [db_path, "--get", "/{}".format(tilde_encode(filename_no_stem))] - ) - assert result2.exit_code == 0, result2.output - - -def test_help_settings(): - runner = CliRunner() - result = runner.invoke(cli, ["--help-settings"]) - for setting in SETTINGS: - assert setting.name in result.output - - -def test_internal_db(tmpdir): - runner = CliRunner() - internal_path = tmpdir / "internal.db" - assert not internal_path.exists() - result = runner.invoke( - cli, ["--memory", "--internal", str(internal_path), "--get", "/"] - ) - assert result.exit_code == 0 - assert internal_path.exists() - - -def test_duplicate_database_files_error(tmpdir): - """Test that passing the same database file multiple times raises an error""" - runner = CliRunner() - db_path = str(tmpdir / "test.db") - conn = sqlite3.connect(db_path) - conn.execute("vacuum") - conn.close() - - # Test with exact duplicate - result = runner.invoke(cli, ["serve", db_path, db_path, "--get", "/"]) - assert result.exit_code == 1 - assert "Duplicate database file" in result.output - assert "both refer to" in result.output - - # Test with different paths to same file (relative vs absolute) - result2 = runner.invoke( - cli, ["serve", db_path, str(pathlib.Path(db_path).resolve()), "--get", "/"] - ) - assert result2.exit_code == 1 - assert "Duplicate database file" in result2.output - - # Test that a file in the config_dir can't also be passed explicitly - config_dir = tmpdir / "config" - config_dir.mkdir() - config_db_path = str(config_dir / "data.db") - conn = sqlite3.connect(config_db_path) - conn.execute("vacuum") - conn.close() - - result3 = runner.invoke( - cli, ["serve", config_db_path, str(config_dir), "--get", "/"] - ) - assert result3.exit_code == 1 - assert "Duplicate database file" in result3.output - assert "both refer to" in result3.output - - # Test that mixing a file NOT in the directory with a directory works fine - other_db_path = str(tmpdir / "other.db") - conn = sqlite3.connect(other_db_path) - conn.execute("vacuum") - conn.close() - - result4 = runner.invoke( - cli, ["serve", other_db_path, str(config_dir), "--get", "/-/databases.json"] - ) - assert result4.exit_code == 0 - databases = json.loads(result4.output) - assert {db["name"] for db in databases} == {"other", "data"} - - # Test that multiple directories raise an error - config_dir2 = tmpdir / "config2" - config_dir2.mkdir() - - result5 = runner.invoke( - cli, ["serve", str(config_dir), str(config_dir2), "--get", "/"] - ) - assert result5.exit_code == 1 - assert "Cannot pass multiple directories" in result5.output diff --git a/tests/test_cli_serve_get.py b/tests/test_cli_serve_get.py deleted file mode 100644 index dc852201..00000000 --- a/tests/test_cli_serve_get.py +++ /dev/null @@ -1,137 +0,0 @@ -from datasette.cli import cli -from datasette.plugins import pm -from click.testing import CliRunner -import textwrap -import json - - -def test_serve_with_get(tmp_path_factory): - plugins_dir = tmp_path_factory.mktemp("plugins_for_serve_with_get") - (plugins_dir / "init_for_serve_with_get.py").write_text( - textwrap.dedent( - """ - from datasette import hookimpl - - @hookimpl - def startup(datasette): - with open("{}", "w") as fp: - fp.write("hello") - """.format(str(plugins_dir / "hello.txt")), - ), - "utf-8", - ) - runner = CliRunner() - result = runner.invoke( - cli, - [ - "serve", - "--memory", - "--plugins-dir", - str(plugins_dir), - "--get", - "/_memory/-/query.json?sql=select+sqlite_version()", - ], - ) - assert result.exit_code == 0, result.output - data = json.loads(result.output) - # Should have a single row with a single column - assert len(data["rows"]) == 1 - assert list(data["rows"][0].keys()) == ["sqlite_version()"] - assert set(data.keys()) == {"rows", "ok", "truncated"} - - # The plugin should have created hello.txt - assert (plugins_dir / "hello.txt").read_text() == "hello" - - # Annoyingly that new test plugin stays resident - we need - # to manually unregister it to avoid conflict with other tests - to_unregister = [ - p for p in pm.get_plugins() if p.__name__ == "init_for_serve_with_get.py" - ][0] - pm.unregister(to_unregister) - - -def test_serve_with_get_headers(): - runner = CliRunner() - result = runner.invoke( - cli, - [ - "serve", - "--memory", - "--get", - "/_memory/", - "--headers", - ], - ) - # exit_code is 1 because it wasn't a 200 response - assert result.exit_code == 1, result.output - lines = result.output.splitlines() - assert lines and lines[0] == "HTTP/1.1 302" - assert "location: /_memory" in lines - assert "content-type: text/html; charset=utf-8" in lines - - -def test_serve_with_get_and_token(): - runner = CliRunner() - result1 = runner.invoke( - cli, - [ - "create-token", - "--secret", - "sekrit", - "root", - ], - ) - token = result1.output.strip() - result2 = runner.invoke( - cli, - [ - "serve", - "--secret", - "sekrit", - "--get", - "/-/actor.json", - "--token", - token, - ], - ) - assert 0 == result2.exit_code, result2.output - assert json.loads(result2.output) == {"actor": {"id": "root", "token": "dstok"}} - - -def test_serve_with_get_exit_code_for_error(): - runner = CliRunner() - result = runner.invoke( - cli, - [ - "serve", - "--memory", - "--get", - "/this-is-404", - ], - catch_exceptions=False, - ) - assert result.exit_code == 1 - assert "404" in result.output - - -def test_serve_get_actor(): - runner = CliRunner() - result = runner.invoke( - cli, - [ - "serve", - "--memory", - "--get", - "/-/actor.json", - "--actor", - '{"id": "root", "extra": "x"}', - ], - catch_exceptions=False, - ) - assert result.exit_code == 0 - assert json.loads(result.output) == { - "actor": { - "id": "root", - "extra": "x", - } - } diff --git a/tests/test_cli_serve_server.py b/tests/test_cli_serve_server.py deleted file mode 100644 index 47f23c08..00000000 --- a/tests/test_cli_serve_server.py +++ /dev/null @@ -1,29 +0,0 @@ -import httpx -import pytest -import socket - - -@pytest.mark.serial -def test_serve_localhost_http(ds_localhost_http_server): - response = httpx.get("http://localhost:8041/_memory.json") - assert { - "database": "_memory", - "path": "/_memory", - "tables": [], - }.items() <= response.json().items() - - -@pytest.mark.serial -@pytest.mark.skipif( - not hasattr(socket, "AF_UNIX"), reason="Requires socket.AF_UNIX support" -) -def test_serve_unix_domain_socket(ds_unix_domain_socket_server): - _, uds = ds_unix_domain_socket_server - transport = httpx.HTTPTransport(uds=uds) - client = httpx.Client(transport=transport) - response = client.get("http://localhost/_memory.json") - assert { - "database": "_memory", - "path": "/_memory", - "tables": [], - }.items() <= response.json().items() diff --git a/tests/test_column_types.py b/tests/test_column_types.py deleted file mode 100644 index d77f2cf5..00000000 --- a/tests/test_column_types.py +++ /dev/null @@ -1,1113 +0,0 @@ -import json -import logging - -from bs4 import BeautifulSoup as Soup -from datasette.app import Datasette -from datasette.column_types import ( - ColumnType, - SQLiteType, -) -from datasette.hookspecs import hookimpl -from datasette.plugins import pm -from datasette.utils import sqlite3 -from datasette.utils import StartupError -import markupsafe -import pytest -import time - - -@pytest.fixture -def ds_ct(tmp_path_factory): - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute( - "create table posts (id integer primary key, title text, body text, " - "author_email text, website text, metadata text)" - ) - db.execute( - "insert into posts values (1, 'Hello', '# World', 'test@example.com', " - "'https://example.com', '{\"key\": \"value\"}')" - ) - db.commit() - ds = Datasette( - [db_path], - config={ - "databases": { - "data": { - "tables": { - "posts": { - "column_types": { - "body": "markdown", - "author_email": "email", - "website": "url", - "metadata": "json", - } - } - } - } - } - }, - ) - ds.root_enabled = True - yield ds - ds.close() - - -@pytest.fixture -def ds_ct_editor_permission(tmp_path_factory): - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute( - "create table posts (id integer primary key, title text, body text, " - "author_email text, website text, metadata text)" - ) - db.execute( - "insert into posts values (1, 'Hello', '# World', 'test@example.com', " - "'https://example.com', '{\"key\": \"value\"}')" - ) - db.commit() - ds = Datasette( - [db_path], - config={ - "databases": { - "data": { - "tables": { - "posts": { - "permissions": {"set-column-type": {"id": "editor"}}, - "column_types": { - "body": "markdown", - "author_email": "email", - "website": "url", - "metadata": "json", - }, - } - } - } - } - }, - ) - ds.root_enabled = True - yield ds - ds.close() - - -def write_token(ds, actor_id="root", permissions=None): - to_sign = {"a": actor_id, "token": "dstok", "t": int(time.time())} - if permissions: - to_sign["_r"] = {"a": permissions} - return "dstok_{}".format(ds.sign(to_sign, namespace="token")) - - -def _headers(token): - return { - "Authorization": "Bearer {}".format(token), - "Content-Type": "application/json", - } - - -def _window_data_from_html(html, variable_name): - soup = Soup(html, "html.parser") - scripts = soup.find_all("script") - matching_scripts = [ - script for script in scripts if variable_name in (script.string or "") - ] - assert len(matching_scripts) == 1 - script_text = matching_scripts[0].string.strip() - prefix = f"window.{variable_name} = " - assert script_text.startswith(prefix) - return json.loads(script_text[len(prefix) :].rstrip(";")) - - -# --- Internal DB and config loading --- - - -@pytest.mark.asyncio -async def test_column_types_table_created(ds_ct): - await ds_ct.invoke_startup() - internal = ds_ct.get_internal_database() - result = await internal.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='column_types'" - ) - assert len(result.rows) == 1 - - -@pytest.mark.asyncio -async def test_config_loaded_into_internal_db(ds_ct): - await ds_ct.invoke_startup() - ct_map = await ds_ct.get_column_types("data", "posts") - # "markdown" is not a registered type, so it won't appear - assert "body" not in ct_map - assert ct_map["author_email"].name == "email" - assert ct_map["author_email"].config is None - assert ct_map["website"].name == "url" - assert ct_map["metadata"].name == "json" - - -@pytest.mark.asyncio -async def test_config_with_type_and_config(tmp_path_factory): - class PointColumnType(ColumnType): - name = "point" - description = "Geographic point" - - class _Plugin: - @hookimpl - def register_column_types(self, datasette): - return [PointColumnType] - - plugin = _Plugin() - pm.register(plugin, name="test_point_ct") - try: - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute("create table geo (id integer primary key, location text)") - ds = Datasette( - [db_path], - config={ - "databases": { - "data": { - "tables": { - "geo": { - "column_types": { - "location": { - "type": "point", - "config": {"srid": 4326}, - } - } - } - } - } - } - }, - ) - await ds.invoke_startup() - ct = await ds.get_column_type("data", "geo", "location") - assert ct.name == "point" - assert ct.config == {"srid": 4326} - db.close() - for database in ds.databases.values(): - if not database.is_memory: - database.close() - finally: - pm.unregister(plugin, name="test_point_ct") - - -# --- Datasette API methods --- - - -@pytest.mark.asyncio -async def test_get_column_type(ds_ct): - await ds_ct.invoke_startup() - ct = await ds_ct.get_column_type("data", "posts", "author_email") - assert isinstance(ct, ColumnType) - assert ct.name == "email" - assert ct.config is None - - -@pytest.mark.asyncio -async def test_get_column_type_missing(ds_ct): - await ds_ct.invoke_startup() - ct = await ds_ct.get_column_type("data", "posts", "title") - assert ct is None - - -@pytest.mark.asyncio -async def test_set_and_remove_column_type(ds_ct): - await ds_ct.invoke_startup() - await ds_ct.set_column_type("data", "posts", "title", "email") - ct = await ds_ct.get_column_type("data", "posts", "title") - assert ct.name == "email" - assert ct.config is None - - await ds_ct.remove_column_type("data", "posts", "title") - ct = await ds_ct.get_column_type("data", "posts", "title") - assert ct is None - - -@pytest.mark.asyncio -async def test_set_column_type_with_config(ds_ct): - await ds_ct.invoke_startup() - await ds_ct.set_column_type("data", "posts", "title", "url", {"max_length": 200}) - ct = await ds_ct.get_column_type("data", "posts", "title") - assert ct.name == "url" - assert ct.config == {"max_length": 200} - - -@pytest.mark.asyncio -async def test_set_column_type_api(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct, permissions=["sct"]) - response = await ds_ct.client.post( - "/data/posts/-/set-column-type", - json={"column": "title", "column_type": {"type": "email"}}, - headers=_headers(token), - ) - assert response.status_code == 200 - assert response.json() == { - "ok": True, - "database": "data", - "table": "posts", - "column": "title", - "column_type": {"type": "email", "config": None}, - } - ct = await ds_ct.get_column_type("data", "posts", "title") - assert ct.name == "email" - assert ct.config is None - - -@pytest.mark.asyncio -async def test_set_column_type_api_with_config(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct, permissions=["sct"]) - response = await ds_ct.client.post( - "/data/posts/-/set-column-type", - json={ - "column": "title", - "column_type": {"type": "url", "config": {"max_length": 200}}, - }, - headers=_headers(token), - ) - assert response.status_code == 200 - assert response.json()["column_type"] == { - "type": "url", - "config": {"max_length": 200}, - } - ct = await ds_ct.get_column_type("data", "posts", "title") - assert ct.name == "url" - assert ct.config == {"max_length": 200} - - -@pytest.mark.asyncio -async def test_clear_column_type_api(ds_ct): - await ds_ct.invoke_startup() - await ds_ct.set_column_type("data", "posts", "title", "email") - token = write_token(ds_ct, permissions=["sct"]) - response = await ds_ct.client.post( - "/data/posts/-/set-column-type", - json={"column": "title", "column_type": None}, - headers=_headers(token), - ) - assert response.status_code == 200 - assert response.json() == { - "ok": True, - "database": "data", - "table": "posts", - "column": "title", - "column_type": None, - } - ct = await ds_ct.get_column_type("data", "posts", "title") - assert ct is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "body,special_case,expected_status,expected_errors", - ( - ( - {"column": "title", "column_type": {"type": "email"}}, - "no_permission", - 403, - ["Permission denied"], - ), - ( - None, - "invalid_json", - 400, - [ - "Invalid JSON: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)" - ], - ), - ( - {"column": "title", "column_type": {"type": "email"}}, - "invalid_content_type", - 400, - ["Invalid content-type, must be application/json"], - ), - ( - [], - None, - 400, - ["JSON must be a dictionary"], - ), - ( - {"column_type": {"type": "email"}}, - None, - 400, - ['"column" is required'], - ), - ( - {"column": 1, "column_type": {"type": "email"}}, - None, - 400, - ['"column" must be a string'], - ), - ( - {"column": "not_a_column", "column_type": {"type": "email"}}, - None, - 400, - ["Column not found: not_a_column"], - ), - ( - {"column": "title", "column_type": "email"}, - None, - 400, - ['"column_type" must be an object or null'], - ), - ( - {"column": "title", "column_type": {}}, - None, - 400, - ['"column_type.type" is required'], - ), - ( - {"column": "title", "column_type": {"type": 1}}, - None, - 400, - ['"column_type.type" must be a string'], - ), - ( - {"column": "title", "column_type": {"type": "url", "config": []}}, - None, - 400, - ['"column_type.config" must be a dictionary'], - ), - ( - {"column": "title", "column_type": {"type": "markdown"}}, - None, - 400, - ["Unknown column type: markdown"], - ), - ( - {"column": "id", "column_type": {"type": "json"}}, - None, - 400, - [ - "Column type 'json' is only applicable to SQLite types TEXT but data.posts.id has SQLite type INTEGER" - ], - ), - ( - { - "column": "title", - "column_type": {"type": "email"}, - "extra": True, - }, - None, - 400, - ['Invalid parameter: "extra"'], - ), - ), -) -async def test_set_column_type_api_errors( - ds_ct, body, special_case, expected_status, expected_errors -): - await ds_ct.invoke_startup() - token = write_token( - ds_ct, - permissions=(["sct"] if special_case != "no_permission" else ["vi"]), - ) - kwargs = { - "headers": { - "Authorization": f"Bearer {token}", - "Content-Type": ( - "text/plain" - if special_case == "invalid_content_type" - else "application/json" - ), - } - } - if special_case == "invalid_json": - kwargs["content"] = "{bad json" - else: - kwargs["json"] = body - response = await ds_ct.client.post("/data/posts/-/set-column-type", **kwargs) - assert response.status_code == expected_status - assert response.json() == {"ok": False, "errors": expected_errors} - - -@pytest.mark.asyncio -async def test_set_column_type_api_works_for_immutable_database(tmp_path_factory): - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "immutable.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute("create table posts (id integer primary key, title text)") - db.commit() - ds = Datasette([], immutables=[db_path]) - ds.root_enabled = True - try: - await ds.invoke_startup() - token = write_token(ds, permissions=["sct"]) - response = await ds.client.post( - "/immutable/posts/-/set-column-type", - json={"column": "title", "column_type": {"type": "email"}}, - headers=_headers(token), - ) - assert response.status_code == 200 - assert response.json()["column_type"] == {"type": "email", "config": None} - ct = await ds.get_column_type("immutable", "posts", "title") - assert ct.name == "email" - finally: - db.close() - for database in ds.databases.values(): - if not database.is_memory: - database.close() - - -@pytest.mark.asyncio -async def test_set_column_type_rejects_incompatible_sqlite_type(ds_ct): - await ds_ct.invoke_startup() - with pytest.raises(ValueError, match="only applicable to SQLite types TEXT"): - await ds_ct.set_column_type("data", "posts", "id", "json") - - -@pytest.mark.asyncio -async def test_set_column_type_allows_varchar_for_text_only_type(tmp_path_factory): - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute("create table links (id integer primary key, url varchar(255))") - db.commit() - ds = Datasette([db_path]) - await ds.invoke_startup() - await ds.set_column_type("data", "links", "url", "url") - ct = await ds.get_column_type("data", "links", "url") - assert ct.name == "url" - db.close() - for database in ds.databases.values(): - if not database.is_memory: - database.close() - - -# --- Plugin registration --- - - -@pytest.mark.asyncio -async def test_builtin_column_types_registered(ds_ct): - """register_column_types returns classes; _column_types stores them by name.""" - await ds_ct.invoke_startup() - assert "url" in ds_ct._column_types - assert "email" in ds_ct._column_types - assert "json" in ds_ct._column_types - assert "nonexistent" not in ds_ct._column_types - - -@pytest.mark.asyncio -async def test_column_type_class_attributes(ds_ct): - await ds_ct.invoke_startup() - url_cls = ds_ct._column_types["url"] - assert url_cls.name == "url" - assert url_cls.description == "URL" - assert url_cls.sqlite_types == (SQLiteType.TEXT,) - email_cls = ds_ct._column_types["email"] - assert email_cls.name == "email" - assert email_cls.description == "Email address" - assert email_cls.sqlite_types == (SQLiteType.TEXT,) - json_cls = ds_ct._column_types["json"] - assert json_cls.sqlite_types == (SQLiteType.TEXT,) - - -def test_sqlite_type_from_declared_type(): - assert SQLiteType.from_declared_type("text") == SQLiteType.TEXT - assert SQLiteType.from_declared_type("varchar(255)") == SQLiteType.TEXT - assert SQLiteType.from_declared_type("integer") == SQLiteType.INTEGER - assert SQLiteType.from_declared_type("float") == SQLiteType.REAL - assert SQLiteType.from_declared_type("blob") == SQLiteType.BLOB - assert SQLiteType.from_declared_type("") == SQLiteType.NULL - assert SQLiteType.from_declared_type("numeric") is None - - -# --- JSON API --- - - -@pytest.mark.asyncio -async def test_column_types_extra(ds_ct): - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts.json?_extra=column_types") - assert response.status_code == 200 - data = response.json() - assert "column_types" in data - assert data["column_types"]["author_email"] == {"type": "email", "config": None} - assert data["column_types"]["website"] == {"type": "url", "config": None} - assert data["column_types"]["metadata"] == {"type": "json", "config": None} - # "markdown" is not a registered type, so body should not appear - assert "body" not in data["column_types"] - # title has no column type, should not appear - assert "title" not in data["column_types"] - - -@pytest.mark.asyncio -async def test_display_columns_include_column_type(ds_ct): - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts.json?_extra=display_columns") - assert response.status_code == 200 - data = response.json() - cols = {c["name"]: c for c in data["display_columns"]} - assert cols["author_email"]["column_type"] == "email" - assert cols["author_email"]["column_type_config"] is None - assert cols["website"]["column_type"] == "url" - assert cols["title"]["column_type"] is None - - -# --- Rendering --- - - -@pytest.mark.asyncio -async def test_url_render_cell(ds_ct): - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts.json?_extra=render_cell") - assert response.status_code == 200 - data = response.json() - rendered = data["render_cell"][0] - assert "href" in rendered["website"] - assert "https://example.com" in rendered["website"] - - -@pytest.mark.asyncio -async def test_email_render_cell(ds_ct): - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts.json?_extra=render_cell") - assert response.status_code == 200 - data = response.json() - rendered = data["render_cell"][0] - assert "mailto:" in rendered["author_email"] - assert "test@example.com" in rendered["author_email"] - - -@pytest.mark.asyncio -async def test_json_render_cell(ds_ct): - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts.json?_extra=render_cell") - assert response.status_code == 200 - data = response.json() - rendered = data["render_cell"][0] - assert "<pre>" in rendered["metadata"] - - -# --- Validation --- - - -@pytest.mark.asyncio -async def test_email_validation_on_insert(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct) - response = await ds_ct.client.post( - "/data/posts/-/insert", - json={"row": {"title": "Test", "author_email": "not-an-email"}}, - headers=_headers(token), - ) - assert response.status_code == 400 - assert "author_email" in response.json()["errors"][0] - - -@pytest.mark.asyncio -async def test_email_validation_passes_valid(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct) - response = await ds_ct.client.post( - "/data/posts/-/insert", - json={"row": {"title": "Test", "author_email": "valid@example.com"}}, - headers=_headers(token), - ) - assert response.status_code == 201 - - -@pytest.mark.asyncio -async def test_url_validation_on_insert(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct) - response = await ds_ct.client.post( - "/data/posts/-/insert", - json={"row": {"title": "Test", "website": "not-a-url"}}, - headers=_headers(token), - ) - assert response.status_code == 400 - assert "website" in response.json()["errors"][0] - - -@pytest.mark.asyncio -async def test_json_validation_on_insert(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct) - response = await ds_ct.client.post( - "/data/posts/-/insert", - json={"row": {"title": "Test", "metadata": "not-json{"}}, - headers=_headers(token), - ) - assert response.status_code == 400 - assert "metadata" in response.json()["errors"][0] - - -@pytest.mark.asyncio -async def test_validation_on_update(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct) - response = await ds_ct.client.post( - "/data/posts/1/-/update", - json={"update": {"author_email": "invalid"}}, - headers=_headers(token), - ) - assert response.status_code == 400 - assert "author_email" in response.json()["errors"][0] - - -@pytest.mark.asyncio -async def test_validation_allows_null(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct) - response = await ds_ct.client.post( - "/data/posts/-/insert", - json={"row": {"title": "Test", "author_email": None}}, - headers=_headers(token), - ) - assert response.status_code == 201 - - -@pytest.mark.asyncio -async def test_validation_allows_empty_string(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct) - response = await ds_ct.client.post( - "/data/posts/-/insert", - json={"row": {"title": "Test", "author_email": ""}}, - headers=_headers(token), - ) - assert response.status_code == 201 - - -# --- ColumnType base class --- - - -@pytest.mark.asyncio -async def test_column_type_base_defaults(): - class TestType(ColumnType): - name = "test" - description = "Test type" - - ct = TestType() - assert ct.config is None - assert await ct.render_cell("val", "col", "tbl", "db", None, None) is None - assert await ct.validate("val", None) is None - assert await ct.transform_value("val", None) == "val" - - -# --- render_cell extra with column types --- - - -@pytest.mark.asyncio -async def test_render_cell_extra_with_column_types(ds_ct): - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts.json?_extra=render_cell") - assert response.status_code == 200 - data = response.json() - rendered = data["render_cell"][0] - assert "mailto:" in rendered["author_email"] - assert "href" in rendered["website"] - - -# --- Duplicate column type name --- - - -@pytest.mark.asyncio -async def test_duplicate_column_type_name_raises_error(): - class DuplicateUrlType(ColumnType): - name = "url" - description = "Duplicate URL" - - async def render_cell(self, value, column, table, database, datasette, request): - return None - - class _Plugin: - @hookimpl - def register_column_types(self, datasette): - return [DuplicateUrlType] - - plugin = _Plugin() - pm.register(plugin, name="test_duplicate_ct") - try: - ds = Datasette() - with pytest.raises(StartupError, match="Duplicate column type name: url"): - await ds.invoke_startup() - finally: - pm.unregister(plugin, name="test_duplicate_ct") - - -# --- Row endpoint --- - - -@pytest.mark.asyncio -async def test_row_endpoint_render_cell_with_column_types(ds_ct): - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts/1.json?_extra=render_cell") - assert response.status_code == 200 - data = response.json() - rendered = data["render_cell"][0] - assert "mailto:" in rendered["author_email"] - assert "href" in rendered["website"] - - -# --- transform_value in JSON output --- - - -@pytest.mark.asyncio -async def test_transform_value_in_json_output(tmp_path_factory): - """A column type with transform_value should modify rows in JSON API.""" - - class UpperColumnType(ColumnType): - name = "upper" - description = "Uppercase" - - async def transform_value(self, value, datasette): - if isinstance(value, str): - return value.upper() - return value - - class _Plugin: - @hookimpl - def register_column_types(self, datasette): - return [UpperColumnType] - - plugin = _Plugin() - pm.register(plugin, name="test_transform_ct") - try: - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute("create table t (id integer primary key, name text)") - db.execute("insert into t values (1, 'hello')") - db.commit() - ds = Datasette( - [db_path], - config={ - "databases": { - "data": {"tables": {"t": {"column_types": {"name": "upper"}}}} - } - }, - ) - await ds.invoke_startup() - response = await ds.client.get("/data/t.json") - assert response.status_code == 200 - data = response.json() - assert data["rows"][0]["name"] == "HELLO" - db.close() - for database in ds.databases.values(): - if not database.is_memory: - database.close() - finally: - pm.unregister(plugin, name="test_transform_ct") - - -# --- Column type priority over plugins --- - - -@pytest.mark.asyncio -async def test_column_type_render_cell_has_priority_over_plugins(tmp_path_factory): - """Column type render_cell should take priority over render_cell plugin hook.""" - - class PriorityColumnType(ColumnType): - name = "priority_test" - description = "Priority test" - - async def render_cell(self, value, column, table, database, datasette, request): - if value is not None: - return markupsafe.Markup( - f"<b>COLUMN_TYPE:{markupsafe.escape(value)}</b>" - ) - return None - - class _ColumnTypePlugin: - @hookimpl - def register_column_types(self, datasette): - return [PriorityColumnType] - - class _RenderCellPlugin: - @hookimpl - def render_cell( - self, - row, - value, - column, - table, - pks, - database, - datasette, - request, - column_type, - ): - if column == "name": - return markupsafe.Markup(f"<i>PLUGIN:{markupsafe.escape(value)}</i>") - - ct_plugin = _ColumnTypePlugin() - rc_plugin = _RenderCellPlugin() - pm.register(ct_plugin, name="test_priority_ct") - pm.register(rc_plugin, name="test_priority_render") - try: - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute("create table t (id integer primary key, name text)") - db.execute("insert into t values (1, 'hello')") - db.commit() - ds = Datasette( - [db_path], - config={ - "databases": { - "data": { - "tables": {"t": {"column_types": {"name": "priority_test"}}} - } - } - }, - ) - await ds.invoke_startup() - response = await ds.client.get("/data/t.json?_extra=render_cell") - assert response.status_code == 200 - data = response.json() - rendered = data["render_cell"][0] - # Column type should win over the plugin - assert "COLUMN_TYPE:" in rendered["name"] - assert "PLUGIN:" not in rendered["name"] - db.close() - for database in ds.databases.values(): - if not database.is_memory: - database.close() - finally: - pm.unregister(ct_plugin, name="test_priority_ct") - pm.unregister(rc_plugin, name="test_priority_render") - - -# --- Row detail page rendering --- - - -@pytest.mark.asyncio -async def test_row_detail_page_html_rendering(ds_ct): - """Row detail HTML page should use column type rendering.""" - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts/1") - assert response.status_code == 200 - html = response.text - # The email column should be rendered with mailto: link - assert "mailto:test@example.com" in html - # The url column should be rendered with href - assert 'href="https://example.com"' in html - - -# --- HTML table page rendering --- - - -@pytest.mark.asyncio -async def test_html_table_page_rendering(ds_ct): - """HTML table page should use column type rendering.""" - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts") - assert response.status_code == 200 - html = response.text - assert "mailto:test@example.com" in html - assert 'href="https://example.com"' in html - - -@pytest.mark.asyncio -async def test_set_column_type_ui_data_hidden_without_permission(ds_ct): - await ds_ct.invoke_startup() - response = await ds_ct.client.get("/data/posts") - assert response.status_code == 200 - assert "window._setColumnTypeData" not in response.text - - -@pytest.mark.asyncio -async def test_set_column_type_ui_data_includes_applicable_types( - ds_ct_editor_permission, -): - await ds_ct_editor_permission.invoke_startup() - response = await ds_ct_editor_permission.client.get( - "/data/posts", - actor={"id": "editor"}, - ) - assert response.status_code == 200 - data = _window_data_from_html(response.text, "_setColumnTypeData") - assert data["path"] == "/data/posts/-/set-column-type" - assert data["columns"]["id"] == { - "current": None, - "options": [], - } - assert data["columns"]["title"] == { - "current": None, - "options": [ - {"name": "email", "description": "Email address"}, - {"name": "json", "description": "JSON data"}, - {"name": "url", "description": "URL"}, - ], - } - assert data["columns"]["author_email"] == { - "current": {"type": "email", "config": None}, - "options": [ - {"name": "email", "description": "Email address"}, - {"name": "json", "description": "JSON data"}, - {"name": "url", "description": "URL"}, - ], - } - - -# --- Validation on upsert --- - - -@pytest.mark.asyncio -async def test_validation_on_upsert(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct) - response = await ds_ct.client.post( - "/data/posts/-/upsert", - json={ - "rows": [{"id": 1, "title": "Updated", "author_email": "invalid"}], - }, - headers=_headers(token), - ) - assert response.status_code == 400 - assert "author_email" in response.json()["errors"][0] - - -@pytest.mark.asyncio -async def test_validation_on_upsert_passes_valid(ds_ct): - await ds_ct.invoke_startup() - token = write_token(ds_ct) - response = await ds_ct.client.post( - "/data/posts/-/upsert", - json={ - "rows": [{"id": 1, "title": "Updated", "author_email": "valid@test.com"}], - }, - headers=_headers(token), - ) - assert response.status_code == 200 - - -# --- Unknown type warning logged --- - - -@pytest.mark.asyncio -async def test_unknown_type_warning_logged(tmp_path_factory, caplog): - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute("create table t (id integer primary key, col text)") - db.commit() - ds = Datasette( - [db_path], - config={ - "databases": { - "data": {"tables": {"t": {"column_types": {"col": "nonexistent_type"}}}} - } - }, - ) - with caplog.at_level(logging.WARNING): - await ds.invoke_startup() - assert "unknown type" in caplog.text.lower() - assert "nonexistent_type" in caplog.text - db.close() - for database in ds.databases.values(): - if not database.is_memory: - database.close() - - -@pytest.mark.asyncio -async def test_incompatible_sqlite_type_warning_logged(tmp_path_factory, caplog): - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute("create table t (id integer primary key, col integer)") - db.commit() - ds = Datasette( - [db_path], - config={ - "databases": {"data": {"tables": {"t": {"column_types": {"col": "json"}}}}} - }, - ) - with caplog.at_level(logging.WARNING): - await ds.invoke_startup() - assert "only applicable to sqlite types text" in caplog.text.lower() - assert await ds.get_column_type("data", "t", "col") is None - db.close() - for database in ds.databases.values(): - if not database.is_memory: - database.close() - - -# --- Config overwrites on restart --- - - -@pytest.mark.asyncio -async def test_config_overwrites_on_restart(tmp_path_factory): - """Config values should overwrite any existing column types in internal DB on startup.""" - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute("create table t (id integer primary key, col text)") - db.commit() - ds = Datasette( - [db_path], - config={ - "databases": {"data": {"tables": {"t": {"column_types": {"col": "email"}}}}} - }, - ) - await ds.invoke_startup() - ct = await ds.get_column_type("data", "t", "col") - assert ct.name == "email" - - # Manually change the column type in the internal DB - await ds.set_column_type("data", "t", "col", "url") - ct = await ds.get_column_type("data", "t", "col") - assert ct.name == "url" - - # Re-apply config (simulating what happens on restart) - await ds._apply_column_types_config() - ct = await ds.get_column_type("data", "t", "col") - assert ct.name == "email" # Config wins - - db.close() - for database in ds.databases.values(): - if not database.is_memory: - database.close() - - -# --- No column_types in config --- - - -@pytest.mark.asyncio -async def test_no_column_types_in_config(tmp_path_factory): - """Datasette should work fine without any column_types configuration.""" - db_directory = tmp_path_factory.mktemp("dbs") - db_path = str(db_directory / "data.db") - db = sqlite3.connect(str(db_path)) - db.execute("vacuum") - db.execute("create table t (id integer primary key, col text)") - db.execute("insert into t values (1, 'hello')") - db.commit() - ds = Datasette([db_path]) - await ds.invoke_startup() - - # No column types assigned - ct_map = await ds.get_column_types("data", "t") - assert ct_map == {} - - # JSON endpoint should work without column_types extra - response = await ds.client.get("/data/t.json") - assert response.status_code == 200 - assert response.json()["rows"][0]["col"] == "hello" - - # column_types extra should return empty - response = await ds.client.get("/data/t.json?_extra=column_types") - assert response.status_code == 200 - assert response.json()["column_types"] == {} - - db.close() - for database in ds.databases.values(): - if not database.is_memory: - database.close() diff --git a/tests/test_config_dir.py b/tests/test_config_dir.py deleted file mode 100644 index 0a9b30d8..00000000 --- a/tests/test_config_dir.py +++ /dev/null @@ -1,152 +0,0 @@ -import json -import pathlib -import pytest - -from datasette.app import Datasette -from datasette.utils.sqlite import sqlite3 -from datasette.utils import StartupError -from .fixtures import TestClient as _TestClient - -PLUGIN = """ -from datasette import hookimpl - -@hookimpl -def extra_template_vars(): - return { - "from_plugin": "hooray" - } -""" -METADATA = {"title": "This is from metadata"} -CONFIG = { - "settings": { - "default_cache_ttl": 60, - } -} -CSS = """ -body { margin-top: 3em} -""" - - -@pytest.fixture(scope="session") -def config_dir(tmp_path_factory): - config_dir = tmp_path_factory.mktemp("config-dir") - plugins_dir = config_dir / "plugins" - plugins_dir.mkdir() - (plugins_dir / "hooray.py").write_text(PLUGIN, "utf-8") - (plugins_dir / "non_py_file.txt").write_text(PLUGIN, "utf-8") - (plugins_dir / ".mypy_cache").mkdir() - - templates_dir = config_dir / "templates" - templates_dir.mkdir() - (templates_dir / "row.html").write_text( - "Show row here. Plugin says {{ from_plugin }}", "utf-8" - ) - - static_dir = config_dir / "static" - static_dir.mkdir() - (static_dir / "hello.css").write_text(CSS, "utf-8") - - (config_dir / "metadata.json").write_text(json.dumps(METADATA), "utf-8") - (config_dir / "datasette.json").write_text(json.dumps(CONFIG), "utf-8") - - for dbname in ("demo.db", "immutable.db", "j.sqlite3", "k.sqlite"): - db = sqlite3.connect(str(config_dir / dbname)) - db.executescript(""" - CREATE TABLE cities ( - id integer primary key, - name text - ); - INSERT INTO cities (id, name) VALUES - (1, 'San Francisco') - ; - """) - db.close() - - # Mark "immutable.db" as immutable - (config_dir / "inspect-data.json").write_text( - json.dumps( - { - "immutable": { - "hash": "hash", - "size": 8192, - "file": "immutable.db", - "tables": {"cities": {"count": 1}}, - } - } - ), - "utf-8", - ) - return config_dir - - -def test_invalid_settings(config_dir): - previous = (config_dir / "datasette.json").read_text("utf-8") - (config_dir / "datasette.json").write_text( - json.dumps({"settings": {"invalid": "invalid-setting"}}), "utf-8" - ) - try: - with pytest.raises(StartupError) as ex: - Datasette([], config_dir=config_dir) - assert ex.value.args[0] == "Invalid setting 'invalid' in config file" - finally: - (config_dir / "datasette.json").write_text(previous, "utf-8") - - -@pytest.fixture(scope="session") -def config_dir_client(config_dir): - ds = Datasette([], config_dir=config_dir) - yield _TestClient(ds) - for db in ds.databases.values(): - db.close() - - -def test_settings(config_dir_client): - response = config_dir_client.get("/-/settings.json") - assert 200 == response.status - assert 60 == response.json["default_cache_ttl"] - - -def test_plugins(config_dir_client): - response = config_dir_client.get("/-/plugins.json") - assert 200 == response.status - assert "hooray.py" in {p["name"] for p in response.json} - assert "non_py_file.txt" not in {p["name"] for p in response.json} - assert "mypy_cache" not in {p["name"] for p in response.json} - - -def test_templates_and_plugin(config_dir_client): - response = config_dir_client.get("/demo/cities/1") - assert 200 == response.status - assert "Show row here. Plugin says hooray" == response.text - - -def test_static(config_dir_client): - response = config_dir_client.get("/static/hello.css") - assert 200 == response.status - assert CSS == response.text - assert "text/css" == response.headers["content-type"] - - -def test_static_directory_browsing_not_allowed(config_dir_client): - response = config_dir_client.get("/static/") - assert 403 == response.status - assert "403: Directory listing is not allowed" == response.text - - -def test_databases(config_dir_client): - response = config_dir_client.get("/-/databases.json") - assert 200 == response.status - databases = response.json - assert 4 == len(databases) - databases.sort(key=lambda d: d["name"]) - for db, expected_name in zip(databases, ("demo", "immutable", "j", "k")): - assert expected_name == db["name"] - assert db["is_mutable"] == (expected_name != "immutable") - - -def test_store_config_dir(config_dir_client): - ds = config_dir_client.ds - - assert hasattr(ds, "config_dir") - assert ds.config_dir is not None - assert isinstance(ds.config_dir, pathlib.Path) diff --git a/tests/test_config_permission_rules.py b/tests/test_config_permission_rules.py deleted file mode 100644 index 8327ecbf..00000000 --- a/tests/test_config_permission_rules.py +++ /dev/null @@ -1,163 +0,0 @@ -import pytest - -from datasette.app import Datasette -from datasette.database import Database -from datasette.resources import DatabaseResource, TableResource - - -async def setup_datasette(config=None, databases=None): - ds = Datasette(memory=True, config=config) - for name in databases or []: - ds.add_database(Database(ds, memory_name=f"{name}_memory"), name=name) - await ds.invoke_startup() - await ds.refresh_schemas() - return ds - - -@pytest.mark.asyncio -async def test_root_permissions_allow(): - config = {"permissions": {"execute-sql": {"id": "alice"}}} - ds = await setup_datasette(config=config, databases=["content"]) - - assert await ds.allowed( - action="execute-sql", - resource=DatabaseResource(database="content"), - actor={"id": "alice"}, - ) - assert not await ds.allowed( - action="execute-sql", - resource=DatabaseResource(database="content"), - actor={"id": "bob"}, - ) - - -@pytest.mark.asyncio -async def test_database_permission(): - config = { - "databases": { - "content": { - "permissions": { - "insert-row": {"id": "alice"}, - } - } - } - } - ds = await setup_datasette(config=config, databases=["content"]) - - assert await ds.allowed( - action="insert-row", - resource=TableResource(database="content", table="repos"), - actor={"id": "alice"}, - ) - assert not await ds.allowed( - action="insert-row", - resource=TableResource(database="content", table="repos"), - actor={"id": "bob"}, - ) - - -@pytest.mark.asyncio -async def test_table_permission(): - config = { - "databases": { - "content": { - "tables": {"repos": {"permissions": {"delete-row": {"id": "alice"}}}} - } - } - } - ds = await setup_datasette(config=config, databases=["content"]) - - assert await ds.allowed( - action="delete-row", - resource=TableResource(database="content", table="repos"), - actor={"id": "alice"}, - ) - assert not await ds.allowed( - action="delete-row", - resource=TableResource(database="content", table="repos"), - actor={"id": "bob"}, - ) - - -@pytest.mark.asyncio -async def test_view_table_allow_block(): - config = { - "databases": {"content": {"tables": {"repos": {"allow": {"id": "alice"}}}}} - } - ds = await setup_datasette(config=config, databases=["content"]) - - assert await ds.allowed( - action="view-table", - resource=TableResource(database="content", table="repos"), - actor={"id": "alice"}, - ) - assert not await ds.allowed( - action="view-table", - resource=TableResource(database="content", table="repos"), - actor={"id": "bob"}, - ) - assert await ds.allowed( - action="view-table", - resource=TableResource(database="content", table="other"), - actor={"id": "bob"}, - ) - - -@pytest.mark.asyncio -async def test_view_table_allow_false_blocks(): - config = {"databases": {"content": {"tables": {"repos": {"allow": False}}}}} - ds = await setup_datasette(config=config, databases=["content"]) - - assert not await ds.allowed( - action="view-table", - resource=TableResource(database="content", table="repos"), - actor={"id": "alice"}, - ) - - -@pytest.mark.asyncio -async def test_allow_sql_blocks(): - config = {"allow_sql": {"id": "alice"}} - ds = await setup_datasette(config=config, databases=["content"]) - - assert await ds.allowed( - action="execute-sql", - resource=DatabaseResource(database="content"), - actor={"id": "alice"}, - ) - assert not await ds.allowed( - action="execute-sql", - resource=DatabaseResource(database="content"), - actor={"id": "bob"}, - ) - - config = {"databases": {"content": {"allow_sql": {"id": "bob"}}}} - ds = await setup_datasette(config=config, databases=["content"]) - - assert await ds.allowed( - action="execute-sql", - resource=DatabaseResource(database="content"), - actor={"id": "bob"}, - ) - assert not await ds.allowed( - action="execute-sql", - resource=DatabaseResource(database="content"), - actor={"id": "alice"}, - ) - - config = {"allow_sql": False} - ds = await setup_datasette(config=config, databases=["content"]) - assert not await ds.allowed( - action="execute-sql", - resource=DatabaseResource(database="content"), - actor={"id": "alice"}, - ) - - -@pytest.mark.asyncio -async def test_view_instance_allow_block(): - config = {"allow": {"id": "alice"}} - ds = await setup_datasette(config=config) - - assert await ds.allowed(action="view-instance", actor={"id": "alice"}) - assert not await ds.allowed(action="view-instance", actor={"id": "bob"}) diff --git a/tests/test_crossdb.py b/tests/test_crossdb.py deleted file mode 100644 index 11e53224..00000000 --- a/tests/test_crossdb.py +++ /dev/null @@ -1,77 +0,0 @@ -from datasette.cli import cli -from click.testing import CliRunner -import urllib -import sqlite3 - - -def test_crossdb_join(app_client_two_attached_databases_crossdb_enabled): - app_client = app_client_two_attached_databases_crossdb_enabled - sql = """ - select - 'extra database' as db, - pk, - text1, - text2 - from - [extra database].searchable - union all - select - 'fixtures' as db, - pk, - text1, - text2 - from - fixtures.searchable - """ - response = app_client.get( - "/_memory/-/query.json?" - + urllib.parse.urlencode({"sql": sql, "_shape": "array"}) - ) - assert response.status == 200 - assert response.json == [ - {"db": "extra database", "pk": 1, "text1": "barry cat", "text2": "terry dog"}, - {"db": "extra database", "pk": 2, "text1": "terry dog", "text2": "sara weasel"}, - {"db": "fixtures", "pk": 1, "text1": "barry cat", "text2": "terry dog"}, - {"db": "fixtures", "pk": 2, "text1": "terry dog", "text2": "sara weasel"}, - ] - - -def test_crossdb_warning_if_too_many_databases(tmp_path_factory): - db_dir = tmp_path_factory.mktemp("dbs") - dbs = [] - for i in range(11): - path = str(db_dir / "db_{}.db".format(i)) - conn = sqlite3.connect(path) - conn.execute("vacuum") - conn.close() - dbs.append(path) - runner = CliRunner() - result = runner.invoke( - cli, - [ - "serve", - "--crossdb", - "--get", - "/", - ] - + dbs, - catch_exceptions=False, - ) - assert ( - "Warning: --crossdb only works with the first 10 attached databases" - in result.stderr - ) - - -def test_crossdb_attached_database_list_display( - app_client_two_attached_databases_crossdb_enabled, -): - app_client = app_client_two_attached_databases_crossdb_enabled - response = app_client.get("/_memory") - app_client.get("/") - for fragment in ( - "databases are attached to this connection", - "<li><strong>fixtures</strong> - ", - '<li><strong>extra database</strong> - <a href="/extra+database/-/query?sql=', - ): - assert fragment in response.text diff --git a/tests/test_csrf_middleware.py b/tests/test_csrf_middleware.py deleted file mode 100644 index 2fcfb216..00000000 --- a/tests/test_csrf_middleware.py +++ /dev/null @@ -1,334 +0,0 @@ -""" -Tests for the header-based CSRF (Cross-Origin) protection middleware. - -Datasette uses the Sec-Fetch-Site + Origin header approach described in -Filippo Valsorda's article (https://words.filippo.io/csrf/) and implemented -in Go 1.25's http.CrossOriginProtection. This replaces the previous -token-based asgi-csrf mechanism. -""" - -import pluggy -import pytest - -from datasette import hookimpl -from datasette.csrf import CrossOriginProtectionMiddleware, _install_legacy_csrftoken - - -async def _post(bare_ds, **kwargs): - kwargs.setdefault("data", {"message": "hello", "message_class": "info"}) - return await bare_ds.client.post("/-/messages", **kwargs) - - -async def _run_middleware(scope): - """ - Run CrossOriginProtectionMiddleware against a scope and return - ("allowed",) if the inner app was called, or ("blocked", status) - if the middleware sent a response itself. - """ - - class FakeDs: - async def render_template(self, name, ctx): - return "BLOCKED" - - inner_called = [] - - async def app(scope, receive, send): - inner_called.append(True) - - sent = [] - - async def send(msg): - sent.append(msg) - - mw = CrossOriginProtectionMiddleware(app, FakeDs()) - await mw(scope, None, send) - if inner_called: - return ("allowed",) - start = [m for m in sent if m["type"] == "http.response.start"][0] - return ("blocked", start["status"]) - - -def _http_scope(headers, method="POST"): - return { - "type": "http", - "method": method, - "headers": [(k.encode(), v.encode()) for k, v in headers.items()], - } - - -@pytest.mark.asyncio -@pytest.mark.parametrize("method", ["GET", "HEAD", "OPTIONS"]) -async def test_safe_methods_always_pass(bare_ds, method): - # Safe methods bypass CSRF entirely, even with hostile headers - response = await bare_ds.client.request( - method, - "/-/messages", - headers={"sec-fetch-site": "cross-site", "origin": "http://evil.example"}, - ) - assert response.status_code != 403 or "origin" not in response.text.lower() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("sec_fetch_site", ["same-origin", "none"]) -async def test_post_with_trusted_sec_fetch_site_allowed(bare_ds, sec_fetch_site): - # "same-origin" = first-party; "none" = user-initiated direct navigation - response = await _post(bare_ds, headers={"sec-fetch-site": sec_fetch_site}) - assert response.status_code != 403 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("sec_fetch_site", ["cross-site", "same-site", "cross-origin"]) -async def test_post_with_untrusted_sec_fetch_site_blocked(bare_ds, sec_fetch_site): - # same-site is blocked too: different subdomains must not bypass CSRF - response = await _post( - bare_ds, data={"message": "hi"}, headers={"sec-fetch-site": sec_fetch_site} - ) - assert response.status_code == 403 - assert response.headers["content-type"].startswith("text/html") - - -@pytest.mark.asyncio -async def test_post_with_no_browser_headers_allowed(bare_ds): - # curl / requests / server-to-server: no Sec-Fetch-Site, no Origin. - # CSRF is browser-specific so these pass through. - response = await _post(bare_ds) - assert response.status_code != 403 - - -@pytest.mark.asyncio -async def test_post_with_matching_origin_allowed(bare_ds): - # Fallback for older browsers without Sec-Fetch-Site: Origin must match Host - response = await _post(bare_ds, headers={"origin": "http://localhost"}) - assert response.status_code != 403 - - -@pytest.mark.asyncio -async def test_post_with_mismatched_origin_blocked(bare_ds): - response = await _post( - bare_ds, data={"message": "hi"}, headers={"origin": "http://evil.example.com"} - ) - assert response.status_code == 403 - - -@pytest.mark.asyncio -async def test_csrf_error_page_renders(bare_ds): - response = await _post( - bare_ds, data={"message": "hi"}, headers={"sec-fetch-site": "cross-site"} - ) - assert response.status_code == 403 - assert "origin" in response.text.lower() - - -@pytest.mark.asyncio -async def test_csrf_error_page_title_has_no_typo(bare_ds): - response = await _post( - bare_ds, data={"message": "hi"}, headers={"sec-fetch-site": "cross-site"} - ) - assert "<title>CSRF check failed" in response.text - assert "CSRF check failed)" not in response.text - - -@pytest.mark.asyncio -@pytest.mark.parametrize("scope_type", ["websocket", "lifespan"]) -async def test_non_http_scope_passes_through(scope_type): - called = [] - - async def app(scope, receive, send): - called.append(scope["type"]) - - mw = CrossOriginProtectionMiddleware(app, datasette=None) - await mw({"type": scope_type}, None, None) - assert called == [scope_type] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "label,headers,expected", - [ - ( - "plain cross-site blocked", - {"sec-fetch-site": "cross-site", "host": "example.com"}, - ("blocked", 403), - ), - ( - "basic auth does not bypass", - { - "sec-fetch-site": "cross-site", - "host": "example.com", - "authorization": "Basic dXNlcjpwYXNz", - }, - ("blocked", 403), - ), - ( - "bearer auth bypasses", - { - "sec-fetch-site": "cross-site", - "origin": "https://evil.example", - "host": "example.com", - "authorization": "Bearer dstok_abc", - }, - ("allowed",), - ), - ( - "bearer scheme case-insensitive", - { - "sec-fetch-site": "cross-site", - "host": "example.com", - "authorization": "bearer dstok_abc", - }, - ("allowed",), - ), - ( - "non-browser (no Sec-Fetch-Site, no Origin) allowed", - {"host": "example.com"}, - ("allowed",), - ), - ], -) -async def test_middleware_unit(label, headers, expected): - assert await _run_middleware(_http_scope(headers)) == expected - - -def test_legacy_csrftoken_scope_value_nonempty(app_client): - # GET /post/ calls request.scope["csrftoken"]() - must not 500 - response = app_client.get("/post/") - assert response.status == 200 - assert response.text.strip() != "" - assert len(response.text.strip()) >= 20 - - -def test_legacy_csrftoken_no_ds_csrftoken_cookie(app_client): - response = app_client.get("/post/") - assert "ds_csrftoken" not in response.cookies - - -def test_legacy_csrftoken_varies_across_requests(app_client): - r1 = app_client.get("/post/").text.strip() - r2 = app_client.get("/post/").text.strip() - assert r1 != r2 - - -def test_legacy_csrftoken_stable_within_request(): - # Two calls in the same request return the same value - scope = {} - _install_legacy_csrftoken(scope) - assert scope["csrftoken"]() == scope["csrftoken"]() - - -@pytest.mark.asyncio -async def test_cross_site_post_blocked_even_with_ds_csrftoken_cookie(bare_ds): - # A stale ds_csrftoken cookie + csrftoken body field must NOT bypass - # the header-based CSRF check. - response = await _post( - bare_ds, - data={"message": "hi", "message_class": "info", "csrftoken": "abc"}, - headers={"sec-fetch-site": "cross-site"}, - cookies={"ds_csrftoken": "abc"}, - ) - assert response.status_code == 403 - - -@pytest.mark.asyncio -async def test_bearer_invalid_token_not_csrf_error(bare_ds): - # Cross-site POST with bogus bearer must pass CSRF and be rejected - # by auth/permission handling, not by the CSRF middleware. - response = await _post( - bare_ds, - headers={ - "sec-fetch-site": "cross-site", - "authorization": "Bearer totally-invalid-token", - }, - ) - if response.status_code == 403: - assert "origin" not in response.text.lower() - assert "sec-fetch-site" not in response.text.lower() - - -@pytest.mark.asyncio -async def test_cross_site_post_without_auth_still_blocked(bare_ds): - response = await _post( - bare_ds, data={"message": "hi"}, headers={"sec-fetch-site": "cross-site"} - ) - assert response.status_code == 403 - - -@pytest.mark.asyncio -async def test_bearer_with_cookie_does_not_bypass(): - # Bearer + Cookie => ambient cookie auth is in play, not exempt. - scope = _http_scope( - { - "sec-fetch-site": "cross-site", - "host": "example.com", - "authorization": "Bearer dstok_abc", - "cookie": "ds_actor=anything", - } - ) - assert await _run_middleware(scope) == ("blocked", 403) - - -@pytest.mark.asyncio -async def test_origin_scheme_must_match(): - # http Origin against an https request must be blocked even when host matches. - scope = _http_scope({"origin": "http://example.com", "host": "example.com"}) - scope["scheme"] = "https" - assert await _run_middleware(scope) == ("blocked", 403) - - -@pytest.mark.asyncio -async def test_origin_port_must_match(): - scope = _http_scope({"origin": "http://example.com:8001", "host": "example.com"}) - scope["scheme"] = "http" - assert await _run_middleware(scope) == ("blocked", 403) - - -@pytest.mark.asyncio -async def test_origin_default_port_normalized(): - # http://example.com:80 == http://example.com - scope = _http_scope({"origin": "http://example.com:80", "host": "example.com"}) - scope["scheme"] = "http" - assert await _run_middleware(scope) == ("allowed",) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "headers", - [ - {"authorization": " ", "host": "example.com", "origin": "http://evil"}, - {"origin": "http://example.com:notaport", "host": "example.com"}, - {"origin": "not-a-url", "host": "example.com"}, - ], -) -async def test_malformed_headers_do_not_500(headers): - # Should be a clean 403, not an unhandled exception. - result = await _run_middleware(_http_scope(headers)) - assert result[0] == "blocked" - assert result[1] == 403 - - -@pytest.mark.asyncio -async def test_uppercase_header_names_normalized(): - # ASGI servers should lowercase, but middleware normalizes defensively. - scope = { - "type": "http", - "method": "POST", - "headers": [(b"Sec-Fetch-Site", b"same-origin")], - } - assert await _run_middleware(scope) == ("allowed",) - - -def test_legacy_skip_csrf_hookimpl_does_not_break_loading(): - # Plugins that still define skip_csrf must load cleanly - pluggy ignores - # unknown hook implementations - even though the hook is no longer - # consulted by core. Use a throwaway PluginManager so that registering - # this hookimpl does not leak a _HookCaller onto the real datasette.pm. - class LegacyPlugin: - __name__ = "legacy-skip-csrf-plugin" - - @hookimpl - def skip_csrf(self, datasette, scope): - return True - - throwaway = pluggy.PluginManager("datasette") - plugin = LegacyPlugin() - throwaway.register(plugin, name=LegacyPlugin.__name__) - assert throwaway.is_registered(plugin) diff --git a/tests/test_csv.py b/tests/test_csv.py index a2f03776..2353b9ce 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -1,250 +1,61 @@ -from datasette.app import Datasette -from bs4 import BeautifulSoup as Soup -import pytest -import urllib.parse +from .fixtures import app_client # noqa -EXPECTED_TABLE_CSV = """id,content +EXPECTED_TABLE_CSV = '''id,content 1,hello 2,world 3, -4,RENDER_CELL_DEMO -5,RENDER_CELL_ASYNC -""".replace("\n", "\r\n") +'''.replace('\n', '\r\n') -EXPECTED_CUSTOM_CSV = """content +EXPECTED_CUSTOM_CSV = '''content hello world -""".replace("\n", "\r\n") +'''.replace('\n', '\r\n') -EXPECTED_TABLE_WITH_LABELS_CSV = """ -pk,created,planet_int,on_earth,state,_city_id,_city_id_label,_neighborhood,tags,complex_array,distinct_some_null,n -1,2019-01-14 08:00:00,1,1,CA,1,San Francisco,Mission,"[""tag1"", ""tag2""]","[{""foo"": ""bar""}]",one,n1 -2,2019-01-14 08:00:00,1,1,CA,1,San Francisco,Dogpatch,"[""tag1"", ""tag3""]",[],two,n2 -3,2019-01-14 08:00:00,1,1,CA,1,San Francisco,SOMA,[],[],, -4,2019-01-14 08:00:00,1,1,CA,1,San Francisco,Tenderloin,[],[],, -5,2019-01-15 08:00:00,1,1,CA,1,San Francisco,Bernal Heights,[],[],, -6,2019-01-15 08:00:00,1,1,CA,1,San Francisco,Hayes Valley,[],[],, -7,2019-01-15 08:00:00,1,1,CA,2,Los Angeles,Hollywood,[],[],, -8,2019-01-15 08:00:00,1,1,CA,2,Los Angeles,Downtown,[],[],, -9,2019-01-16 08:00:00,1,1,CA,2,Los Angeles,Los Feliz,[],[],, -10,2019-01-16 08:00:00,1,1,CA,2,Los Angeles,Koreatown,[],[],, -11,2019-01-16 08:00:00,1,1,MI,3,Detroit,Downtown,[],[],, -12,2019-01-17 08:00:00,1,1,MI,3,Detroit,Greektown,[],[],, -13,2019-01-17 08:00:00,1,1,MI,3,Detroit,Corktown,[],[],, -14,2019-01-17 08:00:00,1,1,MI,3,Detroit,Mexicantown,[],[],, -15,2019-01-17 08:00:00,2,0,MC,4,Memnonia,Arcadia Planitia,[],[],, -""".lstrip().replace("\n", "\r\n") +EXPECTED_TABLE_WITH_LABELS_CSV = ''' +pk,planet_int,state,city_id,city_id_label,neighborhood +1,1,CA,1,San Francisco,Mission +2,1,CA,1,San Francisco,Dogpatch +3,1,CA,1,San Francisco,SOMA +4,1,CA,1,San Francisco,Tenderloin +5,1,CA,1,San Francisco,Bernal Heights +6,1,CA,1,San Francisco,Hayes Valley +7,1,CA,2,Los Angeles,Hollywood +8,1,CA,2,Los Angeles,Downtown +9,1,CA,2,Los Angeles,Los Feliz +10,1,CA,2,Los Angeles,Koreatown +11,1,MI,3,Detroit,Downtown +12,1,MI,3,Detroit,Greektown +13,1,MI,3,Detroit,Corktown +14,1,MI,3,Detroit,Mexicantown +15,2,MC,4,Memnonia,Arcadia Planitia +'''.lstrip().replace('\n', '\r\n') -EXPECTED_TABLE_WITH_NULLABLE_LABELS_CSV = """ -pk,foreign_key_with_label,foreign_key_with_label_label,foreign_key_with_blank_label,foreign_key_with_blank_label_label,foreign_key_with_no_label,foreign_key_with_no_label_label,foreign_key_compound_pk1,foreign_key_compound_pk2 -1,1,hello,3,,1,1,a,b -2,,,,,,,, -""".lstrip().replace("\n", "\r\n") - - -@pytest.mark.asyncio -async def test_table_csv(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.csv?_oh=1") - assert response.status_code == 200 - assert not response.headers.get("Access-Control-Allow-Origin") - assert response.headers["content-type"] == "text/plain; charset=utf-8" - assert response.text == EXPECTED_TABLE_CSV - - -def test_table_csv_cors_headers(app_client_with_cors): - response = app_client_with_cors.get("/fixtures/simple_primary_key.csv") +def test_table_csv(app_client): + response = app_client.get('/test_tables/simple_primary_key.csv') assert response.status == 200 - assert response.headers["Access-Control-Allow-Origin"] == "*" + assert 'text/plain; charset=utf-8' == response.headers['Content-Type'] + assert EXPECTED_TABLE_CSV == response.text -@pytest.mark.asyncio -async def test_table_csv_no_header(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.csv?_header=off") - assert response.status_code == 200 - assert not response.headers.get("Access-Control-Allow-Origin") - assert response.headers["content-type"] == "text/plain; charset=utf-8" - assert response.text == EXPECTED_TABLE_CSV.split("\r\n", 1)[1] - - -@pytest.mark.asyncio -async def test_table_csv_with_labels(ds_client): - response = await ds_client.get("/fixtures/facetable.csv?_labels=1") - assert response.status_code == 200 - assert response.headers["content-type"] == "text/plain; charset=utf-8" - assert response.text == EXPECTED_TABLE_WITH_LABELS_CSV - - -@pytest.mark.asyncio -async def test_table_csv_with_nullable_labels(ds_client): - response = await ds_client.get("/fixtures/foreign_key_references.csv?_labels=1") - assert response.status_code == 200 - assert response.headers["content-type"] == "text/plain; charset=utf-8" - assert response.text == EXPECTED_TABLE_WITH_NULLABLE_LABELS_CSV - - -@pytest.mark.asyncio -async def test_table_csv_with_invalid_labels(): - # https://github.com/simonw/datasette/issues/2214 - ds = Datasette( - config={ - "databases": { - "db_2214": { - "tables": { - "t2": { - "label_column": "name", - } - } - } - } - } - ) - await ds.invoke_startup() - db = ds.add_memory_database("db_2214") - await db.execute_write_script(""" - create table t1 (id integer primary key, name text); - insert into t1 (id, name) values (1, 'one'); - insert into t1 (id, name) values (2, 'two'); - create table t2 (textid text primary key, name text); - insert into t2 (textid, name) values ('a', 'alpha'); - insert into t2 (textid, name) values ('b', 'beta'); - create table if not exists maintable ( - id integer primary key, - fk_integer integer references t1(id), - fk_text text references t2(textid) - ); - insert into maintable (id, fk_integer, fk_text) values (1, 1, 'a'); - insert into maintable (id, fk_integer, fk_text) values (2, 3, 'b'); -- invalid fk_integer - insert into maintable (id, fk_integer, fk_text) values (3, 2, 'c'); -- invalid fk_text - """) - response = await ds.client.get("/db_2214/maintable.csv?_labels=1") - assert response.status_code == 200 - assert response.text == ( - "id,fk_integer,fk_integer_label,fk_text,fk_text_label\r\n" - "1,1,one,a,alpha\r\n" - "2,3,,b,beta\r\n" - "3,2,two,c,\r\n" - ) - - -@pytest.mark.asyncio -async def test_table_csv_blob_columns(ds_client): - response = await ds_client.get("/fixtures/binary_data.csv") - assert response.status_code == 200 - assert response.headers["content-type"] == "text/plain; charset=utf-8" - assert response.text == ( - "rowid,data\r\n" - "1,http://localhost/fixtures/binary_data/1.blob?_blob_column=data\r\n" - "2,http://localhost/fixtures/binary_data/2.blob?_blob_column=data\r\n" - "3,\r\n" - ) - - -@pytest.mark.asyncio -async def test_custom_sql_csv_blob_columns(ds_client): - response = await ds_client.get( - "/fixtures/-/query.csv?sql=select+rowid,+data+from+binary_data" - ) - assert response.status_code == 200 - assert response.headers["content-type"] == "text/plain; charset=utf-8" - assert response.text == ( - "rowid,data\r\n" - '1,"http://localhost/fixtures/-/query.blob?sql=select+rowid,+data+from+binary_data&_blob_column=data&_blob_hash=f3088978da8f9aea479ffc7f631370b968d2e855eeb172bea7f6c7a04262bb6d"\r\n' - '2,"http://localhost/fixtures/-/query.blob?sql=select+rowid,+data+from+binary_data&_blob_column=data&_blob_hash=b835b0483cedb86130b9a2c280880bf5fadc5318ddf8c18d0df5204d40df1724"\r\n' - "3,\r\n" - ) - - -@pytest.mark.asyncio -async def test_custom_sql_csv(ds_client): - response = await ds_client.get( - "/fixtures/-/query.csv?sql=select+content+from+simple_primary_key+limit+2" - ) - assert response.status_code == 200 - assert response.headers["content-type"] == "text/plain; charset=utf-8" - assert response.text == EXPECTED_CUSTOM_CSV - - -@pytest.mark.asyncio -async def test_table_csv_download(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.csv?_dl=1") - assert response.status_code == 200 - assert response.headers["content-type"] == "text/csv; charset=utf-8" - assert ( - response.headers["content-disposition"] - == 'attachment; filename="simple_primary_key.csv"' - ) - - -@pytest.mark.asyncio -async def test_csv_with_non_ascii_characters(ds_client): - response = await ds_client.get( - "/fixtures/-/query.csv?sql=select%0D%0A++%27%F0%9D%90%9C%F0%9D%90%A2%F0%9D%90%AD%F0%9D%90%A2%F0%9D%90%9E%F0%9D%90%AC%27+as+text%2C%0D%0A++1+as+number%0D%0Aunion%0D%0Aselect%0D%0A++%27bob%27+as+text%2C%0D%0A++2+as+number%0D%0Aorder+by%0D%0A++number" - ) - assert response.status_code == 200 - assert response.headers["content-type"] == "text/plain; charset=utf-8" - assert response.text == "text,number\r\n𝐜𝐢𝐭𝐢𝐞𝐬,1\r\nbob,2\r\n" - - -@pytest.mark.xfail(reason="Flaky, see https://github.com/simonw/datasette/issues/2355") -def test_max_csv_mb(app_client_csv_max_mb_one): - # This query deliberately generates a really long string - # should be 100*100*100*2 = roughly 2MB - response = app_client_csv_max_mb_one.get( - "/fixtures.csv?" - + urllib.parse.urlencode( - { - "sql": """ - select group_concat('ab', '') - from json_each(json_array({lots})), - json_each(json_array({lots})), - json_each(json_array({lots})) - """.format( - lots=", ".join(str(i) for i in range(100)) - ), - "_stream": 1, - "_size": "max", - } - ), - ) - # It's a 200 because we started streaming before we knew the error +def test_table_csv_with_labels(app_client): + response = app_client.get('/test_tables/facetable.csv?_labels=1') assert response.status == 200 - # Last line should be an error message - last_line = [line for line in response.body.split(b"\r\n") if line][-1] - assert last_line.startswith(b"CSV contains more than") + assert 'text/plain; charset=utf-8' == response.headers['Content-Type'] + assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text -@pytest.mark.asyncio -async def test_table_csv_stream(ds_client): - # Without _stream should return header + 100 rows: - response = await ds_client.get( - "/fixtures/compound_three_primary_keys.csv?_size=max" +def test_custom_sql_csv(app_client): + response = app_client.get( + '/test_tables.csv?sql=select+content+from+simple_primary_key+limit+2' ) - assert len([b for b in response.content.split(b"\r\n") if b]) == 101 - # With _stream=1 should return header + 1001 rows - response = await ds_client.get( - "/fixtures/compound_three_primary_keys.csv?_stream=1" - ) - assert len([b for b in response.content.split(b"\r\n") if b]) == 1002 + assert response.status == 200 + assert 'text/plain; charset=utf-8' == response.headers['Content-Type'] + assert EXPECTED_CUSTOM_CSV == response.text -def test_csv_trace(app_client_with_trace): - response = app_client_with_trace.get("/fixtures/simple_primary_key.csv?_trace=1") - assert response.headers["content-type"] == "text/html; charset=utf-8" - soup = Soup(response.text, "html.parser") - assert ( - soup.find("textarea").text - == "id,content\r\n1,hello\r\n2,world\r\n3,\r\n4,RENDER_CELL_DEMO\r\n5,RENDER_CELL_ASYNC\r\n" - ) - assert "select id, content from simple_primary_key" in soup.find("pre").text - - -def test_table_csv_stream_does_not_calculate_facets(app_client_with_trace): - response = app_client_with_trace.get("/fixtures/simple_primary_key.csv?_trace=1") - soup = Soup(response.text, "html.parser") - assert "select content, count(*) as n" not in soup.find("pre").text - - -def test_table_csv_stream_does_not_calculate_counts(app_client_with_trace): - response = app_client_with_trace.get("/fixtures/simple_primary_key.csv?_trace=1") - soup = Soup(response.text, "html.parser") - assert "select count(*)" not in soup.find("pre").text +def test_table_csv_download(app_client): + response = app_client.get('/test_tables/simple_primary_key.csv?_dl=1') + assert response.status == 200 + assert 'text/csv; charset=utf-8' == response.headers['Content-Type'] + expected_disposition = 'attachment; filename="simple_primary_key.csv"' + assert expected_disposition == response.headers['Content-Disposition'] diff --git a/tests/test_custom_pages.py b/tests/test_custom_pages.py deleted file mode 100644 index 39a4c06b..00000000 --- a/tests/test_custom_pages.py +++ /dev/null @@ -1,106 +0,0 @@ -import pathlib -import pytest -from .fixtures import make_app_client - -TEST_TEMPLATE_DIRS = str(pathlib.Path(__file__).parent / "test_templates") - - -@pytest.fixture(scope="session") -def custom_pages_client(): - with make_app_client(template_dir=TEST_TEMPLATE_DIRS) as client: - yield client - - -@pytest.fixture(scope="session") -def custom_pages_client_with_base_url(): - with make_app_client( - template_dir=TEST_TEMPLATE_DIRS, settings={"base_url": "/prefix/"} - ) as client: - yield client - - -def test_custom_pages_view_name(custom_pages_client): - response = custom_pages_client.get("/about") - assert response.status == 200 - assert response.text == "ABOUT! view_name:page" - - -def test_request_is_available(custom_pages_client): - response = custom_pages_client.get("/request") - assert response.status == 200 - assert response.text == "path:/request" - - -def test_custom_pages_with_base_url(custom_pages_client_with_base_url): - response = custom_pages_client_with_base_url.get("/prefix/request") - assert response.status == 200 - assert response.text == "path:/prefix/request" - - -def test_custom_pages_nested(custom_pages_client): - response = custom_pages_client.get("/nested/nest") - assert response.status == 200 - assert response.text == "Nest!" - response = custom_pages_client.get("/nested/nest2") - assert response.status == 404 - - -def test_custom_status(custom_pages_client): - response = custom_pages_client.get("/202") - assert response.status == 202 - assert response.text == "202!" - - -def test_custom_headers(custom_pages_client): - response = custom_pages_client.get("/headers") - assert response.status == 200 - assert response.headers["x-this-is-foo"] == "foo" - assert response.headers["x-this-is-bar"] == "bar" - assert response.text == "FOOBAR" - - -def test_custom_content_type(custom_pages_client): - response = custom_pages_client.get("/atom") - assert response.status == 200 - assert response.headers["content-type"] == "application/xml" - assert response.text == "" - - -def test_redirect(custom_pages_client): - response = custom_pages_client.get("/redirect") - assert response.status == 302 - assert response.headers["Location"] == "/example" - - -def test_redirect2(custom_pages_client): - response = custom_pages_client.get("/redirect2") - assert response.status == 301 - assert response.headers["Location"] == "/example" - - -@pytest.mark.parametrize( - "path,expected", - [ - ("/route_Sally", "

Hello from Sally

"), - ("/topic_python", "Topic page for python"), - ("/topic_python/info", "Slug: info, Topic: python"), - ], -) -def test_custom_route_pattern(custom_pages_client, path, expected): - response = custom_pages_client.get(path) - assert response.status == 200 - assert response.text.strip() == expected - - -def test_custom_route_pattern_404(custom_pages_client): - response = custom_pages_client.get("/route_OhNo") - assert response.status == 404 - assert "

Error 404

" in response.text - assert ">Oh no /dev/null ); do - if [ $waiting -eq 4 ]; then - echo "$server_pid does still exist, server failed to stop" - cleanup - exit 1 - fi - let waiting=waiting+1 - sleep 1 -done - -# Clean up the certificates -cleanup - -echo $curl_exit_code -exit $curl_exit_code diff --git a/tests/test_default_deny.py b/tests/test_default_deny.py deleted file mode 100644 index 81e95b84..00000000 --- a/tests/test_default_deny.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -from datasette.app import Datasette -from datasette.resources import DatabaseResource, TableResource - - -@pytest.mark.asyncio -async def test_default_deny_denies_default_permissions(): - """Test that default_deny=True denies default permissions""" - # Without default_deny, anonymous users can view instance/database/tables - ds_normal = Datasette() - await ds_normal.invoke_startup() - - # Add a test database - db = ds_normal.add_memory_database("test_db_normal") - await db.execute_write("create table test_table (id integer primary key)") - await ds_normal._refresh_schemas() # Trigger catalog refresh - - # Test default behavior - anonymous user should be able to view - response = await ds_normal.client.get("/") - assert response.status_code == 200 - - response = await ds_normal.client.get("/test_db_normal") - assert response.status_code == 200 - - response = await ds_normal.client.get("/test_db_normal/test_table") - assert response.status_code == 200 - - # With default_deny=True, anonymous users should be denied - ds_deny = Datasette(default_deny=True) - await ds_deny.invoke_startup() - - # Add the same test database - db = ds_deny.add_memory_database("test_db_deny") - await db.execute_write("create table test_table (id integer primary key)") - await ds_deny._refresh_schemas() # Trigger catalog refresh - - # Anonymous user should be denied - response = await ds_deny.client.get("/") - assert response.status_code == 403 - - response = await ds_deny.client.get("/test_db_deny") - assert response.status_code == 403 - - response = await ds_deny.client.get("/test_db_deny/test_table") - assert response.status_code == 403 - - -@pytest.mark.asyncio -async def test_default_deny_with_root_user(): - """Test that root user still has access when default_deny=True""" - ds = Datasette(default_deny=True) - ds.root_enabled = True - await ds.invoke_startup() - - root_actor = {"id": "root"} - - # Root user should have all permissions even with default_deny - assert await ds.allowed(action="view-instance", actor=root_actor) is True - assert ( - await ds.allowed( - action="view-database", - actor=root_actor, - resource=DatabaseResource("test_db"), - ) - is True - ) - assert ( - await ds.allowed( - action="view-table", - actor=root_actor, - resource=TableResource("test_db", "test_table"), - ) - is True - ) - assert ( - await ds.allowed( - action="execute-sql", actor=root_actor, resource=DatabaseResource("test_db") - ) - is True - ) - - -@pytest.mark.asyncio -async def test_default_deny_with_config_allow(): - """Test that config allow rules still work with default_deny=True""" - ds = Datasette(default_deny=True, config={"allow": {"id": "user1"}}) - await ds.invoke_startup() - - # Anonymous user should be denied - assert await ds.allowed(action="view-instance", actor=None) is False - - # Authenticated user with explicit permission should have access - assert await ds.allowed(action="view-instance", actor={"id": "user1"}) is True - - # Different user should be denied - assert await ds.allowed(action="view-instance", actor={"id": "user2"}) is False - - -@pytest.mark.asyncio -async def test_default_deny_basic_permissions(): - """Test that default_deny=True denies basic permissions""" - ds = Datasette(default_deny=True) - await ds.invoke_startup() - - # Anonymous user should be denied all default permissions - assert await ds.allowed(action="view-instance", actor=None) is False - assert ( - await ds.allowed( - action="view-database", actor=None, resource=DatabaseResource("test_db") - ) - is False - ) - assert ( - await ds.allowed( - action="view-table", - actor=None, - resource=TableResource("test_db", "test_table"), - ) - is False - ) - assert ( - await ds.allowed( - action="execute-sql", actor=None, resource=DatabaseResource("test_db") - ) - is False - ) - - # Authenticated user without explicit permission should also be denied - assert await ds.allowed(action="view-instance", actor={"id": "user"}) is False diff --git a/tests/test_docs.py b/tests/test_docs.py index 9cf39f41..6f7d0b4c 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -1,216 +1,16 @@ """ Tests to ensure certain things are documented. """ - -from datasette import app, utils -import datasette.fixtures # noqa: F401 -from datasette.app import Datasette -from datasette.filters import Filters +from datasette import app from pathlib import Path import pytest import re -docs_path = Path(__file__).parent.parent / "docs" -label_re = re.compile(r"\.\. _([^\s:]+):") +markdown = (Path(__file__).parent.parent / 'docs' / 'config.rst').open().read() +setting_heading_re = re.compile(r'(\w+)\n\-+\n') +setting_headings = set(setting_heading_re.findall(markdown)) -def get_headings(content, underline="-"): - heading_re = re.compile(r"(\w+)(\([^)]*\))?\n\{}+\n".format(underline)) - return {h[0] for h in heading_re.findall(content)} - - -def get_labels(filename): - content = (docs_path / filename).read_text() - return set(label_re.findall(content)) - - -@pytest.fixture(scope="session") -def settings_headings(): - return get_headings((docs_path / "settings.rst").read_text(), "~") - - -def test_settings_are_documented(settings_headings, subtests): - for setting in app.SETTINGS: - with subtests.test(setting=setting.name): - assert setting.name in settings_headings - - -@pytest.fixture(scope="session") -def plugin_hooks_content(): - return (docs_path / "plugin_hooks.rst").read_text() - - -def test_plugin_hooks_are_documented(plugin_hooks_content, subtests): - headings = set() - headings.update(get_headings(plugin_hooks_content, "-")) - headings.update(get_headings(plugin_hooks_content, "~")) - plugins = [name for name in dir(app.pm.hook) if not name.startswith("_")] - for plugin in plugins: - with subtests.test(plugin=plugin): - assert plugin in headings - hook_caller = getattr(app.pm.hook, plugin) - arg_names = [a for a in hook_caller.spec.argnames if a != "__multicall__"] - # Check for plugin_name(arg1, arg2, arg3) - expected = f"{plugin}({', '.join(arg_names)})" - assert ( - expected in plugin_hooks_content - ), f"Missing from plugin hook documentation: {expected}" - - -@pytest.fixture(scope="session") -def documented_views(): - view_labels = set() - for filename in docs_path.glob("*.rst"): - for label in get_labels(filename): - first_word = label.split("_")[0] - if first_word.endswith("View"): - view_labels.add(first_word) - # We deliberately don't document these: - view_labels.update( - ( - "PatternPortfolioView", - "AuthTokenView", - "ApiExplorerView", - "ExecuteWriteAnalyzeView", - "ExecuteWriteView", - "GlobalQueryListView", - "QueryCreateAnalyzeView", - "QueryDeleteView", - "QueryDefinitionView", - "QueryListView", - "QueryParametersView", - "QueryStoreView", - "QueryUpdateView", - ) - ) - return view_labels - - -def test_view_classes_are_documented(documented_views, subtests): - view_classes = [v for v in dir(app) if v.endswith("View")] - for view_class in view_classes: - with subtests.test(view_class=view_class): - assert view_class in documented_views - - -@pytest.fixture(scope="session") -def documented_table_filters(): - json_api_rst = (docs_path / "json_api.rst").read_text() - section = json_api_rst.split(".. _table_arguments:")[-1] - # Lines starting with ``?column__exact= are docs for filters - return { - line.split("__")[1].split("=")[0] - for line in section.split("\n") - if line.startswith("``?column__") - } - - -def test_table_filters_are_documented(documented_table_filters, subtests): - for f in Filters._filters: - with subtests.test(filter=f.key): - assert f.key in documented_table_filters - - -@pytest.fixture(scope="session") -def documented_labels(): - labels = set() - for filename in docs_path.glob("*.rst"): - labels.update(get_labels(filename.name)) - return labels - - -def test_functions_marked_with_documented_are_documented(documented_labels, subtests): - for fn in utils.functions_marked_as_documented: - with subtests.test(fn=fn.__name__): - assert fn._datasette_docs_label in documented_labels - - -def test_rst_heading_underlines_match_title_length(): - """Test that RST heading underlines are the same length as their titles.""" - # Common RST underline characters - underline_chars = ["-", "=", "~", "^", "+", "*", "#"] - - errors = [] - - for rst_file in docs_path.glob("*.rst"): - content = rst_file.read_text() - lines = content.split("\n") - - for i in range(len(lines) - 1): - current_line = lines[i] - next_line = lines[i + 1] - - # Check if next line is entirely made of a single underline character - # and is at least 5 characters long (to avoid false positives) - if ( - next_line - and len(next_line) >= 5 - and len(set(next_line)) == 1 - and next_line[0] in underline_chars - ): - # Skip if the previous line is empty (blank line before underline) - if not current_line: - continue - - # Check if this is an overline+underline style heading - # Look at the line before current_line to see if it's also an underline - if i > 0: - prev_line = lines[i - 1] - if ( - prev_line - and len(prev_line) >= 5 - and len(set(prev_line)) == 1 - and prev_line[0] in underline_chars - and len(prev_line) == len(next_line) - ): - # This is overline+underline style, skip it - continue - - # This is a heading underline - title_length = len(current_line) - underline_length = len(next_line) - - if title_length != underline_length: - errors.append( - f"{rst_file.name}:{i+1}: Title length {title_length} != underline length {underline_length}\n" - f" Title: {current_line!r}\n" - f" Underline: {next_line!r}" - ) - - if errors: - raise AssertionError( - f"Found {len(errors)} RST heading(s) with mismatched underline length:\n\n" - + "\n\n".join(errors) - ) - - -# Tests for testing_plugins.rst documentation - -# fmt: off -# -- start test_homepage -- -@pytest.mark.asyncio -async def test_homepage(): - ds = Datasette(memory=True) - response = await ds.client.get("/") - html = response.text - assert "

" in html -# -- end test_homepage -- - - -# -- start test_actor_is_null -- -@pytest.mark.asyncio -async def test_actor_is_null(): - ds = Datasette(memory=True) - response = await ds.client.get("/-/actor.json") - assert response.json() == {"actor": None} -# -- end test_actor_is_null -- - - -# -- start test_signed_cookie_actor -- -@pytest.mark.asyncio -async def test_signed_cookie_actor(): - ds = Datasette(memory=True) - cookies = {"ds_actor": ds.client.actor_cookie({"id": "root"})} - response = await ds.client.get("/-/actor.json", cookies=cookies) - assert response.json() == {"actor": {"id": "root"}} -# -- end test_signed_cookie_actor -- +@pytest.mark.parametrize('config', app.CONFIG_OPTIONS) +def test_config_options_are_documented(config): + assert config.name in setting_headings diff --git a/tests/test_docs_plugins.py b/tests/test_docs_plugins.py deleted file mode 100644 index 613160ac..00000000 --- a/tests/test_docs_plugins.py +++ /dev/null @@ -1,35 +0,0 @@ -# fmt: off -# -- start datasette_with_plugin_fixture -- -from datasette import hookimpl -from datasette.app import Datasette -import pytest -import pytest_asyncio - - -@pytest_asyncio.fixture -async def datasette_with_plugin(): - class TestPlugin: - __name__ = "TestPlugin" - - @hookimpl - def register_routes(self): - return [ - (r"^/error$", lambda: 1 / 0), - ] - - datasette = Datasette() - datasette.pm.register(TestPlugin(), name="undo") - try: - yield datasette - finally: - datasette.pm.unregister(name="undo") - datasette.close() -# -- end datasette_with_plugin_fixture -- - - -# -- start datasette_with_plugin_test -- -@pytest.mark.asyncio -async def test_error(datasette_with_plugin): - response = await datasette_with_plugin.client.get("/error") - assert response.status_code == 500 -# -- end datasette_with_plugin_test -- diff --git a/tests/test_facets.py b/tests/test_facets.py deleted file mode 100644 index 8c22ffce..00000000 --- a/tests/test_facets.py +++ /dev/null @@ -1,736 +0,0 @@ -from datasette.app import Datasette -from datasette.database import Database -from datasette.facets import Facet, ColumnFacet, ArrayFacet, DateFacet -from datasette.utils.asgi import Request -from datasette.utils import detect_json1 -from .fixtures import make_app_client -import json -import pytest - - -@pytest.mark.asyncio -async def test_column_facet_suggest(ds_client): - facet = ColumnFacet( - ds_client.ds, - Request.fake("/"), - database="fixtures", - sql="select * from facetable", - table="facetable", - ) - suggestions = await facet.suggest() - assert [ - {"name": "created", "toggle_url": "http://localhost/?_facet=created"}, - {"name": "planet_int", "toggle_url": "http://localhost/?_facet=planet_int"}, - {"name": "on_earth", "toggle_url": "http://localhost/?_facet=on_earth"}, - {"name": "state", "toggle_url": "http://localhost/?_facet=state"}, - {"name": "_city_id", "toggle_url": "http://localhost/?_facet=_city_id"}, - { - "name": "_neighborhood", - "toggle_url": "http://localhost/?_facet=_neighborhood", - }, - {"name": "tags", "toggle_url": "http://localhost/?_facet=tags"}, - { - "name": "complex_array", - "toggle_url": "http://localhost/?_facet=complex_array", - }, - ] == suggestions - - -@pytest.mark.asyncio -async def test_column_facet_suggest_skip_if_already_selected(ds_client): - facet = ColumnFacet( - ds_client.ds, - Request.fake("/?_facet=planet_int&_facet=on_earth"), - database="fixtures", - sql="select * from facetable", - table="facetable", - ) - suggestions = await facet.suggest() - assert [ - { - "name": "created", - "toggle_url": "http://localhost/?_facet=planet_int&_facet=on_earth&_facet=created", - }, - { - "name": "state", - "toggle_url": "http://localhost/?_facet=planet_int&_facet=on_earth&_facet=state", - }, - { - "name": "_city_id", - "toggle_url": "http://localhost/?_facet=planet_int&_facet=on_earth&_facet=_city_id", - }, - { - "name": "_neighborhood", - "toggle_url": "http://localhost/?_facet=planet_int&_facet=on_earth&_facet=_neighborhood", - }, - { - "name": "tags", - "toggle_url": "http://localhost/?_facet=planet_int&_facet=on_earth&_facet=tags", - }, - { - "name": "complex_array", - "toggle_url": "http://localhost/?_facet=planet_int&_facet=on_earth&_facet=complex_array", - }, - ] == suggestions - - -@pytest.mark.asyncio -async def test_column_facet_suggest_skip_if_enabled_by_metadata(ds_client): - facet = ColumnFacet( - ds_client.ds, - Request.fake("/"), - database="fixtures", - sql="select * from facetable", - table="facetable", - table_config={"facets": ["_city_id"]}, - ) - suggestions = [s["name"] for s in await facet.suggest()] - assert [ - "created", - "planet_int", - "on_earth", - "state", - "_neighborhood", - "tags", - "complex_array", - ] == suggestions - - -@pytest.mark.asyncio -async def test_column_facet_results(ds_client): - facet = ColumnFacet( - ds_client.ds, - Request.fake("/?_facet=_city_id"), - database="fixtures", - sql="select * from facetable", - table="facetable", - ) - buckets, timed_out = await facet.facet_results() - assert [] == timed_out - assert [ - { - "name": "_city_id", - "type": "column", - "hideable": True, - "toggle_url": "/", - "results": [ - { - "value": 1, - "label": "San Francisco", - "count": 6, - "toggle_url": "http://localhost/?_facet=_city_id&_city_id__exact=1", - "selected": False, - }, - { - "value": 2, - "label": "Los Angeles", - "count": 4, - "toggle_url": "http://localhost/?_facet=_city_id&_city_id__exact=2", - "selected": False, - }, - { - "value": 3, - "label": "Detroit", - "count": 4, - "toggle_url": "http://localhost/?_facet=_city_id&_city_id__exact=3", - "selected": False, - }, - { - "value": 4, - "label": "Memnonia", - "count": 1, - "toggle_url": "http://localhost/?_facet=_city_id&_city_id__exact=4", - "selected": False, - }, - ], - "truncated": False, - } - ] == buckets - - -@pytest.mark.asyncio -async def test_column_facet_results_column_starts_with_underscore(ds_client): - facet = ColumnFacet( - ds_client.ds, - Request.fake("/?_facet=_neighborhood"), - database="fixtures", - sql="select * from facetable", - table="facetable", - ) - buckets, timed_out = await facet.facet_results() - assert [] == timed_out - assert buckets == [ - { - "name": "_neighborhood", - "type": "column", - "hideable": True, - "toggle_url": "/", - "results": [ - { - "value": "Downtown", - "label": "Downtown", - "count": 2, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Downtown", - "selected": False, - }, - { - "value": "Arcadia Planitia", - "label": "Arcadia Planitia", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Arcadia+Planitia", - "selected": False, - }, - { - "value": "Bernal Heights", - "label": "Bernal Heights", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Bernal+Heights", - "selected": False, - }, - { - "value": "Corktown", - "label": "Corktown", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Corktown", - "selected": False, - }, - { - "value": "Dogpatch", - "label": "Dogpatch", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Dogpatch", - "selected": False, - }, - { - "value": "Greektown", - "label": "Greektown", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Greektown", - "selected": False, - }, - { - "value": "Hayes Valley", - "label": "Hayes Valley", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Hayes+Valley", - "selected": False, - }, - { - "value": "Hollywood", - "label": "Hollywood", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Hollywood", - "selected": False, - }, - { - "value": "Koreatown", - "label": "Koreatown", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Koreatown", - "selected": False, - }, - { - "value": "Los Feliz", - "label": "Los Feliz", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Los+Feliz", - "selected": False, - }, - { - "value": "Mexicantown", - "label": "Mexicantown", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Mexicantown", - "selected": False, - }, - { - "value": "Mission", - "label": "Mission", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Mission", - "selected": False, - }, - { - "value": "SOMA", - "label": "SOMA", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=SOMA", - "selected": False, - }, - { - "value": "Tenderloin", - "label": "Tenderloin", - "count": 1, - "toggle_url": "http://localhost/?_facet=_neighborhood&_neighborhood__exact=Tenderloin", - "selected": False, - }, - ], - "truncated": False, - } - ] - - -@pytest.mark.asyncio -async def test_column_facet_from_metadata_cannot_be_hidden(ds_client): - facet = ColumnFacet( - ds_client.ds, - Request.fake("/"), - database="fixtures", - sql="select * from facetable", - table="facetable", - table_config={"facets": ["_city_id"]}, - ) - buckets, timed_out = await facet.facet_results() - assert [] == timed_out - assert [ - { - "name": "_city_id", - "type": "column", - "hideable": False, - "toggle_url": "/", - "results": [ - { - "value": 1, - "label": "San Francisco", - "count": 6, - "toggle_url": "http://localhost/?_city_id__exact=1", - "selected": False, - }, - { - "value": 2, - "label": "Los Angeles", - "count": 4, - "toggle_url": "http://localhost/?_city_id__exact=2", - "selected": False, - }, - { - "value": 3, - "label": "Detroit", - "count": 4, - "toggle_url": "http://localhost/?_city_id__exact=3", - "selected": False, - }, - { - "value": 4, - "label": "Memnonia", - "count": 1, - "toggle_url": "http://localhost/?_city_id__exact=4", - "selected": False, - }, - ], - "truncated": False, - } - ] == buckets - - -@pytest.mark.asyncio -@pytest.mark.skipif(not detect_json1(), reason="Requires the SQLite json1 module") -async def test_array_facet_suggest(ds_client): - facet = ArrayFacet( - ds_client.ds, - Request.fake("/"), - database="fixtures", - sql="select * from facetable", - table="facetable", - ) - suggestions = await facet.suggest() - assert [ - { - "name": "tags", - "type": "array", - "toggle_url": "http://localhost/?_facet_array=tags", - } - ] == suggestions - - -@pytest.mark.asyncio -@pytest.mark.skipif(not detect_json1(), reason="Requires the SQLite json1 module") -async def test_array_facet_suggest_not_if_all_empty_arrays(ds_client): - facet = ArrayFacet( - ds_client.ds, - Request.fake("/"), - database="fixtures", - sql="select * from facetable where tags = '[]'", - table="facetable", - ) - suggestions = await facet.suggest() - assert [] == suggestions - - -@pytest.mark.asyncio -@pytest.mark.skipif(not detect_json1(), reason="Requires the SQLite json1 module") -async def test_array_facet_results(ds_client): - facet = ArrayFacet( - ds_client.ds, - Request.fake("/?_facet_array=tags"), - database="fixtures", - sql="select * from facetable", - table="facetable", - ) - buckets, timed_out = await facet.facet_results() - assert [] == timed_out - assert [ - { - "name": "tags", - "type": "array", - "results": [ - { - "value": "tag1", - "label": "tag1", - "count": 2, - "toggle_url": "http://localhost/?_facet_array=tags&tags__arraycontains=tag1", - "selected": False, - }, - { - "value": "tag2", - "label": "tag2", - "count": 1, - "toggle_url": "http://localhost/?_facet_array=tags&tags__arraycontains=tag2", - "selected": False, - }, - { - "value": "tag3", - "label": "tag3", - "count": 1, - "toggle_url": "http://localhost/?_facet_array=tags&tags__arraycontains=tag3", - "selected": False, - }, - ], - "hideable": True, - "toggle_url": "/", - "truncated": False, - } - ] == buckets - - -@pytest.mark.asyncio -@pytest.mark.skipif(not detect_json1(), reason="Requires the SQLite json1 module") -async def test_array_facet_handle_duplicate_tags(): - ds = Datasette([], memory=True) - db = ds.add_database(Database(ds, memory_name="test_array_facet")) - await db.execute_write("create table otters(name text, tags text)") - for name, tags in ( - ("Charles", ["friendly", "cunning", "friendly"]), - ("Shaun", ["cunning", "empathetic", "friendly"]), - ("Tracy", ["empathetic", "eager"]), - ): - await db.execute_write( - "insert into otters (name, tags) values (?, ?)", [name, json.dumps(tags)] - ) - - response = await ds.client.get("/test_array_facet/otters.json?_facet_array=tags") - assert response.json()["facet_results"]["results"]["tags"] == { - "name": "tags", - "type": "array", - "results": [ - { - "value": "cunning", - "label": "cunning", - "count": 2, - "toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=cunning", - "selected": False, - }, - { - "value": "empathetic", - "label": "empathetic", - "count": 2, - "toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=empathetic", - "selected": False, - }, - { - "value": "friendly", - "label": "friendly", - "count": 2, - "toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=friendly", - "selected": False, - }, - { - "value": "eager", - "label": "eager", - "count": 1, - "toggle_url": "http://localhost/test_array_facet/otters.json?_facet_array=tags&tags__arraycontains=eager", - "selected": False, - }, - ], - "hideable": True, - "toggle_url": "/test_array_facet/otters.json", - "truncated": False, - } - - -@pytest.mark.asyncio -async def test_date_facet_results(ds_client): - facet = DateFacet( - ds_client.ds, - Request.fake("/?_facet_date=created"), - database="fixtures", - sql="select * from facetable", - table="facetable", - ) - buckets, timed_out = await facet.facet_results() - assert [] == timed_out - assert [ - { - "name": "created", - "type": "date", - "results": [ - { - "value": "2019-01-14", - "label": "2019-01-14", - "count": 4, - "toggle_url": "http://localhost/?_facet_date=created&created__date=2019-01-14", - "selected": False, - }, - { - "value": "2019-01-15", - "label": "2019-01-15", - "count": 4, - "toggle_url": "http://localhost/?_facet_date=created&created__date=2019-01-15", - "selected": False, - }, - { - "value": "2019-01-17", - "label": "2019-01-17", - "count": 4, - "toggle_url": "http://localhost/?_facet_date=created&created__date=2019-01-17", - "selected": False, - }, - { - "value": "2019-01-16", - "label": "2019-01-16", - "count": 3, - "toggle_url": "http://localhost/?_facet_date=created&created__date=2019-01-16", - "selected": False, - }, - ], - "hideable": True, - "toggle_url": "/", - "truncated": False, - } - ] == buckets - - -@pytest.mark.asyncio -async def test_json_array_with_blanks_and_nulls(): - ds = Datasette([], memory=True) - db = ds.add_database(Database(ds, memory_name="test_json_array")) - await db.execute_write("create table foo(json_column text)") - for value in ('["a", "b", "c"]', '["a", "b"]', "", None): - await db.execute_write("insert into foo (json_column) values (?)", [value]) - response = await ds.client.get("/test_json_array/foo.json?_extra=suggested_facets") - data = response.json() - assert data["suggested_facets"] == [ - { - "name": "json_column", - "type": "array", - "toggle_url": "http://localhost/test_json_array/foo.json?_extra=suggested_facets&_facet_array=json_column", - } - ] - - -@pytest.mark.asyncio -async def test_facet_size(): - ds = Datasette([], memory=True, settings={"max_returned_rows": 50}) - db = ds.add_database(Database(ds, memory_name="test_facet_size")) - await db.execute_write("create table neighbourhoods(city text, neighbourhood text)") - for i in range(1, 51): - for j in range(1, 4): - await db.execute_write( - "insert into neighbourhoods (city, neighbourhood) values (?, ?)", - ["City {}".format(i), "Neighbourhood {}".format(j)], - ) - response = await ds.client.get( - "/test_facet_size/neighbourhoods.json?_extra=suggested_facets" - ) - data = response.json() - assert data["suggested_facets"] == [ - { - "name": "neighbourhood", - "toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_extra=suggested_facets&_facet=neighbourhood", - } - ] - # Bump up _facet_size= to suggest city too - response2 = await ds.client.get( - "/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets" - ) - data2 = response2.json() - assert sorted(data2["suggested_facets"], key=lambda f: f["name"]) == [ - { - "name": "city", - "toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets&_facet=city", - }, - { - "name": "neighbourhood", - "toggle_url": "http://localhost/test_facet_size/neighbourhoods.json?_facet_size=50&_extra=suggested_facets&_facet=neighbourhood", - }, - ] - # Facet by city should return expected number of results - response3 = await ds.client.get( - "/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=city" - ) - data3 = response3.json() - assert len(data3["facet_results"]["results"]["city"]["results"]) == 50 - # Reduce max_returned_rows and check that it's respected - ds._settings["max_returned_rows"] = 20 - response4 = await ds.client.get( - "/test_facet_size/neighbourhoods.json?_facet_size=50&_facet=city" - ) - data4 = response4.json() - assert len(data4["facet_results"]["results"]["city"]["results"]) == 20 - # Test _facet_size=max - response5 = await ds.client.get( - "/test_facet_size/neighbourhoods.json?_facet_size=max&_facet=city" - ) - data5 = response5.json() - assert len(data5["facet_results"]["results"]["city"]["results"]) == 20 - # Now try messing with facet_size in the table metadata - orig_config = ds.config - try: - ds.config = { - "databases": { - "test_facet_size": {"tables": {"neighbourhoods": {"facet_size": 6}}} - } - } - response6 = await ds.client.get( - "/test_facet_size/neighbourhoods.json?_facet=city" - ) - data6 = response6.json() - assert len(data6["facet_results"]["results"]["city"]["results"]) == 6 - # Setting it to max bumps it up to 50 again - ds.config["databases"]["test_facet_size"]["tables"]["neighbourhoods"][ - "facet_size" - ] = "max" - data7 = ( - await ds.client.get("/test_facet_size/neighbourhoods.json?_facet=city") - ).json() - assert len(data7["facet_results"]["results"]["city"]["results"]) == 20 - finally: - ds.config = orig_config - - -def test_other_types_of_facet_in_metadata(): - with make_app_client( - metadata={ - "databases": { - "fixtures": { - "tables": { - "facetable": { - "facets": ["state", {"array": "tags"}, {"date": "created"}] - } - } - } - } - } - ) as client: - response = client.get("/fixtures/facetable") - fragments = ( - "state\n", - "tags (array)\n", - "created (date)\n", - ) - for fragment in fragments: - assert fragment in response.text - # Verify they appear in the metadata-defined order - positions = [response.text.index(f) for f in fragments] - assert positions == sorted( - positions - ), "Facets should appear in metadata-defined order" - - -def test_metadata_facet_ordering(): - with make_app_client( - metadata={ - "databases": { - "fixtures": { - "tables": { - "facetable": { - "facets": ["state", {"array": "tags"}, {"date": "created"}] - } - } - } - } - } - ) as client: - # JSON response should have facets in the metadata-defined order - response = client.get("/fixtures/facetable.json?_extra=sorted_facet_results") - data = response.json - facet_names = [f["name"] for f in data["sorted_facet_results"]] - assert facet_names == ["state", "tags", "created"] - - # With an additional request-based facet, metadata facets come first - # in their defined order, followed by request-based facets - response2 = client.get( - "/fixtures/facetable.json?_extra=sorted_facet_results&_facet=_city_id" - ) - data2 = response2.json - facet_names2 = [f["name"] for f in data2["sorted_facet_results"]] - assert facet_names2 == ["state", "tags", "created", "_city_id"] - - -@pytest.mark.asyncio -async def test_conflicting_facet_names_json(ds_client): - response = await ds_client.get( - "/fixtures/facetable.json?_facet=created&_facet_date=created" - "&_facet=tags&_facet_array=tags" - ) - assert set(response.json()["facet_results"]["results"].keys()) == { - "created", - "tags", - "created_2", - "tags_2", - } - - -@pytest.mark.asyncio -async def test_facet_against_in_memory_database(): - ds = Datasette() - db = ds.add_memory_database("mem") - await db.execute_write( - "create table t (id integer primary key, name text, name2 text)" - ) - to_insert = [{"name": "one", "name2": "1"} for _ in range(800)] + [ - {"name": "two", "name2": "2"} for _ in range(300) - ] - await db.execute_write_many( - "insert into t (name, name2) values (:name, :name2)", to_insert - ) - response1 = await ds.client.get("/mem/t") - assert response1.status_code == 200 - response2 = await ds.client.get("/mem/t?_facet=name&_facet=name2") - assert response2.status_code == 200 - - -@pytest.mark.asyncio -async def test_facet_only_considers_first_x_rows(): - # This test works by manually fiddling with Facet.suggest_consider - ds = Datasette() - original_suggest_consider = Facet.suggest_consider - try: - Facet.suggest_consider = 40 - db = ds.add_memory_database("test_facet_only_x_rows") - await db.execute_write("create table t (id integer primary key, col text)") - # First 50 rows make it look like col and col_json should be faceted - to_insert = [{"col": "one" if i % 2 else "two"} for i in range(50)] - await db.execute_write_many("insert into t (col) values (:col)", to_insert) - # Next 50 break that assumption - to_insert2 = [{"col": f"x{i}"} for i in range(50)] - await db.execute_write_many("insert into t (col) values (:col)", to_insert2) - response = await ds.client.get( - "/test_facet_only_x_rows/t.json?_extra=suggested_facets" - ) - data = response.json() - assert data["suggested_facets"] == [ - { - "name": "col", - "toggle_url": "http://localhost/test_facet_only_x_rows/t.json?_extra=suggested_facets&_facet=col", - } - ] - # But if we set suggest_consider to 100 they are not suggested - Facet.suggest_consider = 100 - response2 = await ds.client.get( - "/test_facet_only_x_rows/t.json?_extra=suggested_facets" - ) - data2 = response2.json() - assert data2["suggested_facets"] == [] - finally: - Facet.suggest_consider = original_suggest_consider diff --git a/tests/test_fd_leak.py b/tests/test_fd_leak.py deleted file mode 100644 index 926722a1..00000000 --- a/tests/test_fd_leak.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -Regression test for https://github.com/simonw/datasette/issues/2692 — -confirm that creating and closing Datasette instances in a loop does not -leak open file descriptors. - -Each Datasette() with is_temp_disk internal DB opens a temp file and a -write thread with its own SQLite connection. Without Datasette.close() -nothing unwinds this state, and a large pytest run exhausts the process -FD limit. -""" - -import asyncio -import threading - -import pytest - -try: - import psutil -except ImportError: # pragma: no cover - psutil = None - -from datasette.app import Datasette - - -def _count_open_files(): - return len(psutil.Process().open_files()) - - -def _count_threads(): - return threading.active_count() - - -@pytest.mark.skipif(psutil is None, reason="psutil not installed") -def test_close_releases_file_descriptors(): - # Warm-up so Python/library caches don't skew the baseline - ds = Datasette(memory=True) - asyncio.run(ds.invoke_startup()) - ds.close() - - baseline_fds = _count_open_files() - baseline_threads = _count_threads() - - for _ in range(50): - ds = Datasette(memory=True) - asyncio.run(ds.invoke_startup()) - ds.close() - - after_fds = _count_open_files() - after_threads = _count_threads() - - assert ( - after_fds - baseline_fds <= 2 - ), f"Leaked FDs: baseline={baseline_fds}, after=50 iterations={after_fds}" - assert ( - after_threads - baseline_threads <= 2 - ), f"Leaked threads: baseline={baseline_threads}, after={after_threads}" diff --git a/tests/test_filters.py b/tests/test_filters.py deleted file mode 100644 index eda9e9a1..00000000 --- a/tests/test_filters.py +++ /dev/null @@ -1,137 +0,0 @@ -from datasette.filters import Filters, through_filters, where_filters, search_filters -from datasette.utils.asgi import Request -import pytest - - -@pytest.mark.parametrize( - "args,expected_where,expected_params", - [ - ((("name_english__contains", "foo"),), ['"name_english" like :p0'], ["%foo%"]), - ( - (("name_english__notcontains", "foo"),), - ['"name_english" not like :p0'], - ["%foo%"], - ), - ( - (("foo", "bar"), ("bar__contains", "baz")), - ['"bar" like :p0', '"foo" = :p1'], - ["%baz%", "bar"], - ), - ( - (("foo__startswith", "bar"), ("bar__endswith", "baz")), - ['"bar" like :p0', '"foo" like :p1'], - ["%baz", "bar%"], - ), - ( - (("foo__lt", "1"), ("bar__gt", "2"), ("baz__gte", "3"), ("bax__lte", "4")), - ['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'], - [2, 4, 3, 1], - ), - ( - (("foo__like", "2%2"), ("zax__glob", "3*")), - ['"foo" like :p0', '"zax" glob :p1'], - ["2%2", "3*"], - ), - # Multiple like arguments: - ( - (("foo__like", "2%2"), ("foo__like", "3%3")), - ['"foo" like :p0', '"foo" like :p1'], - ["2%2", "3%3"], - ), - # notlike: - ( - (("foo__notlike", "2%2"),), - ['"foo" not like :p0'], - ["2%2"], - ), - ( - (("foo__isnull", "1"), ("baz__isnull", "1"), ("bar__gt", "10")), - ['"bar" > :p0', '"baz" is null', '"foo" is null'], - [10], - ), - ((("foo__in", "1,2,3"),), ["foo in (:p0, :p1, :p2)"], ["1", "2", "3"]), - # date - ((("foo__date", "1988-01-01"),), ['date("foo") = :p0'], ["1988-01-01"]), - # JSON array variants of __in (useful for unexpected characters) - ((("foo__in", "[1,2,3]"),), ["foo in (:p0, :p1, :p2)"], [1, 2, 3]), - ( - (("foo__in", '["dog,cat", "cat[dog]"]'),), - ["foo in (:p0, :p1)"], - ["dog,cat", "cat[dog]"], - ), - # Not in, and JSON array not in - ((("foo__notin", "1,2,3"),), ["foo not in (:p0, :p1, :p2)"], ["1", "2", "3"]), - ((("foo__notin", "[1,2,3]"),), ["foo not in (:p0, :p1, :p2)"], [1, 2, 3]), - # JSON arraycontains, arraynotcontains - ( - (("Availability+Info__arraycontains", "yes"),), - [":p0 in (select value from json_each([table].[Availability+Info]))"], - ["yes"], - ), - ( - (("Availability+Info__arraynotcontains", "yes"),), - [":p0 not in (select value from json_each([table].[Availability+Info]))"], - ["yes"], - ), - ], -) -def test_build_where(args, expected_where, expected_params): - f = Filters(sorted(args)) - sql_bits, actual_params = f.build_where_clauses("table") - assert expected_where == sql_bits - assert {f"p{i}": param for i, param in enumerate(expected_params)} == actual_params - - -@pytest.mark.asyncio -async def test_through_filters_from_request(ds_client): - request = Request.fake( - '/?_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' - ) - filter_args = await through_filters( - request=request, - datasette=ds_client.ds, - table="roadside_attractions", - database="fixtures", - )() - assert filter_args.where_clauses == [ - "pk in (select attraction_id from roadside_attraction_characteristics where characteristic_id = :p0)" - ] - assert filter_args.params == {"p0": "1"} - assert filter_args.human_descriptions == [ - 'roadside_attraction_characteristics.characteristic_id = "1"' - ] - assert filter_args.extra_context == {} - - -@pytest.mark.asyncio -async def test_where_filters_from_request(ds_client): - await ds_client.ds.invoke_startup() - request = Request.fake("/?_where=pk+>+3") - filter_args = await where_filters( - request=request, - datasette=ds_client.ds, - database="fixtures", - )() - assert filter_args.where_clauses == ["pk > 3"] - assert filter_args.params == {} - assert filter_args.human_descriptions == [] - assert filter_args.extra_context == { - "extra_wheres_for_ui": [{"text": "pk > 3", "remove_url": "/"}] - } - - -@pytest.mark.asyncio -async def test_search_filters_from_request(ds_client): - request = Request.fake("/?_search=bobcat") - filter_args = await search_filters( - request=request, - datasette=ds_client.ds, - database="fixtures", - table="searchable", - )() - assert filter_args.where_clauses == [ - "rowid in (select rowid from searchable_fts where searchable_fts match escape_fts(:search))" - ] - assert filter_args.params == {"search": "bobcat"} - assert filter_args.human_descriptions == ['search matches "bobcat"'] - assert filter_args.extra_context == {"supports_search": True, "search": "bobcat"} diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py deleted file mode 100644 index 45f9854b..00000000 --- a/tests/test_fixtures.py +++ /dev/null @@ -1,49 +0,0 @@ -from datasette.fixtures import ( - populate_extra_database, - populate_fixture_database, - write_extra_database, - write_fixture_database, -) -from datasette.utils.sqlite import sqlite3 - - -def count(conn, table): - return conn.execute(f"select count(*) from [{table}]").fetchone()[0] - - -def test_populate_fixture_database(): - conn = sqlite3.connect(":memory:") - try: - populate_fixture_database(conn) - assert count(conn, "facetable") == 15 - assert count(conn, "compound_three_primary_keys") == 1001 - assert count(conn, "binary_data") == 3 - finally: - conn.close() - - -def test_write_fixture_database(tmp_path): - db_path = tmp_path / "fixtures.db" - write_fixture_database(db_path) - conn = sqlite3.connect(db_path) - try: - assert count(conn, "sortable") == 201 - finally: - conn.close() - - -def test_extra_database_helpers(tmp_path): - conn = sqlite3.connect(":memory:") - try: - populate_extra_database(conn) - assert count(conn, "searchable") == 2 - finally: - conn.close() - - db_path = tmp_path / "extra.db" - write_extra_database(db_path) - conn = sqlite3.connect(db_path) - try: - assert count(conn, "searchable") == 2 - finally: - conn.close() diff --git a/tests/test_html.py b/tests/test_html.py index a9de5e79..3f5815d4 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -1,1320 +1,561 @@ from bs4 import BeautifulSoup as Soup -from datasette.app import Datasette -from datasette.utils import allowed_pragmas -from .fixtures import make_app_client -from .utils import assert_footer_links, inner_html -import copy -import json -import pathlib +from .fixtures import ( # noqa + app_client, + app_client_shorter_time_limit, +) import pytest import re import urllib.parse -def test_homepage(app_client_two_attached_databases): - response = app_client_two_attached_databases.get("/") - assert response.status_code == 200 - assert "text/html; charset=utf-8" == response.headers["content-type"] - # Should have a html lang="en" attribute - assert '' in response.text - soup = Soup(response.content, "html.parser") - assert "Datasette Fixtures" == soup.find("h1").text - assert ( - "An example SQLite database demonstrating Datasette. Sign in as root user" - == soup.select(".metadata-description")[0].text.strip() +def test_homepage(app_client): + response = app_client.get('/') + assert response.status == 200 + assert 'test_tables' in response.text + + +def test_database_page(app_client): + response = app_client.get('/test_tables', allow_redirects=False) + assert response.status == 302 + response = app_client.get('/test_tables') + assert 'test_tables' in response.text + + +def test_invalid_custom_sql(app_client): + response = app_client.get( + '/test_tables?sql=.schema' ) - # Should be two attached databases - assert [ - {"href": "/extra+database", "text": "extra database"}, - {"href": "/fixtures", "text": "fixtures"}, - ] == [{"href": a["href"], "text": a.text.strip()} for a in soup.select("h2 a")] - # Database should show count text and attached tables - h2 = soup.select("h2")[0] - assert "extra database" == h2.text.strip() - counts_p, links_p = h2.find_all_next("p")[:2] - assert ( - "2 rows in 1 table, 5 rows in 4 hidden tables, 1 view" == counts_p.text.strip() - ) - # We should only show visible, not hidden tables here: - table_links = [ - {"href": a["href"], "text": a.text.strip()} for a in links_p.find_all("a") - ] - assert [ - {"href": r"/extra+database/searchable", "text": "searchable"}, - {"href": r"/extra+database/searchable_view", "text": "searchable_view"}, - ] == table_links + assert response.status == 400 + assert 'Statement must be a SELECT' in response.text -@pytest.mark.asyncio -@pytest.mark.parametrize("path", ("/", "/-/")) -async def test_homepage_alternative_location(path, tmp_path_factory): - template_dir = tmp_path_factory.mktemp("templates") - (template_dir / "index.html").write_text("Custom homepage", "utf-8") - datasette = Datasette(template_dir=str(template_dir)) - response = await datasette.client.get(path) - assert response.status_code == 200 - html = response.text - if path == "/": - assert html == "Custom homepage" - else: - assert '' in html - - -@pytest.mark.asyncio -async def test_homepage_alternative_redirect(ds_client): - response = await ds_client.get("/-") - assert response.status_code == 301 - - -@pytest.mark.asyncio -async def test_http_head(ds_client): - response = await ds_client.head("/") - assert response.status_code == 200 - - -@pytest.mark.asyncio -async def test_homepage_options(ds_client): - response = await ds_client.options("/") - assert response.status_code == 200 - assert response.text == "ok" - - -@pytest.mark.asyncio -async def test_favicon(ds_client): - response = await ds_client.get("/favicon.ico") - assert response.status_code == 200 - assert response.headers["cache-control"] == "max-age=3600, immutable, public" - assert int(response.headers["content-length"]) > 100 - assert response.headers["content-type"] == "image/png" - - -@pytest.mark.asyncio -async def test_static(ds_client): - response = await ds_client.get("/-/static/app2.css") - assert response.status_code == 404 - response = await ds_client.get("/-/static/app.css") - assert response.status_code == 200 - assert "text/css" == response.headers["content-type"] - assert "etag" in response.headers - etag = response.headers.get("etag") - response = await ds_client.get("/-/static/app.css", headers={"if-none-match": etag}) - assert response.status_code == 304 - - -def test_static_mounts(): - with make_app_client( - static_mounts=[("custom-static", str(pathlib.Path(__file__).parent))] - ) as client: - response = client.get("/custom-static/test_html.py") - assert response.status_code == 200 - response = client.get("/custom-static/not_exists.py") - assert response.status_code == 404 - response = client.get("/custom-static/../LICENSE") - assert response.status_code == 404 - - -def test_memory_database_page(): - with make_app_client(memory=True) as client: - response = client.get("/_memory") - assert response.status_code == 200 - - -def test_not_allowed_methods(): - with make_app_client(memory=True) as client: - for method in ("post", "put", "patch", "delete"): - response = client.request(path="/_memory", method=method.upper()) - assert response.status_code == 405 - - -@pytest.mark.asyncio -async def test_database_page(ds_client): - response = await ds_client.get("/fixtures") - soup = Soup(response.text, "html.parser") - # Should have a ', + expected_html_fragment = """ + sql_time_limit_ms + """.strip() + assert expected_html_fragment in response.text + + +def test_view(app_client): + response = app_client.get('/test_tables/simple_view') + assert response.status == 200 + + +def test_row(app_client): + response = app_client.get( + '/test_tables/simple_primary_key/1', + allow_redirects=False + ) + assert response.status == 302 + assert response.headers['Location'].endswith('/1') + response = app_client.get('/test_tables/simple_primary_key/1') + assert response.status == 200 + + +def test_add_filter_redirects(app_client): + filter_args = urllib.parse.urlencode({ + '_filter_column': 'content', + '_filter_op': 'startswith', + '_filter_value': 'x' + }) + # First we need to resolve the correct path before testing more redirects + path_base = app_client.get( + '/test_tables/simple_primary_key', allow_redirects=False + ).headers['Location'] + path = path_base + '?' + filter_args + response = app_client.get(path, allow_redirects=False) + assert response.status == 302 + assert response.headers['Location'].endswith('?content__startswith=x') + + # Adding a redirect to an existing querystring: + path = path_base + '?foo=bar&' + filter_args + response = app_client.get(path, allow_redirects=False) + assert response.status == 302 + assert response.headers['Location'].endswith('?foo=bar&content__startswith=x') + + # Test that op with a __x suffix overrides the filter value + path = path_base + '?' + urllib.parse.urlencode({ + '_filter_column': 'content', + '_filter_op': 'isnull__5', + '_filter_value': 'x' + }) + response = app_client.get(path, allow_redirects=False) + assert response.status == 302 + assert response.headers['Location'].endswith('?content__isnull=5') + + +def test_existing_filter_redirects(app_client): + filter_args = { + '_filter_column_1': 'name', + '_filter_op_1': 'contains', + '_filter_value_1': 'hello', + '_filter_column_2': 'age', + '_filter_op_2': 'gte', + '_filter_value_2': '22', + '_filter_column_3': 'age', + '_filter_op_3': 'lt', + '_filter_value_3': '30', + '_filter_column_4': 'name', + '_filter_op_4': 'contains', + '_filter_value_4': 'world', + } + path_base = app_client.get( + '/test_tables/simple_primary_key', allow_redirects=False + ).headers['Location'] + path = path_base + '?' + urllib.parse.urlencode(filter_args) + response = app_client.get(path, allow_redirects=False) + assert response.status == 302 + assert_querystring_equal( + 'name__contains=hello&age__gte=22&age__lt=30&name__contains=world', + response.headers['Location'].split('?')[1], + ) + + # Setting _filter_column_3 to empty string should remove *_3 entirely + filter_args['_filter_column_3'] = '' + path = path_base + '?' + urllib.parse.urlencode(filter_args) + response = app_client.get(path, allow_redirects=False) + assert response.status == 302 + assert_querystring_equal( + 'name__contains=hello&age__gte=22&name__contains=world', + response.headers['Location'].split('?')[1], + ) + + # ?_filter_op=exact should be removed if unaccompanied by _fiter_column + response = app_client.get(path_base + '?_filter_op=exact', allow_redirects=False) + assert response.status == 302 + assert '?' not in response.headers['Location'] + + +def test_empty_search_parameter_gets_removed(app_client): + path_base = app_client.get( + '/test_tables/simple_primary_key', allow_redirects=False + ).headers['Location'] + path = path_base + '?' + urllib.parse.urlencode({ + '_search': '', + '_filter_column': 'name', + '_filter_op': 'exact', + '_filter_value': 'chidi', + }) + response = app_client.get(path, allow_redirects=False) + assert response.status == 302 + assert response.headers['Location'].endswith( + '?name__exact=chidi' + ) + + +def test_sort_by_desc_redirects(app_client): + path_base = app_client.get( + '/test_tables/sortable', allow_redirects=False + ).headers['Location'] + path = path_base + '?' + urllib.parse.urlencode({ + '_sort': 'sortable', + '_sort_by_desc': '1', + }) + response = app_client.get(path, allow_redirects=False) + assert response.status == 302 + assert response.headers['Location'].endswith('?_sort_desc=sortable') + + +def test_sort_links(app_client): + response = app_client.get( + '/test_tables/sortable?_sort=sortable' + + ) + assert response.status == 200 + ths = Soup(response.body, 'html.parser').findAll('th') + attrs_and_link_attrs = [{ + 'attrs': th.attrs, + 'a_href': ( + th.find('a')['href'].split('/')[-1] + if th.find('a') + else None + ), + } for th in ths] + assert [ + { + "attrs": {"class": ["col-Link"], "scope": "col"}, + "a_href": None + }, + { + "attrs": {"class": ["col-pk1"], "scope": "col"}, + "a_href": None + }, + { + "attrs": {"class": ["col-pk2"], "scope": "col"}, + "a_href": None + }, + { + "attrs": {"class": ["col-content"], "scope": "col"}, + "a_href": None + }, + { + "attrs": {"class": ["col-sortable"], "scope": "col"}, + "a_href": "sortable?_sort_desc=sortable", + }, + { + "attrs": {"class": ["col-sortable_with_nulls"], "scope": "col"}, + "a_href": "sortable?_sort=sortable_with_nulls", + }, + { + "attrs": {"class": ["col-sortable_with_nulls_2"], "scope": "col"}, + "a_href": "sortable?_sort=sortable_with_nulls_2", + }, + { + "attrs": {"class": ["col-text"], "scope": "col"}, + "a_href": "sortable?_sort=text", + }, + ] == attrs_and_link_attrs + + +def test_facets_persist_through_filter_form(app_client): + response = app_client.get( + '/test_tables/facetable?_facet=planet_int&_facet=city_id' + ) + assert response.status == 200 + inputs = Soup(response.body, 'html.parser').find('form').findAll('input') + hiddens = [i for i in inputs if i['type'] == 'hidden'] + assert [ + ('_facet', 'city_id'), + ('_facet', 'planet_int'), + ] == [ + (hidden['name'], hidden['value']) for hidden in hiddens ] - for expected_html_fragment in expected_html_fragments: - assert expected_html_fragment in response.text -def test_row_page_does_not_truncate(): - with make_app_client(settings={"truncate_cells_html": 5}) as client: - response = client.get("/fixtures/facetable/1") - assert response.status_code == 200 - table = Soup(response.content, "html.parser").find("table") - assert table["class"] == ["rows-and-columns"] - assert ["Mission"] == [ - td.string - for td in table.find_all("td", {"class": "col-neighborhood-b352a7"}) - ] - - -def test_query_page_truncates(): - with make_app_client(settings={"truncate_cells_html": 5}) as client: - response = client.get( - "/fixtures/-/query?" - + urllib.parse.urlencode( - { - "sql": "select 'this is longer than 5' as a, 'https://example.com/' as b" - } - ) - ) - assert response.status_code == 200 - table = Soup(response.content, "html.parser").find("table") - tds = table.find_all("td") - assert [str(td) for td in tds] == [ - '

', - '', - ] - - -@pytest.mark.asyncio -async def test_query_page_with_no_sql(ds_client): - # https://github.com/simonw/datasette/issues/2743 - response = await ds_client.get("/fixtures/-/query") - assert response.status_code == 200 - assert '" in html - assert "0 results" not in html - - -def test_config_template_debug_on(): - with make_app_client(settings={"template_debug": True}) as client: - response = client.get("/fixtures/facetable?_context=1") - assert response.status_code == 200 - assert response.text.startswith("
{")
-
-
-@pytest.mark.asyncio
-async def test_config_template_debug_off(ds_client):
-    response = await ds_client.get("/fixtures/facetable?_context=1")
-    assert response.status_code == 200
-    assert not response.text.startswith("
{")
-
-
-def test_debug_context_includes_extra_template_vars():
-    # https://github.com/simonw/datasette/issues/693
-    with make_app_client(settings={"template_debug": True}) as client:
-        response = client.get("/fixtures/facetable?_context=1")
-        # scope_path is added by PLUGIN1
-        assert "scope_path" in response.text
-
-
-@pytest.mark.parametrize(
-    "path",
-    [
-        "/",
-        "/fixtures",
-        "/fixtures/compound_three_primary_keys",
-        "/fixtures/compound_three_primary_keys/a,a,a",
-        "/fixtures/paginated_view",
-        "/fixtures/facetable",
-        "/fixtures/facetable?_facet=state",
-        "/fixtures/-/query?sql=select+1",
-    ],
-)
-@pytest.mark.parametrize("use_prefix", (True, False))
-def test_base_url_config(app_client_base_url_prefix, path, use_prefix):
-    client = app_client_base_url_prefix
-    path_to_get = path
-    if use_prefix:
-        path_to_get = "/prefix/" + path.lstrip("/")
-    response = client.get(path_to_get)
-    soup = Soup(response.content, "html.parser")
-    for form in soup.select("form"):
-        action = form.get("action")
-        if action is None:
-            assert form.get("method") == "dialog", json.dumps(
-                {
-                    "path": path,
-                    "path_to_get": path_to_get,
-                    "form": str(form),
-                },
-                indent=4,
-                default=repr,
-            )
-            continue
-        assert action.startswith("/prefix"), json.dumps(
-            {
-                "path": path,
-                "path_to_get": path_to_get,
-                "action": action,
-                "form": str(form),
-            },
-            indent=4,
-            default=repr,
-        )
-    for el in soup.find_all(["a", "link", "script"]):
-        if "href" in el.attrs:
-            href = el["href"]
-        elif "src" in el.attrs:
-            href = el["src"]
-        else:
-            continue  # Could be a 
-        if (
-            not href.startswith("#")
-            and href
-            not in {
-                "https://datasette.io/",
-                "https://github.com/simonw/datasette",
-                "https://github.com/simonw/datasette/blob/main/LICENSE",
-                "https://github.com/simonw/datasette/blob/main/tests/fixtures.py",
-                "/login-as-root",  # Only used for the latest.datasette.io demo
-            }
-            and not href.startswith("https://plugin-example.datasette.io/")
-        ):
-            # If this has been made absolute it may start http://localhost/
-            if href.startswith("http://localhost/"):
-                href = href[len("http://localost/") :]
-            assert href.startswith("/prefix/"), json.dumps(
-                {
-                    "path": path,
-                    "path_to_get": path_to_get,
-                    "href_or_src": href,
-                    "element_parent": str(el.parent),
-                },
-                indent=4,
-                default=repr,
-            )
-
-
-def test_base_url_affects_filter_redirects(app_client_base_url_prefix):
-    path = "/fixtures/binary_data?_filter_column=rowid&_filter_op=exact&_filter_value=1&_sort=rowid"
-    response = app_client_base_url_prefix.get(path)
-    assert response.status_code == 302
-    assert (
-        response.headers["location"]
-        == "/prefix/fixtures/binary_data?_sort=rowid&rowid__exact=1"
-    )
-
-
-def test_base_url_affects_metadata_extra_css_urls(app_client_base_url_prefix):
-    html = app_client_base_url_prefix.get("/").text
-    assert '' in html
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize(
-    "path,expected",
-    [
-        (
-            "/fixtures/neighborhood_search",
-            "/fixtures/-/query?sql=%0Aselect+_neighborhood%2C+facet_cities.name%2C+state%0Afrom+facetable%0A++++join+facet_cities%0A++++++++on+facetable._city_id+%3D+facet_cities.id%0Awhere+_neighborhood+like+%27%25%27+%7C%7C+%3Atext+%7C%7C+%27%25%27%0Aorder+by+_neighborhood%3B%0A&text=",
-        ),
-        (
-            "/fixtures/neighborhood_search?text=ber",
-            "/fixtures/-/query?sql=%0Aselect+_neighborhood%2C+facet_cities.name%2C+state%0Afrom+facetable%0A++++join+facet_cities%0A++++++++on+facetable._city_id+%3D+facet_cities.id%0Awhere+_neighborhood+like+%27%25%27+%7C%7C+%3Atext+%7C%7C+%27%25%27%0Aorder+by+_neighborhood%3B%0A&text=ber",
-        ),
-        ("/fixtures/pragma_cache_size", None),
-        (
-            # /fixtures/𝐜𝐢𝐭𝐢𝐞𝐬
-            "/fixtures/~F0~9D~90~9C~F0~9D~90~A2~F0~9D~90~AD~F0~9D~90~A2~F0~9D~90~9E~F0~9D~90~AC",
-            "/fixtures/-/query?sql=select+id%2C+name+from+facet_cities+order+by+id+limit+1%3B",
-        ),
-        ("/fixtures/magic_parameters", None),
-    ],
-)
-async def test_edit_sql_link_on_stored_queries(ds_client, path, expected):
-    response = await ds_client.get(path)
-    assert response.status_code == 200
-    expected_link = f'Edit SQL'
-    if expected:
-        assert expected_link in response.text
-    else:
-        assert "Edit SQL" not in response.text
-
-
-@pytest.mark.parametrize(
-    "has_permission",
-    [
-        pytest.param(
-            True,
-        ),
-        False,
-    ],
-)
-def test_edit_sql_link_not_shown_if_user_lacks_permission(has_permission):
-    with make_app_client(
-        config={
-            "allow_sql": None if has_permission else {"id": "not-you"},
-            "databases": {"fixtures": {"queries": {"simple": "select 1 + 1"}}},
-        }
-    ) as client:
-        response = client.get("/fixtures/simple")
-        if has_permission:
-            assert "Edit SQL" in response.text
-        else:
-            assert "Edit SQL" not in response.text
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize(
-    "actor_id,should_have_links,should_not_have_links",
-    [
-        (None, None, None),
-        ("test", None, ["/-/permissions"]),
-        ("root", None, ["/-/permissions", "/-/allow-debug"]),
-    ],
-)
-async def test_navigation_menu_links(
-    ds_client, actor_id, should_have_links, should_not_have_links
-):
-    # Enable root user if testing with root actor
-    if actor_id == "root":
-        ds_client.ds.root_enabled = True
-    kwargs = {}
-    if actor_id:
-        kwargs["actor"] = {"id": actor_id}
-    html = (await ds_client.get("/", **kwargs)).text
-    soup = Soup(html, "html.parser")
-    details = soup.find("nav").find("details", {"class": "nav-menu"})
-    assert details is not None
-    search_button = details.find("button", {"data-navigation-search-open": True})
-    assert search_button is not None
-    assert search_button.text.strip() == "Jump to... /"
-    assert search_button.find("kbd", {"class": "keyboard-shortcut"}).text == "/"
-    assert search_button.find("kbd")["aria-hidden"] == "true"
-    assert (
-        search_button.find("kbd")["title"]
-        == "Keyboard shortcut: press / to open Jump to"
-    )
-    navigation_search_script = soup.find(
-        "script", {"src": re.compile(r"navigation-search\.js")}
-    )
-    assert navigation_search_script["src"] == "/-/static/navigation-search.js"
-    assert details.find("li").find("button") == search_button
-    if not actor_id:
-        # The app menu is always visible, but anonymous users do not see logout
-        # or debug links.
-        assert details.find("form") is None
-        return
-    # They are logged in: should show a menu
-    assert details is not None
-    # And a logout form
-    assert details.find("form") is not None
-    if should_have_links:
-        for link in should_have_links:
-            assert (
-                details.find("a", {"href": link}) is not None
-            ), f"{link} expected but missing from nav menu"
-
-    if should_not_have_links:
-        for link in should_not_have_links:
-            assert (
-                details.find("a", {"href": link}) is None
-            ), f"{link} found but should not have been in nav menu"
-
-
-@pytest.mark.asyncio
-async def test_trace_correctly_escaped(ds_client):
-    response = await ds_client.get("/fixtures/-/query?sql=select+'

Hello'&_trace=1") - assert "select '

Hello" not in response.text - assert "select '<h1>Hello" in response.text - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected", - ( - # Instance index page - ("/", "http://localhost/.json"), - # Table page - ("/fixtures/facetable", "http://localhost/fixtures/facetable.json"), - ( - "/fixtures/table~2Fwith~2Fslashes~2Ecsv", - "http://localhost/fixtures/table~2Fwith~2Fslashes~2Ecsv.json", - ), - # Row page - ( - "/fixtures/no_primary_key/1", - "http://localhost/fixtures/no_primary_key/1.json", - ), - # Database index page - ( - "/fixtures", - "http://localhost/fixtures.json", - ), - # Custom query page - ( - "/fixtures/-/query?sql=select+*+from+facetable", - "http://localhost/fixtures/-/query.json?sql=select+*+from+facetable", - ), - # Stored query page - ( - "/fixtures/neighborhood_search?text=town", - "http://localhost/fixtures/neighborhood_search.json?text=town", - ), - # /-/ pages - ( - "/-/plugins", - "http://localhost/-/plugins.json", - ), - ), -) -async def test_alternate_url_json(ds_client, path, expected): - response = await ds_client.get(path) - assert response.status_code == 200 - link = response.headers["link"] - assert link == '<{}>; rel="alternate"; type="application/json+datasette"'.format( - expected - ) - assert ( - ''.format( - expected - ) - in response.text - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path", - ("/-/patterns", "/-/messages", "/-/allow-debug", "/fixtures.db"), -) -async def test_no_alternate_url_json(ds_client, path): - response = await ds_client.get(path) - assert "link" not in response.headers - assert ( - 'Name" in response.text - assert "view-instance" in response.text - assert "view-database" in response.text - finally: - ds_client.ds.root_enabled = original_root_enabled - - -@pytest.mark.asyncio -async def test_actions_page_does_not_display_none_string(ds_client): - """Ensure the Resource column doesn't display the string 'None' for null values.""" - # https://github.com/simonw/datasette/issues/2599 - original_root_enabled = ds_client.ds.root_enabled - try: - ds_client.ds.root_enabled = True - response = await ds_client.get("/-/actions", actor={"id": "root"}) - assert response.status_code == 200 - assert "None" not in response.text - finally: - ds_client.ds.root_enabled = original_root_enabled - - -@pytest.mark.asyncio -async def test_permission_debug_tabs_with_query_string(ds_client): - """Test that navigation tabs persist query strings across Check, Allowed, and Rules pages""" - original_root_enabled = ds_client.ds.root_enabled - try: - ds_client.ds.root_enabled = True - actor = {"id": "root"} - - # Test /-/allowed with query string - response = await ds_client.get( - "/-/allowed?action=view-table&page_size=50", actor=actor - ) - assert response.status_code == 200 - # Check that Rules and Check tabs have the query string - assert 'href="/-/rules?action=view-table&page_size=50"' in response.text - assert 'href="/-/check?action=view-table&page_size=50"' in response.text - # Playground and Actions should not have query string - assert 'href="/-/permissions"' in response.text - assert 'href="/-/actions"' in response.text - - # Test /-/rules with query string - response = await ds_client.get( - "/-/rules?action=view-database&parent=test", actor=actor - ) - assert response.status_code == 200 - # Check that Allowed and Check tabs have the query string - assert 'href="/-/allowed?action=view-database&parent=test"' in response.text - assert 'href="/-/check?action=view-database&parent=test"' in response.text - - # Test /-/check with query string - response = await ds_client.get("/-/check?action=execute-sql", actor=actor) - assert response.status_code == 200 - # Check that Allowed and Rules tabs have the query string - assert 'href="/-/allowed?action=execute-sql"' in response.text - assert 'href="/-/rules?action=execute-sql"' in response.text - finally: - ds_client.ds.root_enabled = original_root_enabled +def inner_html(soup): + html = str(soup) + # This includes the parent tag - so remove that + inner_html = html.split('>', 1)[1].rsplit('<', 1)[0] + return inner_html.strip() diff --git a/tests/test_inspect.py b/tests/test_inspect.py new file mode 100644 index 00000000..954dcff7 --- /dev/null +++ b/tests/test_inspect.py @@ -0,0 +1,103 @@ +from datasette.app import Datasette +import os +import pytest +import sqlite3 +import tempfile + + +TABLES = ''' +CREATE TABLE "election_results" ( + "county" INTEGER, + "party" INTEGER, + "office" INTEGER, + "votes" INTEGER, + FOREIGN KEY (county) REFERENCES county(id), + FOREIGN KEY (party) REFERENCES party(id), + FOREIGN KEY (office) REFERENCES office(id) + ); + +CREATE VIRTUAL TABLE "election_results_fts" USING FTS4 ("county", "party"); + +CREATE TABLE "county" ( + "id" INTEGER PRIMARY KEY , + "name" TEXT +); + +CREATE TABLE "party" ( + "id" INTEGER PRIMARY KEY , + "name" TEXT +); + +CREATE TABLE "office" ( + "id" INTEGER PRIMARY KEY , + "name" TEXT +); +''' + + +@pytest.fixture(scope='session') +def ds_instance(): + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, 'test_tables.db') + conn = sqlite3.connect(filepath) + conn.executescript(TABLES) + yield Datasette([filepath]) + + +def test_inspect_hidden_tables(ds_instance): + info = ds_instance.inspect() + tables = info['test_tables']['tables'] + expected_hidden = ( + 'election_results_fts', + 'election_results_fts_content', + 'election_results_fts_docsize', + 'election_results_fts_segdir', + 'election_results_fts_segments', + 'election_results_fts_stat', + ) + expected_visible = ( + 'election_results', + 'county', + 'party', + 'office', + ) + assert sorted(expected_hidden) == sorted( + [table for table in tables if tables[table]['hidden']] + ) + assert sorted(expected_visible) == sorted( + [table for table in tables if not tables[table]['hidden']] + ) + + +def test_inspect_foreign_keys(ds_instance): + info = ds_instance.inspect() + tables = info['test_tables']['tables'] + for table_name in ('county', 'party', 'office'): + assert 0 == tables[table_name]['count'] + foreign_keys = tables[table_name]['foreign_keys'] + assert [] == foreign_keys['outgoing'] + assert [{ + 'column': 'id', + 'other_column': table_name, + 'other_table': 'election_results' + }] == foreign_keys['incoming'] + + election_results = tables['election_results'] + assert 0 == election_results['count'] + assert sorted([{ + 'column': 'county', + 'other_column': 'id', + 'other_table': 'county' + }, { + 'column': 'party', + 'other_column': 'id', + 'other_table': 'party' + }, { + 'column': 'office', + 'other_column': 'id', + 'other_table': 'office' + }], key=lambda d: d['column']) == sorted( + election_results['foreign_keys']['outgoing'], + key=lambda d: d['column'] + ) + assert [] == election_results['foreign_keys']['incoming'] diff --git a/tests/test_internal_db.py b/tests/test_internal_db.py deleted file mode 100644 index 26d63a92..00000000 --- a/tests/test_internal_db.py +++ /dev/null @@ -1,234 +0,0 @@ -import pytest -import sqlite_utils - - -# ensure refresh_schemas() gets called before interacting with internal_db -async def ensure_internal(ds_client): - await ds_client.get("/fixtures.json?sql=select+1") - return ds_client.ds.get_internal_database() - - -@pytest.mark.asyncio -async def test_internal_databases(ds_client): - internal_db = await ensure_internal(ds_client) - databases = await internal_db.execute("select * from catalog_databases") - assert len(databases) == 1 - assert databases.rows[0]["database_name"] == "fixtures" - - -@pytest.mark.asyncio -async def test_internal_tables(ds_client): - internal_db = await ensure_internal(ds_client) - tables = await internal_db.execute("select * from catalog_tables") - assert len(tables) > 5 - table = tables.rows[0] - assert set(table.keys()) == {"rootpage", "table_name", "database_name", "sql"} - - -@pytest.mark.asyncio -async def test_internal_views(ds_client): - internal_db = await ensure_internal(ds_client) - views = await internal_db.execute("select * from catalog_views") - assert len(views) >= 4 - view = views.rows[0] - assert set(view.keys()) == {"rootpage", "view_name", "database_name", "sql"} - - -@pytest.mark.asyncio -async def test_internal_indexes(ds_client): - internal_db = await ensure_internal(ds_client) - indexes = await internal_db.execute("select * from catalog_indexes") - assert len(indexes) > 5 - index = indexes.rows[0] - assert set(index.keys()) == { - "partial", - "name", - "table_name", - "unique", - "seq", - "database_name", - "origin", - } - - -@pytest.mark.asyncio -async def test_internal_foreign_keys(ds_client): - internal_db = await ensure_internal(ds_client) - foreign_keys = await internal_db.execute("select * from catalog_foreign_keys") - assert len(foreign_keys) > 5 - foreign_key = foreign_keys.rows[0] - assert set(foreign_key.keys()) == { - "table", - "seq", - "on_update", - "on_delete", - "to", - "id", - "match", - "database_name", - "table_name", - "from", - } - - -@pytest.mark.asyncio -async def test_internal_foreign_key_references(ds_client): - internal_db = await ensure_internal(ds_client) - - def inner(conn): - db = sqlite_utils.Database(conn) - table_names = db.table_names() - for table in db.tables: - for fk in table.foreign_keys: - other_table = fk.other_table - other_column = fk.other_column - message = 'Column "{}.{}" references other column "{}.{}" which does not exist'.format( - table.name, fk.column, other_table, other_column - ) - assert other_table in table_names, message + " (bad table)" - assert other_column in db[other_table].columns_dict, ( - message + " (bad column)" - ) - - await internal_db.execute_fn(inner) - - -@pytest.mark.asyncio -async def test_stale_catalog_entry_database_fix(tmp_path): - """ - Test for https://github.com/simonw/datasette/issues/2605 - - When the internal database persists across restarts and has entries in - catalog_databases for databases that no longer exist, accessing the - index page should not cause a 500 error (KeyError). - """ - from datasette.app import Datasette - - internal_db_path = str(tmp_path / "internal.db") - data_db_path = str(tmp_path / "data.db") - - # Create a data database file - import sqlite3 - - conn = sqlite3.connect(data_db_path) - conn.execute("CREATE TABLE test_table (id INTEGER PRIMARY KEY)") - conn.close() - - # First Datasette instance: with the data database and persistent internal db - ds1 = Datasette(files=[data_db_path], internal=internal_db_path) - await ds1.invoke_startup() - - # Access the index page to populate the internal catalog - response = await ds1.client.get("/") - assert "data" in ds1.databases - assert response.status_code == 200 - - # Second Datasette instance: reusing internal.db but WITHOUT the data database - # This simulates restarting Datasette after removing a database - ds2 = Datasette(internal=internal_db_path) - await ds2.invoke_startup() - - # The database is not in ds2.databases - assert "data" not in ds2.databases - - # Accessing the index page should NOT cause a 500 error - # This is the bug: it currently raises KeyError when trying to - # access ds.databases["data"] for the stale catalog entry - response = await ds2.client.get("/") - assert response.status_code == 200, ( - f"Index page should return 200, not {response.status_code}. " - "This fails due to stale catalog entries causing KeyError." - ) - - -@pytest.mark.asyncio -async def test_stale_catalog_child_entries_removed_for_missing_database(tmp_path): - from datasette.app import Datasette - - import sqlite3 - - internal_db_path = str(tmp_path / "internal.db") - alpha_db_path = str(tmp_path / "alpha.db") - bravo_db_path = str(tmp_path / "bravo.db") - - for db_path, table_name in ( - (alpha_db_path, "alpha_table"), - (bravo_db_path, "bravo_table"), - (bravo_db_path, "bravo_table_2"), - ): - conn = sqlite3.connect(db_path) - conn.execute(f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY)") - conn.close() - - ds1 = Datasette(files=[alpha_db_path, bravo_db_path], internal=internal_db_path) - await ds1.invoke_startup() - - catalog_tables = await ds1.get_internal_database().execute(""" - SELECT database_name, table_name - FROM catalog_tables - ORDER BY database_name, table_name - """) - assert [tuple(row) for row in catalog_tables.rows] == [ - ("alpha", "alpha_table"), - ("bravo", "bravo_table"), - ("bravo", "bravo_table_2"), - ] - - ds1.close() - - ds2 = Datasette(files=[alpha_db_path], internal=internal_db_path) - await ds2.invoke_startup() - - catalog_tables = await ds2.get_internal_database().execute(""" - SELECT database_name, table_name - FROM catalog_tables - ORDER BY database_name, table_name - """) - assert [tuple(row) for row in catalog_tables.rows] == [("alpha", "alpha_table")] - - ds2.close() - - -@pytest.mark.asyncio -async def test_orphan_stale_catalog_child_entries_removed(tmp_path): - from datasette.app import Datasette - - import sqlite3 - - internal_db_path = str(tmp_path / "internal.db") - alpha_db_path = str(tmp_path / "alpha.db") - - conn = sqlite3.connect(alpha_db_path) - conn.execute("CREATE TABLE alpha_table (id INTEGER PRIMARY KEY)") - conn.close() - - ds1 = Datasette(files=[alpha_db_path], internal=internal_db_path) - await ds1.invoke_startup() - ds1.close() - - # Simulate the state left behind by old cleanup code: the parent database - # row was deleted, but child catalog rows survived because foreign key - # enforcement is not enabled for these internal catalog writes. - conn = sqlite3.connect(internal_db_path) - conn.execute("DELETE FROM catalog_databases WHERE database_name = 'fixtures'") - conn.execute(""" - INSERT INTO catalog_tables (database_name, table_name, rootpage, sql) - VALUES ('fixtures', 'stale_table', 1, 'CREATE TABLE stale_table (id INTEGER)') - """) - conn.commit() - conn.close() - - ds2 = Datasette(files=[alpha_db_path], internal=internal_db_path) - await ds2.invoke_startup() - - catalog_tables = await ds2.get_internal_database().execute(""" - SELECT database_name, table_name - FROM catalog_tables - ORDER BY database_name, table_name - """) - assert [tuple(row) for row in catalog_tables.rows] == [("alpha", "alpha_table")] - - response = await ds2.client.get("/-/jump.json") - assert response.status_code == 200 - - ds2.close() diff --git a/tests/test_internals_database.py b/tests/test_internals_database.py deleted file mode 100644 index 5481a398..00000000 --- a/tests/test_internals_database.py +++ /dev/null @@ -1,973 +0,0 @@ -""" -Tests for the datasette.database.Database class -""" - -import asyncio -from types import SimpleNamespace -from datasette.app import Datasette -from datasette.database import Database, Results, MultipleValues -from datasette.database import DatasetteClosedError -from datasette.database import _deliver_write_result -from datasette.utils.sqlite import sqlite3, sqlite_version -from datasette.utils import Column -import pytest -import time -import uuid - - -@pytest.fixture -def db(app_client): - return app_client.ds.get_database("fixtures") - - -@pytest.mark.asyncio -async def test_execute(db): - results = await db.execute("select * from facetable") - assert isinstance(results, Results) - assert 15 == len(results) - - -@pytest.mark.asyncio -async def test_results_first(db): - assert None is (await db.execute("select * from facetable where pk > 100")).first() - results = await db.execute("select * from facetable") - row = results.first() - assert isinstance(row, sqlite3.Row) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("expected", (True, False)) -async def test_results_bool(db, expected): - where = "" if expected else "where pk = 0" - results = await db.execute("select * from facetable {}".format(where)) - assert bool(results) is expected - - -@pytest.mark.asyncio -async def test_results_dicts(db): - results = await db.execute("select pk, name from roadside_attractions") - assert results.dicts() == [ - {"pk": 1, "name": "The Mystery Spot"}, - {"pk": 2, "name": "Winchester Mystery House"}, - {"pk": 3, "name": "Burlingame Museum of PEZ Memorabilia"}, - {"pk": 4, "name": "Bigfoot Discovery Museum"}, - ] - - -@pytest.mark.parametrize( - "query,expected", - [ - ("select 1", 1), - ("select 1, 2", None), - ("select 1 as num union select 2 as num", None), - ], -) -@pytest.mark.asyncio -async def test_results_single_value(db, query, expected): - results = await db.execute(query) - if expected: - assert expected == results.single_value() - else: - with pytest.raises(MultipleValues): - results.single_value() - - -@pytest.mark.asyncio -async def test_execute_fn(db): - def get_1_plus_1(conn): - return conn.execute("select 1 + 1").fetchall()[0][0] - - assert 2 == await db.execute_fn(get_1_plus_1) - - -@pytest.mark.asyncio -async def test_execute_fn_transaction_false(): - datasette = Datasette(memory=True) - db = datasette.add_memory_database("test_execute_fn_transaction_false") - - def run(conn): - try: - with conn: - conn.execute("create table foo (id integer primary key)") - conn.execute("insert into foo (id) values (44)") - # Table should exist - assert ( - conn.execute( - 'select count(*) from sqlite_master where name = "foo"' - ).fetchone()[0] - == 1 - ) - assert conn.execute("select id from foo").fetchall()[0][0] == 44 - raise ValueError("Cancel commit") - except ValueError: - pass - # Row should NOT exist - assert conn.execute("select count(*) from foo").fetchone()[0] == 0 - - await db.execute_write_fn(run, transaction=False) - - -@pytest.mark.parametrize( - "tables,exists", - ( - (["facetable", "searchable", "tags", "searchable_tags"], True), - (["foo", "bar", "baz"], False), - ), -) -@pytest.mark.asyncio -async def test_table_exists(db, tables, exists): - for table in tables: - actual = await db.table_exists(table) - assert exists == actual - - -@pytest.mark.parametrize( - "view,expected", - ( - ("not_a_view", False), - ("paginated_view", True), - ), -) -@pytest.mark.asyncio -async def test_view_exists(db, view, expected): - actual = await db.view_exists(view) - assert actual == expected - - -@pytest.mark.parametrize( - "table,expected", - ( - ( - "facetable", - [ - "pk", - "created", - "planet_int", - "on_earth", - "state", - "_city_id", - "_neighborhood", - "tags", - "complex_array", - "distinct_some_null", - "n", - ], - ), - ( - "sortable", - [ - "pk1", - "pk2", - "content", - "sortable", - "sortable_with_nulls", - "sortable_with_nulls_2", - "text", - ], - ), - ), -) -@pytest.mark.asyncio -async def test_table_columns(db, table, expected): - columns = await db.table_columns(table) - assert columns == expected - - -@pytest.mark.parametrize( - "table,expected", - ( - ( - "facetable", - [ - Column( - cid=0, - name="pk", - type="integer", - notnull=0, - default_value=None, - is_pk=1, - hidden=0, - ), - Column( - cid=1, - name="created", - type="text", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=2, - name="planet_int", - type="integer", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=3, - name="on_earth", - type="integer", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=4, - name="state", - type="text", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=5, - name="_city_id", - type="integer", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=6, - name="_neighborhood", - type="text", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=7, - name="tags", - type="text", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=8, - name="complex_array", - type="text", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=9, - name="distinct_some_null", - type="", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=10, - name="n", - type="text", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - ], - ), - ( - "sortable", - [ - Column( - cid=0, - name="pk1", - type="varchar(30)", - notnull=0, - default_value=None, - is_pk=1, - hidden=0, - ), - Column( - cid=1, - name="pk2", - type="varchar(30)", - notnull=0, - default_value=None, - is_pk=2, - hidden=0, - ), - Column( - cid=2, - name="content", - type="text", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=3, - name="sortable", - type="integer", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=4, - name="sortable_with_nulls", - type="real", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=5, - name="sortable_with_nulls_2", - type="real", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - Column( - cid=6, - name="text", - type="text", - notnull=0, - default_value=None, - is_pk=0, - hidden=0, - ), - ], - ), - ), -) -@pytest.mark.asyncio -async def test_table_column_details(db, table, expected): - columns = await db.table_column_details(table) - # Convert "type" to lowercase before comparison - # https://github.com/simonw/datasette/issues/1647 - compare_columns = [ - Column( - c.cid, c.name, c.type.lower(), c.notnull, c.default_value, c.is_pk, c.hidden - ) - for c in columns - ] - assert compare_columns == expected - - -@pytest.mark.asyncio -async def test_get_all_foreign_keys(db): - all_foreign_keys = await db.get_all_foreign_keys() - assert all_foreign_keys["roadside_attraction_characteristics"] == { - "incoming": [], - "outgoing": [ - { - "other_table": "attraction_characteristic", - "column": "characteristic_id", - "other_column": "pk", - }, - { - "other_table": "roadside_attractions", - "column": "attraction_id", - "other_column": "pk", - }, - ], - } - assert all_foreign_keys["attraction_characteristic"] == { - "incoming": [ - { - "other_table": "roadside_attraction_characteristics", - "column": "pk", - "other_column": "characteristic_id", - } - ], - "outgoing": [], - } - assert all_foreign_keys["compound_primary_key"] == { - # No incoming because these are compound foreign keys, which we currently ignore - "incoming": [], - "outgoing": [], - } - assert all_foreign_keys["foreign_key_references"] == { - "incoming": [], - "outgoing": [ - { - "other_table": "primary_key_multiple_columns", - "column": "foreign_key_with_no_label", - "other_column": "id", - }, - { - "other_table": "simple_primary_key", - "column": "foreign_key_with_blank_label", - "other_column": "id", - }, - { - "other_table": "simple_primary_key", - "column": "foreign_key_with_label", - "other_column": "id", - }, - ], - } - - -@pytest.mark.asyncio -async def test_table_names(db): - table_names = await db.table_names() - # Tables are sorted alphabetically by name - assert table_names == [ - "123_starts_with_digits", - "Table With Space In Name", - "attraction_characteristic", - "binary_data", - "complex_foreign_keys", - "compound_primary_key", - "compound_three_primary_keys", - "custom_foreign_key_label", - "facet_cities", - "facetable", - "foreign_key_references", - "infinity", - "no_primary_key", - "primary_key_multiple_columns", - "primary_key_multiple_columns_explicit_label", - "roadside_attraction_characteristics", - "roadside_attractions", - "searchable", - "searchable_fts", - "searchable_fts_config", - "searchable_fts_data", - "searchable_fts_docsize", - "searchable_fts_idx", - "searchable_tags", - "select", - "simple_primary_key", - "sortable", - "table/with/slashes.csv", - "tags", - ] - - -@pytest.mark.asyncio -async def test_view_names(db): - view_names = await db.view_names() - assert view_names == [ - "paginated_view", - "simple_view", - "searchable_view", - "searchable_view_configured_by_metadata", - ] - - -@pytest.mark.asyncio -async def test_execute_write_block_true(db): - await db.execute_write( - "update roadside_attractions set name = ? where pk = ?", ["Mystery!", 1] - ) - rows = await db.execute("select name from roadside_attractions where pk = 1") - assert "Mystery!" == rows.rows[0][0] - - -@pytest.mark.asyncio -async def test_execute_write_block_false(db): - await db.execute_write( - "update roadside_attractions set name = ? where pk = ?", - ["Mystery!", 1], - ) - time.sleep(0.1) - rows = await db.execute("select name from roadside_attractions where pk = 1") - assert "Mystery!" == rows.rows[0][0] - - -@pytest.mark.asyncio -async def test_execute_write_script(db): - await db.execute_write_script( - "create table foo (id integer primary key); create table bar (id integer primary key);" - ) - table_names = await db.table_names() - assert {"foo", "bar"}.issubset(table_names) - - -@pytest.mark.asyncio -async def test_execute_write_many(db): - await db.execute_write_script("create table foomany (id integer primary key)") - await db.execute_write_many( - "insert into foomany (id) values (?)", [(1,), (10,), (100,)] - ) - result = await db.execute("select * from foomany") - assert [r[0] for r in result.rows] == [1, 10, 100] - - -@pytest.mark.asyncio -async def test_execute_write_has_correctly_prepared_connection(db): - # The sleep() function is only available if ds._prepare_connection() was called - await db.execute_write("select sleep(0.01)") - - -@pytest.mark.asyncio -async def test_execute_write_fn_block_false(db): - def write_fn(conn): - conn.execute("delete from roadside_attractions where pk = 1;") - row = conn.execute("select count(*) from roadside_attractions").fetchone() - return row[0] - - task_id = await db.execute_write_fn(write_fn, block=False) - assert isinstance(task_id, uuid.UUID) - - -@pytest.mark.asyncio -async def test_execute_write_fn_block_true(db): - def write_fn(conn): - conn.execute("delete from roadside_attractions where pk = 1;") - row = conn.execute("select count(*) from roadside_attractions").fetchone() - return row[0] - - new_count = await db.execute_write_fn(write_fn) - assert 3 == new_count - - -@pytest.mark.asyncio -async def test_execute_write_fn_exception(db): - def write_fn(conn): - assert False - - with pytest.raises(AssertionError): - await db.execute_write_fn(write_fn) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("param_name", ["conn", "connection", "db", "c"]) -async def test_execute_write_fn_accepts_any_single_param_name(db, param_name): - # Plugins historically relied on the fact that the callback was invoked - # positionally, so any parameter name worked. Preserve that contract. - scope = {} - exec( - "def write_fn({0}):\n" - " return {0}.execute('select 1 + 1').fetchone()[0]".format(param_name), - scope, - ) - write_fn = scope["write_fn"] - result = await db.execute_write_fn(write_fn) - assert result == 2 - - -@pytest.mark.asyncio -async def test_execute_write_fn_with_track_event(db): - # When the callback declares track_event it still receives both args - # via dependency injection. - seen = [] - - def write_fn(conn, track_event): - seen.append(track_event) - return conn.execute("select 1 + 1").fetchone()[0] - - result = await db.execute_write_fn(write_fn) - assert result == 2 - assert len(seen) == 1 and callable(seen[0]) - - -@pytest.mark.asyncio -@pytest.mark.timeout(1) -async def test_execute_write_fn_connection_exception(tmpdir, app_client): - path = str(tmpdir / "immutable.db") - conn = sqlite3.connect(path) - conn.execute("vacuum") - conn.close() - db = Database(app_client.ds, path=path, is_mutable=False) - app_client.ds.add_database(db, name="immutable-db") - - def write_fn(conn): - assert False - - with pytest.raises(AssertionError): - await db.execute_write_fn(write_fn) - - app_client.ds.remove_database("immutable-db") - - -@pytest.mark.asyncio -async def test_deliver_write_result_leaves_done_future_alone(): - loop = asyncio.get_running_loop() - reply_future = loop.create_future() - reply_future.set_result("original") - task = SimpleNamespace(loop=loop, reply_future=reply_future) - - # The write thread can finish after the caller has stopped waiting for the - # result. Delivery should notice that the future is already resolved and - # leave the caller's outcome alone instead of raising InvalidStateError. - _deliver_write_result(task, "replacement", None) - await asyncio.sleep(0) - - assert reply_future.result() == "original" - - -@pytest.mark.asyncio -async def test_deliver_write_result_ignores_closed_loop(): - closed_loop = asyncio.new_event_loop() - closed_loop.close() - reply_future = asyncio.get_running_loop().create_future() - task = SimpleNamespace(loop=closed_loop, reply_future=reply_future) - - # If the event loop that submitted the write has gone away, the write - # thread should drop the result rather than crash while reporting back to - # that closed loop. - _deliver_write_result(task, "result", None) - - assert not reply_future.done() - - -def table_exists(conn, name): - return bool( - conn.execute( - """ - with all_tables as ( - select name from sqlite_master where type = 'table' - union all - select name from temp.sqlite_master where type = 'table' - ) - select 1 from all_tables where name = ? - """, - (name,), - ).fetchall(), - ) - - -def table_exists_checker(name): - def inner(conn): - return table_exists(conn, name) - - return inner - - -@pytest.mark.asyncio -@pytest.mark.parametrize("disable_threads", (False, True)) -async def test_execute_isolated(db, disable_threads): - if disable_threads: - ds = Datasette(memory=True, settings={"num_sql_threads": 0}) - db = ds.add_database(Database(ds, memory_name="test_num_sql_threads_zero")) - - # Create temporary table in write - await db.execute_write( - "create temporary table created_by_write (id integer primary key)" - ) - # Should stay visible to write connection - assert await db.execute_write_fn(table_exists_checker("created_by_write")) - - def create_shared_table(conn): - conn.execute("create table shared (id integer primary key)") - # And a temporary table that should not continue to exist - conn.execute( - "create temporary table created_by_isolated (id integer primary key)" - ) - assert table_exists(conn, "created_by_isolated") - # Also confirm that created_by_write does not exist - return table_exists(conn, "created_by_write") - - # shared should not exist - assert not await db.execute_fn(table_exists_checker("shared")) - - # Create it using isolated - created_by_write_exists = await db.execute_isolated_fn(create_shared_table) - assert not created_by_write_exists - - # shared SHOULD exist now - assert await db.execute_fn(table_exists_checker("shared")) - - # created_by_isolated should not exist, even in write connection - assert not await db.execute_write_fn(table_exists_checker("created_by_isolated")) - - # ... and a second call to isolated should not see that connection either - assert not await db.execute_isolated_fn(table_exists_checker("created_by_isolated")) - - -@pytest.mark.asyncio -async def test_analyze_sql(): - ds = Datasette(memory=True) - db = ds.add_memory_database("test_analyze_sql", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - - analysis = await db.analyze_sql("select name from dogs where id = ?", (1,)) - - assert [ - ( - access.operation, - access.database, - access.sqlite_schema, - access.table, - access.columns, - access.source, - ) - for access in analysis.table_accesses - ] == [ - ("read", "data", "main", "dogs", ("id", "name"), None), - ] - - -@pytest.mark.asyncio -async def test_analyze_sql_insert_select(): - ds = Datasette(memory=True) - db = ds.add_memory_database("test_analyze_sql_insert_select", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await db.execute_write("create table cats (id integer primary key, name text)") - - analysis = await db.analyze_sql("insert into dogs (name) select name from cats") - - assert { - ( - access.operation, - access.database, - access.sqlite_schema, - access.table, - access.columns, - access.source, - ) - for access in analysis.table_accesses - } == { - ("insert", "data", "main", "dogs", (), None), - ("read", "data", "main", "cats", ("name",), None), - } - - -@pytest.mark.asyncio -async def test_mtime_ns(db): - assert isinstance(db.mtime_ns, int) - - -def test_mtime_ns_is_none_for_memory(app_client): - memory_db = Database(app_client.ds, is_memory=True) - assert memory_db.is_memory is True - assert None is memory_db.mtime_ns - - -def test_is_mutable(app_client): - assert Database(app_client.ds, is_memory=True).is_mutable is True - assert Database(app_client.ds, is_memory=True, is_mutable=True).is_mutable is True - assert Database(app_client.ds, is_memory=True, is_mutable=False).is_mutable is False - - -@pytest.mark.asyncio -async def test_attached_databases(app_client_two_attached_databases_crossdb_enabled): - database = app_client_two_attached_databases_crossdb_enabled.ds.get_database( - "_memory" - ) - attached = await database.attached_databases() - assert {a.name for a in attached} == {"extra database", "fixtures"} - - -@pytest.mark.asyncio -async def test_database_memory_name(app_client): - ds = app_client.ds - foo1 = ds.add_database(Database(ds, memory_name="foo")) - foo2 = ds.add_memory_database("foo") - bar1 = ds.add_database(Database(ds, memory_name="bar")) - bar2 = ds.add_memory_database("bar") - for db in (foo1, foo2, bar1, bar2): - table_names = await db.table_names() - assert table_names == [] - # Now create a table in foo - await foo1.execute_write("create table foo (t text)") - assert await foo1.table_names() == ["foo"] - assert await foo2.table_names() == ["foo"] - assert await bar1.table_names() == [] - assert await bar2.table_names() == [] - - -@pytest.mark.asyncio -async def test_in_memory_databases_forbid_writes(app_client): - ds = app_client.ds - db = ds.add_database(Database(ds, memory_name="test")) - with pytest.raises(sqlite3.OperationalError): - await db.execute("create table foo (t text)") - assert await db.table_names() == [] - # Using db.execute_write() should work: - await db.execute_write("create table foo (t text)") - assert await db.table_names() == ["foo"] - - -def pragma_table_list_supported(): - return sqlite_version()[1] >= 37 - - -@pytest.mark.asyncio -@pytest.mark.skipif( - not pragma_table_list_supported(), reason="Requires PRAGMA table_list support" -) -async def test_hidden_tables(app_client): - ds = app_client.ds - db = ds.add_database(Database(ds, is_memory=True, is_mutable=True)) - assert await db.hidden_table_names() == [] - await db.execute("create virtual table f using fts5(a)") - assert await db.hidden_table_names() == [ - "f_config", - "f_content", - "f_data", - "f_docsize", - "f_idx", - ] - - await db.execute("create virtual table r using rtree(id, amin, amax)") - assert await db.hidden_table_names() == [ - "f_config", - "f_content", - "f_data", - "f_docsize", - "f_idx", - "r_node", - "r_parent", - "r_rowid", - ] - - await db.execute("create table _hideme(_)") - assert await db.hidden_table_names() == [ - "_hideme", - "f_config", - "f_content", - "f_data", - "f_docsize", - "f_idx", - "r_node", - "r_parent", - "r_rowid", - ] - - # A fts virtual table with a content table should be hidden too - await db.execute("create virtual table f2_fts using fts5(a, content='f')") - assert await db.hidden_table_names() == [ - "_hideme", - "f2_fts_config", - "f2_fts_data", - "f2_fts_docsize", - "f2_fts_idx", - "f_config", - "f_content", - "f_data", - "f_docsize", - "f_idx", - "r_node", - "r_parent", - "r_rowid", - "f2_fts", - ] - - -@pytest.mark.asyncio -async def test_replace_database(tmpdir): - path1 = str(tmpdir / "data1.db") - (tmpdir / "two").mkdir() - path2 = str(tmpdir / "two" / "data1.db") - conn1 = sqlite3.connect(path1) - conn1.executescript(""" - create table t (id integer primary key); - insert into t (id) values (1); - insert into t (id) values (2); - """) - conn1.close() - conn2 = sqlite3.connect(path2) - conn2.executescript(""" - create table t (id integer primary key); - insert into t (id) values (1); - """) - conn2.close() - datasette = Datasette([path1]) - db = datasette.get_database("data1") - count = (await db.execute("select count(*) from t")).first()[0] - assert count == 2 - # Now replace that database - datasette.get_database("data1").close() - datasette.remove_database("data1") - datasette.add_database(Database(datasette, path2), "data1") - db2 = datasette.get_database("data1") - count = (await db2.execute("select count(*) from t")).first()[0] - assert count == 1 - - -@pytest.mark.parametrize( - "kwargs,expected_repr", - [ - ({"is_memory": True}, ""), - ({"memory_name": "my_mem"}, ""), - ( - {"is_memory": True, "is_mutable": False}, - "", - ), - ], - ids=["memory", "named_memory", "immutable_memory"], -) -def test_repr(app_client, kwargs, expected_repr): - db = Database(app_client.ds, **kwargs) - db.name = "test_db" - assert repr(db) == expected_repr - - -def test_repr_temp_disk(app_client): - db = Database(app_client.ds, is_temp_disk=True) - db.name = "test_db" - r = repr(db) - assert r.startswith("") - assert isinstance(db.size, int) - assert isinstance(db.mtime_ns, int) - db.close() - - -@pytest.mark.asyncio -async def test_database_close_shuts_down_write_thread(tmpdir): - path = str(tmpdir / "dbclose.db") - conn = sqlite3.connect(path) - conn.execute("create table t (id integer primary key)") - conn.close() - ds = Datasette([path]) - db = ds.get_database("dbclose") - # Trigger write thread creation - await db.execute_write("insert into t (id) values (1)") - assert db._write_thread is not None - assert db._write_thread.is_alive() - db.close() - # Wait briefly for the thread to exit — the sentinel should cause it to return. - db._write_thread.join(timeout=5) - assert not db._write_thread.is_alive() - ds._internal_database.close() - - -@pytest.mark.asyncio -async def test_database_close_raises_on_further_use(tmpdir): - path = str(tmpdir / "closed.db") - conn = sqlite3.connect(path) - conn.execute("create table t (id integer primary key)") - conn.close() - ds = Datasette([path]) - db = ds.get_database("closed") - await db.execute("select 1") - db.close() - with pytest.raises(DatasetteClosedError): - await db.execute("select 1") - with pytest.raises(DatasetteClosedError): - await db.execute_write("insert into t (id) values (1)") - with pytest.raises(DatasetteClosedError): - await db.execute_fn(lambda conn: conn.execute("select 1").fetchone()) - with pytest.raises(DatasetteClosedError): - await db.execute_write_fn(lambda conn: conn.execute("select 1")) - ds._internal_database.close() - - -@pytest.mark.asyncio -async def test_database_close_is_idempotent(tmpdir): - path = str(tmpdir / "idemp.db") - conn = sqlite3.connect(path) - conn.execute("create table t (id integer primary key)") - conn.close() - ds = Datasette([path]) - db = ds.get_database("idemp") - await db.execute_write("insert into t (id) values (1)") - db.close() - # Second call should be a no-op, not raise - db.close() - ds._internal_database.close() diff --git a/tests/test_internals_datasette.py b/tests/test_internals_datasette.py deleted file mode 100644 index 3f867eb0..00000000 --- a/tests/test_internals_datasette.py +++ /dev/null @@ -1,346 +0,0 @@ -""" -Tests for the datasette.app.Datasette class -""" - -import asyncio -import dataclasses -import os -import sqlite3 -import time -from datasette import Context -from datasette.app import Datasette, Database, ResourcesSQL -from datasette.database import DatasetteClosedError -from datasette.resources import DatabaseResource -from itsdangerous import BadSignature -import pytest - - -@pytest.fixture -def datasette(ds_client): - return ds_client.ds - - -def test_get_database(datasette): - db = datasette.get_database("fixtures") - assert "fixtures" == db.name - with pytest.raises(KeyError): - datasette.get_database("missing") - - -def test_get_database_no_argument(datasette): - # Returns the first available database: - db = datasette.get_database() - assert "fixtures" == db.name - - -@pytest.mark.parametrize("value", ["hello", 123, {"key": "value"}]) -@pytest.mark.parametrize("namespace", [None, "two"]) -def test_sign_unsign(datasette, value, namespace): - extra_args = [namespace] if namespace else [] - signed = datasette.sign(value, *extra_args) - assert value != signed - assert value == datasette.unsign(signed, *extra_args) - with pytest.raises(BadSignature): - datasette.unsign(signed[:-1] + ("!" if signed[-1] != "!" else ":")) - - -@pytest.mark.parametrize( - "setting,expected", - ( - ("base_url", "/"), - ("max_csv_mb", 100), - ("allow_csv_stream", True), - ), -) -def test_datasette_setting(datasette, setting, expected): - assert datasette.setting(setting) == expected - - -@pytest.mark.asyncio -async def test_datasette_constructor(): - ds = Datasette() - databases = (await ds.client.get("/-/databases.json")).json() - assert databases == [ - { - "name": "_memory", - "route": "_memory", - "path": None, - "size": 0, - "is_mutable": False, - "is_memory": True, - "hash": None, - } - ] - - -@pytest.mark.asyncio -async def test_num_sql_threads_zero(): - ds = Datasette([], memory=True, settings={"num_sql_threads": 0}) - db = ds.add_database(Database(ds, memory_name="test_num_sql_threads_zero")) - await db.execute_write("create table t(id integer primary key)") - await db.execute_write("insert into t (id) values (1)") - response = await ds.client.get("/-/threads.json") - assert response.json() == {"num_threads": 0, "threads": []} - response2 = await ds.client.get("/test_num_sql_threads_zero/t.json?_shape=array") - assert response2.json() == [{"id": 1}] - - -ROOT = {"id": "root"} -ALLOW_ROOT = {"allow": {"id": "root"}} - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "actor,config,action,resource,should_allow,expected_private", - ( - (None, ALLOW_ROOT, "view-instance", None, False, False), - (ROOT, ALLOW_ROOT, "view-instance", None, True, True), - ( - None, - {"databases": {"_memory": ALLOW_ROOT}}, - "view-database", - DatabaseResource(database="_memory"), - False, - False, - ), - ( - ROOT, - {"databases": {"_memory": ALLOW_ROOT}}, - "view-database", - DatabaseResource(database="_memory"), - True, - True, - ), - # Check private is false for non-protected instance check - ( - ROOT, - {"allow": True}, - "view-instance", - None, - True, - False, - ), - ), -) -async def test_datasette_check_visibility( - actor, config, action, resource, should_allow, expected_private -): - ds = Datasette([], memory=True, config=config) - await ds.invoke_startup() - visible, private = await ds.check_visibility( - actor, action=action, resource=resource - ) - assert visible == should_allow - assert private == expected_private - - -@pytest.mark.asyncio -async def test_datasette_render_template_no_request(): - # https://github.com/simonw/datasette/issues/1849 - ds = Datasette(memory=True) - await ds.invoke_startup() - rendered = await ds.render_template("error.html") - assert "Error " in rendered - - -@pytest.mark.asyncio -async def test_datasette_render_template_with_dataclass(): - @dataclasses.dataclass - class ExampleContext(Context): - title: str - status: int - error: str - - context = ExampleContext(title="Hello", status=200, error="Error message") - ds = Datasette(memory=True) - await ds.invoke_startup() - rendered = await ds.render_template("error.html", context) - assert "

Hello

" in rendered - assert "Error message" in rendered - - -def test_datasette_error_if_string_not_list(tmpdir): - # https://github.com/simonw/datasette/issues/1985 - db_path = str(tmpdir / "data.db") - with pytest.raises(ValueError): - Datasette(db_path) - - -@pytest.mark.asyncio -async def test_get_action(ds_client): - ds = ds_client.ds - for name_or_abbr in ( - "vi", - "view-instance", - "vt", - "view-table", - "sct", - "set-column-type", - ): - action = ds.get_action(name_or_abbr) - if "-" in name_or_abbr: - assert action.name == name_or_abbr - else: - assert action.abbr == name_or_abbr - # And test None return for missing action - assert ds.get_action("missing-permission") is None - - -@pytest.mark.asyncio -async def test_apply_metadata_json(): - ds = Datasette( - metadata={ - "databases": { - "legislators": { - "tables": {"offices": {"summary": "office address or sumtin"}}, - "queries": { - "millennial_representatives": { - "summary": "Social media accounts for current legislators" - } - }, - } - }, - "weird_instance_value": {"nested": [1, 2, 3]}, - }, - ) - await ds.invoke_startup() - assert (await ds.client.get("/")).status_code == 200 - value = (await ds.get_instance_metadata()).get("weird_instance_value") - assert value == '{"nested": [1, 2, 3]}' - - -@pytest.mark.asyncio -async def test_allowed_resources_sql(datasette): - result = await datasette.allowed_resources_sql( - action="view-table", - actor=None, - ) - assert isinstance(result, ResourcesSQL) - assert "all_rules AS" in result.sql - assert result.params["action"] == "view-table" - - -@pytest.mark.asyncio -async def test_datasette_close_closes_all_databases_and_executor(): - ds = Datasette(memory=True) - await ds.invoke_startup() - # Confirm internal DB has write machinery running - assert ds._internal_database._write_thread is not None - assert ds._internal_database._write_thread.is_alive() - temp_path = ds._internal_database.path - assert os.path.exists(temp_path) - executor = ds.executor - ds.close() - # Executor is shut down - assert executor._shutdown - # All attached Database instances are closed - for db in ds.databases.values(): - assert db._closed - assert ds._internal_database._closed - # Temp internal DB file is unlinked - assert not os.path.exists(temp_path) - - -@pytest.mark.asyncio -async def test_datasette_close_is_idempotent(): - ds = Datasette(memory=True) - await ds.invoke_startup() - ds.close() - # Second call should be a no-op - ds.close() - - -@pytest.mark.asyncio -async def test_datasette_close_raises_on_use(): - ds = Datasette(memory=True) - await ds.invoke_startup() - ds.close() - with pytest.raises(DatasetteClosedError): - await ds.get_internal_database().execute("select 1") - - -async def _datasette_with_sleeping_execute(tmp_path, sleep_ms=200): - db_path = tmp_path / "data.db" - internal_path = tmp_path / "internal.db" - sqlite3.connect(db_path).close() - ds = Datasette([str(db_path)], internal=str(internal_path)) - loop = asyncio.get_running_loop() - sql_started = asyncio.Event() - original_prepare_connection = ds._prepare_connection - - def prepare_connection(conn, name): - original_prepare_connection(conn, name) - - def sleep_ms(ms): - loop.call_soon_threadsafe(sql_started.set) - time.sleep(ms / 1000) - return ms - - conn.create_function("sleep_ms", 1, sleep_ms) - - ds._prepare_connection = prepare_connection - task = asyncio.create_task( - ds.get_database().execute( - f"select sleep_ms({sleep_ms})", custom_time_limit=1000 - ) - ) - await asyncio.wait_for(sql_started.wait(), timeout=5) - return ds, task - - -@pytest.mark.asyncio -async def test_datasette_close_waits_for_in_flight_execute(tmp_path): - ds, task = await _datasette_with_sleeping_execute(tmp_path) - ds.close() - results = await task - assert [tuple(row) for row in results.rows] == [(200,)] - - -@pytest.mark.asyncio -async def test_datasette_close_waits_for_cancelled_in_flight_execute(tmp_path): - ds, task = await _datasette_with_sleeping_execute(tmp_path) - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - ds.close() - - -@pytest.mark.asyncio -async def test_asgi_lifespan_shutdown_closes_datasette(): - ds = Datasette(memory=True) - app = ds.app() - # Drive an ASGI lifespan: startup, then shutdown. - messages_sent = [] - inbox = [ - {"type": "lifespan.startup"}, - {"type": "lifespan.shutdown"}, - ] - - async def receive(): - return inbox.pop(0) - - async def send(message): - messages_sent.append(message) - - await app({"type": "lifespan"}, receive, send) - assert {"type": "lifespan.startup.complete"} in messages_sent - assert {"type": "lifespan.shutdown.complete"} in messages_sent - assert ds._closed - - -@pytest.mark.asyncio -async def test_datasette_close_continues_past_db_error(): - # If one Database raises during close(), the others still get closed. - ds = Datasette(memory=True) - await ds.invoke_startup() - - class Boom(Database): - def close(self): - raise RuntimeError("boom") - - ds.add_database(Boom(ds, is_memory=True), name="bad") - good = ds.add_database(Database(ds, is_memory=True), name="good") - with pytest.raises(RuntimeError, match="boom"): - ds.close() - assert good._closed - assert ds._internal_database._closed diff --git a/tests/test_internals_datasette_client.py b/tests/test_internals_datasette_client.py deleted file mode 100644 index 543077a5..00000000 --- a/tests/test_internals_datasette_client.py +++ /dev/null @@ -1,386 +0,0 @@ -import httpx -import pytest -import pytest_asyncio -from datasette.app import Datasette - - -@pytest_asyncio.fixture -async def datasette(ds_client): - await ds_client.ds.invoke_startup() - return ds_client.ds - - -@pytest_asyncio.fixture -async def datasette_with_permissions(): - """A datasette instance with permission restrictions for testing""" - ds = Datasette(config={"databases": {"test_db": {"allow": {"id": "admin"}}}}) - await ds.invoke_startup() - db = ds.add_memory_database("test_datasette_with_permissions", name="test_db") - await db.execute_write( - "create table if not exists test_table (id integer primary key, name text)" - ) - await db.execute_write( - "insert or ignore into test_table (id, name) values (1, 'Alice')" - ) - # Trigger catalog refresh - await ds.client.get("/") - return ds - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "method,path,expected_status", - [ - ("get", "/", 200), - ("options", "/", 200), - ("head", "/", 200), - ("put", "/", 405), - ("patch", "/", 405), - ("delete", "/", 405), - ], -) -async def test_client_methods(datasette, method, path, expected_status): - client_method = getattr(datasette.client, method) - response = await client_method(path) - assert isinstance(response, httpx.Response) - assert response.status_code == expected_status - # Try that again using datasette.client.request - response2 = await datasette.client.request(method, path) - assert response2.status_code == expected_status - - -@pytest.mark.asyncio -@pytest.mark.parametrize("prefix", [None, "/prefix/"]) -async def test_client_post(datasette, prefix): - original_base_url = datasette._settings["base_url"] - try: - if prefix is not None: - datasette._settings["base_url"] = prefix - response = await datasette.client.post( - "/-/messages", - data={ - "message": "A message", - }, - ) - assert isinstance(response, httpx.Response) - assert response.status_code == 302 - assert "ds_messages" in response.cookies - finally: - datasette._settings["base_url"] = original_base_url - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "prefix,expected_path", [(None, "/asgi-scope"), ("/prefix/", "/prefix/asgi-scope")] -) -async def test_client_path(datasette, prefix, expected_path): - original_base_url = datasette._settings["base_url"] - try: - if prefix is not None: - datasette._settings["base_url"] = prefix - response = await datasette.client.get("/asgi-scope") - path = response.json()["path"] - assert path == expected_path - finally: - datasette._settings["base_url"] = original_base_url - - -@pytest.mark.asyncio -async def test_skip_permission_checks_allows_forbidden_access( - datasette_with_permissions, -): - """Test that skip_permission_checks=True bypasses permission checks""" - ds = datasette_with_permissions - - # Without skip_permission_checks, anonymous user should get 403 for protected database - response = await ds.client.get("/test_db.json") - assert response.status_code == 403 - - # With skip_permission_checks=True, should get 200 - response = await ds.client.get("/test_db.json", skip_permission_checks=True) - assert response.status_code == 200 - data = response.json() - assert data["database"] == "test_db" - - -@pytest.mark.asyncio -async def test_skip_permission_checks_on_table(datasette_with_permissions): - """Test skip_permission_checks works for table access""" - ds = datasette_with_permissions - - # Without skip_permission_checks, should get 403 - response = await ds.client.get("/test_db/test_table.json") - assert response.status_code == 403 - - # With skip_permission_checks=True, should get table data - response = await ds.client.get( - "/test_db/test_table.json", skip_permission_checks=True - ) - assert response.status_code == 200 - data = response.json() - assert data["rows"] == [{"id": 1, "name": "Alice"}] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "method", ["get", "post", "put", "patch", "delete", "options", "head"] -) -async def test_skip_permission_checks_all_methods(datasette_with_permissions, method): - """Test that skip_permission_checks works with all HTTP methods""" - ds = datasette_with_permissions - - # All methods should work with skip_permission_checks=True - client_method = getattr(ds.client, method) - response = await client_method("/test_db.json", skip_permission_checks=True) - # We don't check status code since some methods might not be allowed, - # but we verify the request doesn't fail due to permissions - assert isinstance(response, httpx.Response) - - -@pytest.mark.asyncio -async def test_skip_permission_checks_request_method(datasette_with_permissions): - """Test that skip_permission_checks works with client.request()""" - ds = datasette_with_permissions - - # Without skip_permission_checks - response = await ds.client.request("GET", "/test_db.json") - assert response.status_code == 403 - - # With skip_permission_checks=True - response = await ds.client.request( - "GET", "/test_db.json", skip_permission_checks=True - ) - assert response.status_code == 200 - - -@pytest.mark.asyncio -async def test_skip_permission_checks_isolated_to_request(datasette_with_permissions): - """Test that skip_permission_checks doesn't affect other concurrent requests""" - ds = datasette_with_permissions - - # First request with skip_permission_checks=True should succeed - response1 = await ds.client.get("/test_db.json", skip_permission_checks=True) - assert response1.status_code == 200 - - # Subsequent request without it should still get 403 - response2 = await ds.client.get("/test_db.json") - assert response2.status_code == 403 - - # And another with skip should succeed again - response3 = await ds.client.get("/test_db.json", skip_permission_checks=True) - assert response3.status_code == 200 - - -@pytest.mark.asyncio -async def test_skip_permission_checks_with_admin_actor(datasette_with_permissions): - """Test that skip_permission_checks works even when actor is provided""" - ds = datasette_with_permissions - - # Admin actor should normally have access - admin_cookies = {"ds_actor": ds.client.actor_cookie({"id": "admin"})} - response = await ds.client.get("/test_db.json", cookies=admin_cookies) - assert response.status_code == 200 - - # Non-admin actor should get 403 - user_cookies = {"ds_actor": ds.client.actor_cookie({"id": "user"})} - response = await ds.client.get("/test_db.json", cookies=user_cookies) - assert response.status_code == 403 - - # Non-admin actor with skip_permission_checks=True should get 200 - response = await ds.client.get( - "/test_db.json", cookies=user_cookies, skip_permission_checks=True - ) - assert response.status_code == 200 - - -@pytest.mark.asyncio -async def test_skip_permission_checks_shows_denied_tables(): - """Test that skip_permission_checks=True shows tables from denied databases in /-/jump.json""" - ds = Datasette( - config={ - "databases": { - "fixtures": {"allow": False} # Deny all access to this database - } - } - ) - await ds.invoke_startup() - db = ds.add_memory_database("fixtures") - await db.execute_write( - "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)" - ) - await db.execute_write("INSERT INTO test_table (id, name) VALUES (1, 'Alice')") - await ds._refresh_schemas() - - # Without skip_permission_checks, tables from denied database should not appear in /-/jump.json - response = await ds.client.get("/-/jump.json") - assert response.status_code == 200 - data = response.json() - table_names = [match["name"] for match in data["matches"]] - # Should not see any fixtures tables since access is denied - fixtures_tables = [name for name in table_names if name.startswith("fixtures:")] - assert len(fixtures_tables) == 0 - - # With skip_permission_checks=True, tables from denied database SHOULD appear - response = await ds.client.get("/-/jump.json", skip_permission_checks=True) - assert response.status_code == 200 - data = response.json() - table_names = [match["name"] for match in data["matches"]] - # Should see fixtures tables when permission checks are skipped - assert "fixtures: test_table" in table_names - - -@pytest.mark.asyncio -async def test_in_client_returns_false_outside_request(datasette): - """Test that datasette.in_client() returns False outside of a client request""" - assert datasette.in_client() is False - - -@pytest.mark.asyncio -async def test_in_client_returns_true_inside_request(): - """Test that datasette.in_client() returns True inside a client request""" - from datasette import hookimpl, Response - - class TestPlugin: - __name__ = "test_in_client_plugin" - - @hookimpl - def register_routes(self): - async def test_view(datasette): - # Assert in_client() returns True within the view - assert datasette.in_client() is True - return Response.json({"in_client": datasette.in_client()}) - - return [ - (r"^/-/test-in-client$", test_view), - ] - - ds = Datasette() - await ds.invoke_startup() - ds.pm.register(TestPlugin(), name="test_in_client_plugin") - try: - - # Outside of a client request, should be False - assert ds.in_client() is False - - # Make a request via datasette.client - response = await ds.client.get("/-/test-in-client") - assert response.status_code == 200 - assert response.json()["in_client"] is True - - # After the request, should be False again - assert ds.in_client() is False - finally: - ds.pm.unregister(name="test_in_client_plugin") - - -@pytest.mark.asyncio -async def test_in_client_with_skip_permission_checks(): - """Test that in_client() works regardless of skip_permission_checks value""" - from datasette import hookimpl - from datasette.utils.asgi import Response - - in_client_values = [] - - class TestPlugin: - __name__ = "test_in_client_skip_plugin" - - @hookimpl - def register_routes(self): - async def test_view(datasette): - in_client_values.append(datasette.in_client()) - return Response.json({"in_client": datasette.in_client()}) - - return [ - (r"^/-/test-in-client$", test_view), - ] - - ds = Datasette(config={"databases": {"test_db": {"allow": {"id": "admin"}}}}) - await ds.invoke_startup() - ds.pm.register(TestPlugin(), name="test_in_client_skip_plugin") - try: - - # Request without skip_permission_checks - await ds.client.get("/-/test-in-client") - # Request with skip_permission_checks=True - await ds.client.get("/-/test-in-client", skip_permission_checks=True) - - # Both should have detected in_client as True - assert ( - len(in_client_values) == 2 - ), f"Expected 2 values, got {len(in_client_values)}" - assert all(in_client_values), f"Expected all True, got {in_client_values}" - finally: - ds.pm.unregister(name="test_in_client_skip_plugin") - - -@pytest.mark.asyncio -async def test_actor_parameter_sets_cookie(datasette): - """Passing actor= should sign a ds_actor cookie and authenticate the request.""" - response = await datasette.client.get("/-/actor.json", actor={"id": "root"}) - assert response.status_code == 200 - assert response.json() == {"actor": {"id": "root"}} - - -@pytest.mark.asyncio -async def test_actor_parameter_works_with_request_method(datasette): - response = await datasette.client.request( - "GET", "/-/actor.json", actor={"id": "root"} - ) - assert response.status_code == 200 - assert response.json() == {"actor": {"id": "root"}} - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "method", ["get", "post", "options", "head", "put", "patch", "delete"] -) -async def test_actor_parameter_all_http_methods(datasette, method): - """actor= should not cause errors on any HTTP verb wrapper.""" - client_method = getattr(datasette.client, method) - # Just verify no TypeError about unexpected 'actor' kwarg - response = await client_method("/", actor={"id": "root"}) - assert isinstance(response, httpx.Response) - - -@pytest.mark.asyncio -async def test_actor_parameter_conflicts_with_ds_actor_cookie(datasette): - """Passing both actor= and a ds_actor cookie should raise TypeError.""" - with pytest.raises(TypeError, match="actor"): - await datasette.client.get( - "/-/actor.json", - actor={"id": "root"}, - cookies={"ds_actor": datasette.client.actor_cookie({"id": "other"})}, - ) - - -@pytest.mark.asyncio -async def test_actor_parameter_merges_with_other_cookies(datasette): - """actor= should coexist with unrelated cookies.""" - response = await datasette.client.get( - "/-/actor.json", - actor={"id": "root"}, - cookies={"unrelated": "value"}, - ) - assert response.status_code == 200 - assert response.json() == {"actor": {"id": "root"}} - - -@pytest.mark.asyncio -async def test_actor_parameter_with_skip_permission_checks( - datasette_with_permissions, -): - """actor= should be compatible with skip_permission_checks.""" - ds = datasette_with_permissions - # Non-admin actor with skip_permission_checks=True should get 200 - response = await ds.client.get( - "/test_db.json", - actor={"id": "user"}, - skip_permission_checks=True, - ) - assert response.status_code == 200 - # Admin actor on its own should also get 200 - response = await ds.client.get("/test_db.json", actor={"id": "admin"}) - assert response.status_code == 200 - # Non-admin actor should get 403 - response = await ds.client.get("/test_db.json", actor={"id": "user"}) - assert response.status_code == 403 diff --git a/tests/test_internals_request.py b/tests/test_internals_request.py deleted file mode 100644 index d1ca1f46..00000000 --- a/tests/test_internals_request.py +++ /dev/null @@ -1,148 +0,0 @@ -from datasette.utils.asgi import Request -import json -import pytest - - -@pytest.mark.asyncio -async def test_request_post_vars(): - scope = { - "http_version": "1.1", - "method": "POST", - "path": "/", - "raw_path": b"/", - "query_string": b"", - "scheme": "http", - "type": "http", - "headers": [[b"content-type", b"application/x-www-form-urlencoded"]], - } - - async def receive(): - return { - "type": "http.request", - "body": b"foo=bar&baz=1&empty=", - "more_body": False, - } - - request = Request(scope, receive) - assert {"foo": "bar", "baz": "1", "empty": ""} == await request.post_vars() - - -@pytest.mark.asyncio -async def test_request_post_body(): - scope = { - "http_version": "1.1", - "method": "POST", - "path": "/", - "raw_path": b"/", - "query_string": b"", - "scheme": "http", - "type": "http", - "headers": [[b"content-type", b"application/json"]], - } - - data = {"hello": "world"} - - async def receive(): - return { - "type": "http.request", - "body": json.dumps(data, indent=4).encode("utf-8"), - "more_body": False, - } - - request = Request(scope, receive) - body = await request.post_body() - assert isinstance(body, bytes) - assert data == json.loads(body) - - -def test_request_args(): - request = Request.fake("/foo?multi=1&multi=2&single=3") - assert "1" == request.args.get("multi") - assert "3" == request.args.get("single") - assert "1" == request.args["multi"] - assert "3" == request.args["single"] - assert ["1", "2"] == request.args.getlist("multi") - assert [] == request.args.getlist("missing") - assert "multi" in request.args - assert "single" in request.args - assert "missing" not in request.args - expected = ["multi", "single"] - assert expected == list(request.args.keys()) - for i, key in enumerate(request.args): - assert expected[i] == key - assert 2 == len(request.args) - with pytest.raises(KeyError): - request.args["missing"] - - -def test_request_fake_url_vars(): - request = Request.fake("/") - assert request.url_vars == {} - request = Request.fake("/", url_vars={"database": "fixtures"}) - assert request.url_vars == {"database": "fixtures"} - - -def test_request_repr(): - request = Request.fake("/foo?multi=1&multi=2&single=3") - assert ( - repr(request) - == '' - ) - - -def test_request_url_vars(): - scope = { - "http_version": "1.1", - "method": "POST", - "path": "/", - "raw_path": b"/", - "query_string": b"", - "scheme": "http", - "type": "http", - "headers": [[b"content-type", b"application/x-www-form-urlencoded"]], - } - assert {} == Request(scope, None).url_vars - assert {"name": "cleo"} == Request( - dict(scope, url_route={"kwargs": {"name": "cleo"}}), None - ).url_vars - - -@pytest.mark.parametrize( - "path,query_string,expected_full_path", - [("/", "", "/"), ("/", "foo=bar", "/?foo=bar"), ("/foo", "bar", "/foo?bar")], -) -def test_request_properties(path, query_string, expected_full_path): - path_with_query_string = path - if query_string: - path_with_query_string += "?" + query_string - scope = { - "http_version": "1.1", - "method": "POST", - "path": path, - "raw_path": path_with_query_string.encode("latin-1"), - "query_string": query_string.encode("latin-1"), - "scheme": "http", - "type": "http", - } - request = Request(scope, None) - assert request.path == path - assert request.query_string == query_string - assert request.full_path == expected_full_path - - -def test_request_blank_values(): - request = Request.fake("/?a=b&foo=bar&foo=bar2&baz=") - assert request.args._data == {"a": ["b"], "foo": ["bar", "bar2"], "baz": [""]} - - -def test_json_in_query_string_name(): - query_string = ( - '?_through.["roadside_attraction_characteristics"%2C"characteristic_id"]=1' - ) - request = Request.fake("/" + query_string) - assert ( - request.args[ - '_through.["roadside_attraction_characteristics","characteristic_id"]' - ] - == "1" - ) diff --git a/tests/test_internals_response.py b/tests/test_internals_response.py deleted file mode 100644 index 820b20b2..00000000 --- a/tests/test_internals_response.py +++ /dev/null @@ -1,54 +0,0 @@ -from datasette.utils.asgi import Response -import pytest - - -def test_response_html(): - response = Response.html("Hello from HTML") - assert 200 == response.status - assert "Hello from HTML" == response.body - assert "text/html; charset=utf-8" == response.content_type - - -def test_response_text(): - response = Response.text("Hello from text") - assert 200 == response.status - assert "Hello from text" == response.body - assert "text/plain; charset=utf-8" == response.content_type - - -def test_response_json(): - response = Response.json({"this_is": "json"}) - assert 200 == response.status - assert '{"this_is": "json"}' == response.body - assert "application/json; charset=utf-8" == response.content_type - - -def test_response_redirect(): - response = Response.redirect("/foo") - assert 302 == response.status - assert "/foo" == response.headers["Location"] - - -@pytest.mark.asyncio -async def test_response_set_cookie(): - events = [] - - async def send(event): - events.append(event) - - response = Response.redirect("/foo") - response.set_cookie("foo", "bar", max_age=10, httponly=True) - await response.asgi_send(send) - - assert [ - { - "type": "http.response.start", - "status": 302, - "headers": [ - [b"Location", b"/foo"], - [b"content-type", b"text/plain"], - [b"set-cookie", b"foo=bar; HttpOnly; Max-Age=10; Path=/; SameSite=lax"], - ], - }, - {"type": "http.response.body", "body": b""}, - ] == events diff --git a/tests/test_internals_urls.py b/tests/test_internals_urls.py deleted file mode 100644 index d60aafcf..00000000 --- a/tests/test_internals_urls.py +++ /dev/null @@ -1,148 +0,0 @@ -from datasette.app import Datasette -from datasette.utils import PrefixedUrlString -import pytest - - -@pytest.fixture(scope="module") -def ds(): - return Datasette([], memory=True) - - -@pytest.mark.parametrize( - "base_url,path,expected", - [ - ("/", "/", "/"), - ("/", "/foo", "/foo"), - ("/prefix/", "/", "/prefix/"), - ("/prefix/", "/foo", "/prefix/foo"), - ("/prefix/", "foo", "/prefix/foo"), - ], -) -def test_path(ds, base_url, path, expected): - ds._settings["base_url"] = base_url - actual = ds.urls.path(path) - assert actual == expected - assert isinstance(actual, PrefixedUrlString) - - -def test_path_applied_twice_does_not_double_prefix(ds): - ds._settings["base_url"] = "/prefix/" - path = ds.urls.path("/") - assert path == "/prefix/" - path = ds.urls.path(path) - assert path == "/prefix/" - - -@pytest.mark.parametrize( - "base_url,expected", - [ - ("/", "/"), - ("/prefix/", "/prefix/"), - ], -) -def test_instance(ds, base_url, expected): - ds._settings["base_url"] = base_url - actual = ds.urls.instance() - assert actual == expected - assert isinstance(actual, PrefixedUrlString) - - -@pytest.mark.parametrize( - "base_url,file,expected", - [ - ("/", "foo.js", "/-/static/foo.js"), - ("/prefix/", "foo.js", "/prefix/-/static/foo.js"), - ], -) -def test_static(ds, base_url, file, expected): - ds._settings["base_url"] = base_url - actual = ds.urls.static(file) - assert actual == expected - assert isinstance(actual, PrefixedUrlString) - - -@pytest.mark.parametrize( - "base_url,plugin,file,expected", - [ - ( - "/", - "datasette_cluster_map", - "datasette-cluster-map.js", - "/-/static-plugins/datasette_cluster_map/datasette-cluster-map.js", - ), - ( - "/prefix/", - "datasette_cluster_map", - "datasette-cluster-map.js", - "/prefix/-/static-plugins/datasette_cluster_map/datasette-cluster-map.js", - ), - ], -) -def test_static_plugins(ds, base_url, plugin, file, expected): - ds._settings["base_url"] = base_url - actual = ds.urls.static_plugins(plugin, file) - assert actual == expected - assert isinstance(actual, PrefixedUrlString) - - -@pytest.mark.parametrize( - "base_url,expected", - [ - ("/", "/-/logout"), - ("/prefix/", "/prefix/-/logout"), - ], -) -def test_logout(ds, base_url, expected): - ds._settings["base_url"] = base_url - actual = ds.urls.logout() - assert actual == expected - assert isinstance(actual, PrefixedUrlString) - - -@pytest.mark.parametrize( - "base_url,format,expected", - [ - ("/", None, "/_memory"), - ("/prefix/", None, "/prefix/_memory"), - ("/", "json", "/_memory.json"), - ], -) -def test_database(ds, base_url, format, expected): - ds._settings["base_url"] = base_url - actual = ds.urls.database("_memory", format=format) - assert actual == expected - assert isinstance(actual, PrefixedUrlString) - - -@pytest.mark.parametrize( - "base_url,name,format,expected", - [ - ("/", "name", None, "/_memory/name"), - ("/prefix/", "name", None, "/prefix/_memory/name"), - ("/", "name", "json", "/_memory/name.json"), - ("/", "name.json", "json", "/_memory/name~2Ejson.json"), - ], -) -def test_table_and_query(ds, base_url, name, format, expected): - ds._settings["base_url"] = base_url - actual1 = ds.urls.table("_memory", name, format=format) - assert actual1 == expected - assert isinstance(actual1, PrefixedUrlString) - actual2 = ds.urls.query("_memory", name, format=format) - assert actual2 == expected - assert isinstance(actual2, PrefixedUrlString) - - -@pytest.mark.parametrize( - "base_url,format,expected", - [ - ("/", None, "/_memory/facetable/1"), - ("/prefix/", None, "/prefix/_memory/facetable/1"), - ("/", "json", "/_memory/facetable/1.json"), - ], -) -def test_row(ds, base_url, format, expected): - ds._settings["base_url"] = base_url - actual = ds.urls.row("_memory", "facetable", "1", format=format) - assert actual == expected - assert isinstance(actual, PrefixedUrlString) diff --git a/tests/test_jump.py b/tests/test_jump.py deleted file mode 100644 index 513a809f..00000000 --- a/tests/test_jump.py +++ /dev/null @@ -1,465 +0,0 @@ -import pytest -import pytest_asyncio - -from datasette import hookimpl -from datasette.app import Datasette -from datasette.jump import JumpSQL -from datasette.plugins import pm -from datasette.views.special import JumpView - - -@pytest_asyncio.fixture -async def ds_for_jump(): - ds = Datasette( - config={ - "databases": { - "content": { - "allow": {"id": "*"}, - "tables": { - "articles": {"allow": {"id": "editor"}}, - "comments": {"allow": True}, - }, - "queries": { - "recent_comments": { - "sql": "select * from comments", - "allow": {"id": "*"}, - "title": "Recent comments", - }, - "release_notes": { - "sql": "select 1", - "allow": {"id": "*"}, - "title": "Recent Datasette releases", - }, - "editor_report": { - "sql": "select * from articles", - "allow": {"id": "editor"}, - }, - }, - }, - "private": { - "allow": False, - "queries": { - "private_report": "select 1", - }, - }, - } - } - ) - await ds.invoke_startup() - - content_db = ds.add_memory_database("jump_test_content", name="content") - await content_db.execute_write( - "CREATE TABLE IF NOT EXISTS articles (id INTEGER PRIMARY KEY, title TEXT)" - ) - await content_db.execute_write( - "CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, body TEXT)" - ) - await content_db.execute_write( - "CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT)" - ) - await content_db.execute_write( - "CREATE VIEW IF NOT EXISTS comment_summary AS SELECT body FROM comments" - ) - - private_db = ds.add_memory_database("jump_test_private", name="private") - await private_db.execute_write( - "CREATE TABLE IF NOT EXISTS secrets (id INTEGER PRIMARY KEY, data TEXT)" - ) - - public_db = ds.add_memory_database("jump_test_public", name="public") - await public_db.execute_write( - "CREATE TABLE IF NOT EXISTS articles (id INTEGER PRIMARY KEY, content TEXT)" - ) - - await ds._refresh_schemas() - return ds - - -@pytest.mark.asyncio -async def test_jump_searches_tables_databases_views_and_stored_queries(ds_for_jump): - response = await ds_for_jump.client.get( - "/-/jump.json?q=content", actor={"id": "user"} - ) - assert response.status_code == 200 - data = response.json() - - matches_by_type_and_name = { - (match["type"], match["name"]): match for match in data["matches"] - } - assert ("database", "content") in matches_by_type_and_name - assert ("table", "content: comments") in matches_by_type_and_name - assert ("view", "content: comment_summary") in matches_by_type_and_name - assert ("query", "content: recent_comments") in matches_by_type_and_name - assert matches_by_type_and_name[("database", "content")]["url"] == "/content" - assert ( - matches_by_type_and_name[("query", "content: recent_comments")]["url"] - == "/content/recent_comments" - ) - - -@pytest.mark.asyncio -async def test_jump_uses_stored_query_names_not_titles(ds_for_jump): - response = await ds_for_jump.client.get( - "/-/jump.json?q=datasette", actor={"id": "user"} - ) - assert response.status_code == 200 - assert response.json()["matches"] == [] - - response = await ds_for_jump.client.get( - "/-/jump.json?q=release", actor={"id": "user"} - ) - assert response.status_code == 200 - assert response.json()["matches"] == [ - { - "name": "content: release_notes", - "url": "/content/release_notes", - "type": "query", - "description": None, - } - ] - - -@pytest.mark.asyncio -async def test_jump_respects_resource_permissions(ds_for_jump): - regular = await ds_for_jump.client.get( - "/-/jump.json?q=articles", actor={"id": "regular"} - ) - editor = await ds_for_jump.client.get( - "/-/jump.json?q=articles", actor={"id": "editor"} - ) - private = await ds_for_jump.client.get( - "/-/jump.json?q=secrets", actor={"id": "editor"} - ) - - assert {match["name"] for match in regular.json()["matches"]} == { - "public: articles" - } - assert {match["name"] for match in editor.json()["matches"]} == { - "content: articles", - "public: articles", - } - assert private.json()["matches"] == [] - - -@pytest.mark.asyncio -async def test_jump_sql_menu_item_helper(ds_for_jump): - assert JumpSQL("SELECT 1").database is None - assert JumpSQL("SELECT 1", database="content").database == "content" - assert JumpSQL("SELECT 1", None, "content").database == "content" - - fragment = JumpSQL.menu_item( - label="Plugin dashboard", - url="/-/plugin-dashboard", - description="Plugin tool", - search_text="dashboard plugin", - display_name="Plugin Dashboard", - item_type="plugin", - ) - result = await ds_for_jump.get_internal_database().execute( - fragment.sql, fragment.params - ) - assert dict(result.first()) == { - "type": "plugin", - "label": "Plugin dashboard", - "description": "Plugin tool", - "url": "/-/plugin-dashboard", - "search_text": "dashboard plugin", - "display_name": "Plugin Dashboard", - } - - -@pytest.mark.asyncio -async def test_debug_menu_items_are_in_jump_for_debug_menu_permission(): - ds = Datasette( - config={ - "permissions": { - "debug-menu": {"id": "debugger"}, - } - } - ) - await ds.invoke_startup() - response = await ds.client.get("/-/jump.json?q=debug", actor={"id": "debugger"}) - assert response.status_code == 200 - debug_matches = [ - match for match in response.json()["matches"] if match["type"] == "debug" - ] - assert {match["name"]: match["url"] for match in debug_matches} == { - "Databases": "/-/databases", - "Installed plugins": "/-/plugins", - "Version info": "/-/versions", - "Settings": "/-/settings", - "Debug permissions": "/-/permissions", - "Debug messages": "/-/messages", - "Debug allow rules": "/-/allow-debug", - "Debug threads": "/-/threads", - "Debug actor": "/-/actor", - "Pattern portfolio": "/-/patterns", - } - descriptions_by_name = { - match["name"]: match["description"] for match in debug_matches - } - assert all(descriptions_by_name.values()) - assert descriptions_by_name["Databases"] == ( - "List of databases known to this Datasette instance." - ) - - -@pytest.mark.asyncio -async def test_debug_menu_items_are_hidden_without_debug_menu_permission(): - ds = Datasette() - await ds.invoke_startup() - response = await ds.client.get("/-/jump.json?q=debug", actor={"id": "regular"}) - assert response.status_code == 200 - assert [ - match for match in response.json()["matches"] if match["type"] == "debug" - ] == [] - - -@pytest.mark.asyncio -async def test_jump_uses_plugin_sql_with_namespaced_parameters(ds_for_jump): - class JumpPlugin: - @hookimpl - def jump_items_sql(self, datasette, actor, request): - return JumpSQL( - sql=""" - SELECT - 'plugin' AS type, - 'plugin-dashboard: ' || :actor_id AS label, - 'Plugin supplied item' AS description, - '/-/plugin-dashboard' AS url, - 'plugin dashboard ' || :actor_id AS search_text, - 'Plugin dashboard for ' || :actor_id AS display_name - """, - params={"actor_id": actor["id"] if actor else "anonymous"}, - ) - - plugin = JumpPlugin() - pm.register(plugin, name="test-jump-plugin") - try: - response = await ds_for_jump.client.get( - "/-/jump.json?q=dashboard", actor={"id": "alice"} - ) - finally: - pm.unregister(name="test-jump-plugin") - - assert response.status_code == 200 - plugin_matches = [ - match for match in response.json()["matches"] if match["type"] == "plugin" - ] - assert plugin_matches == [ - { - "name": "plugin-dashboard: alice", - "display_name": "Plugin dashboard for alice", - "url": "/-/plugin-dashboard", - "type": "plugin", - "description": "Plugin supplied item", - } - ] - - -@pytest.mark.asyncio -async def test_jump_sql_unions_fragments_by_database(ds_for_jump, monkeypatch): - class JumpPlugin: - @hookimpl - def jump_items_sql(self, datasette, actor, request): - return [ - JumpSQL(sql=""" - SELECT - 'plugin' AS type, - 'first-unioned-item' AS label, - NULL AS description, - '/-/first-unioned-item' AS url, - 'unioned item' AS search_text, - NULL AS display_name - """), - JumpSQL(sql=""" - SELECT - 'plugin' AS type, - 'second-unioned-item' AS label, - NULL AS description, - '/-/second-unioned-item' AS url, - 'unioned item' AS search_text, - NULL AS display_name - """), - JumpSQL( - """ - SELECT - 'plugin' AS type, - 'content-first-unioned-item' AS label, - NULL AS description, - '/-/content-first-unioned-item' AS url, - 'unioned item' AS search_text, - NULL AS display_name - """, - None, - "content", - ), - JumpSQL( - database="content", - sql=""" - SELECT - 'plugin' AS type, - 'content-second-unioned-item' AS label, - NULL AS description, - '/-/content-second-unioned-item' AS url, - 'unioned item' AS search_text, - NULL AS display_name - """, - ), - ] - - internal_db = ds_for_jump.get_internal_database() - original_execute = internal_db.execute - internal_jump_query_sql = [] - - async def internal_execute_with_recording(sql, *args, **kwargs): - if "unioned-item" in sql: - internal_jump_query_sql.append(sql) - return await original_execute(sql, *args, **kwargs) - - monkeypatch.setattr(internal_db, "execute", internal_execute_with_recording) - - content_db = ds_for_jump.get_database("content") - original_content_execute = content_db.execute - content_jump_query_sql = [] - - async def content_execute_with_recording(sql, *args, **kwargs): - if "unioned-item" in sql: - content_jump_query_sql.append(sql) - return await original_content_execute(sql, *args, **kwargs) - - monkeypatch.setattr(content_db, "execute", content_execute_with_recording) - - plugin = JumpPlugin() - pm.register(plugin, name="test-jump-union-plugin") - try: - response = await ds_for_jump.client.get( - "/-/jump.json?q=unioned", actor={"id": "alice"} - ) - finally: - pm.unregister(name="test-jump-union-plugin") - - assert response.status_code == 200 - assert len(internal_jump_query_sql) == 1 - assert " UNION ALL " in internal_jump_query_sql[0] - assert len(content_jump_query_sql) == 1 - assert " UNION ALL " in content_jump_query_sql[0] - assert {match["name"] for match in response.json()["matches"]} == { - "content-first-unioned-item", - "content-second-unioned-item", - "first-unioned-item", - "second-unioned-item", - } - - -@pytest.mark.asyncio -async def test_jump_sql_can_query_named_database(ds_for_jump): - content_db = ds_for_jump.get_database("content") - await content_db.execute_write( - "INSERT INTO comments (id, body) VALUES (1001, 'Named database jump target')" - ) - - class JumpPlugin: - @hookimpl - def jump_items_sql(self, datasette, actor, request): - return JumpSQL( - database="content", - sql=""" - SELECT - 'comment' AS type, - body AS label, - 'Comment from content database' AS description, - json_object( - 'method', 'table', - 'database', 'content', - 'table', 'comments' - ) AS url, - body AS search_text, - body AS display_name - FROM comments - WHERE id = :comment_id - """, - params={"comment_id": 1001}, - ) - - plugin = JumpPlugin() - pm.register(plugin, name="test-jump-content-db-plugin") - try: - response = await ds_for_jump.client.get( - "/-/jump.json?q=named+database", actor={"id": "alice"} - ) - finally: - pm.unregister(name="test-jump-content-db-plugin") - - assert response.status_code == 200 - plugin_matches = [ - match for match in response.json()["matches"] if match["type"] == "comment" - ] - assert plugin_matches == [ - { - "name": "Named database jump target", - "display_name": "Named database jump target", - "url": "/content/comments", - "type": "comment", - "description": "Comment from content database", - } - ] - - -@pytest.mark.asyncio -async def test_jump_resolves_url_descriptors_from_sql(ds_for_jump): - class JumpPlugin: - @hookimpl - def jump_items_sql(self, datasette, actor, request): - return JumpSQL(sql=""" - SELECT - 'plugin' AS type, - 'Table descriptor' AS label, - NULL AS description, - json_object( - 'method', 'table', - 'database', 'content', - 'table', 'comments' - ) AS url, - 'table descriptor comments' AS search_text, - NULL AS display_name - """) - - plugin = JumpPlugin() - pm.register(plugin, name="test-jump-url-descriptor-plugin") - try: - response = await ds_for_jump.client.get( - "/-/jump.json?q=descriptor", actor={"id": "alice"} - ) - finally: - pm.unregister(name="test-jump-url-descriptor-plugin") - - assert response.status_code == 200 - plugin_matches = [ - match for match in response.json()["matches"] if match["type"] == "plugin" - ] - assert plugin_matches == [ - { - "name": "Table descriptor", - "url": "/content/comments", - "type": "plugin", - "description": None, - } - ] - - -@pytest.mark.asyncio -async def test_jump_url_descriptor_errors(ds_for_jump): - view = JumpView(ds_for_jump) - with pytest.raises(AttributeError): - view._resolve_url('{"method": "not_a_url_method"}') - with pytest.raises(TypeError): - view._resolve_url( - '{"method": "table", "database_name": "content", "table_name": "comments"}' - ) - - -@pytest.mark.asyncio -async def test_tables_endpoint_removed(ds_for_jump): - response = await ds_for_jump.client.get("/-/tables.json") - assert response.status_code == 404 diff --git a/tests/test_label_column_for_table.py b/tests/test_label_column_for_table.py deleted file mode 100644 index 7667b595..00000000 --- a/tests/test_label_column_for_table.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest -from datasette.database import Database -from datasette.app import Datasette - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "create_sql,table_name,config,expected_label_column", - [ - # Explicit label_column - ( - "create table t1 (id integer primary key, name text, title text);", - "t1", - {"t1": {"label_column": "title"}}, - "title", - ), - # Single unique text column - ( - "create table t2 (id integer primary key, name2 text unique, title text);", - "t2", - {}, - "name2", - ), - ( - "create table t3 (id integer primary key, title2 text unique, name text);", - "t3", - {}, - "title2", - ), - # Two unique text columns means it cannot decide on one - ( - "create table t3x (id integer primary key, name2 text unique, title2 text unique);", - "t3x", - {}, - None, - ), - # Name or title column - ( - "create table t4 (id integer primary key, name text);", - "t4", - {}, - "name", - ), - ( - "create table t5 (id integer primary key, title text);", - "t5", - {}, - "title", - ), - # But not if there are multiple non-unique text that are not called title - ( - "create table t5x (id integer primary key, other1 text, other2 text);", - "t5x", - {}, - None, - ), - ( - "create table t6 (id integer primary key, Name text);", - "t6", - {}, - "Name", - ), - ( - "create table t7 (id integer primary key, Title text);", - "t7", - {}, - "Title", - ), - # Two columns, one of which is id - ( - "create table t8 (id integer primary key, content text);", - "t8", - {}, - "content", - ), - ( - "create table t9 (pk integer primary key, content text);", - "t9", - {}, - "content", - ), - ], -) -async def test_label_column_for_table( - create_sql, table_name, config, expected_label_column -): - """Test cases for label_column_for_table method""" - ds = Datasette() - db = ds.add_database(Database(ds, memory_name="test_label_column_for_table")) - await db.execute_write_script(create_sql) - if config: - ds.config = {"databases": {"test_label_column_for_table": {"tables": config}}} - actual_label_column = await db.label_column_for_table(table_name) - if expected_label_column is None: - assert actual_label_column is None - else: - assert actual_label_column == expected_label_column diff --git a/tests/test_load_extensions.py b/tests/test_load_extensions.py deleted file mode 100644 index cdadb091..00000000 --- a/tests/test_load_extensions.py +++ /dev/null @@ -1,64 +0,0 @@ -from datasette.app import Datasette -import pytest -from pathlib import Path - -# not necessarily a full path - the full compiled path looks like "ext.dylib" -# or another suffix, but sqlite will, under the hood, decide which file -# extension to use based on the operating system (apple=dylib, windows=dll etc) -# this resolves to "./ext", which is enough for SQLite to calculate the rest -COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext") - - -# See if ext.c has been compiled, based off the different possible suffixes. -def has_compiled_ext(): - for ext in ["dylib", "so", "dll"]: - path = Path(__file__).parent / f"ext.{ext}" - if path.is_file(): - return True - return False - - -@pytest.mark.asyncio -@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c") -async def test_load_extension_default_entrypoint(): - # The default entrypoint only loads a() and NOT b() or c(), so those - # should fail. - ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH]) - - response = await ds.client.get("/_memory/-/query.json?_shape=arrays&sql=select+a()") - assert response.status_code == 200 - assert response.json()["rows"][0][0] == "a" - - response = await ds.client.get("/_memory/-/query.json?_shape=arrays&sql=select+b()") - assert response.status_code == 400 - assert response.json()["error"] == "no such function: b" - - response = await ds.client.get("/_memory/-/query.json?_shape=arrays&sql=select+c()") - assert response.status_code == 400 - assert response.json()["error"] == "no such function: c" - - -@pytest.mark.asyncio -@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c") -async def test_load_extension_multiple_entrypoints(): - # Load in the default entrypoint and the other 2 custom entrypoints, now - # all a(), b(), and c() should run successfully. - ds = Datasette( - sqlite_extensions=[ - COMPILED_EXTENSION_PATH, - (COMPILED_EXTENSION_PATH, "sqlite3_ext_b_init"), - (COMPILED_EXTENSION_PATH, "sqlite3_ext_c_init"), - ] - ) - - response = await ds.client.get("/_memory/-/query.json?_shape=arrays&sql=select+a()") - assert response.status_code == 200 - assert response.json()["rows"][0][0] == "a" - - response = await ds.client.get("/_memory/-/query.json?_shape=arrays&sql=select+b()") - assert response.status_code == 200 - assert response.json()["rows"][0][0] == "b" - - response = await ds.client.get("/_memory/-/query.json?_shape=arrays&sql=select+c()") - assert response.status_code == 200 - assert response.json()["rows"][0][0] == "c" diff --git a/tests/test_messages.py b/tests/test_messages.py deleted file mode 100644 index 62d9f647..00000000 --- a/tests/test_messages.py +++ /dev/null @@ -1,32 +0,0 @@ -from .utils import cookie_was_deleted -import pytest - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "qs,expected", - [ - ("add_msg=added-message", [["added-message", 1]]), - ("add_msg=added-warning&type=WARNING", [["added-warning", 2]]), - ("add_msg=added-error&type=ERROR", [["added-error", 3]]), - ], -) -async def test_add_message_sets_cookie(ds_client, qs, expected): - response = await ds_client.get(f"/fixtures/-/query.message?sql=select+1&{qs}") - signed = response.cookies["ds_messages"] - decoded = ds_client.ds.unsign(signed, "messages") - assert expected == decoded - - -@pytest.mark.asyncio -async def test_messages_are_displayed_and_cleared(ds_client): - # First set the message cookie - set_msg_response = await ds_client.get( - "/fixtures/-/query.message?sql=select+1&add_msg=xmessagex" - ) - # Now access a page that displays messages - response = await ds_client.get("/", cookies=set_msg_response.cookies) - # Messages should be in that HTML - assert "xmessagex" in response.text - # Cookie should have been set that clears messages - assert cookie_was_deleted(response, "ds_messages") diff --git a/tests/test_multipart.py b/tests/test_multipart.py deleted file mode 100644 index 0dc3ecd7..00000000 --- a/tests/test_multipart.py +++ /dev/null @@ -1,1152 +0,0 @@ -""" -Tests for request.form() multipart form data parsing. - -Uses TDD approach - these tests are written first, then implementation follows. -""" - -import base64 -import json -import pytest -from collections import namedtuple - -from multipart_form_data_conformance import get_tests_dir - -from datasette.utils.asgi import Request, BadRequest - - -def make_receive(body: bytes): - """Create an async receive callable that yields body in chunks.""" - consumed = False - - async def receive(): - nonlocal consumed - if consumed: - return {"type": "http.request", "body": b"", "more_body": False} - consumed = True - return {"type": "http.request", "body": body, "more_body": False} - - return receive - - -def make_chunked_receive(body: bytes, chunk_size: int = 64): - """Create an async receive callable that yields body in small chunks.""" - offset = 0 - - async def receive(): - nonlocal offset - chunk = body[offset : offset + chunk_size] - offset += chunk_size - more_body = offset < len(body) - return {"type": "http.request", "body": chunk, "more_body": more_body} - - return receive - - -def make_receive_with_noise(body: bytes): - """ - Create an async receive callable that includes an unexpected ASGI message. - - The parser should ignore the unknown message type and continue. - """ - messages = [ - {"type": "http.response.start", "status": 200, "headers": []}, - {"type": "http.request", "body": body, "more_body": False}, - ] - index = 0 - - async def receive(): - nonlocal index - if index >= len(messages): - return {"type": "http.request", "body": b"", "more_body": False} - message = messages[index] - index += 1 - return message - - return receive - - -def make_disconnect_receive(body: bytes, chunk_size: int = 64): - """ - Create an async receive callable that disconnects mid-request. - - The parser should raise on the disconnect. - """ - offset = 0 - disconnected = False - - async def receive(): - nonlocal offset, disconnected - if disconnected: - return {"type": "http.disconnect"} - chunk = body[offset : offset + chunk_size] - offset += chunk_size - more_body = offset < len(body) - if more_body: - disconnected = True - return {"type": "http.request", "body": chunk, "more_body": more_body} - - return receive - - -class TestFormUrlEncoded: - """Test request.form() with application/x-www-form-urlencoded data.""" - - @pytest.mark.asyncio - async def test_basic_form_fields(self): - """Basic URL-encoded form should be parseable via request.form().""" - body = b"username=john&password=secret" - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", b"application/x-www-form-urlencoded"), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form() - - assert form["username"] == "john" - assert form["password"] == "secret" - - @pytest.mark.asyncio - async def test_form_with_multiple_values(self): - """Multiple values for same key should be accessible via getlist().""" - body = b"tag=python&tag=web&tag=api" - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", b"application/x-www-form-urlencoded"), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form() - - assert form["tag"] == "python" # First value - assert form.getlist("tag") == ["python", "web", "api"] - - @pytest.mark.asyncio - async def test_empty_form(self): - """Empty form should return empty FormData.""" - body = b"" - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", b"application/x-www-form-urlencoded"), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form() - - assert len(form) == 0 - - @pytest.mark.asyncio - async def test_form_with_special_characters(self): - """URL-encoded special characters should be decoded properly.""" - body = b"message=hello%20world&emoji=%F0%9F%91%8B" - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", b"application/x-www-form-urlencoded"), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form() - - assert form["message"] == "hello world" - assert form["emoji"] == "👋" - - -class TestMultipartBasic: - """Test request.form() with multipart/form-data (fields only, no files).""" - - @pytest.mark.asyncio - async def test_single_text_field(self): - """Single text field in multipart should be parseable.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="username"\r\n' - b"\r\n" - b"john_doe\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form() - - assert form["username"] == "john_doe" - - @pytest.mark.asyncio - async def test_multiple_text_fields(self): - """Multiple text fields in multipart should all be accessible.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="first_name"\r\n' - b"\r\n" - b"John\r\n" - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="last_name"\r\n' - b"\r\n" - b"Doe\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form() - - assert form["first_name"] == "John" - assert form["last_name"] == "Doe" - - @pytest.mark.asyncio - async def test_file_discarded_when_files_false(self): - """File content should be discarded when files=False (default).""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="title"\r\n' - b"\r\n" - b"My Document\r\n" - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="doc.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"File content here\r\n" - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="description"\r\n' - b"\r\n" - b"A sample document\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form() # files=False is default - - # Text fields should be present - assert form["title"] == "My Document" - assert form["description"] == "A sample document" - # File should NOT be present - assert "file" not in form - - @pytest.mark.asyncio - async def test_chunked_body_parsing(self): - """Multipart should work when body arrives in small chunks.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="username"\r\n' - b"\r\n" - b"john_doe\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - # Use small chunks to test streaming parser - request = Request(scope, make_chunked_receive(body, chunk_size=16)) - - form = await request.form() - - assert form["username"] == "john_doe" - - -class TestMultipartWithFiles: - """Test request.form(files=True) for file uploads.""" - - @pytest.mark.asyncio - async def test_single_file_upload(self): - """Single file upload should create UploadedFile object.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="document"; filename="test.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"Hello, World!\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form(files=True) - - uploaded_file = form["document"] - assert uploaded_file.filename == "test.txt" - assert uploaded_file.content_type == "text/plain" - assert await uploaded_file.read() == b"Hello, World!" - assert uploaded_file.size == 13 - - @pytest.mark.asyncio - async def test_mixed_fields_and_files(self): - """Mixed form fields and files should all be accessible.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="title"\r\n' - b"\r\n" - b"My Document\r\n" - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="doc.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"Document content\r\n" - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="description"\r\n' - b"\r\n" - b"A sample\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form(files=True) - - # Text fields - assert form["title"] == "My Document" - assert form["description"] == "A sample" - # File - uploaded_file = form["file"] - assert uploaded_file.filename == "doc.txt" - assert await uploaded_file.read() == b"Document content" - - @pytest.mark.asyncio - async def test_multiple_files_same_name(self): - """Multiple files with same name should be accessible via getlist().""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="files"; filename="a.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"File A\r\n" - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="files"; filename="b.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"File B\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form(files=True) - - files = form.getlist("files") - assert len(files) == 2 - assert files[0].filename == "a.txt" - assert files[1].filename == "b.txt" - - @pytest.mark.asyncio - async def test_large_file_spills_to_disk(self): - """Files larger than threshold should spill to temp file.""" - boundary = "----TestBoundary123" - # Create a body larger than the in-memory threshold (1MB) - large_content = b"x" * (2 * 1024 * 1024) # 2MB - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="bigfile"; filename="large.bin"\r\n' - b"Content-Type: application/octet-stream\r\n" - b"\r\n" + large_content + b"\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form(files=True) - - uploaded_file = form["bigfile"] - assert uploaded_file.size == len(large_content) - # Content should still be readable - content = await uploaded_file.read() - assert content == large_content - - @pytest.mark.asyncio - async def test_uploaded_file_seek_and_read(self): - """UploadedFile should support seek and multiple reads.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"Hello, World!\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form(files=True) - uploaded_file = form["file"] - - # First read - content1 = await uploaded_file.read() - assert content1 == b"Hello, World!" - - # Seek back to start - await uploaded_file.seek(0) - - # Second read - content2 = await uploaded_file.read() - assert content2 == b"Hello, World!" - - -class TestMultipartCleanup: - """Test deterministic cleanup of uploaded files.""" - - @pytest.mark.asyncio - async def test_formdata_close_closes_uploaded_files(self): - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"Hello\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - form = await request.form(files=True) - uploaded_file = form["file"] - - form.close() - - with pytest.raises(ValueError): - await uploaded_file.read() - - @pytest.mark.asyncio - async def test_formdata_async_context_manager_closes_files(self): - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"Hello\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - form = await request.form(files=True) - uploaded_file = form["file"] - - async with form: - pass - - with pytest.raises(ValueError): - await uploaded_file.read() - - -class TestMultipartEdgeCases: - """Test edge cases in multipart parsing.""" - - @pytest.mark.asyncio - async def test_empty_file_upload(self): - """Empty file (filename but no content) should be handled.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="empty.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form(files=True) - - uploaded_file = form["file"] - assert uploaded_file.filename == "empty.txt" - assert uploaded_file.size == 0 - assert await uploaded_file.read() == b"" - - @pytest.mark.asyncio - async def test_filename_with_path(self): - """Filename containing path should extract just the filename.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="C:\\Users\\test\\doc.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"content\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form(files=True) - - # Should extract just the filename, not the full path - uploaded_file = form["file"] - assert uploaded_file.filename == "doc.txt" - - @pytest.mark.asyncio - async def test_missing_content_type_header(self): - """Missing content-type in request should raise BadRequest.""" - body = b"some body" - scope = { - "type": "http", - "method": "POST", - "headers": [], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest): - await request.form() - - @pytest.mark.asyncio - async def test_invalid_content_type(self): - """Non-form content-type should raise BadRequest.""" - body = b'{"key": "value"}' - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", b"application/json"), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest): - await request.form() - - @pytest.mark.asyncio - async def test_missing_boundary(self): - """Multipart without boundary should raise BadRequest.""" - body = b"some body" - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", b"multipart/form-data"), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest): - await request.form() - - -class TestSecurityLimits: - """Test security limits on form parsing.""" - - @pytest.mark.asyncio - async def test_max_fields_limit(self): - """Should reject requests with too many fields.""" - boundary = "----TestBoundary123" - # Create body with many fields - parts = [] - for i in range(1001): # Default max is 1000 - parts.append( - f"------TestBoundary123\r\n" - f'Content-Disposition: form-data; name="field{i}"\r\n' - f"\r\n" - f"value{i}\r\n" - ) - parts.append("------TestBoundary123--\r\n") - body = "".join(parts).encode() - - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest, match="(?i)too many"): - await request.form(max_fields=1000) - - @pytest.mark.asyncio - async def test_max_file_size_limit(self): - """Should reject files exceeding size limit.""" - boundary = "----TestBoundary123" - large_content = b"x" * (11 * 1024 * 1024) # 11MB - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="big.bin"\r\n' - b"Content-Type: application/octet-stream\r\n" - b"\r\n" + large_content + b"\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest, match="(?i)file.*too large|too large"): - await request.form(files=True, max_file_size=10 * 1024 * 1024) - - @pytest.mark.asyncio - async def test_max_request_size_limit(self): - """Should reject requests exceeding total size limit.""" - boundary = "----TestBoundary123" - large_content = b"x" * (6 * 1024 * 1024) # 6MB - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="big.bin"\r\n' - b"Content-Type: application/octet-stream\r\n" - b"\r\n" + large_content + b"\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest, match="(?i)too large|request.*too large"): - await request.form(files=True, max_request_size=5 * 1024 * 1024) - - -class TestMultipartStrictnessAndLimits: - """Tests that enforce stricter ASGI and multipart behaviors.""" - - @pytest.mark.asyncio - async def test_multipart_truncated_body_is_error(self): - """Truncated multipart without closing boundary should raise.""" - boundary = "----TestBoundary123" - # Missing the final closing boundary line - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="field"\r\n' - b"\r\n" - b"value\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest, match="Truncated multipart body"): - await request.form() - - @pytest.mark.asyncio - async def test_disconnect_mid_body_is_error(self): - """Client disconnect during body streaming should raise.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="field"\r\n' - b"\r\n" - b"value\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_disconnect_receive(body, chunk_size=16)) - - with pytest.raises(BadRequest, match="disconnected"): - await request.form() - - @pytest.mark.asyncio - async def test_unknown_asgi_message_type_is_ignored(self): - """Unexpected ASGI message types should be ignored.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="field"\r\n' - b"\r\n" - b"value\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive_with_noise(body)) - - form = await request.form() - assert form["field"] == "value" - - @pytest.mark.asyncio - async def test_max_files_enforced_even_when_files_false(self): - """File count limits should apply even when file handling is disabled.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="f1"; filename="a.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"a\r\n" - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="f2"; filename="b.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"b\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest, match="Too many files"): - await request.form(files=False, max_files=1) - - @pytest.mark.asyncio - async def test_max_parts_limit(self): - """Total part count should be bounded.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="a"\r\n' - b"\r\n" - b"1\r\n" - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="b"\r\n' - b"\r\n" - b"2\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest, match="Too many parts"): - await request.form(max_parts=1) - - @pytest.mark.asyncio - async def test_max_file_size_enforced_even_when_files_false(self): - """File size limits should apply even when file handling is disabled.""" - boundary = "----TestBoundary123" - big_content = b"x" * 2048 - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="big.bin"\r\n' - b"Content-Type: application/octet-stream\r\n" - b"\r\n" + big_content + b"\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest, match="File too large"): - await request.form(files=False, max_file_size=1024) - - @pytest.mark.asyncio - async def test_part_header_limits(self): - """Overly large part headers should be rejected.""" - boundary = "----TestBoundary123" - huge_header_value = "x" * 5000 - body = ( - b"------TestBoundary123\r\n" - + f'Content-Disposition: form-data; name="field"; foo="{huge_header_value}"\r\n'.encode() - + b"\r\n" - + b"value\r\n" - + b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest, match="headers too large"): - await request.form(max_part_header_bytes=1024) - - @pytest.mark.asyncio - async def test_insufficient_disk_space_rejects_upload(self, monkeypatch): - """Uploads should be rejected when free disk is below the floor.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n' - b"Content-Type: text/plain\r\n" - b"\r\n" - b"Hello\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - - DiskUsage = namedtuple("DiskUsage", ("total", "used", "free")) - monkeypatch.setattr( - "datasette.utils.multipart.shutil.disk_usage", - lambda path: DiskUsage(total=100, used=95, free=5), - ) - - request = Request(scope, make_receive(body)) - with pytest.raises(BadRequest, match="Insufficient disk space"): - await request.form(files=True, min_free_disk_bytes=50) - - @pytest.mark.asyncio - async def test_low_disk_space_does_not_block_field_only_forms(self, monkeypatch): - """Low disk space should not reject multipart forms with no file parts.""" - boundary = "----TestBoundary123" - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="field"\r\n' - b"\r\n" - b"value\r\n" - b"------TestBoundary123--\r\n" - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - - DiskUsage = namedtuple("DiskUsage", ("total", "used", "free")) - monkeypatch.setattr( - "datasette.utils.multipart.shutil.disk_usage", - lambda path: DiskUsage(total=100, used=99, free=1), - ) - - request = Request(scope, make_receive(body)) - form = await request.form(files=True, min_free_disk_bytes=50) - assert form["field"] == "value" - - @pytest.mark.asyncio - async def test_headers_without_newline_hit_header_byte_limit(self): - """Headers that never terminate should still hit the header byte limit.""" - boundary = "----TestBoundary123" - huge = b"x" * 5000 - # No CRLF is included after the header line - body = ( - b"------TestBoundary123\r\n" - b'Content-Disposition: form-data; name="field"; foo="' + huge + b'"' - ) - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", f"multipart/form-data; boundary={boundary}".encode()), - ], - } - request = Request(scope, make_receive(body)) - - with pytest.raises(BadRequest, match="headers too large"): - await request.form(max_part_header_bytes=1024) - - -class TestFormDataLenSemantics: - """Test that FormData.__len__ reflects number of items, not unique keys.""" - - @pytest.mark.asyncio - async def test_len_counts_items(self): - body = b"tag=python&tag=web&tag=api" - scope = { - "type": "http", - "method": "POST", - "headers": [ - (b"content-type", b"application/x-www-form-urlencoded"), - ], - } - request = Request(scope, make_receive(body)) - - form = await request.form() - assert len(form) == 3 - - -# Conformance test suite using multipart-form-data-conformance - -# Tests where our parser intentionally differs from strict spec for security/practicality -# Our parser sanitizes filenames (strips paths) while the conformance suite expects raw -FILENAME_SANITIZATION_TESTS = { - "026-filename-with-backslash", # We preserve backslashes but they test expects raw - "029-filename-path-traversal", # We strip path components for security -} - -# Tests for optional/lenient features we don't implement -OPTIONAL_TESTS = { - "085-header-folding", # Obsolete header folding feature -} - -# Tests for malformed input where we're lenient instead of erroring -LENIENT_PARSING_TESTS = { - "203-missing-content-disposition", - "204-invalid-content-disposition", -} - - -def load_conformance_test_cases(): - """Load all test cases from multipart-form-data-conformance.""" - tests_dir = get_tests_dir() - test_cases = [] - - for category_dir in sorted(tests_dir.iterdir()): - if not category_dir.is_dir(): - continue - for test_dir in sorted(category_dir.iterdir()): - if not test_dir.is_dir(): - continue - test_json = test_dir / "test.json" - headers_json = test_dir / "headers.json" - input_raw = test_dir / "input.raw" - - if not all(f.exists() for f in [test_json, headers_json, input_raw]): - continue - - with open(test_json) as f: - test_spec = json.load(f) - with open(headers_json) as f: - headers = json.load(f) - with open(input_raw, "rb") as f: - body = f.read() - - test_id = test_spec["id"] - - # Add marks for tests we handle differently - marks = [] - if test_id in FILENAME_SANITIZATION_TESTS: - marks.append( - pytest.mark.xfail(reason="Parser sanitizes filenames for security") - ) - elif test_id in OPTIONAL_TESTS: - marks.append( - pytest.mark.xfail(reason="Optional feature not implemented") - ) - elif test_id in LENIENT_PARSING_TESTS: - marks.append( - pytest.mark.xfail(reason="Parser is lenient with malformed input") - ) - - test_cases.append( - pytest.param( - test_spec, - headers, - body, - id=test_id, - marks=marks, - ) - ) - - return test_cases - - -CONFORMANCE_TEST_CASES = load_conformance_test_cases() - - -@pytest.mark.parametrize("test_spec,headers,body", CONFORMANCE_TEST_CASES) -@pytest.mark.asyncio -async def test_conformance(test_spec, headers, body): - """ - Run conformance test cases from multipart-form-data-conformance. - - Each test case specifies: - - headers: HTTP headers including Content-Type with boundary - - body: Raw multipart body bytes - - expected: Expected parse result (valid/invalid, parts list) - """ - scope = { - "type": "http", - "method": "POST", - "headers": [(k.encode(), v.encode()) for k, v in headers.items()], - } - request = Request(scope, make_receive(body)) - - expected = test_spec["expected"] - - if not expected["valid"]: - # Should raise an error for invalid input - with pytest.raises((BadRequest, ValueError)): - await request.form(files=True) - return - - # Parse form data - form = await request.form(files=True) - - # Verify each expected part - for i, expected_part in enumerate(expected["parts"]): - name = expected_part["name"] - - # Get value(s) for this name - values = form.getlist(name) - - # Find the value at the correct index for this name - # (handles multiple values with same name) - same_name_count = sum(1 for p in expected["parts"][:i] if p["name"] == name) - - if same_name_count >= len(values): - pytest.fail( - f"Expected part {name} at index {same_name_count} but only {len(values)} found" - ) - - value = values[same_name_count] - - # Determine expected content - if "body_base64" in expected_part: - expected_content = base64.b64decode(expected_part["body_base64"]) - elif "body_text" in expected_part: - expected_content = expected_part["body_text"].encode("utf-8") - else: - expected_content = None - - # Check for file vs field - # A part is a file if it has a filename OR filename_star - is_file = ( - expected_part.get("filename") is not None - or expected_part.get("filename_star") is not None - ) - - if is_file: - # It's a file - assert hasattr(value, "filename"), f"Expected file for {name}" - - # Check filename - use filename_star if present, else filename - expected_filename = expected_part.get("filename_star") or expected_part.get( - "filename" - ) - if expected_filename: - assert ( - value.filename == expected_filename - ), f"Filename mismatch: expected {expected_filename!r}, got {value.filename!r}" - - if expected_part.get("content_type"): - assert value.content_type == expected_part["content_type"] - - content = await value.read() - assert ( - len(content) == expected_part["body_size"] - ), f"Size mismatch: expected {expected_part['body_size']}, got {len(content)}" - if expected_content is not None: - assert content == expected_content - else: - # It's a text field - if hasattr(value, "filename"): - pytest.fail(f"Expected text field for {name}, got file") - - if expected_content is not None: - # For text fields, value is a string - try: - expected_text = expected_content.decode("utf-8") - except UnicodeDecodeError: - expected_text = expected_content.decode("latin-1") - assert ( - value == expected_text - ), f"Value mismatch: expected {expected_text!r}, got {value!r}" diff --git a/tests/test_navigation_search_js.py b/tests/test_navigation_search_js.py deleted file mode 100644 index b487357d..00000000 --- a/tests/test_navigation_search_js.py +++ /dev/null @@ -1,394 +0,0 @@ -import json -from pathlib import Path -import subprocess -import textwrap - -REPO_ROOT = Path(__file__).resolve().parents[1] -STATIC_DIR = REPO_ROOT / "datasette" / "static" - - -def test_navigation_search_tracks_and_renders_recent_items(): - script = textwrap.dedent(""" - const fs = require("fs"); - const vm = require("vm"); - const navigationSearchJs = __NAVIGATION_SEARCH_JS__; - - class FakeElement { - constructor() { - this.innerHTML = ""; - this.value = ""; - this.dataset = {}; - this.open = false; - } - addEventListener() {} - close() { this.open = false; } - focus() {} - querySelector() { - return { scrollIntoView() {} }; - } - showModal() { this.open = true; } - } - - class FakeShadowRoot { - constructor() { - this.innerHTML = ""; - this.dialog = new FakeElement(); - this.input = new FakeElement(); - this.results = new FakeElement(); - } - querySelector(selector) { - if (selector == "dialog") return this.dialog; - if (selector == ".search-input") return this.input; - if (selector == ".results-container") return this.results; - return new FakeElement(); - } - } - - global.HTMLElement = class { - constructor() { - this.attributes = {}; - } - attachShadow() { - this.shadowRoot = new FakeShadowRoot(); - return this.shadowRoot; - } - dispatchEvent() {} - getAttribute(name) { - return this.attributes[name] || null; - } - querySelector() { - return null; - } - setAttribute(name, value) { - this.attributes[name] = value; - } - }; - global.CustomEvent = class { - constructor(name, options) { - this.name = name; - this.options = options; - } - }; - global.customElements = { - registry: new Map(), - define(name, cls) { - this.registry.set(name, cls); - }, - }; - global.document = { - addEventListener() {}, - activeElement: null, - createElement() { - return { - set textContent(value) { - this.innerHTML = String(value) - .replace(/&/g, "&") - .replace(//g, ">") - .replace(/"/g, """); - }, - }; - }, - }; - global.localStorage = { - store: {}, - getItem(key) { - return Object.prototype.hasOwnProperty.call(this.store, key) - ? this.store[key] - : null; - }, - setItem(key, value) { - this.store[key] = String(value); - }, - removeItem(key) { - delete this.store[key]; - }, - }; - global.window = { location: { href: "" } }; - - vm.runInThisContext( - fs.readFileSync(navigationSearchJs, "utf8"), - { filename: "navigation-search.js" } - ); - - const Component = customElements.registry.get("navigation-search"); - const element = new Component(); - const items = Array.from({ length: 6 }, (_, index) => ({ - name: `Item ${index + 1}`, - url: `/item-${index + 1}`, - type: "table", - description: "Table", - })); - items[5].name = "content: recent_datasette_releases"; - items[5].display_name = "Recent Datasette releases"; - - for (const item of items) { - element.matches = [item]; - element.renderedMatches = [item]; - element.selectedIndex = 0; - element.selectCurrentItem(); - } - - const stored = JSON.parse( - Object.values(localStorage.store).find((value) => value.includes("/item-6")) - ); - if (stored.length !== 5) { - throw new Error(`Expected 5 recent items, got ${stored.length}`); - } - if (stored[0].url !== "/item-6" || stored[4].url !== "/item-2") { - throw new Error(`Unexpected recent order: ${JSON.stringify(stored)}`); - } - if (stored[0].display_name !== "Recent Datasette releases") { - throw new Error(`Missing display_name in recent item: ${JSON.stringify(stored[0])}`); - } - - element.matches = [ - items[5], - items[4], - { - name: "Other", - url: "/other", - type: "database", - description: "Database", - }, - ]; - element.shadowRoot.input.value = ""; - element.renderResults(); - - const html = element.shadowRoot.results.innerHTML; - if (!html.includes("Recent")) { - throw new Error(`Missing Recent heading: ${html}`); - } - if (!html.includes("Recent Datasette releases") || !html.includes("Item 5")) { - throw new Error(`Missing recent items: ${html}`); - } - if (!html.includes("content: recent_datasette_releases")) { - throw new Error(`Missing canonical item name for display_name item: ${html}`); - } - if (!html.includes("Item 4") || !html.includes("Item 2")) { - throw new Error(`Expected all stored recent items in empty state: ${html}`); - } - if (html.includes("Other")) { - throw new Error(`Rendered non-recent item in empty state: ${html}`); - } - if (!html.includes("Clear recent")) { - throw new Error(`Missing Clear recent control: ${html}`); - } - - element.clearRecentItems(); - if (localStorage.getItem(element.recentItemsStorageKey()) !== null) { - throw new Error("Expected recent items to be cleared"); - } - element.renderResults(); - if (element.shadowRoot.results.innerHTML.includes("Clear recent")) { - throw new Error("Clear recent should disappear after clearing"); - } - - process.stdout.write(JSON.stringify(stored)); - """).replace( - "__NAVIGATION_SEARCH_JS__", - json.dumps(str(STATIC_DIR / "navigation-search.js")), - ) - result = subprocess.run( - ["node", "-e", script], - cwd=REPO_ROOT, - text=True, - capture_output=True, - check=False, - ) - assert result.returncode == 0, result.stderr - assert [item["url"] for item in json.loads(result.stdout)] == [ - "/item-6", - "/item-5", - "/item-4", - "/item-3", - "/item-2", - ] - assert json.loads(result.stdout)[0]["display_name"] == "Recent Datasette releases" - - -def test_navigation_search_renders_jump_sections_from_javascript_plugins(): - script = ( - textwrap.dedent(""" - const fs = require("fs"); - const vm = require("vm"); - const datasetteManagerJs = __DATASETTE_MANAGER_JS__; - const navigationSearchJs = __NAVIGATION_SEARCH_JS__; - - const documentListeners = {}; - - class FakeElement { - constructor(tagName = "div", parent = null) { - this._innerHTML = ""; - this.value = ""; - this.dataset = {}; - this.open = false; - this.parent = parent; - this.tagName = tagName.toUpperCase(); - } - set textContent(value) { - this.innerHTML = String(value) - .replace(/&/g, "&") - .replace(//g, ">") - .replace(/"/g, """); - } - get innerHTML() { - return this._innerHTML; - } - set innerHTML(value) { - this._innerHTML = String(value); - if (this.parent) { - this.parent._innerHTML += this._innerHTML; - } - } - addEventListener() {} - appendChild(child) { - this._innerHTML += child.innerHTML || ""; - return child; - } - close() { this.open = false; } - focus() {} - querySelector(selector) { - if (selector.startsWith("[data-jump-section-index=")) { - return new FakeElement("div", this); - } - return { scrollIntoView() {} }; - } - showModal() { this.open = true; } - } - - class FakeShadowRoot { - constructor() { - this.innerHTML = ""; - this.dialog = new FakeElement("dialog"); - this.input = new FakeElement("input"); - this.results = new FakeElement("div"); - } - querySelector(selector) { - if (selector == "dialog") return this.dialog; - if (selector == ".search-input") return this.input; - if (selector == ".results-container") return this.results; - return new FakeElement(); - } - } - - global.HTMLElement = class { - constructor() { - this.attributes = {}; - } - attachShadow() { - this.shadowRoot = new FakeShadowRoot(); - return this.shadowRoot; - } - dispatchEvent() {} - getAttribute(name) { - return this.attributes[name] || null; - } - querySelector() { - return null; - } - setAttribute(name, value) { - this.attributes[name] = value; - } - }; - global.CustomEvent = class { - constructor(name, options) { - this.name = name; - this.type = name; - this.detail = options ? options.detail : undefined; - } - }; - global.customElements = { - registry: new Map(), - define(name, cls) { - this.registry.set(name, cls); - }, - }; - global.document = { - addEventListener(name, callback) { - documentListeners[name] = documentListeners[name] || []; - documentListeners[name].push(callback); - }, - activeElement: null, - createElement(tagName) { - return new FakeElement(tagName); - }, - dispatchEvent(event) { - for (const callback of documentListeners[event.type] || []) { - callback(event); - } - }, - querySelectorAll() { - return []; - }, - }; - global.localStorage = { - getItem() { return null; }, - setItem() {}, - removeItem() {}, - }; - global.window = { datasetteVersion: "test", location: { href: "" } }; - - vm.runInThisContext( - fs.readFileSync(datasetteManagerJs, "utf8"), - { filename: "datasette-manager.js" } - ); - for (const callback of documentListeners.DOMContentLoaded || []) { - callback(); - } - window.__DATASETTE__.registerPlugin("agent", { - version: "0.1", - makeJumpSections() { - return [ - { - id: "agent-chat", - render(node, context) { - if (!context.navigationSearch) { - throw new Error("Expected navigationSearch in render context"); - } - node.innerHTML = [ - '
', - '', - '
', - ].join(''); - }, - }, - ]; - }, - }); - - vm.runInThisContext( - fs.readFileSync(navigationSearchJs, "utf8"), - { filename: "navigation-search.js" } - ); - - const Component = customElements.registry.get("navigation-search"); - const element = new Component(); - element.shadowRoot.input.value = ""; - element.renderResults(); - - const html = element.shadowRoot.results.innerHTML; - if (!html.includes("Start a new agent chat")) { - throw new Error(`Missing jump section content: ${html}`); - } - process.stdout.write("ok"); - """) - .replace( - "__DATASETTE_MANAGER_JS__", - json.dumps(str(STATIC_DIR / "datasette-manager.js")), - ) - .replace( - "__NAVIGATION_SEARCH_JS__", - json.dumps(str(STATIC_DIR / "navigation-search.js")), - ) - ) - result = subprocess.run( - ["node", "-e", script], - cwd=REPO_ROOT, - text=True, - capture_output=True, - check=False, - ) - assert result.returncode == 0, result.stderr - assert result.stdout.endswith("ok") diff --git a/tests/test_package.py b/tests/test_package.py deleted file mode 100644 index f05f3ece..00000000 --- a/tests/test_package.py +++ /dev/null @@ -1,59 +0,0 @@ -from click.testing import CliRunner -from datasette import cli -from unittest import mock -import os -import pathlib -import pytest - - -class CaptureDockerfile: - def __call__(self, _): - self.captured = (pathlib.Path() / "Dockerfile").read_text() - - -EXPECTED_DOCKERFILE = """ -FROM python:3.11.0-slim-bullseye -COPY . /app -WORKDIR /app - -ENV DATASETTE_SECRET 'sekrit' -RUN pip install -U datasette -RUN datasette inspect test.db --inspect-file inspect-data.json -ENV PORT {port} -EXPOSE {port} -CMD datasette serve --host 0.0.0.0 -i test.db --cors --inspect-file inspect-data.json --port $PORT -""".strip() - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.cli.call") -def test_package(mock_call, mock_which, tmp_path_factory): - mock_which.return_value = True - runner = CliRunner() - capture = CaptureDockerfile() - mock_call.side_effect = capture - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke(cli.cli, ["package", "test.db", "--secret", "sekrit"]) - assert 0 == result.exit_code - mock_call.assert_has_calls([mock.call(["docker", "build", "."])]) - assert EXPECTED_DOCKERFILE.format(port=8001) == capture.captured - - -@mock.patch("shutil.which") -@mock.patch("datasette.cli.call") -def test_package_with_port(mock_call, mock_which, tmp_path_factory): - mock_which.return_value = True - capture = CaptureDockerfile() - mock_call.side_effect = capture - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke( - cli.cli, ["package", "test.db", "-p", "8080", "--secret", "sekrit"] - ) - assert 0 == result.exit_code - assert EXPECTED_DOCKERFILE.format(port=8080) == capture.captured diff --git a/tests/test_permission_endpoints.py b/tests/test_permission_endpoints.py deleted file mode 100644 index e25be23e..00000000 --- a/tests/test_permission_endpoints.py +++ /dev/null @@ -1,499 +0,0 @@ -""" -Tests for permission endpoints: -- /-/allowed.json -- /-/rules.json -""" - -import pytest -import pytest_asyncio -from datasette.app import Datasette - - -@pytest_asyncio.fixture -async def ds_with_permissions(): - """Create a Datasette instance with test data and permissions.""" - ds = Datasette() - ds.root_enabled = True - await ds.invoke_startup() - - # Add some test databases and tables - db = ds.add_memory_database("analytics") - await db.execute_write( - "CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT, email TEXT)" - ) - await db.execute_write( - "CREATE TABLE IF NOT EXISTS events (id INTEGER PRIMARY KEY, event_type TEXT, user_id INTEGER)" - ) - - db2 = ds.add_memory_database("production") - await db2.execute_write( - "CREATE TABLE IF NOT EXISTS orders (id INTEGER PRIMARY KEY, total REAL)" - ) - await db2.execute_write( - "CREATE TABLE IF NOT EXISTS customers (id INTEGER PRIMARY KEY, name TEXT)" - ) - - await ds.refresh_schemas() - - return ds - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_status,expected_keys", - [ - # Instance level permission - ( - "/-/allowed.json?action=view-instance", - 200, - {"action", "items", "total", "page"}, - ), - # Database level permission - ( - "/-/allowed.json?action=view-database", - 200, - {"action", "items", "total", "page"}, - ), - # Table level permission - ( - "/-/allowed.json?action=view-table", - 200, - {"action", "items", "total", "page"}, - ), - ( - "/-/allowed.json?action=execute-sql", - 200, - {"action", "items", "total", "page"}, - ), - # Missing action parameter - ("/-/allowed.json", 400, {"error"}), - # Invalid action - ("/-/allowed.json?action=nonexistent", 404, {"error"}), - # Any valid action works, even if no permission rules exist for it - ( - "/-/allowed.json?action=insert-row", - 200, - {"action", "items", "total", "page"}, - ), - ], -) -async def test_allowed_json_basic( - ds_with_permissions, path, expected_status, expected_keys -): - response = await ds_with_permissions.client.get(path) - assert response.status_code == expected_status - data = response.json() - assert expected_keys.issubset(data.keys()) - - -@pytest.mark.asyncio -async def test_allowed_json_response_structure(ds_with_permissions): - """Test that /-/allowed.json returns the expected structure.""" - response = await ds_with_permissions.client.get( - "/-/allowed.json?action=view-instance" - ) - assert response.status_code == 200 - data = response.json() - - # Check required fields - assert "action" in data - assert "actor_id" in data - assert "page" in data - assert "page_size" in data - assert "total" in data - assert "items" in data - - # Check items structure - assert isinstance(data["items"], list) - if data["items"]: - item = data["items"][0] - assert "parent" in item - assert "child" in item - assert "resource" in item - - -@pytest.mark.asyncio -async def test_allowed_json_with_actor(ds_with_permissions): - """Test /-/allowed.json includes actor information.""" - response = await ds_with_permissions.client.get( - "/-/allowed.json?action=view-table", - actor={"id": "test_user"}, - ) - assert response.status_code == 200 - data = response.json() - assert data["actor_id"] == "test_user" - - -@pytest.mark.asyncio -async def test_allowed_json_pagination(): - """Test that /-/allowed.json pagination works.""" - ds = Datasette() - await ds.invoke_startup() - - # Create many tables to test pagination - db = ds.add_memory_database("test") - for i in range(30): - await db.execute_write(f"CREATE TABLE table{i:02d} (id INTEGER PRIMARY KEY)") - await ds.refresh_schemas() - - # Test page 1 - response = await ds.client.get( - "/-/allowed.json?action=view-table&page_size=10&page=1" - ) - assert response.status_code == 200 - data = response.json() - assert data["page"] == 1 - assert data["page_size"] == 10 - assert len(data["items"]) == 10 - - # Test page 2 - response = await ds.client.get( - "/-/allowed.json?action=view-table&page_size=10&page=2" - ) - assert response.status_code == 200 - data = response.json() - assert data["page"] == 2 - assert len(data["items"]) == 10 - - # Verify items are different between pages - response1 = await ds.client.get( - "/-/allowed.json?action=view-table&page_size=10&page=1" - ) - response2 = await ds.client.get( - "/-/allowed.json?action=view-table&page_size=10&page=2" - ) - items1 = {(item["parent"], item["child"]) for item in response1.json()["items"]} - items2 = {(item["parent"], item["child"]) for item in response2.json()["items"]} - assert items1 != items2 - - -@pytest.mark.asyncio -async def test_allowed_json_total_count(tmp_path_factory): - """Test that /-/allowed.json returns correct total count.""" - from datasette.database import Database - - # Use temporary file databases to avoid leakage from other tests - tmp_dir = tmp_path_factory.mktemp("test_allowed_json_total_count") - - ds = Datasette() - await ds.invoke_startup() - - # Create test databases with tables - analytics_db = ds.add_database( - Database(ds, path=str(tmp_dir / "analytics.db")), name="analytics" - ) - await analytics_db.execute_write( - "CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT, email TEXT)" - ) - await analytics_db.execute_write( - "CREATE TABLE IF NOT EXISTS events (id INTEGER PRIMARY KEY, event_type TEXT, user_id INTEGER)" - ) - - production_db = ds.add_database( - Database(ds, path=str(tmp_dir / "production.db")), name="production" - ) - await production_db.execute_write( - "CREATE TABLE IF NOT EXISTS orders (id INTEGER PRIMARY KEY, total REAL)" - ) - await production_db.execute_write( - "CREATE TABLE IF NOT EXISTS customers (id INTEGER PRIMARY KEY, name TEXT)" - ) - - await ds.refresh_schemas() - - response = await ds.client.get("/-/allowed.json?action=view-table") - assert response.status_code == 200 - data = response.json() - - # We created 4 tables total (2 in analytics, 2 in production) - import json - - assert ( - data["total"] == 4 - ), f"Expected total=4, got: {json.dumps(data, separators=(',', ':'))}" - - -# /-/rules.json tests - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_status,expected_keys", - [ - # Instance level rules - ( - "/-/rules.json?action=view-instance", - 200, - {"action", "items", "total", "page"}, - ), - # Database level rules - ( - "/-/rules.json?action=view-database", - 200, - {"action", "items", "total", "page"}, - ), - # Table level rules - ( - "/-/rules.json?action=view-table", - 200, - {"action", "items", "total", "page"}, - ), - # Missing action parameter - ("/-/rules.json", 400, {"error"}), - # Invalid action - ("/-/rules.json?action=nonexistent", 404, {"error"}), - ], -) -async def test_rules_json_basic( - ds_with_permissions, path, expected_status, expected_keys -): - # Use root actor for rules endpoint (requires permissions-debug) - response = await ds_with_permissions.client.get( - path, - actor={"id": "root"}, - ) - assert response.status_code == expected_status - data = response.json() - assert expected_keys.issubset(data.keys()) - - -@pytest.mark.asyncio -async def test_rules_json_response_structure(ds_with_permissions): - """Test that /-/rules.json returns the expected structure.""" - response = await ds_with_permissions.client.get( - "/-/rules.json?action=view-instance", - actor={"id": "root"}, - ) - assert response.status_code == 200 - data = response.json() - - # Check required fields - assert "action" in data - assert "actor_id" in data - assert "page" in data - assert "page_size" in data - assert "total" in data - assert "items" in data - - # Check items structure - assert isinstance(data["items"], list) - if data["items"]: - item = data["items"][0] - assert "parent" in item - assert "child" in item - assert "resource" in item - assert "allow" in item - assert "reason" in item - - -@pytest.mark.asyncio -async def test_rules_json_includes_all_rules(ds_with_permissions): - """Test that /-/rules.json includes both allowed and denied resources.""" - # Root user should see rules for everything - response = await ds_with_permissions.client.get( - "/-/rules.json?action=view-table", - actor={"id": "root"}, - ) - assert response.status_code == 200 - data = response.json() - - # Should have items (root has global allow) - assert len(data["items"]) > 0 - - # Each item should have allow field (0 or 1) - for item in data["items"]: - assert "allow" in item - assert item["allow"] in [0, 1] - - -@pytest.mark.asyncio -async def test_rules_json_pagination(): - """Test that /-/rules.json pagination works.""" - ds = Datasette() - ds.root_enabled = True - await ds.invoke_startup() - - # Create some tables - db = ds.add_memory_database("test") - for i in range(5): - await db.execute_write( - f"CREATE TABLE IF NOT EXISTS table{i:02d} (id INTEGER PRIMARY KEY)" - ) - await ds.refresh_schemas() - - # Test basic pagination structure - just verify it returns paginated results - response = await ds.client.get( - "/-/rules.json?action=view-table&page_size=2&page=1", - actor={"id": "root"}, - ) - assert response.status_code == 200 - data = response.json() - assert data["page"] == 1 - assert data["page_size"] == 2 - # Verify items is a list (may have fewer items than page_size if there aren't many rules) - assert isinstance(data["items"], list) - assert "total" in data - - -@pytest.mark.asyncio -async def test_rules_json_with_actor(ds_with_permissions): - """Test /-/rules.json includes actor information.""" - # Use root actor (rules endpoint requires permissions-debug) - response = await ds_with_permissions.client.get( - "/-/rules.json?action=view-table", - actor={"id": "root"}, - ) - assert response.status_code == 200 - data = response.json() - assert data["actor_id"] == "root" - - -@pytest.mark.asyncio -async def test_root_user_respects_settings_deny(): - """ - Test for issue #2509: Settings-based deny rules should override root user privileges. - - When a database has `allow: false` in settings, the root user should NOT see - that database in /-/allowed.json?action=view-database. - """ - ds = Datasette( - config={ - "databases": { - "content": { - "allow": False, # Deny everyone, including root - } - } - } - ) - ds.root_enabled = True - await ds.invoke_startup() - ds.add_memory_database("content") - - # Root user should NOT see the denied database - response = await ds.client.get( - "/-/allowed.json?action=view-database", - actor={"id": "root"}, - ) - assert response.status_code == 200 - data = response.json() - - # Check that content database is NOT in the allowed list - allowed_databases = [item["parent"] for item in data["items"]] - assert "content" not in allowed_databases, ( - f"Root user should not see 'content' database when settings deny it, " - f"but found it in: {allowed_databases}" - ) - - -@pytest.mark.asyncio -async def test_root_user_respects_settings_deny_tables(): - """ - Test for issue #2509: Settings-based deny rules should override root for tables too. - - When a database has `allow: false` in settings, the root user should NOT see - tables from that database in /-/allowed.json?action=view-table. - """ - ds = Datasette( - config={ - "databases": { - "content": { - "allow": False, # Deny everyone, including root - } - } - } - ) - ds.root_enabled = True - await ds.invoke_startup() - - # Add a database with a table - db = ds.add_memory_database("content") - await db.execute_write("CREATE TABLE repos (id INTEGER PRIMARY KEY, name TEXT)") - await ds.refresh_schemas() - - # Root user should NOT see tables from the content database - response = await ds.client.get( - "/-/allowed.json?action=view-table", - actor={"id": "root"}, - ) - assert response.status_code == 200 - data = response.json() - - # Check that content.repos table is NOT in the allowed list - content_tables = [ - item["child"] for item in data["items"] if item["parent"] == "content" - ] - assert "repos" not in content_tables, ( - f"Root user should not see tables from 'content' database when settings deny it, " - f"but found: {content_tables}" - ) - - -@pytest.mark.asyncio -async def test_execute_sql_requires_view_database(): - """ - Test for issue #2527: execute-sql permission should require view-database permission. - - A user who has execute-sql permission but not view-database permission should not - be able to execute SQL on that database. - """ - from datasette.permissions import PermissionSQL - from datasette import hookimpl - - class TestPermissionPlugin: - __name__ = "TestPermissionPlugin" - - @hookimpl - def permission_resources_sql(self, datasette, actor, action): - if actor is None or actor.get("id") != "test_user": - return [] - - if action == "execute-sql": - # Grant execute-sql on the "secret" database - return PermissionSQL( - sql="SELECT 'secret' AS parent, NULL AS child, 1 AS allow, 'can execute sql' AS reason", - ) - elif action == "view-database": - # Deny view-database on the "secret" database - return PermissionSQL( - sql="SELECT 'secret' AS parent, NULL AS child, 0 AS allow, 'cannot view db' AS reason", - ) - - return [] - - plugin = TestPermissionPlugin() - - ds = Datasette() - await ds.invoke_startup() - ds.pm.register(plugin, name="test_plugin") - - try: - ds.add_memory_database("secret") - await ds.refresh_schemas() - - # User should NOT have execute-sql permission because view-database is denied - response = await ds.client.get( - "/-/allowed.json?action=execute-sql", - actor={"id": "test_user"}, - ) - assert response.status_code == 200 - data = response.json() - - # The "secret" database should NOT be in the allowed list for execute-sql - allowed_databases = [item["parent"] for item in data["items"]] - assert "secret" not in allowed_databases, ( - f"User should not have execute-sql permission without view-database, " - f"but found 'secret' in: {allowed_databases}" - ) - - # Also verify that attempting to execute SQL on the database is denied - # (may be 403 or 302 redirect to login/error page depending on middleware) - response = await ds.client.get( - "/secret?sql=SELECT+1", - actor={"id": "test_user"}, - ) - assert response.status_code in (302, 403), ( - f"Expected 302 or 403 when trying to execute SQL without view-database permission, " - f"but got {response.status_code}" - ) - finally: - ds.pm.unregister(plugin) diff --git a/tests/test_permissions.py b/tests/test_permissions.py deleted file mode 100644 index e5e75432..00000000 --- a/tests/test_permissions.py +++ /dev/null @@ -1,1778 +0,0 @@ -import collections -from asgiref.sync import async_to_sync -from datasette.app import Datasette -from datasette.cli import cli -from datasette.default_permissions import restrictions_allow_action -from .fixtures import assert_permissions_checked, make_app_client -from click.testing import CliRunner -from bs4 import BeautifulSoup as Soup -import copy -import json -from pprint import pprint -import pytest_asyncio -import pytest -import re -import time -import urllib - - -@pytest.fixture(scope="module") -def padlock_client(): - with make_app_client( - config={ - "databases": { - "fixtures": { - "queries": {"two": {"sql": "select 1 + 1"}}, - } - } - } - ) as client: - yield client - - -@pytest_asyncio.fixture -async def perms_ds(): - ds = Datasette() - await ds.invoke_startup() - one = ds.add_memory_database("perms_ds_one") - two = ds.add_memory_database("perms_ds_two") - await one.execute_write("create table if not exists t1 (id integer primary key)") - await one.execute_write("insert or ignore into t1 (id) values (1)") - await one.execute_write("create view if not exists v1 as select * from t1") - await one.execute_write("create table if not exists t2 (id integer primary key)") - await two.execute_write("create table if not exists t1 (id integer primary key)") - # Trigger catalog refresh so allowed_resources() can be called - await ds.client.get("/") - return ds - - -@pytest.mark.parametrize( - "allow,expected_anon,expected_auth", - [ - (None, 200, 200), - ({}, 403, 403), - ({"id": "root"}, 403, 200), - ], -) -@pytest.mark.parametrize( - "path", - ( - "/", - "/fixtures", - "/-/api", - "/fixtures/compound_three_primary_keys", - "/fixtures/compound_three_primary_keys/a,a,a", - pytest.param( - "/fixtures/two", - marks=pytest.mark.xfail( - reason="view-query not yet migrated to new permission system" - ), - ), # Query - ), -) -def test_view_padlock(allow, expected_anon, expected_auth, path, padlock_client): - padlock_client.ds.config["allow"] = allow - fragment = "🔒

" - anon_response = padlock_client.get(path) - assert expected_anon == anon_response.status - if allow and anon_response.status == 200: - # Should be no padlock - assert fragment not in anon_response.text - auth_response = padlock_client.get( - path, - cookies={"ds_actor": padlock_client.actor_cookie({"id": "root"})}, - ) - assert expected_auth == auth_response.status - # Check for the padlock - if allow and expected_anon == 403 and expected_auth == 200: - assert fragment in auth_response.text - del padlock_client.ds.config["allow"] - - -@pytest.mark.parametrize( - "allow,expected_anon,expected_auth", - [ - (None, 200, 200), - ({}, 403, 403), - ({"id": "root"}, 403, 200), - ], -) -@pytest.mark.parametrize("use_metadata", (True, False)) -def test_view_database(allow, expected_anon, expected_auth, use_metadata): - key = "metadata" if use_metadata else "config" - kwargs = {key: {"databases": {"fixtures": {"allow": allow}}}} - with make_app_client(**kwargs) as client: - for path in ( - "/fixtures", - "/fixtures/compound_three_primary_keys", - "/fixtures/compound_three_primary_keys/a,a,a", - ): - anon_response = client.get(path) - assert expected_anon == anon_response.status, path - if allow and path == "/fixtures" and anon_response.status == 200: - # Should be no padlock - assert ">fixtures 🔒" not in anon_response.text - auth_response = client.get( - path, - cookies={"ds_actor": client.actor_cookie({"id": "root"})}, - ) - assert expected_auth == auth_response.status - if ( - allow - and path == "/fixtures" - and expected_anon == 403 - and expected_auth == 200 - ): - assert ">fixtures 🔒" in auth_response.text - - -def test_database_list_respects_view_database(): - with make_app_client( - config={"databases": {"fixtures": {"allow": {"id": "root"}}}}, - extra_databases={"data.db": "create table names (name text)"}, - ) as client: - anon_response = client.get("/") - assert 'data' in anon_response.text - assert 'fixtures' not in anon_response.text - auth_response = client.get( - "/", - cookies={"ds_actor": client.actor_cookie({"id": "root"})}, - ) - assert 'data' in auth_response.text - assert 'fixtures 🔒' in auth_response.text - - -def test_database_list_respects_view_table(): - with make_app_client( - config={ - "databases": { - "data": { - "tables": { - "names": {"allow": {"id": "root"}}, - "v": {"allow": {"id": "root"}}, - } - } - } - }, - extra_databases={ - "data.db": "create table names (name text); create view v as select * from names" - }, - ) as client: - html_fragments = [ - ">names 🔒", - ">v 🔒", - ] - anon_response_text = client.get("/").text - assert "0 rows in 0 tables" in anon_response_text - for html_fragment in html_fragments: - assert html_fragment not in anon_response_text - auth_response_text = client.get( - "/", - cookies={"ds_actor": client.actor_cookie({"id": "root"})}, - ).text - for html_fragment in html_fragments: - assert html_fragment in auth_response_text - - -@pytest.mark.parametrize( - "allow,expected_anon,expected_auth", - [ - (None, 200, 200), - ({}, 403, 403), - ({"id": "root"}, 403, 200), - ], -) -@pytest.mark.parametrize("use_metadata", (True, False)) -def test_view_table(allow, expected_anon, expected_auth, use_metadata): - key = "metadata" if use_metadata else "config" - kwargs = { - key: { - "databases": { - "fixtures": { - "tables": {"compound_three_primary_keys": {"allow": allow}} - } - } - } - } - with make_app_client(**kwargs) as client: - anon_response = client.get("/fixtures/compound_three_primary_keys") - assert expected_anon == anon_response.status - if allow and anon_response.status == 200: - # Should be no padlock - assert ">compound_three_primary_keys 🔒" not in anon_response.text - auth_response = client.get( - "/fixtures/compound_three_primary_keys", - cookies={"ds_actor": client.actor_cookie({"id": "root"})}, - ) - assert expected_auth == auth_response.status - if allow and expected_anon == 403 and expected_auth == 200: - assert ">compound_three_primary_keys 🔒" in auth_response.text - - -def test_table_list_respects_view_table(): - with make_app_client( - config={ - "databases": { - "fixtures": { - "tables": { - "compound_three_primary_keys": {"allow": {"id": "root"}}, - # And a SQL view too: - "paginated_view": {"allow": {"id": "root"}}, - } - } - } - } - ) as client: - html_fragments = [ - ">compound_three_primary_keys 🔒", - ">paginated_view 🔒", - ] - anon_response = client.get("/fixtures") - for html_fragment in html_fragments: - assert html_fragment not in anon_response.text - auth_response = client.get( - "/fixtures", cookies={"ds_actor": client.actor_cookie({"id": "root"})} - ) - for html_fragment in html_fragments: - assert html_fragment in auth_response.text - - -@pytest.mark.parametrize( - "allow,expected_anon,expected_auth", - [ - (None, 200, 200), - ({}, 403, 403), - ({"id": "root"}, 403, 200), - ], -) -def test_view_query(allow, expected_anon, expected_auth): - with make_app_client( - config={ - "databases": { - "fixtures": {"queries": {"q": {"sql": "select 1 + 1", "allow": allow}}} - } - } - ) as client: - anon_response = client.get("/fixtures/q") - assert expected_anon == anon_response.status - if allow and anon_response.status == 200: - # Should be no padlock - assert "🔒" not in anon_response.text - auth_response = client.get( - "/fixtures/q", cookies={"ds_actor": client.actor_cookie({"id": "root"})} - ) - assert expected_auth == auth_response.status - if allow and expected_anon == 403 and expected_auth == 200: - assert ">fixtures: q 🔒" in auth_response.text - - -@pytest.mark.parametrize( - "config", - [ - {"allow_sql": {"id": "root"}}, - {"databases": {"fixtures": {"allow_sql": {"id": "root"}}}}, - ], -) -def test_execute_sql(config): - schema_re = re.compile("const schema = ({.*?});", re.DOTALL) - with make_app_client(config=config) as client: - form_fragment = '', - '', - '', - ): - assert fragment in response.text - # Should show one failure and one success - soup = Soup(response.text, "html.parser") - table = soup.find("table", {"id": "permission-checks-table"}) - rows = table.find("tbody").find_all("tr") - checks = [] - for row in rows: - cells = row.find_all("td") - result_cell = cells[5] - if result_cell.select_one(".check-result-true"): - result = True - elif result_cell.select_one(".check-result-false"): - result = False - else: - result = None - actor_code = cells[4].find("code") - actor = json.loads(actor_code.text) if actor_code else None - checks.append( - { - "action": cells[1].text.strip(), - "result": result, - "actor": actor, - } - ) - expected_checks = [ - { - "action": "permissions-debug", - "result": True, - "actor": {"id": "root"}, - }, - { - "action": "view-instance", - "result": True, - "actor": {"id": "root"}, - }, - { - "action": "view-instance", - "result": True, - "actor": None, - }, - { - "action": "permissions-debug", - "result": False, - "actor": None, - }, - { - "action": "view-instance", - "result": True, - "actor": None, - }, - ] - if filter_ == "only-yours": - expected_checks = [ - check for check in expected_checks if check["actor"] is not None - ] - elif filter_ == "exclude-yours": - expected_checks = [check for check in expected_checks if check["actor"] is None] - assert checks == expected_checks - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "actor,allow,expected_fragment", - [ - ('{"id":"root"}', "{}", "Result: deny"), - ('{"id":"root"}', '{"id": "*"}', "Result: allow"), - ('{"', '{"id": "*"}', "Actor JSON error"), - ('{"id":"root"}', '"*"}', "Allow JSON error"), - ], -) -async def test_allow_debug(ds_client, actor, allow, expected_fragment): - response = await ds_client.get( - "/-/allow-debug?" + urllib.parse.urlencode({"actor": actor, "allow": allow}) - ) - assert response.status_code == 200 - assert expected_fragment in response.text - - -@pytest.mark.parametrize( - "allow,expected", - [ - ({"id": "root"}, 403), - ({"id": "root", "unauthenticated": True}, 200), - ], -) -def test_allow_unauthenticated(allow, expected): - with make_app_client(config={"allow": allow}) as client: - assert expected == client.get("/").status - - -@pytest.fixture(scope="session") -def view_instance_client(): - with make_app_client(config={"allow": {}}) as client: - yield client - - -@pytest.mark.parametrize( - "path", - [ - "/", - "/fixtures", - "/fixtures/facetable", - "/-/versions", - "/-/plugins", - "/-/settings", - "/-/threads", - "/-/databases", - "/-/permissions", - "/-/messages", - "/-/patterns", - ], -) -def test_view_instance(path, view_instance_client): - assert 403 == view_instance_client.get(path).status - if path not in ("/-/permissions", "/-/messages", "/-/patterns"): - assert 403 == view_instance_client.get(path + ".json").status - - -@pytest.fixture(scope="session") -def cascade_app_client(): - with make_app_client(is_immutable=True) as client: - yield client - - -@pytest.mark.parametrize( - "path,permissions,expected_status", - [ - ("/", [], 403), - ("/", ["instance"], 200), - # Can view table even if not allowed database or instance - ("/fixtures/binary_data", [], 403), - ("/fixtures/binary_data", ["database"], 403), - ("/fixtures/binary_data", ["instance"], 403), - ("/fixtures/binary_data", ["table"], 200), - ("/fixtures/binary_data", ["table", "database"], 200), - ("/fixtures/binary_data", ["table", "database", "instance"], 200), - # ... same for row - ("/fixtures/binary_data/1", [], 403), - ("/fixtures/binary_data/1", ["database"], 403), - ("/fixtures/binary_data/1", ["instance"], 403), - ("/fixtures/binary_data/1", ["table"], 200), - ("/fixtures/binary_data/1", ["table", "database"], 200), - ("/fixtures/binary_data/1", ["table", "database", "instance"], 200), - # Can view query even if not allowed database or instance - ("/fixtures/magic_parameters", [], 403), - ("/fixtures/magic_parameters", ["database"], 403), - ("/fixtures/magic_parameters", ["instance"], 403), - ("/fixtures/magic_parameters", ["query"], 200), - ("/fixtures/magic_parameters", ["query", "database"], 200), - ("/fixtures/magic_parameters", ["query", "database", "instance"], 200), - # Can view database even if not allowed instance - ("/fixtures", [], 403), - ("/fixtures", ["instance"], 403), - ("/fixtures", ["database"], 200), - # Downloading the fixtures.db file - ("/fixtures.db", [], 403), - ("/fixtures.db", ["instance"], 403), - ("/fixtures.db", ["database"], 200), - ("/fixtures.db", ["download"], 200), - ], -) -def test_permissions_cascade(cascade_app_client, path, permissions, expected_status): - """Test that e.g. having view-table but NOT view-database lets you view table page, etc""" - allow = {"id": "*"} - deny = {} - previous_config = cascade_app_client.ds.config - updated_config = copy.deepcopy(previous_config) - actor = {"id": "test"} - if "download" in permissions: - actor["can_download"] = 1 - try: - # Set up the different allow blocks - updated_config["allow"] = allow if "instance" in permissions else deny - # Note: download permission also needs database access (via plugin granting both) - # so we don't set a deny rule when download is in permissions - updated_config["databases"]["fixtures"]["allow"] = ( - allow if ("database" in permissions or "download" in permissions) else deny - ) - updated_config["databases"]["fixtures"]["tables"]["binary_data"] = { - "allow": (allow if "table" in permissions else deny) - } - updated_config["databases"]["fixtures"]["queries"]["magic_parameters"][ - "allow" - ] = (allow if "query" in permissions else deny) - cascade_app_client.ds.config = updated_config - response = cascade_app_client.get( - path, - cookies={"ds_actor": cascade_app_client.actor_cookie(actor)}, - ) - assert ( - response.status == expected_status - ), "path: {}, permissions: {}, expected_status: {}, status: {}".format( - path, permissions, expected_status, response.status - ) - finally: - cascade_app_client.ds.config = previous_config - - -def test_padlocks_on_database_page(cascade_app_client): - config = { - "databases": { - "fixtures": { - "allow": {"id": "test"}, - "tables": { - "123_starts_with_digits": {"allow": True}, - "simple_view": {"allow": True}, - }, - "queries": {"query_two": {"allow": True, "sql": "select 2"}}, - } - } - } - previous_config = cascade_app_client.ds.config - try: - cascade_app_client.ds.config = config - async_to_sync(cascade_app_client.ds.invoke_startup)() - async_to_sync(cascade_app_client.ds.add_query)( - "fixtures", "query_two", "select 2", source="config" - ) - response = cascade_app_client.get( - "/fixtures", - cookies={"ds_actor": cascade_app_client.actor_cookie({"id": "test"})}, - ) - # Tables - assert ">123_starts_with_digits" in response.text - assert ">Table With Space In Name 🔒" in response.text - # Queries - assert ">query_two" in response.text - # Views - assert ">paginated_view 🔒" in response.text - assert ">simple_view" in response.text - finally: - cascade_app_client.ds.config = previous_config - async_to_sync(cascade_app_client.ds.remove_query)("fixtures", "query_two") - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "actor,permission,resource_1,resource_2,expected_result", - ( - # Without restrictions the defaults apply - ({"id": "t"}, "view-instance", None, None, True), - ({"id": "t"}, "view-database", "one", None, True), - ({"id": "t"}, "view-table", "one", "t1", True), - # If there is an _r block, everything gets denied unless explicitly allowed - ({"id": "t", "_r": {}}, "view-instance", None, None, False), - ({"id": "t", "_r": {}}, "view-database", "one", None, False), - ({"id": "t", "_r": {}}, "view-table", "one", "t1", False), - # Explicit allowing works at the "a" for all level: - ({"id": "t", "_r": {"a": ["vi"]}}, "view-instance", None, None, True), - ({"id": "t", "_r": {"a": ["vd"]}}, "view-database", "one", None, True), - ({"id": "t", "_r": {"a": ["vt"]}}, "view-table", "one", "t1", True), - # But not if it's the wrong permission - ({"id": "t", "_r": {"a": ["vi"]}}, "view-database", "one", None, False), - ({"id": "t", "_r": {"a": ["vd"]}}, "view-table", "one", "t1", False), - # Works at the "d" for database level: - ({"id": "t", "_r": {"d": {"one": ["vd"]}}}, "view-database", "one", None, True), - ( - # view-database-download requires view-database too (also_requires) - {"id": "t", "_r": {"d": {"one": ["vdd", "vd"]}}}, - "view-database-download", - "one", - None, - True, - ), - ( - # execute-sql requires view-database too (also_requires) - {"id": "t", "_r": {"d": {"one": ["es", "vd"]}}}, - "execute-sql", - "one", - None, - True, - ), - # Works at the "r" for table level: - ( - {"id": "t", "_r": {"r": {"one": {"t1": ["vt"]}}}}, - "view-table", - "one", - "t1", - True, - ), - ( - {"id": "t", "_r": {"r": {"one": {"t1": ["vt"]}}}}, - "view-table", - "one", - "t2", - False, - ), - # non-abbreviations should work too - ( - {"id": "t", "_r": {"a": ["view-instance"]}}, - "view-instance", - None, - None, - True, - ), - ( - {"id": "t", "_r": {"d": {"one": ["view-database"]}}}, - "view-database", - "one", - None, - True, - ), - ( - {"id": "t", "_r": {"r": {"one": {"t1": ["view-table"]}}}}, - "view-table", - "one", - "t1", - True, - ), - # view-database does NOT grant view-instance (no upward cascading) - ({"id": "t", "_r": {"a": ["vd"]}}, "view-instance", None, None, False), - ), -) -async def test_actor_restricted_permissions( - perms_ds, actor, permission, resource_1, resource_2, expected_result -): - perms_ds.pdb = True - perms_ds.root_enabled = True # Allow root actor to access /-/permissions - cookies = {"ds_actor": perms_ds.sign({"a": {"id": "root"}}, "actor")} - response = await perms_ds.client.post( - "/-/permissions", - data={ - "actor": json.dumps(actor), - "permission": permission, - "resource_1": resource_1, - "resource_2": resource_2, - }, - cookies=cookies, - ) - # Response mirrors /-/check JSON structure - if resource_1 is None: - expected_path = "/" - elif resource_2 is None: - expected_path = f"/{resource_1}" - else: - expected_path = f"/{resource_1}/{resource_2}" - - expected_resource = { - "parent": resource_1, - "child": resource_2, - "path": expected_path, - } - expected = { - "action": permission, - "allowed": expected_result, - "resource": expected_resource, - } - if actor.get("id"): - expected["actor_id"] = actor["id"] - assert response.json() == expected - - -PermConfigTestCase = collections.namedtuple( - "PermConfigTestCase", - "config,actor,action,resource,expected_result", -) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "config,actor,action,resource,expected_result", - ( - # Simple view-instance default=True example - PermConfigTestCase( - config={}, - actor=None, - action="view-instance", - resource=None, - expected_result=True, - ), - # debug-menu on root - PermConfigTestCase( - config={"permissions": {"debug-menu": {"id": "user"}}}, - actor={"id": "user"}, - action="debug-menu", - resource=None, - expected_result=True, - ), - # debug-menu on root, wrong actor - PermConfigTestCase( - config={"permissions": {"debug-menu": {"id": "user"}}}, - actor={"id": "user2"}, - action="debug-menu", - resource=None, - expected_result=False, - ), - # create-table on root - PermConfigTestCase( - config={"permissions": {"create-table": {"id": "user"}}}, - actor={"id": "user"}, - action="create-table", - resource=None, - expected_result=True, - ), - # create-table on database - no resource specified - PermConfigTestCase( - config={ - "databases": { - "perms_ds_one": {"permissions": {"create-table": {"id": "user"}}} - } - }, - actor={"id": "user"}, - action="create-table", - resource=None, - expected_result=False, - ), - # create-table on database - PermConfigTestCase( - config={ - "databases": { - "perms_ds_one": {"permissions": {"create-table": {"id": "user"}}} - } - }, - actor={"id": "user"}, - action="create-table", - resource="perms_ds_one", - expected_result=True, - ), - # insert-row on root, wrong actor - PermConfigTestCase( - config={"permissions": {"insert-row": {"id": "user"}}}, - actor={"id": "user2"}, - action="insert-row", - resource=("perms_ds_one", "t1"), - expected_result=False, - ), - # insert-row on root, right actor - PermConfigTestCase( - config={"permissions": {"insert-row": {"id": "user"}}}, - actor={"id": "user"}, - action="insert-row", - resource=("perms_ds_one", "t1"), - expected_result=True, - ), - # set-column-type on specific table - PermConfigTestCase( - config={ - "databases": { - "perms_ds_one": { - "tables": { - "t1": {"permissions": {"set-column-type": {"id": "user"}}} - } - } - } - }, - actor={"id": "user"}, - action="set-column-type", - resource=("perms_ds_one", "t1"), - expected_result=True, - ), - # insert-row on database - PermConfigTestCase( - config={ - "databases": { - "perms_ds_one": {"permissions": {"insert-row": {"id": "user"}}} - } - }, - actor={"id": "user"}, - action="insert-row", - resource="perms_ds_one", - expected_result=True, - ), - # insert-row on table, wrong table - PermConfigTestCase( - config={ - "databases": { - "perms_ds_one": { - "tables": { - "t1": {"permissions": {"insert-row": {"id": "user"}}} - } - } - } - }, - actor={"id": "user"}, - action="insert-row", - resource=("perms_ds_one", "t2"), - expected_result=False, - ), - # insert-row on table, right table - PermConfigTestCase( - config={ - "databases": { - "perms_ds_one": { - "tables": { - "t1": {"permissions": {"insert-row": {"id": "user"}}} - } - } - } - }, - actor={"id": "user"}, - action="insert-row", - resource=("perms_ds_one", "t1"), - expected_result=True, - ), - # view-query on stored query, wrong actor - PermConfigTestCase( - config={ - "databases": { - "perms_ds_one": { - "queries": { - "q1": { - "sql": "select 1 + 1", - "permissions": {"view-query": {"id": "user"}}, - } - } - } - } - }, - actor={"id": "user2"}, - action="view-query", - resource=("perms_ds_one", "q1"), - expected_result=False, - ), - # view-query on stored query, right actor - PermConfigTestCase( - config={ - "databases": { - "perms_ds_one": { - "queries": { - "q1": { - "sql": "select 1 + 1", - "permissions": {"view-query": {"id": "user"}}, - } - } - } - } - }, - actor={"id": "user"}, - action="view-query", - resource=("perms_ds_one", "q1"), - expected_result=True, - ), - ), -) -async def test_permissions_in_config( - perms_ds, config, actor, action, resource, expected_result -): - previous_config = perms_ds.config - updated_config = copy.deepcopy(previous_config) - updated_config.update(config) - perms_ds.config = updated_config - await perms_ds._save_queries_from_config() - try: - # Convert old-style resource to Resource object - from datasette.resources import DatabaseResource, QueryResource, TableResource - - resource_obj = None - if resource: - if isinstance(resource, str): - resource_obj = DatabaseResource(database=resource) - elif isinstance(resource, tuple) and len(resource) == 2: - if action == "view-query": - resource_obj = QueryResource( - database=resource[0], query=resource[1] - ) - else: - resource_obj = TableResource( - database=resource[0], table=resource[1] - ) - - result = await perms_ds.allowed( - action=action, resource=resource_obj, actor=actor - ) - if result != expected_result: - pprint(perms_ds._permission_checks) - assert result == expected_result - finally: - perms_ds.config = previous_config - await perms_ds._save_queries_from_config() - - -@pytest.mark.asyncio -async def test_allowed_resources_view_query_includes_actor_specific_query_permissions(): - from datasette import hookimpl - from datasette.permissions import PermissionSQL - from datasette.resources import QueryResource - - class ActorSpecificQueryPermissionPlugin: - __name__ = "ActorSpecificQueryPermissionPlugin" - - @hookimpl - def permission_resources_sql(self, datasette, actor, action): - if action == "view-query" and actor and actor.get("id") == "alice": - return PermissionSQL(sql=""" - SELECT 'testdb' AS parent, 'user_only' AS child, 1 AS allow, - 'alice can view this query' AS reason - """) - return None - - ds = Datasette(default_deny=True) - await ds.invoke_startup() - ds.add_memory_database("testdb") - await ds._refresh_schemas() - await ds.add_query("testdb", "user_only", "select 1 as n") - - plugin = ActorSpecificQueryPermissionPlugin() - ds.pm.register(plugin, name="actor_specific_query_permission_plugin") - - try: - actor = {"id": "alice"} - - assert await ds.allowed( - action="view-query", - resource=QueryResource("testdb", "user_only"), - actor=actor, - ) - - page = await ds.allowed_resources("view-query", actor) - assert [(resource.parent, resource.child) for resource in page.resources] == [ - ("testdb", "user_only") - ] - finally: - ds.pm.unregister(name="actor_specific_query_permission_plugin") - - -@pytest.mark.asyncio -async def test_actor_endpoint_allows_any_token(): - ds = Datasette() - token = ds.sign( - { - "a": "root", - "token": "dstok", - "t": int(time.time()), - "_r": {"a": ["debug-menu"]}, - }, - namespace="token", - ) - response = await ds.client.get( - "/-/actor.json", headers={"Authorization": f"Bearer dstok_{token}"} - ) - assert response.status_code == 200 - assert response.json()["actor"] == { - "id": "root", - "token": "dstok", - "_r": {"a": ["debug-menu"]}, - } - - -@pytest.mark.serial -@pytest.mark.parametrize( - "options,expected", - ( - ([], {"id": "root", "token": "dstok"}), - ( - ["--all", "debug-menu"], - {"_r": {"a": ["dm"]}, "id": "root", "token": "dstok"}, - ), - ( - ["-a", "debug-menu", "--all", "create-table"], - {"_r": {"a": ["dm", "ct"]}, "id": "root", "token": "dstok"}, - ), - ( - ["-r", "db1", "t1", "insert-row"], - {"_r": {"r": {"db1": {"t1": ["ir"]}}}, "id": "root", "token": "dstok"}, - ), - ( - ["-d", "db1", "create-table"], - {"_r": {"d": {"db1": ["ct"]}}, "id": "root", "token": "dstok"}, - ), - # And one with all of them multiple times using all the names - ( - [ - "-a", - "debug-menu", - "--all", - "create-table", - "-r", - "db1", - "t1", - "insert-row", - "--resource", - "db1", - "t2", - "update-row", - "-d", - "db1", - "create-table", - "--database", - "db2", - "drop-table", - ], - { - "_r": { - "a": ["dm", "ct"], - "d": {"db1": ["ct"], "db2": ["dt"]}, - "r": {"db1": {"t1": ["ir"], "t2": ["ur"]}}, - }, - "id": "root", - "token": "dstok", - }, - ), - ), -) -def test_cli_create_token(options, expected): - runner = CliRunner() - result1 = runner.invoke( - cli, - [ - "create-token", - "--secret", - "sekrit", - "root", - ] - + options, - ) - token = result1.output.strip() - result2 = runner.invoke( - cli, - [ - "serve", - "--secret", - "sekrit", - "--get", - "/-/actor.json", - "--token", - token, - ], - ) - assert 0 == result2.exit_code, result2.output - assert json.loads(result2.output) == {"actor": expected} - - -_visible_tables_re = re.compile(r">\/((\w+)\/(\w+))\.json<\/a> - Get rows for") - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "is_logged_in,config,expected_visible_tables", - ( - # Unprotected instance logged out user sees everything: - ( - False, - None, - ["perms_ds_one/t1", "perms_ds_one/t2", "perms_ds_two/t1"], - ), - # Fully protected instance logged out user sees nothing - (False, {"allow": {"id": "user"}}, None), - # User with visibility of just perms_ds_one sees both tables there - ( - True, - { - "databases": { - "perms_ds_one": {"allow": {"id": "user"}}, - "perms_ds_two": {"allow": False}, - } - }, - ["perms_ds_one/t1", "perms_ds_one/t2"], - ), - # User with visibility of only table perms_ds_one/t1 sees just that one - ( - True, - { - "databases": { - "perms_ds_one": { - "allow": {"id": "user"}, - "tables": {"t2": {"allow": False}}, - }, - "perms_ds_two": {"allow": False}, - } - }, - ["perms_ds_one/t1"], - ), - ), -) -async def test_api_explorer_visibility( - perms_ds, is_logged_in, config, expected_visible_tables -): - try: - prev_config = perms_ds.config - perms_ds.config = config or {} - kwargs = {} - if is_logged_in: - kwargs["actor"] = {"id": "user"} - response = await perms_ds.client.get("/-/api", **kwargs) - if expected_visible_tables: - assert response.status_code == 200 - # Search HTML for stuff matching: - # '>/perms_ds_one/t2.json - Get rows for' - visible_tables = [ - match[0] for match in _visible_tables_re.findall(response.text) - ] - assert visible_tables == expected_visible_tables - else: - assert response.status_code == 403 - finally: - perms_ds.config = prev_config - - -@pytest.mark.asyncio -async def test_view_table_token_cannot_gain_access_without_base_permission(perms_ds): - # Only allow a different actor to view this table - previous_config = perms_ds.config - perms_ds.config = { - "databases": { - "perms_ds_two": { - # Only someone-else can see anything in this database - "allow": {"id": "someone-else"}, - } - } - } - try: - actor = { - "id": "restricted-token", - "token": "dstok", - # Restricted token claims access to perms_ds_two/t1 only - "_r": {"r": {"perms_ds_two": {"t1": ["vt"]}}}, - } - response = await perms_ds.client.get("/perms_ds_two/t1.json", actor=actor) - assert response.status_code == 403 - finally: - perms_ds.config = previous_config - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "restrictions,verb,path,body,expected_status", - ( - # No restrictions - (None, "get", "/.json", None, 200), - (None, "get", "/perms_ds_one.json", None, 200), - (None, "get", "/perms_ds_one/t1.json", None, 200), - (None, "get", "/perms_ds_one/t1/1.json", None, 200), - (None, "get", "/perms_ds_one/v1.json", None, 200), - # Restricted to just view-instance - ({"a": ["vi"]}, "get", "/.json", None, 200), - ({"a": ["vi"]}, "get", "/perms_ds_one.json", None, 403), - ({"a": ["vi"]}, "get", "/perms_ds_one/t1.json", None, 403), - ({"a": ["vi"]}, "get", "/perms_ds_one/t1/1.json", None, 403), - ({"a": ["vi"]}, "get", "/perms_ds_one/v1.json", None, 403), - # Restricted to just view-database - ( - {"a": ["vd"]}, - "get", - "/.json", - None, - 403, - ), # Cannot see instance (no upward cascading) - ({"a": ["vd"]}, "get", "/perms_ds_one.json", None, 200), - ({"a": ["vd"]}, "get", "/perms_ds_one/t1.json", None, 403), - ({"a": ["vd"]}, "get", "/perms_ds_one/t1/1.json", None, 403), - ({"a": ["vd"]}, "get", "/perms_ds_one/v1.json", None, 403), - # Restricted to just view-table for specific database - ( - {"d": {"perms_ds_one": ["vt"]}}, - "get", - "/.json", - None, - 403, - ), # Cannot see instance (no upward cascading) - ( - {"d": {"perms_ds_one": ["vt"]}}, - "get", - "/perms_ds_one.json", - None, - 403, - ), # Cannot see database page (no upward cascading) - ( - {"d": {"perms_ds_one": ["vt"]}}, - "get", - "/perms_ds_two.json", - None, - 403, - ), # But not this one - ( - # Can see the table - {"d": {"perms_ds_one": ["vt"]}}, - "get", - "/perms_ds_one/t1.json", - None, - 200, - ), - ( - # And the view - {"d": {"perms_ds_one": ["vt"]}}, - "get", - "/perms_ds_one/v1.json", - None, - 200, - ), - # view-table access to a specific table - ( - {"r": {"perms_ds_one": {"t1": ["vt"]}}}, - "get", - "/.json", - None, - 403, - ), # Cannot see instance (no upward cascading) - ( - {"r": {"perms_ds_one": {"t1": ["vt"]}}}, - "get", - "/perms_ds_one.json", - None, - 403, - ), # Cannot see database page (no upward cascading) - ( - {"r": {"perms_ds_one": {"t1": ["vt"]}}}, - "get", - "/perms_ds_one/t1.json", - None, - 200, - ), - # But cannot see the other table - ( - {"r": {"perms_ds_one": {"t1": ["vt"]}}}, - "get", - "/perms_ds_one/t2.json", - None, - 403, - ), - # Or the view - ( - {"r": {"perms_ds_one": {"t1": ["vt"]}}}, - "get", - "/perms_ds_one/v1.json", - None, - 403, - ), - ), -) -async def test_actor_restrictions( - perms_ds, restrictions, verb, path, body, expected_status -): - actor = {"id": "user"} - if restrictions: - actor["_r"] = restrictions - method = getattr(perms_ds.client, verb) - kwargs = {"actor": actor} - if body: - kwargs["json"] = body - perms_ds._permission_checks.clear() - response = await method(path, **kwargs) - assert response.status_code == expected_status, json.dumps( - { - "verb": verb, - "path": path, - "body": body, - "restrictions": restrictions, - "expected_status": expected_status, - "response_status": response.status_code, - "checks": [ - { - "action": check.action, - "parent": check.parent, - "child": check.child, - "result": check.result, - } - for check in perms_ds._permission_checks - ], - }, - indent=2, - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "restrictions,action,resource,expected", - ( - # Exact match: view-instance restriction allows view-instance action - ({"a": ["view-instance"]}, "view-instance", None, True), - # No implication: view-table does NOT imply view-instance - ({"a": ["view-table"]}, "view-instance", None, False), - ({"a": ["view-database"]}, "view-instance", None, False), - # update-row does not imply view-instance - ({"a": ["update-row"]}, "view-instance", None, False), - # view-table on a resource does NOT imply view-instance - ({"r": {"db1": {"t1": ["view-table"]}}}, "view-instance", None, False), - # execute-sql on a database does NOT imply view-instance or view-database - ({"d": {"db1": ["es"]}}, "view-instance", None, False), - ({"d": {"db1": ["es"]}}, "view-database", "db1", False), - ({"d": {"db1": ["es"]}}, "view-database", "db2", False), - # But execute-sql abbreviation DOES allow execute-sql action on that database - ({"d": {"db1": ["es"]}}, "execute-sql", "db1", True), - # update-row on a resource does not imply view-instance - ({"r": {"db1": {"t1": ["update-row"]}}}, "view-instance", None, False), - # view-database on a database does NOT imply view-instance - ({"d": {"db1": ["view-database"]}}, "view-instance", None, False), - # But it DOES allow view-database on that specific database - ({"d": {"db1": ["view-database"]}}, "view-database", "db1", True), - # Having view-table on "a" allows access to any specific table - ({"a": ["view-table"]}, "view-table", ("dbname", "tablename"), True), - # Having view-table on a database allows access to tables in that database - ( - {"d": {"dbname": ["view-table"]}}, - "view-table", - ("dbname", "tablename"), - True, - ), - # But not if it's allowed on a different database - ( - {"d": {"dbname": ["view-table"]}}, - "view-table", - ("dbname2", "tablename"), - False, - ), - # Table-level restriction allows access to that specific table - ( - {"r": {"dbname": {"tablename": ["view-table"]}}}, - "view-table", - ("dbname", "tablename"), - True, - ), - # But not to a different table in the same database - ( - {"r": {"dbname": {"tablename": ["view-table"]}}}, - "view-table", - ("dbname", "other_table"), - False, - ), - ), -) -async def test_restrictions_allow_action(restrictions, action, resource, expected): - ds = Datasette() - await ds.invoke_startup() - actual = restrictions_allow_action(ds, restrictions, action, resource) - assert actual == expected - - -@pytest.mark.asyncio -async def test_actor_restrictions_filters_allowed_resources(perms_ds): - """Test that allowed_resources() respects actor restrictions - issue #2534""" - - # Actor restricted to just perms_ds_one/t1 - actor = {"id": "user", "_r": {"r": {"perms_ds_one": {"t1": ["vt"]}}}} - - # Should only return t1 - page = await perms_ds.allowed_resources("view-table", actor) - assert len(page.resources) == 1 - assert page.resources[0].parent == "perms_ds_one" - assert page.resources[0].child == "t1" - - # Database listing should be empty (no view-database permission) - db_page = await perms_ds.allowed_resources("view-database", actor) - assert len(db_page.resources) == 0 - - -@pytest.mark.asyncio -async def test_actor_restrictions_do_not_expand_allowed_resources(perms_ds): - """Restrictions cannot grant access not already allowed to the actor.""" - - previous_config = perms_ds.config - perms_ds.config = { - "databases": { - "perms_ds_one": { - "allow": {"id": "someone-else"}, - } - } - } - try: - actor = {"id": "user", "_r": {"r": {"perms_ds_one": {"t1": ["vt"]}}}} - - # Base actor is not allowed to see t1, so restrictions should not change that - page = await perms_ds.allowed_resources("view-table", actor) - assert len(page.resources) == 0 - - # And explicit permission checks should still deny - response = await perms_ds.client.get( - "/perms_ds_one/t1.json", - actor=actor, - ) - assert response.status_code == 403 - finally: - perms_ds.config = previous_config - - -@pytest.mark.asyncio -async def test_actor_restrictions_database_level(perms_ds): - """Test database-level restrictions allow all tables in database - issue #2534""" - - actor = {"id": "user", "_r": {"d": {"perms_ds_one": ["vt"]}}} - - page = await perms_ds.allowed_resources("view-table", actor, parent="perms_ds_one") - - # Should return all tables in perms_ds_one - table_names = {r.child for r in page.resources} - assert "t1" in table_names - assert "t2" in table_names - assert "v1" in table_names # views too - - -@pytest.mark.asyncio -async def test_actor_restrictions_global_level(perms_ds): - """Test global-level restrictions allow all resources - issue #2534""" - - actor = {"id": "user", "_r": {"a": ["vt"]}} - - page = await perms_ds.allowed_resources("view-table", actor) - - # Should return all tables in all databases - assert len(page.resources) > 0 - dbs = {r.parent for r in page.resources} - assert "perms_ds_one" in dbs - assert "perms_ds_two" in dbs - - -@pytest.mark.asyncio -async def test_restrictions_gate_before_config(perms_ds): - """Test that restrictions act as gating filter before config permissions - issue #2534""" - from datasette.resources import TableResource - - # Actor restricted to just t1 (not t2) - actor = {"id": "user", "_r": {"r": {"perms_ds_one": {"t1": ["vt"]}}}} - - # Config doesn't matter - restrictions gate what's checked - # t2 is not in restriction allowlist, so should be DENIED - result = await perms_ds.allowed( - action="view-table", - resource=TableResource("perms_ds_one", "t2"), - actor=actor, - ) - assert result is False - - # t1 is in restrictions AND passes normal permission check - should be ALLOWED - result = await perms_ds.allowed( - action="view-table", - resource=TableResource("perms_ds_one", "t1"), - actor=actor, - ) - assert result is True - - -@pytest.mark.asyncio -async def test_actor_restrictions_json_endpoints_show_filtered_listings(perms_ds): - """Test that /.json and /db.json show correct filtered listings - issue #2534""" - - actor = {"id": "user", "_r": {"r": {"perms_ds_one": {"t1": ["vt"]}}}} - - # /.json should be 403 (no view-instance permission) - response = await perms_ds.client.get("/.json", actor=actor) - assert response.status_code == 403 - - # /perms_ds_one.json should be 403 (no view-database permission) - response = await perms_ds.client.get("/perms_ds_one.json", actor=actor) - assert response.status_code == 403 - - # /perms_ds_one/t1.json should be 200 - response = await perms_ds.client.get("/perms_ds_one/t1.json", actor=actor) - assert response.status_code == 200 - - -@pytest.mark.asyncio -async def test_actor_restrictions_view_instance_only(perms_ds): - """Test actor restricted to view-instance only - issue #2534""" - - actor = {"id": "user", "_r": {"a": ["vi"]}} - - # /.json should be 200 (has view-instance permission) - response = await perms_ds.client.get("/.json", actor=actor) - assert response.status_code == 200 - - # But no databases should be visible (no view-database permission) - # The instance is visible but databases list should be empty or minimal - # Actually, let's check via allowed_resources - page = await perms_ds.allowed_resources("view-database", actor) - assert len(page.resources) == 0 - - -@pytest.mark.asyncio -async def test_actor_restrictions_empty_allowlist(perms_ds): - """Test actor with empty restrictions allowlist denies everything - issue #2534""" - - actor = {"id": "user", "_r": {}} - - # No actions in allowlist, so everything should be denied - page1 = await perms_ds.allowed_resources("view-table", actor) - assert len(page1.resources) == 0 - - page2 = await perms_ds.allowed_resources("view-database", actor) - assert len(page2.resources) == 0 - - result = await perms_ds.allowed(action="view-instance", actor=actor) - assert result is False - - -@pytest.mark.asyncio -async def test_actor_restrictions_cannot_be_overridden_by_config(): - """Test that config permissions cannot override actor restrictions - issue #2534""" - from datasette.app import Datasette - from datasette.resources import TableResource - - # Create datasette with config that allows user to access both t1 AND t2 - config = { - "databases": { - "test_db": { - "tables": { - "t1": {"allow": {"id": "user"}}, - "t2": {"allow": {"id": "user"}}, - } - } - } - } - - ds = Datasette(config=config) - await ds.invoke_startup() - db = ds.add_memory_database("test_db") - await db.execute_write("create table t1 (id integer primary key)") - await db.execute_write("create table t2 (id integer primary key)") - - # Actor restricted to ONLY t1 (not t2) - # Even though config allows t2, restrictions should deny it - actor = {"id": "user", "_r": {"r": {"test_db": {"t1": ["vt"]}}}} - - # t1 should be allowed (in restrictions AND config allows) - result = await ds.allowed( - action="view-table", resource=TableResource("test_db", "t1"), actor=actor - ) - assert result is True, "t1 should be allowed - in restriction allowlist" - - # t2 should be DENIED (not in restrictions, even though config allows) - result = await ds.allowed( - action="view-table", resource=TableResource("test_db", "t2"), actor=actor - ) - assert ( - result is False - ), "t2 should be denied - NOT in restriction allowlist, config cannot override" - - -@pytest.mark.asyncio -async def test_actor_restrictions_with_database_level_config(perms_ds): - """Test database-level restrictions with table-level config - issue #2534""" - from datasette.resources import TableResource - - # Config allows specific tables only - perms_ds._config = { - "databases": { - "perms_ds_one": { - "tables": { - "t1": {"allow": {"id": "user"}}, - "t2": {"allow": {"id": "user"}}, - } - } - } - } - - # Actor has database-level restriction (all tables in perms_ds_one) - # Should only access tables that pass BOTH restrictions AND config - actor = {"id": "user", "_r": {"d": {"perms_ds_one": ["vt"]}}} - - # t1 - in restrictions (all tables) AND config allows - result = await perms_ds.allowed( - action="view-table", resource=TableResource("perms_ds_one", "t1"), actor=actor - ) - assert result is True - - # t2 - in restrictions (all tables) AND config allows - result = await perms_ds.allowed( - action="view-table", resource=TableResource("perms_ds_one", "t2"), actor=actor - ) - assert result is True - - # v1 (view) - in restrictions (all tables) AND config doesn't mention it - # Since actor has database-level restriction allowing all tables, v1 is allowed - # Config is additive, not restrictive - it doesn't create implicit denies - result = await perms_ds.allowed( - action="view-table", resource=TableResource("perms_ds_one", "v1"), actor=actor - ) - assert result is True, "v1 should be allowed - actor has db-level restriction" - - # Clean up - perms_ds._config = None - - -@pytest.mark.asyncio -async def test_actor_restrictions_parent_deny_blocks_config_child_allow(perms_ds): - """ - Test that table-level restrictions add parent-level deny to block - other tables in the same database, even if config allows them - """ - from datasette.resources import TableResource - - # Config allows both t1 and t2 - perms_ds._config = { - "databases": { - "perms_ds_one": { - "tables": { - "t1": {"allow": {"id": "user"}}, - "t2": {"allow": {"id": "user"}}, - } - } - } - } - - # Restriction allows ONLY t1 in perms_ds_one - # This should add: - # - parent-level DENY for perms_ds_one (to block other tables) - # - child-level ALLOW for t1 - actor = {"id": "user", "_r": {"r": {"perms_ds_one": {"t1": ["vt"]}}}} - - # t1 should work (child-level allow beats parent-level deny) - result = await perms_ds.allowed( - action="view-table", resource=TableResource("perms_ds_one", "t1"), actor=actor - ) - assert result is True - - # t2 should be DENIED by parent-level deny from restrictions - # even though config has child-level allow - # Because restrictions should run first - result = await perms_ds.allowed( - action="view-table", resource=TableResource("perms_ds_one", "t2"), actor=actor - ) - assert ( - result is False - ), "t2 should be denied - restriction parent deny should beat config child allow" - - # Clean up - perms_ds._config = None - - -@pytest.mark.asyncio -async def test_permission_check_view_requires_debug_permission(): - """Test that /-/check requires permissions-debug permission""" - # Anonymous user should be denied - ds = Datasette() - response = await ds.client.get("/-/check.json?action=view-instance") - assert response.status_code == 403 - assert "permissions-debug" in response.text - - # User without permissions-debug should be denied - response = await ds.client.get( - "/-/check.json?action=view-instance", - cookies={"ds_actor": ds.sign({"id": "user"}, "actor")}, - ) - assert response.status_code == 403 - - # Root user should have access (root has all permissions) - ds_with_root = Datasette() - ds_with_root.root_enabled = True - root_token = await ds_with_root.create_token("root", handler="signed") - response = await ds_with_root.client.get( - "/-/check.json?action=view-instance", - headers={"Authorization": f"Bearer {root_token}"}, - ) - assert response.status_code == 200 - data = response.json() - assert data["action"] == "view-instance" - assert data["allowed"] is True - - -@pytest.mark.asyncio -async def test_root_allow_block_with_table_restricted_actor(): - """ - Test that root-level allow: blocks are processed for actors with - table-level restrictions. - - This covers the case in config.py is_in_restriction_allowlist() where - parent=None, child=None and actor has table restrictions but not global. - """ - from datasette.resources import TableResource - - # Config with root-level allow block that denies non-admin users - ds = Datasette( - config={ - "allow": {"id": "admin"}, # Root-level allow block - } - ) - await ds.invoke_startup() - db = ds.add_memory_database("mydb") - await db.execute_write("create table t1 (id integer primary key)") - await ds.client.get("/") # Trigger catalog refresh - - # Actor with table-level restrictions only (not global) - actor = {"id": "user", "_r": {"r": {"mydb": {"t1": ["view-table"]}}}} - - # The root-level allow: {id: admin} should be processed and deny this user - # because they're not "admin", even though they have table restrictions - result = await ds.allowed( - action="view-table", - resource=TableResource("mydb", "t1"), - actor=actor, - ) - # Should be False because root allow: {id: admin} denies non-admin users - assert result is False - - # But admin with same restrictions should be allowed - admin_actor = {"id": "admin", "_r": {"r": {"mydb": {"t1": ["view-table"]}}}} - result = await ds.allowed( - action="view-table", - resource=TableResource("mydb", "t1"), - actor=admin_actor, - ) - assert result is True diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 32276437..96ffc8d9 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,1970 +1,69 @@ from bs4 import BeautifulSoup as Soup -from .fixtures import ( - make_app_client, - TEMP_PLUGIN_SECRET_FILE, - PLUGINS_DIR, - TestClient as _TestClient, -) # noqa -from click.testing import CliRunner -from datasette.app import Datasette -from datasette import cli, hookimpl -from datasette.fixtures import TABLES -from datasette.filters import FilterArguments -from datasette.plugins import get_plugins, DEFAULT_PLUGINS, pm -from datasette.permissions import PermissionSQL, Action -from datasette.resources import DatabaseResource -from datasette.utils.sqlite import sqlite3 -from datasette.utils import StartupError, await_me_maybe -from jinja2 import ChoiceLoader, FileSystemLoader -import base64 -import datetime -import importlib -import json -import os -import pathlib -import re -import textwrap +from .fixtures import ( # noqa + app_client, +) import pytest -import urllib - -at_memory_re = re.compile(r" at 0x\w+") -@pytest.mark.parametrize( - "plugin_hook", [name for name in dir(pm.hook) if not name.startswith("_")] -) -def test_plugin_hooks_have_tests(plugin_hook): - """Every plugin hook should be referenced in this test module""" - tests_in_this_module = [t for t in globals().keys() if t.startswith("test_hook_")] - ok = False - for test in tests_in_this_module: - if plugin_hook in test: - ok = True - assert ok, f"Plugin hook is missing tests: {plugin_hook}" - - -def test_hook_jump_items_sql(): - # Detailed behavior is covered in tests/test_jump.py. - assert "jump_items_sql" in dir(pm.hook) - - -@pytest.mark.asyncio -async def test_hook_plugins_dir_plugin_prepare_connection(ds_client): - response = await ds_client.get( - "/fixtures/-/query.json?_shape=arrayfirst&sql=select+convert_units(100%2C+'m'%2C+'ft')" +def test_plugins_dir_plugin(app_client): + response = app_client.get( + "/test_tables.json?sql=select+convert_units(100%2C+'m'%2C+'ft')" ) - assert response.json()[0] == pytest.approx(328.0839) + assert pytest.approx(328.0839) == response.json['rows'][0][0] -@pytest.mark.asyncio -async def test_hook_plugin_prepare_connection_arguments(ds_client): - response = await ds_client.get( - "/fixtures/-/query.json?sql=select+prepare_connection_args()&_shape=arrayfirst" - ) +def test_plugin_extra_css_urls(app_client): + response = app_client.get('/') + links = Soup(response.body, 'html.parser').findAll('link') assert [ - "database=fixtures, datasette.plugin_config(\"name-of-plugin\")={'depth': 'root'}" - ] == response.json() - - # Function should not be available on the internal database - db = ds_client.ds.get_internal_database() - with pytest.raises(sqlite3.OperationalError): - await db.execute("select prepare_connection_args()") + l for l in links + if l.attrs == { + 'rel': ['stylesheet'], + 'href': 'https://example.com/app.css' + } + ] -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_decoded_object", - [ - ( - "/", - { - "template": "index.html", - "database": None, - "table": None, - "view_name": "index", - "request_path": "/", - "added": 15, - "columns": None, - }, - ), - ( - "/fixtures", - { - "template": "database.html", - "database": "fixtures", - "table": None, - "view_name": "database", - "request_path": "/fixtures", - "added": 15, - "columns": None, - }, - ), - ( - "/fixtures/sortable", - { - "template": "table.html", - "database": "fixtures", - "table": "sortable", - "view_name": "table", - "request_path": "/fixtures/sortable", - "added": 15, - "columns": [ - "pk1", - "pk2", - "content", - "sortable", - "sortable_with_nulls", - "sortable_with_nulls_2", - "text", - ], - }, - ), - ], -) -async def test_hook_extra_css_urls(ds_client, path, expected_decoded_object): - response = await ds_client.get(path) - assert response.status_code == 200 - links = Soup(response.text, "html.parser").find_all("link") - special_href = [ - link - for link in links - if link.attrs["href"].endswith("/extra-css-urls-demo.css") - ][0]["href"] - # This link has a base64-encoded JSON blob in it - encoded = special_href.split("/")[3] - actual_decoded_object = json.loads(base64.b64decode(encoded).decode("utf8")) - assert expected_decoded_object == actual_decoded_object +def test_plugin_extra_js_urls(app_client): + response = app_client.get('/') + scripts = Soup(response.body, 'html.parser').findAll('script') + assert [ + s for s in scripts + if s.attrs == { + 'integrity': 'SRIHASH', + 'crossorigin': 'anonymous', + 'src': 'https://example.com/jquery.js' + } + ] -@pytest.mark.asyncio -async def test_hook_extra_js_urls(ds_client): - response = await ds_client.get("/") - scripts = Soup(response.text, "html.parser").find_all("script") - script_attrs = [s.attrs for s in scripts] - for attrs in [ - { - "integrity": "SRIHASH", - "crossorigin": "anonymous", - "src": "https://plugin-example.datasette.io/jquery.js", - }, - { - "src": "https://plugin-example.datasette.io/plugin.module.js", - "type": "module", - }, - ]: - assert any(s == attrs for s in script_attrs), "Expected: {}".format(attrs) - - -@pytest.mark.asyncio -async def test_plugins_with_duplicate_js_urls(ds_client): +def test_plugins_with_duplicate_js_urls(app_client): # If two plugins both require jQuery, jQuery should be loaded only once - response = await ds_client.get("/fixtures") + response = app_client.get( + "/test_tables" + ) # This test is a little tricky, as if the user has any other plugins in # their current virtual environment those may affect what comes back too. - # What matters is that https://plugin-example.datasette.io/jquery.js is only there once + # What matters is that https://example.com/jquery.js is only there once # and it comes before plugin1.js and plugin2.js which could be in either # order - scripts = Soup(response.text, "html.parser").find_all("script") - srcs = [s["src"] for s in scripts if s.get("src")] + scripts = Soup(response.body, 'html.parser').findAll('script') + srcs = [s['src'] for s in scripts if s.get('src')] # No duplicates allowed: assert len(srcs) == len(set(srcs)) # jquery.js loaded once: - assert 1 == srcs.count("https://plugin-example.datasette.io/jquery.js") + assert 1 == srcs.count('https://example.com/jquery.js') # plugin1.js and plugin2.js are both there: - assert 1 == srcs.count("https://plugin-example.datasette.io/plugin1.js") - assert 1 == srcs.count("https://plugin-example.datasette.io/plugin2.js") + assert 1 == srcs.count('https://example.com/plugin1.js') + assert 1 == srcs.count('https://example.com/plugin2.js') # jquery comes before them both - assert srcs.index("https://plugin-example.datasette.io/jquery.js") < srcs.index( - "https://plugin-example.datasette.io/plugin1.js" + assert srcs.index( + 'https://example.com/jquery.js' + ) < srcs.index( + 'https://example.com/plugin1.js' ) - assert srcs.index("https://plugin-example.datasette.io/jquery.js") < srcs.index( - "https://plugin-example.datasette.io/plugin2.js" + assert srcs.index( + 'https://example.com/jquery.js' + ) < srcs.index( + 'https://example.com/plugin2.js' ) - - -@pytest.mark.asyncio -async def test_hook_render_cell_link_from_json(ds_client): - sql = """ - select '{"href": "http://example.com/", "label":"Example"}' - """.strip() - path = "/fixtures/-/query?" + urllib.parse.urlencode({"sql": sql}) - response = await ds_client.get(path) - td = Soup(response.text, "html.parser").find("table").find("tbody").find("td") - a = td.find("a") - assert a is not None, str(a) - assert a.attrs["href"] == "http://example.com/" - assert a.attrs["data-database"] == "fixtures" - assert a.text == "Example" - - -@pytest.mark.asyncio -async def test_hook_render_cell_demo(ds_client): - response = await ds_client.get( - "/fixtures/simple_primary_key?id=4&_render_cell_extra=1" - ) - soup = Soup(response.text, "html.parser") - td = soup.find("td", {"class": "col-content"}) - assert json.loads(td.string) == { - "row": {"id": 4, "content": "RENDER_CELL_DEMO"}, - "column": "content", - "table": "simple_primary_key", - "database": "fixtures", - "pks": ["id"], - "config": {"depth": "table", "special": "this-is-simple_primary_key"}, - "render_cell_extra": 1, - } - - -@pytest.mark.asyncio -async def test_hook_render_cell_pks_single_pk(ds_client): - """pks should be ["id"] for a table with a single primary key""" - response = await ds_client.get("/fixtures/simple_primary_key?id=4") - soup = Soup(response.text, "html.parser") - td = soup.find("td", {"class": "col-content"}) - data = json.loads(td.string) - assert data["pks"] == ["id"] - - -@pytest.mark.asyncio -async def test_hook_render_cell_pks_compound_pk(ds_client): - """pks should list all primary key columns for a compound pk table""" - response = await ds_client.get("/fixtures/compound_primary_key?pk1=d&pk2=e") - soup = Soup(response.text, "html.parser") - td = soup.find("td", {"class": "col-content"}) - data = json.loads(td.string) - assert data["pks"] == ["pk1", "pk2"] - - -@pytest.mark.asyncio -async def test_hook_render_cell_pks_rowid_table(ds_client): - """pks should be ["rowid"] for a table with no explicit primary key""" - response = await ds_client.get("/fixtures/no_primary_key?content=RENDER_CELL_DEMO") - soup = Soup(response.text, "html.parser") - td = soup.find("td", {"class": "col-content"}) - data = json.loads(td.string) - assert data["pks"] == ["rowid"] - - -@pytest.mark.asyncio -async def test_hook_render_cell_pks_custom_sql(ds_client): - """pks should be [] for custom SQL queries""" - response = await ds_client.get( - "/fixtures/-/query?sql=select+'RENDER_CELL_DEMO'+as+content" - ) - soup = Soup(response.text, "html.parser") - td = soup.find("td", {"class": "col-content"}) - data = json.loads(td.string) - assert data["pks"] == [] - assert data["table"] is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path", - ( - "/fixtures/-/query?sql=select+'RENDER_CELL_ASYNC'", - "/fixtures/simple_primary_key", - ), -) -async def test_hook_render_cell_async(ds_client, path): - response = await ds_client.get(path) - assert b"RENDER_CELL_ASYNC_RESULT" in response.content - - -@pytest.mark.asyncio -async def test_plugin_config(ds_client): - assert {"depth": "table"} == ds_client.ds.plugin_config( - "name-of-plugin", database="fixtures", table="sortable" - ) - assert {"depth": "database"} == ds_client.ds.plugin_config( - "name-of-plugin", database="fixtures", table="unknown_table" - ) - assert {"depth": "database"} == ds_client.ds.plugin_config( - "name-of-plugin", database="fixtures" - ) - assert {"depth": "root"} == ds_client.ds.plugin_config( - "name-of-plugin", database="unknown_database" - ) - assert {"depth": "root"} == ds_client.ds.plugin_config("name-of-plugin") - assert None is ds_client.ds.plugin_config("unknown-plugin") - - -@pytest.mark.asyncio -async def test_plugin_config_env(ds_client, monkeypatch): - monkeypatch.setenv("FOO_ENV", "FROM_ENVIRONMENT") - assert ds_client.ds.plugin_config("env-plugin") == {"foo": "FROM_ENVIRONMENT"} - - -@pytest.mark.asyncio -async def test_plugin_config_env_from_config(monkeypatch): - monkeypatch.setenv("FOO_ENV", "FROM_ENVIRONMENT_2") - datasette = Datasette( - config={"plugins": {"env-plugin": {"setting": {"$env": "FOO_ENV"}}}} - ) - assert datasette.plugin_config("env-plugin") == {"setting": "FROM_ENVIRONMENT_2"} - - -@pytest.mark.asyncio -async def test_plugin_config_env_from_list(ds_client): - os.environ["FOO_ENV"] = "FROM_ENVIRONMENT" - assert [{"in_a_list": "FROM_ENVIRONMENT"}] == ds_client.ds.plugin_config( - "env-plugin-list" - ) - del os.environ["FOO_ENV"] - - -@pytest.mark.asyncio -async def test_plugin_config_file(ds_client): - with open(TEMP_PLUGIN_SECRET_FILE, "w") as fp: - fp.write("FROM_FILE") - assert {"foo": "FROM_FILE"} == ds_client.ds.plugin_config("file-plugin") - os.remove(TEMP_PLUGIN_SECRET_FILE) - - -@pytest.mark.parametrize( - "path,expected_extra_body_script", - [ - ( - "/", - { - "template": "index.html", - "database": None, - "table": None, - "config": {"depth": "root"}, - "view_name": "index", - "request_path": "/", - "added": 15, - "columns": None, - }, - ), - ( - "/fixtures", - { - "template": "database.html", - "database": "fixtures", - "table": None, - "config": {"depth": "database"}, - "view_name": "database", - "request_path": "/fixtures", - "added": 15, - "columns": None, - }, - ), - ( - "/fixtures/sortable", - { - "template": "table.html", - "database": "fixtures", - "table": "sortable", - "config": {"depth": "table"}, - "view_name": "table", - "request_path": "/fixtures/sortable", - "added": 15, - "columns": [ - "pk1", - "pk2", - "content", - "sortable", - "sortable_with_nulls", - "sortable_with_nulls_2", - "text", - ], - }, - ), - ], -) -def test_hook_extra_body_script(app_client, path, expected_extra_body_script): - r = re.compile(r"") - response = app_client.get(path) - assert response.status_code == 200, response.text - match = r.search(response.text) - assert match is not None, "No extra_body_script found in HTML" - json_data = match.group(1) - actual_data = json.loads(json_data) - assert expected_extra_body_script == actual_data - - -@pytest.mark.asyncio -async def test_hook_asgi_wrapper(ds_client): - response = await ds_client.get("/fixtures") - assert "fixtures" == response.headers["x-databases"] - - -def test_hook_extra_template_vars(restore_working_directory): - with make_app_client( - template_dir=str(pathlib.Path(__file__).parent / "test_templates") - ) as client: - response = client.get("/-/versions") - assert response.status_code == 200 - extra_template_vars = json.loads( - Soup(response.text, "html.parser").select("pre.extra_template_vars")[0].text - ) - assert { - "template": "show_json.html", - "scope_path": "/-/versions", - "columns": None, - } == extra_template_vars - extra_template_vars_from_awaitable = json.loads( - Soup(response.text, "html.parser") - .select("pre.extra_template_vars_from_awaitable")[0] - .text - ) - assert { - "template": "show_json.html", - "awaitable": True, - "scope_path": "/-/versions", - } == extra_template_vars_from_awaitable - - -def test_plugins_async_template_function(restore_working_directory): - with make_app_client( - template_dir=str(pathlib.Path(__file__).parent / "test_templates") - ) as client: - response = client.get("/-/versions") - assert response.status_code == 200 - extra_from_awaitable_function = ( - Soup(response.text, "html.parser") - .select("pre.extra_from_awaitable_function")[0] - .text - ) - conn = sqlite3.connect(":memory:") - expected = conn.execute("select sqlite_version()").fetchone()[0] - conn.close() - assert expected == extra_from_awaitable_function - - -def test_default_plugins_have_no_templates_path_or_static_path(): - # The default plugins that ship with Datasette should have their static_path and - # templates_path all set to None - plugins = get_plugins() - for plugin in plugins: - if plugin["name"] in DEFAULT_PLUGINS: - assert None is plugin["static_path"] - assert None is plugin["templates_path"] - - -@pytest.fixture(scope="session") -def view_names_client(tmp_path_factory): - tmpdir = tmp_path_factory.mktemp("test-view-names") - templates = tmpdir / "templates" - templates.mkdir() - plugins = tmpdir / "plugins" - plugins.mkdir() - for template in ( - "index.html", - "database.html", - "table.html", - "row.html", - "show_json.html", - "query.html", - ): - (templates / template).write_text("view_name:{{ view_name }}", "utf-8") - (plugins / "extra_vars.py").write_text( - textwrap.dedent(""" - from datasette import hookimpl - @hookimpl - def extra_template_vars(view_name): - return {"view_name": view_name} - """), - "utf-8", - ) - db_path = str(tmpdir / "fixtures.db") - conn = sqlite3.connect(db_path) - conn.executescript(TABLES) - conn.close() - return _TestClient( - Datasette([db_path], template_dir=str(templates), plugins_dir=str(plugins)) - ) - - -@pytest.mark.parametrize( - "path,view_name", - ( - ("/", "index"), - ("/fixtures", "database"), - ("/fixtures/facetable", "table"), - ("/fixtures/facetable/1", "row"), - ("/-/versions", "json_data"), - ("/fixtures/-/query?sql=select+1", "database"), - ), -) -def test_view_names(view_names_client, path, view_name): - response = view_names_client.get(path) - assert response.status_code == 200 - assert f"view_name:{view_name}" == response.text - - -@pytest.mark.asyncio -async def test_hook_register_output_renderer_no_parameters(ds_client): - response = await ds_client.get("/fixtures/facetable.testnone") - assert response.status_code == 200 - assert b"Hello" == response.content - - -@pytest.mark.asyncio -async def test_hook_register_output_renderer_all_parameters(ds_client): - response = await ds_client.get("/fixtures/facetable.testall") - assert response.status_code == 200 - # Lots of 'at 0x103a4a690' in here - replace those so we can do - # an easy comparison - body = at_memory_re.sub(" at 0xXXX", response.text) - assert json.loads(body) == { - "datasette": "", - "columns": [ - "pk", - "created", - "planet_int", - "on_earth", - "state", - "_city_id", - "_neighborhood", - "tags", - "complex_array", - "distinct_some_null", - "n", - ], - "rows": [ - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - ], - "sql": "select pk, created, planet_int, on_earth, state, _city_id, _neighborhood, tags, complex_array, distinct_some_null, n from facetable order by pk limit 51", - "query_name": None, - "database": "fixtures", - "table": "facetable", - "request": '', - "view_name": "table", - "1+1": 2, - } - - -@pytest.mark.asyncio -async def test_hook_register_output_renderer_custom_status_code(ds_client): - response = await ds_client.get( - "/fixtures/pragma_cache_size.testall?status_code=202" - ) - assert response.status_code == 202 - - -@pytest.mark.asyncio -async def test_hook_register_output_renderer_custom_content_type(ds_client): - response = await ds_client.get( - "/fixtures/pragma_cache_size.testall?content_type=text/blah" - ) - assert "text/blah" == response.headers["content-type"] - - -@pytest.mark.asyncio -async def test_hook_register_output_renderer_custom_headers(ds_client): - response = await ds_client.get( - "/fixtures/pragma_cache_size.testall?header=x-wow:1&header=x-gosh:2" - ) - assert "1" == response.headers["x-wow"] - assert "2" == response.headers["x-gosh"] - - -@pytest.mark.asyncio -async def test_hook_register_output_renderer_returning_response(ds_client): - response = await ds_client.get("/fixtures/facetable.testresponse") - assert response.status_code == 200 - assert response.json() == {"this_is": "json"} - - -@pytest.mark.asyncio -async def test_hook_register_output_renderer_returning_broken_value(ds_client): - response = await ds_client.get("/fixtures/facetable.testresponse?_broken=1") - assert response.status_code == 500 - assert "this should break should be dict or Response" in response.text - - -@pytest.mark.asyncio -async def test_hook_register_output_renderer_can_render(ds_client): - response = await ds_client.get("/fixtures/facetable?_no_can_render=1") - assert response.status_code == 200 - links = ( - Soup(response.text, "html.parser") - .find("p", {"class": "export-links"}) - .find_all("a") - ) - actual = [link["href"] for link in links] - # Should not be present because we sent ?_no_can_render=1 - assert "/fixtures/facetable.testall?_labels=on" not in actual - # Check that it was passed the values we expected - assert hasattr(ds_client.ds, "_can_render_saw") - assert { - "datasette": ds_client.ds, - "columns": [ - "pk", - "created", - "planet_int", - "on_earth", - "state", - "_city_id", - "_neighborhood", - "tags", - "complex_array", - "distinct_some_null", - "n", - ], - "sql": "select pk, created, planet_int, on_earth, state, _city_id, _neighborhood, tags, complex_array, distinct_some_null, n from facetable order by pk limit 51", - "query_name": None, - "database": "fixtures", - "table": "facetable", - "view_name": "table", - }.items() <= ds_client.ds._can_render_saw.items() - - -@pytest.mark.asyncio -async def test_hook_prepare_jinja2_environment(ds_client): - ds_client.ds._HELLO = "HI" - await ds_client.ds.invoke_startup() - environment = ds_client.ds.get_jinja_environment(None) - template = environment.from_string( - "Hello there, {{ a|format_numeric }}, {{ a|to_hello }}, {{ b|select_times_three }}", - {"a": 3412341, "b": 5}, - ) - rendered = await ds_client.ds.render_template(template) - assert "Hello there, 3,412,341, HI, 15" == rendered - - -def test_hook_publish_subcommand(): - # This is hard to test properly, because publish subcommand plugins - # cannot be loaded using the --plugins-dir mechanism - they need - # to be installed using "pip install". So I'm cheating and taking - # advantage of the fact that cloudrun/heroku use the plugin hook - # to register themselves as default plugins. - assert ["cloudrun", "heroku"] == cli.publish.list_commands({}) - - -@pytest.mark.asyncio -async def test_hook_register_facet_classes(ds_client): - response = await ds_client.get( - "/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets" - ) - assert response.json()["suggested_facets"] == [ - { - "name": "pk1", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet_dummy=pk1", - "type": "dummy", - }, - { - "name": "pk2", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet_dummy=pk2", - "type": "dummy", - }, - { - "name": "pk3", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet_dummy=pk3", - "type": "dummy", - }, - { - "name": "content", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet_dummy=content", - "type": "dummy", - }, - { - "name": "pk1", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet=pk1", - }, - { - "name": "pk2", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet=pk2", - }, - { - "name": "pk3", - "toggle_url": "http://localhost/fixtures/compound_three_primary_keys.json?_dummy_facet=1&_extra=suggested_facets&_facet=pk3", - }, - ] - - -@pytest.mark.asyncio -async def test_hook_actor_from_request(ds_client): - await ds_client.get("/") - # Should have no actor - assert ds_client.ds._last_request.scope["actor"] is None - await ds_client.get("/?_bot=1") - # Should have bot actor - assert ds_client.ds._last_request.scope["actor"] == {"id": "bot"} - - -@pytest.mark.asyncio -async def test_hook_actor_from_request_async(ds_client): - await ds_client.get("/") - # Should have no actor - assert ds_client.ds._last_request.scope["actor"] is None - await ds_client.get("/?_bot2=1") - # Should have bot2 actor - assert ds_client.ds._last_request.scope["actor"] == {"id": "bot2", "1+1": 2} - - -@pytest.mark.asyncio -async def test_existing_scope_actor_respected(ds_client): - await ds_client.get("/?_actor_in_scope=1") - assert ds_client.ds._last_request.scope["actor"] == {"id": "from-scope"} - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "action,expected", - [ - ("this_is_allowed", True), - ("this_is_denied", False), - ("this_is_allowed_async", True), - ("this_is_denied_async", False), - ], -) -async def test_hook_custom_allowed(action, expected): - # Test actions and permission logic are defined in tests/plugins/my_plugin.py - ds = Datasette(plugins_dir=PLUGINS_DIR) - await ds.invoke_startup() - actual = await ds.allowed(action=action, actor={"id": "actor"}) - assert expected == actual - - -@pytest.mark.asyncio -async def test_hook_permission_resources_sql(): - ds = Datasette() - await ds.invoke_startup() - - collected = [] - for block in ds.pm.hook.permission_resources_sql( - datasette=ds, - actor={"id": "alice"}, - action="view-table", - ): - block = await await_me_maybe(block) - if block is None: - continue - if isinstance(block, (list, tuple)): - collected.extend(block) - else: - collected.append(block) - - assert collected - assert all(isinstance(item, PermissionSQL) for item in collected) - - -@pytest.mark.asyncio -async def test_actor_json(ds_client): - assert (await ds_client.get("/-/actor.json")).json() == {"actor": None} - assert (await ds_client.get("/-/actor.json?_bot2=1")).json() == { - "actor": {"id": "bot2", "1+1": 2} - } - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,body", - [ - ("/one/", "2"), - ("/two/Ray?greeting=Hail", "Hail Ray"), - ("/not-async/", "This was not async"), - ], -) -async def test_hook_register_routes(ds_client, path, body): - response = await ds_client.get(path) - assert response.status_code == 200 - assert response.text == body - - -@pytest.mark.parametrize("configured_path", ("path1", "path2")) -def test_hook_register_routes_with_datasette(configured_path): - with make_app_client( - config={ - "plugins": { - "register-route-demo": { - "path": configured_path, - } - } - } - ) as client: - response = client.get(f"/{configured_path}/") - assert response.status_code == 200 - assert configured_path.upper() == response.text - # Other one should 404 - other_path = [p for p in ("path1", "path2") if configured_path != p][0] - assert client.get(f"/{other_path}/", follow_redirects=True).status_code == 404 - - -def test_hook_register_routes_override(): - "Plugins can over-ride default paths such as /db/table" - with make_app_client( - config={ - "plugins": { - "register-route-demo": { - "path": "blah", - } - } - } - ) as client: - response = client.get("/db/table") - assert response.status_code == 200 - assert ( - response.text - == "/db/table: [('db_name', 'db'), ('table_and_format', 'table')]" - ) - - -def test_hook_register_routes_post(app_client): - response = app_client.post("/post/", {"this is": "post data"}) - assert response.status_code == 200 - assert response.json["this is"] == "post data" - - -def test_hook_register_routes_csrftoken(restore_working_directory, tmpdir_factory): - # csrftoken() is a legacy compatibility shim that returns a - # per-request random value - it is no longer used for CSRF enforcement. - templates = tmpdir_factory.mktemp("templates") - (templates / "csrftoken_form.html").write_text( - "CSRFTOKEN:{{ csrftoken() }}:END", "utf-8" - ) - with make_app_client(template_dir=templates) as client: - response = client.get("/csrftoken-form/") - assert response.text.startswith("CSRFTOKEN:") - assert response.text.endswith(":END") - token = response.text[len("CSRFTOKEN:") : -len(":END")] - assert len(token) >= 20 - assert "ds_csrftoken" not in response.cookies - - -@pytest.mark.asyncio -async def test_hook_register_routes_asgi(ds_client): - response = await ds_client.get("/three/") - assert {"hello": "world"} == response.json() - assert "1" == response.headers["x-three"] - - -@pytest.mark.asyncio -async def test_hook_register_routes_add_message(ds_client): - response = await ds_client.get("/add-message/") - assert response.status_code == 200 - assert response.text == "Added message" - decoded = ds_client.ds.unsign(response.cookies["ds_messages"], "messages") - assert decoded == [["Hello from messages", 1]] - - -def test_hook_register_routes_render_message(restore_working_directory, tmpdir_factory): - templates = tmpdir_factory.mktemp("templates") - (templates / "render_message.html").write_text('{% extends "base.html" %}', "utf-8") - with make_app_client(template_dir=templates) as client: - response1 = client.get("/add-message/") - response2 = client.get("/render-message/", cookies=response1.cookies) - assert 200 == response2.status - assert "Hello from messages" in response2.text - - -@pytest.mark.asyncio -async def test_hook_startup(ds_client): - await ds_client.ds.invoke_startup() - assert ds_client.ds._startup_hook_fired - assert 2 == ds_client.ds._startup_hook_calculation - - -@pytest.mark.asyncio -async def test_hook_startup_metadata_available(ds_client): - # Metadata from metadata.yaml should be populated before startup() fires - assert "title" in ds_client.ds._startup_metadata_keys - - -@pytest.mark.asyncio -async def test_hook_startup_catalog_populated(ds_client): - # Internal catalog tables should be populated before startup() fires - assert "fixtures" in ds_client.ds._startup_catalog_databases - - -@pytest.mark.asyncio -async def test_plugin_startup_can_add_queries(): - ds = Datasette(memory=True) - ds.add_memory_database("plugin_startup_queries", name="data") - - class AddQueriesPlugin: - __name__ = "AddQueriesPlugin" - - @hookimpl - def startup(self, datasette): - async def inner(): - result = await datasette.get_database("data").execute("select 1 + 1") - await datasette.add_query( - "data", - "from_startup", - "select {}".format(result.first()[0]), - source="plugin", - ) - - return inner - - ds.pm.register(AddQueriesPlugin(), name="add_queries_plugin") - try: - response = await ds.client.get("/data.json") - finally: - ds.pm.unregister(name="add_queries_plugin") - - queries = response.json()["queries"] - queries_by_name = {q["name"]: q for q in queries} - assert queries_by_name["from_startup"]["sql"] == "select 2" - assert queries_by_name["from_startup"]["private"] is False - - -@pytest.mark.asyncio -async def test_plugin_startup_query_can_execute(): - ds = Datasette(memory=True) - ds.add_memory_database("plugin_startup_query_execute", name="data") - - class AddQueryPlugin: - __name__ = "AddQueryPlugin" - - @hookimpl - def startup(self, datasette): - async def inner(): - await datasette.add_query( - "data", "from_startup", "select 2", source="plugin" - ) - - return inner - - ds.pm.register(AddQueryPlugin(), name="add_query_plugin") - try: - response = await ds.client.get("/data/from_startup.json?_shape=array") - finally: - ds.pm.unregister(name="add_query_plugin") - - assert [{"2": 2}] == response.json() - - -def test_hook_register_magic_parameters(restore_working_directory): - with make_app_client( - extra_databases={"data.db": "create table logs (line text)"}, - config={ - "databases": { - "data": { - "queries": { - "runme": { - "sql": "insert into logs (line) values (:_request_http_version)", - "write": True, - }, - "get_uuid": { - "sql": "select :_uuid_new", - }, - "asyncrequest": { - "sql": "select :_asyncrequest_key", - }, - } - } - } - }, - ) as client: - response = client.post("/data/runme", {}, csrftoken_from=True) - assert response.status_code == 302 - actual = client.get("/data/logs.json?_sort_desc=rowid&_shape=array").json - assert [{"rowid": 1, "line": "1.1"}] == actual - # Now try the GET request against get_uuid - response_get = client.get("/data/get_uuid.json?_shape=array") - assert 200 == response_get.status - new_uuid = response_get.json[0][":_uuid_new"] - assert 4 == new_uuid.count("-") - # And test the async one - response_async = client.get("/data/asyncrequest.json?_shape=array") - assert 200 == response_async.status - assert response_async.json[0][":_asyncrequest_key"] == "key" - - -def test_hook_forbidden(restore_working_directory): - with make_app_client( - extra_databases={"data2.db": "create table logs (line text)"}, - config={"allow": {}}, - ) as client: - response = client.get("/") - assert response.status_code == 403 - response2 = client.get("/data2") - assert 302 == response2.status - assert ( - response2.headers["Location"] - == "/login?message=You do not have permission to view this database" - ) - assert ( - client.ds._last_forbidden_message - == "You do not have permission to view this database" - ) - - -@pytest.mark.asyncio -async def test_hook_handle_exception(ds_client): - await ds_client.get("/trigger-error?x=123") - assert hasattr(ds_client.ds, "_exception_hook_fired") - request, exception = ds_client.ds._exception_hook_fired - assert request.url == "http://localhost/trigger-error?x=123" - assert isinstance(exception, ZeroDivisionError) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("param", ("_custom_error", "_custom_error_async")) -async def test_hook_handle_exception_custom_response(ds_client, param): - response = await ds_client.get("/trigger-error?{}=1".format(param)) - assert response.text == param - - -@pytest.mark.asyncio -async def test_hook_menu_links(ds_client): - def get_menu_links(html): - soup = Soup(html, "html.parser") - return [ - {"label": a.text, "href": a["href"]} for a in soup.select(".nav-menu a") - ] - - response = await ds_client.get("/") - assert get_menu_links(response.text) == [] - - response_2 = await ds_client.get("/?_bot=1&_hello=BOB") - assert get_menu_links(response_2.text) == [ - {"label": "Hello, BOB", "href": "/"}, - {"label": "Hello 2", "href": "/"}, - ] - - -@pytest.mark.asyncio -async def test_hook_table_actions(ds_client): - response = await ds_client.get("/fixtures/facetable") - assert get_actions_links(response.text) == [] - response_2 = await ds_client.get("/fixtures/facetable?_bot=1&_hello=BOB") - assert ">Table actions<" in response_2.text - assert sorted( - get_actions_links(response_2.text), key=lambda link: link["label"] - ) == [ - {"label": "Database: fixtures", "href": "/", "description": None}, - {"label": "From async BOB", "href": "/", "description": None}, - {"label": "Table: facetable", "href": "/", "description": None}, - ] - - -@pytest.mark.asyncio -async def test_hook_view_actions(ds_client): - response = await ds_client.get("/fixtures/simple_view") - assert get_actions_links(response.text) == [] - response_2 = await ds_client.get( - "/fixtures/simple_view", - actor={"id": "bob"}, - ) - assert ">View actions<" in response_2.text - assert sorted( - get_actions_links(response_2.text), key=lambda link: link["label"] - ) == [ - {"label": "Database: fixtures", "href": "/", "description": None}, - {"label": "View: simple_view", "href": "/", "description": None}, - ] - - -def get_actions_links(html): - soup = Soup(html, "html.parser") - details = soup.find("details", {"class": "actions-menu-links"}) - if details is None: - return [] - links = [] - for a_el in details.select("a"): - description = None - if a_el.find("p") is not None: - description = a_el.find("p").text.strip() - a_el.find("p").extract() - label = a_el.text.strip() - href = a_el["href"] - links.append({"label": label, "href": href, "description": description}) - return links - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_url", - ( - ("/fixtures/-/query?sql=select+1", "/fixtures/-/query?sql=explain+select+1"), - pytest.param( - "/fixtures/pragma_cache_size", - "/fixtures/-/query?sql=explain+PRAGMA+cache_size%3B", - ), - # Don't attempt to explain an explain - ("/fixtures/-/query?sql=explain+select+1", None), - ), -) -async def test_hook_query_actions(ds_client, path, expected_url): - response = await ds_client.get(path) - assert response.status_code == 200 - links = get_actions_links(response.text) - if expected_url is None: - assert links == [] - else: - assert links == [ - { - "label": "Explain this query", - "href": expected_url, - "description": "Runs a SQLite explain", - } - ] - - -@pytest.mark.asyncio -async def test_hook_row_actions(ds_client): - response = await ds_client.get("/fixtures/facet_cities/1") - assert get_actions_links(response.text) == [] - - response_2 = await ds_client.get( - "/fixtures/facet_cities/1", - actor={"id": "sam"}, - ) - assert get_actions_links(response_2.text) == [ - { - "label": "Row details for sam", - "href": "/", - "description": '{"id": 1, "name": "San Francisco"}', - } - ] - - -@pytest.mark.asyncio -async def test_hook_database_actions(ds_client): - response = await ds_client.get("/fixtures") - assert get_actions_links(response.text) == [] - - response_2 = await ds_client.get("/fixtures?_bot=1&_hello=BOB") - assert get_actions_links(response_2.text) == [ - {"label": "Database: fixtures - BOB", "href": "/", "description": None}, - ] - - -@pytest.mark.asyncio -async def test_hook_homepage_actions(ds_client): - response = await ds_client.get("/") - # No button for anonymous users - assert "Homepage actions" not in response.text - # Signed in user gets an action - response2 = await ds_client.get("/", actor={"id": "troy"}) - assert "Homepage actions" in response2.text - assert get_actions_links(response2.text) == [ - { - "label": "Custom homepage for: troy", - "href": "/-/custom-homepage", - "description": None, - }, - ] - - -def _extract_commands(output): - lines = output.split("Commands:\n", 1)[1].split("\n") - return {line.split()[0].replace("*", "") for line in lines if line.strip()} - - -def test_hook_register_commands(): - # Without the plugin should have seven commands - runner = CliRunner() - result = runner.invoke(cli.cli, "--help") - commands = _extract_commands(result.output) - assert commands == { - "serve", - "inspect", - "install", - "package", - "plugins", - "publish", - "uninstall", - "create-token", - } - - # Now install a plugin - class VerifyPlugin: - __name__ = "VerifyPlugin" - - @hookimpl - def register_commands(self, cli): - @cli.command() - def verify(): - pass - - @cli.command() - def unverify(): - pass - - pm.register(VerifyPlugin(), name="verify") - importlib.reload(cli) - result2 = runner.invoke(cli.cli, "--help") - commands2 = _extract_commands(result2.output) - assert commands2 == { - "serve", - "inspect", - "install", - "package", - "plugins", - "publish", - "uninstall", - "verify", - "unverify", - "create-token", - } - pm.unregister(name="verify") - importlib.reload(cli) - - -@pytest.mark.asyncio -async def test_hook_filters_from_request(ds_client): - class ReturnNothingPlugin: - __name__ = "ReturnNothingPlugin" - - @hookimpl - def filters_from_request(self, request): - if request.args.get("_nothing"): - return FilterArguments(["1 = 0"], human_descriptions=["NOTHING"]) - - ds_client.ds.pm.register(ReturnNothingPlugin(), name="ReturnNothingPlugin") - response = await ds_client.get("/fixtures/facetable?_nothing=1") - assert "0 rows\n where NOTHING" in response.text - json_response = await ds_client.get("/fixtures/facetable.json?_nothing=1") - assert json_response.json()["rows"] == [] - ds_client.ds.pm.unregister(name="ReturnNothingPlugin") - - -@pytest.mark.asyncio -@pytest.mark.parametrize("extra_metadata", (False, True)) -async def test_hook_register_actions(extra_metadata): - - ds = Datasette( - config=( - { - "plugins": { - "datasette-register-actions": { - "actions": [ - { - "name": "extra-from-metadata", - "abbr": "efm", - "description": "Extra from metadata", - } - ] - } - } - } - if extra_metadata - else None - ), - plugins_dir=PLUGINS_DIR, - ) - await ds.invoke_startup() - assert ds.actions["action-from-plugin"] == Action( - name="action-from-plugin", - abbr="ap", - description="New action added by a plugin", - resource_class=DatabaseResource, - ) - if extra_metadata: - assert ds.actions["extra-from-metadata"] == Action( - name="extra-from-metadata", - abbr="efm", - description="Extra from metadata", - ) - else: - assert "extra-from-metadata" not in ds.actions - - -@pytest.mark.asyncio -@pytest.mark.parametrize("duplicate", ("name", "abbr")) -async def test_hook_register_actions_no_duplicates(duplicate): - name1, name2 = "name1", "name2" - abbr1, abbr2 = "abbr1", "abbr2" - if duplicate == "name": - name2 = "name1" - if duplicate == "abbr": - abbr2 = "abbr1" - ds = Datasette( - config={ - "plugins": { - "datasette-register-actions": { - "actions": [ - { - "name": name1, - "abbr": abbr1, - "description": None, - }, - { - "name": name2, - "abbr": abbr2, - "description": None, - }, - ] - } - } - }, - plugins_dir=PLUGINS_DIR, - ) - # This should error: - with pytest.raises(StartupError) as ex: - await ds.invoke_startup() - assert "Duplicate action {}".format(duplicate) in str(ex.value) - - -@pytest.mark.asyncio -async def test_hook_register_actions_allows_identical_duplicates(): - ds = Datasette( - config={ - "plugins": { - "datasette-register-actions": { - "actions": [ - { - "name": "name1", - "abbr": "abbr1", - "description": None, - }, - { - "name": "name1", - "abbr": "abbr1", - "description": None, - }, - ] - } - } - }, - plugins_dir=PLUGINS_DIR, - ) - await ds.invoke_startup() - # Check that ds.actions has only one of each - assert len([p for p in ds.actions.values() if p.abbr == "abbr1"]) == 1 - - -@pytest.mark.asyncio -async def test_hook_actors_from_ids(): - # Without the hook should return default {"id": id} list - ds = Datasette() - await ds.invoke_startup() - db = ds.add_memory_database("actors_from_ids") - await db.execute_write( - "create table actors (id text primary key, name text, age int)" - ) - await db.execute_write( - "insert into actors (id, name, age) values ('3', 'Cate Blanchett', 52)" - ) - await db.execute_write( - "insert into actors (id, name, age) values ('5', 'Rooney Mara', 36)" - ) - await db.execute_write( - "insert into actors (id, name, age) values ('7', 'Sarah Paulson', 46)" - ) - await db.execute_write( - "insert into actors (id, name, age) values ('9', 'Helena Bonham Carter', 55)" - ) - table_names = await db.table_names() - assert table_names == ["actors"] - actors1 = await ds.actors_from_ids(["3", "5", "7"]) - assert actors1 == { - "3": {"id": "3"}, - "5": {"id": "5"}, - "7": {"id": "7"}, - } - - class ActorsFromIdsPlugin: - __name__ = "ActorsFromIdsPlugin" - - @hookimpl - def actors_from_ids(self, datasette, actor_ids): - db = datasette.get_database("actors_from_ids") - - async def inner(): - sql = "select id, name from actors where id in ({})".format( - ", ".join("?" for _ in actor_ids) - ) - actors = {} - result = await db.execute(sql, actor_ids) - for row in result.rows: - actor = dict(row) - actors[actor["id"]] = actor - return actors - - return inner - - try: - ds.pm.register(ActorsFromIdsPlugin(), name="ActorsFromIdsPlugin") - actors2 = await ds.actors_from_ids(["3", "5", "7"]) - assert actors2 == { - "3": {"id": "3", "name": "Cate Blanchett"}, - "5": {"id": "5", "name": "Rooney Mara"}, - "7": {"id": "7", "name": "Sarah Paulson"}, - } - finally: - ds.pm.unregister(name="ReturnNothingPlugin") - - -@pytest.mark.asyncio -async def test_plugin_is_installed(): - datasette = Datasette(memory=True) - - class DummyPlugin: - __name__ = "DummyPlugin" - - @hookimpl - def actors_from_ids(self, datasette, actor_ids): - return {} - - try: - datasette.pm.register(DummyPlugin(), name="DummyPlugin") - response = await datasette.client.get("/-/plugins.json") - assert response.status_code == 200 - installed_plugins = {p["name"] for p in response.json()} - assert "DummyPlugin" in installed_plugins - - finally: - datasette.pm.unregister(name="DummyPlugin") - - -@pytest.mark.asyncio -async def test_hook_jinja2_environment_from_request(tmpdir): - templates = pathlib.Path(tmpdir / "templates") - templates.mkdir() - (templates / "index.html").write_text("Hello museums!", "utf-8") - - class EnvironmentPlugin: - @hookimpl - def jinja2_environment_from_request(self, request, env): - if request and request.host == "www.niche-museums.com": - return env.overlay( - loader=ChoiceLoader( - [ - FileSystemLoader(str(templates)), - env.loader, - ] - ), - enable_async=True, - ) - return env - - datasette = Datasette(memory=True) - - try: - datasette.pm.register(EnvironmentPlugin(), name="EnvironmentPlugin") - response = await datasette.client.get("/") - assert response.status_code == 200 - assert "Hello museums!" not in response.text - # Try again with the hostname - response2 = await datasette.client.get( - "/", headers={"host": "www.niche-museums.com"} - ) - assert response2.status_code == 200 - assert "Hello museums!" in response2.text - finally: - datasette.pm.unregister(name="EnvironmentPlugin") - - -class SlotPlugin: - __name__ = "SlotPlugin" - - @hookimpl - def top_homepage(self, request): - return "Xtop_homepage:" + request.args["z"] - - @hookimpl - def top_database(self, request, database): - async def inner(): - return "Xtop_database:{}:{}".format(database, request.args["z"]) - - return inner - - @hookimpl - def top_table(self, request, database, table): - return "Xtop_table:{}:{}:{}".format(database, table, request.args["z"]) - - @hookimpl - def top_row(self, request, database, table, row): - return "Xtop_row:{}:{}:{}:{}".format( - database, table, row["name"], request.args["z"] - ) - - @hookimpl - def top_query(self, request, database, sql): - return "Xtop_query:{}:{}:{}".format(database, sql, request.args["z"]) - - @hookimpl - def top_stored_query(self, request, database, query_name): - return "Xtop_stored_query:{}:{}:{}".format( - database, query_name, request.args["z"] - ) - - -@pytest.mark.asyncio -async def test_hook_top_homepage(): - datasette = Datasette(memory=True) - try: - datasette.pm.register(SlotPlugin(), name="SlotPlugin") - response = await datasette.client.get("/?z=foo") - assert response.status_code == 200 - assert "Xtop_homepage:foo" in response.text - finally: - datasette.pm.unregister(name="SlotPlugin") - - -@pytest.mark.asyncio -async def test_hook_top_database(): - datasette = Datasette(memory=True) - try: - datasette.pm.register(SlotPlugin(), name="SlotPlugin") - response = await datasette.client.get("/_memory?z=bar") - assert response.status_code == 200 - assert "Xtop_database:_memory:bar" in response.text - finally: - datasette.pm.unregister(name="SlotPlugin") - - -@pytest.mark.asyncio -async def test_hook_top_table(ds_client): - try: - ds_client.ds.pm.register(SlotPlugin(), name="SlotPlugin") - response = await ds_client.get("/fixtures/facetable?z=baz") - assert response.status_code == 200 - assert "Xtop_table:fixtures:facetable:baz" in response.text - finally: - ds_client.ds.pm.unregister(name="SlotPlugin") - - -@pytest.mark.asyncio -async def test_hook_top_row(ds_client): - try: - ds_client.ds.pm.register(SlotPlugin(), name="SlotPlugin") - response = await ds_client.get("/fixtures/facet_cities/1?z=bax") - assert response.status_code == 200 - assert "Xtop_row:fixtures:facet_cities:San Francisco:bax" in response.text - finally: - ds_client.ds.pm.unregister(name="SlotPlugin") - - -@pytest.mark.asyncio -async def test_hook_top_query(ds_client): - try: - pm.register(SlotPlugin(), name="SlotPlugin") - response = await ds_client.get("/fixtures/-/query?sql=select+1&z=x") - assert response.status_code == 200 - assert "Xtop_query:fixtures:select 1:x" in response.text - finally: - pm.unregister(name="SlotPlugin") - - -@pytest.mark.asyncio -async def test_hook_top_stored_query(ds_client): - try: - pm.register(SlotPlugin(), name="SlotPlugin") - response = await ds_client.get("/fixtures/magic_parameters?z=xyz") - assert response.status_code == 200 - assert "Xtop_stored_query:fixtures:magic_parameters:xyz" in response.text - finally: - pm.unregister(name="SlotPlugin") - - -@pytest.mark.asyncio -async def test_hook_track_event(): - datasette = Datasette(memory=True) - from .conftest import TrackEventPlugin - - await datasette.invoke_startup() - await datasette.track_event( - TrackEventPlugin.OneEvent(actor=None, extra="extra extra") - ) - assert len(datasette._tracked_events) == 1 - assert isinstance(datasette._tracked_events[0], TrackEventPlugin.OneEvent) - event = datasette._tracked_events[0] - assert event.name == "one" - assert event.properties() == {"extra": "extra extra"} - # Should have a recent created as well - created = event.created - assert isinstance(created, datetime.datetime) - assert created.tzinfo == datetime.timezone.utc - - -@pytest.mark.asyncio -async def test_hook_register_events(): - datasette = Datasette(memory=True) - await datasette.invoke_startup() - assert any(k.__name__ == "OneEvent" for k in datasette.event_classes) - - -@pytest.mark.asyncio -async def test_hook_register_token_handler(ds_client): - handlers = ds_client.ds._token_handlers() - handler_names = [h.name for h in handlers] - # Both the default signed handler and the test hardcoded handler - assert "signed" in handler_names - assert "hardcoded" in handler_names - - # Create a token using the hardcoded handler (first registered from plugins dir) - token = await ds_client.ds.create_token("test-user") - assert token.startswith("dstok_hardcoded_token_") - - # Verify it - actor = await ds_client.ds.verify_token(token) - assert actor["id"] == "hardcoded-actor" - assert actor["token"] == "hardcoded" - - # Create a token by explicitly requesting the hardcoded handler by name - token2 = await ds_client.ds.create_token("test-user", handler="hardcoded") - assert token2.startswith("dstok_hardcoded_token_") - actor2 = await ds_client.ds.verify_token(token2) - assert actor2["id"] == "hardcoded-actor" - - # Create a token by explicitly requesting the signed handler by name - signed_token = await ds_client.ds.create_token("test-user", handler="signed") - assert signed_token.startswith("dstok_") - assert not signed_token.startswith("dstok_hardcoded_token_") - signed_actor = await ds_client.ds.verify_token(signed_token) - assert signed_actor["id"] == "test-user" - assert signed_actor["token"] == "dstok" - - -@pytest.mark.asyncio -async def test_hook_write_wrapper(): - datasette = Datasette(memory=True) - log = [] - - class WrapWritePlugin: - __name__ = "WrapWritePlugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - if database != "_memory": - return None - - def wrapper(conn): - log.append("before") - yield - log.append("after") - - return wrapper - - pm.register(WrapWritePlugin(), name="WrapWritePluginTest") - try: - db = datasette.get_database("_memory") - await db.execute_write("create table t (id integer primary key)") - assert log == ["before", "after"] - finally: - pm.unregister(name="WrapWritePluginTest") - - -@pytest.mark.asyncio -async def test_hook_register_actions_view_collection(): - datasette = Datasette(memory=True, plugins_dir=PLUGINS_DIR) - await datasette.invoke_startup() - # Check that the custom action from my_plugin.py is registered - assert "view-collection" in datasette.actions - action = datasette.actions["view-collection"] - assert action.abbr == "vc" - assert action.description == "View a collection" - - -@pytest.mark.asyncio -async def test_hook_register_actions_with_custom_resources(): - """ - Test registering actions with custom Resource classes: - - A global action (no resource) - - A parent-level action (DocumentCollectionResource) - - A child-level action (DocumentResource) - """ - from datasette.permissions import Resource - - # Define custom Resource classes - class DocumentCollectionResource(Resource): - """A collection of documents.""" - - name = "document_collection" - parent_class = None # Top-level resource - - def __init__(self, collection: str): - super().__init__(parent=collection, child=None) - - @classmethod - async def resources_sql(cls, datasette, actor=None) -> str: - return """ - SELECT 'collection1' AS parent, NULL AS child - UNION ALL - SELECT 'collection2' AS parent, NULL AS child - """ - - class DocumentResource(Resource): - """A document in a collection.""" - - name = "document" - parent_class = DocumentCollectionResource # Child of DocumentCollectionResource - - def __init__(self, collection: str, document: str): - super().__init__(parent=collection, child=document) - - @classmethod - async def resources_sql(cls, datasette, actor=None) -> str: - return """ - SELECT 'collection1' AS parent, 'doc1' AS child - UNION ALL - SELECT 'collection1' AS parent, 'doc2' AS child - UNION ALL - SELECT 'collection2' AS parent, 'doc3' AS child - """ - - # Define a test plugin that registers these actions - class TestPlugin: - __name__ = "test_custom_resources_plugin" - - @hookimpl - def register_actions(self, datasette): - return [ - # Global action - no resource_class - Action( - name="manage-documents", - abbr="md", - description="Manage the document system", - ), - # Parent-level action - collection only - Action( - name="view-document-collection", - description="View a document collection", - resource_class=DocumentCollectionResource, - ), - # Child-level action - collection + document - Action( - name="view-document", - abbr="vdoc", - description="View a document", - resource_class=DocumentResource, - ), - ] - - @hookimpl - def permission_resources_sql(self, datasette, actor, action): - from datasette.permissions import PermissionSQL - - # Grant user2 access to manage-documents globally - if actor and actor.get("id") == "user2" and action == "manage-documents": - return PermissionSQL.allow(reason="user2 granted manage-documents") - - # Grant user2 access to view-document-collection globally - if ( - actor - and actor.get("id") == "user2" - and action == "view-document-collection" - ): - return PermissionSQL.allow( - reason="user2 granted view-document-collection" - ) - - # Default allow for view-document-collection (like other view-* actions) - if action == "view-document-collection": - return PermissionSQL.allow( - reason="default allow for view-document-collection" - ) - - # Default allow for view-document (like other view-* actions) - if action == "view-document": - return PermissionSQL.allow(reason="default allow for view-document") - - # Register the plugin temporarily - plugin = TestPlugin() - pm.register(plugin, name="test_custom_resources_plugin") - - try: - # Create datasette instance and invoke startup - datasette = Datasette(memory=True) - await datasette.invoke_startup() - - # Test global action - manage_docs = datasette.actions["manage-documents"] - assert manage_docs.name == "manage-documents" - assert manage_docs.abbr == "md" - assert manage_docs.resource_class is None - assert manage_docs.takes_parent is False - assert manage_docs.takes_child is False - - # Test parent-level action - view_collection = datasette.actions["view-document-collection"] - assert view_collection.name == "view-document-collection" - assert view_collection.abbr is None - assert view_collection.resource_class is DocumentCollectionResource - assert view_collection.takes_parent is True - assert view_collection.takes_child is False - - # Test child-level action - view_doc = datasette.actions["view-document"] - assert view_doc.name == "view-document" - assert view_doc.abbr == "vdoc" - assert view_doc.resource_class is DocumentResource - assert view_doc.takes_parent is True - assert view_doc.takes_child is True - - # Verify the resource classes have correct hierarchy - assert DocumentCollectionResource.parent_class is None - assert DocumentResource.parent_class is DocumentCollectionResource - - # Test that resources can be instantiated correctly - collection_resource = DocumentCollectionResource(collection="collection1") - assert collection_resource.parent == "collection1" - assert collection_resource.child is None - - doc_resource = DocumentResource(collection="collection1", document="doc1") - assert doc_resource.parent == "collection1" - assert doc_resource.child == "doc1" - - # Test permission checks with restricted actors - - # Test 1: Global action - no restrictions (custom actions default to deny) - unrestricted_actor = {"id": "user1"} - allowed = await datasette.allowed( - action="manage-documents", - actor=unrestricted_actor, - ) - assert allowed is False # Custom actions have no default allow - - # Test 2: Global action - user2 has explicit permission via plugin hook - restricted_global = {"id": "user2", "_r": {"a": ["md"]}} - allowed = await datasette.allowed( - action="manage-documents", - actor=restricted_global, - ) - assert allowed is True # Granted by plugin hook for user2 - - # Test 3: Global action - restricted but not in allowlist - restricted_no_access = {"id": "user3", "_r": {"a": ["vdc"]}} - allowed = await datasette.allowed( - action="manage-documents", - actor=restricted_no_access, - ) - assert allowed is False # Not in allowlist - - # Test 4: Collection-level action - allowed for specific collection - collection_resource = DocumentCollectionResource(collection="collection1") - # This one does not have an abbreviation: - restricted_collection = { - "id": "user4", - "_r": {"d": {"collection1": ["view-document-collection"]}}, - } - allowed = await datasette.allowed( - action="view-document-collection", - resource=collection_resource, - actor=restricted_collection, - ) - assert allowed is True # Allowed for collection1 - - # Test 5: Collection-level action - denied for different collection - collection2_resource = DocumentCollectionResource(collection="collection2") - allowed = await datasette.allowed( - action="view-document-collection", - resource=collection2_resource, - actor=restricted_collection, - ) - assert allowed is False # Not allowed for collection2 - - # Test 6: Document-level action - allowed for specific document - doc1_resource = DocumentResource(collection="collection1", document="doc1") - restricted_document = { - "id": "user5", - "_r": {"r": {"collection1": {"doc1": ["vdoc"]}}}, - } - allowed = await datasette.allowed( - action="view-document", - resource=doc1_resource, - actor=restricted_document, - ) - assert allowed is True # Allowed for collection1/doc1 - - # Test 7: Document-level action - denied for different document - doc2_resource = DocumentResource(collection="collection1", document="doc2") - allowed = await datasette.allowed( - action="view-document", - resource=doc2_resource, - actor=restricted_document, - ) - assert allowed is False # Not allowed for collection1/doc2 - - # Test 8: Document-level action - globally allowed - doc_resource = DocumentResource(collection="collection2", document="doc3") - restricted_all_docs = {"id": "user6", "_r": {"a": ["vdoc"]}} - allowed = await datasette.allowed( - action="view-document", - resource=doc_resource, - actor=restricted_all_docs, - ) - assert allowed is True # Globally allowed for all documents - - # Test 9: Verify hierarchy - collection access doesn't grant document access - collection_only_actor = {"id": "user7", "_r": {"d": {"collection1": ["vdc"]}}} - doc_resource = DocumentResource(collection="collection1", document="doc1") - allowed = await datasette.allowed( - action="view-document", - resource=doc_resource, - actor=collection_only_actor, - ) - assert ( - allowed is False - ) # Collection permission doesn't grant document permission - - finally: - # Unregister the plugin - pm.unregister(plugin) - - -@pytest.mark.skip(reason="TODO") -@pytest.mark.parametrize( - "metadata,config,expected_metadata,expected_config", - ( - ( - # Instance level - {"plugins": {"datasette-foo": "bar"}}, - {}, - {}, - {"plugins": {"datasette-foo": "bar"}}, - ), - ( - # Database level - {"databases": {"foo": {"plugins": {"datasette-foo": "bar"}}}}, - {}, - {}, - {"databases": {"foo": {"plugins": {"datasette-foo": "bar"}}}}, - ), - ( - # Table level - { - "databases": { - "foo": {"tables": {"bar": {"plugins": {"datasette-foo": "bar"}}}} - } - }, - {}, - {}, - { - "databases": { - "foo": {"tables": {"bar": {"plugins": {"datasette-foo": "bar"}}}} - } - }, - ), - ( - # Keep other keys - {"plugins": {"datasette-foo": "bar"}, "other": "key"}, - {"original_config": "original"}, - {"other": "key"}, - {"original_config": "original", "plugins": {"datasette-foo": "bar"}}, - ), - ), -) -def test_metadata_plugin_config_treated_as_config( - metadata, config, expected_metadata, expected_config -): - ds = Datasette(metadata=metadata, config=config) - actual_metadata = ds.metadata() - assert "plugins" not in actual_metadata - assert actual_metadata == expected_metadata - assert ds.config == expected_config - - -@pytest.mark.asyncio -async def test_hook_register_column_types(): - ds = Datasette() - await ds.invoke_startup() - # Built-in column types should be registered - assert "url" in ds._column_types - assert "email" in ds._column_types - assert "json" in ds._column_types - assert "nonexistent" not in ds._column_types diff --git a/tests/test_publish_cloudrun.py b/tests/test_publish_cloudrun.py deleted file mode 100644 index 6617bc77..00000000 --- a/tests/test_publish_cloudrun.py +++ /dev/null @@ -1,402 +0,0 @@ -from click.testing import CliRunner -from datasette import cli -from unittest import mock -import json -import os -import pytest -import textwrap - - -@pytest.mark.serial -@mock.patch("shutil.which") -def test_publish_cloudrun_requires_gcloud(mock_which, tmp_path_factory): - mock_which.return_value = False - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke(cli.cli, ["publish", "cloudrun", "test.db"]) - assert result.exit_code == 1 - assert "Publishing to Google Cloud requires gcloud" in result.output - - -@mock.patch("shutil.which") -def test_publish_cloudrun_invalid_database(mock_which): - mock_which.return_value = True - runner = CliRunner() - result = runner.invoke(cli.cli, ["publish", "cloudrun", "woop.db"]) - assert result.exit_code == 2 - assert "Path 'woop.db' does not exist" in result.output - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.cloudrun.check_output") -@mock.patch("datasette.publish.cloudrun.check_call") -@mock.patch("datasette.publish.cloudrun.get_existing_services") -def test_publish_cloudrun_prompts_for_service( - mock_get_existing_services, mock_call, mock_output, mock_which, tmp_path_factory -): - mock_get_existing_services.return_value = [ - {"name": "existing", "created": "2019-01-01", "url": "http://www.example.com/"} - ] - mock_output.return_value = "myproject" - mock_which.return_value = True - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke( - cli.cli, ["publish", "cloudrun", "test.db"], input="input-service" - ) - assert ( - "Please provide a service name for this deployment\n\n" - "Using an existing service name will over-write it\n\n" - "Your existing services:\n\n" - " existing - created 2019-01-01 - http://www.example.com/\n\n" - "Service name: input-service" - ) == result.output.strip() - assert 0 == result.exit_code - tag = "us-docker.pkg.dev/myproject/datasette/datasette-input-service" - mock_call.assert_has_calls( - [ - mock.call( - "gcloud services enable artifactregistry.googleapis.com --project myproject --quiet", - shell=True, - ), - mock.call( - "gcloud artifacts repositories describe datasette --project myproject --location us --quiet", - shell=True, - ), - mock.call(f"gcloud builds submit --tag {tag}", shell=True), - mock.call( - "gcloud run deploy --allow-unauthenticated --platform=managed --image {} input-service --max-instances 1".format( - tag - ), - shell=True, - ), - ] - ) - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.cloudrun.check_output") -@mock.patch("datasette.publish.cloudrun.check_call") -def test_publish_cloudrun(mock_call, mock_output, mock_which, tmp_path_factory): - mock_output.return_value = "myproject" - mock_which.return_value = True - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke( - cli.cli, ["publish", "cloudrun", "test.db", "--service", "test"] - ) - assert 0 == result.exit_code - tag = f"us-docker.pkg.dev/{mock_output.return_value}/datasette/datasette-test" - mock_call.assert_has_calls( - [ - mock.call( - f"gcloud services enable artifactregistry.googleapis.com --project {mock_output.return_value} --quiet", - shell=True, - ), - mock.call( - f"gcloud artifacts repositories describe datasette --project {mock_output.return_value} --location us --quiet", - shell=True, - ), - mock.call(f"gcloud builds submit --tag {tag}", shell=True), - mock.call( - "gcloud run deploy --allow-unauthenticated --platform=managed --image {} test --max-instances 1".format( - tag - ), - shell=True, - ), - ] - ) - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.cloudrun.check_output") -@mock.patch("datasette.publish.cloudrun.check_call") -@pytest.mark.parametrize( - "memory,cpu,timeout,min_instances,max_instances,expected_gcloud_args", - [ - ["1Gi", None, None, None, None, "--memory 1Gi"], - ["2G", None, None, None, None, "--memory 2G"], - ["256Mi", None, None, None, None, "--memory 256Mi"], - [ - "4", - None, - None, - None, - None, - None, - ], - [ - "GB", - None, - None, - None, - None, - None, - ], - [None, 1, None, None, None, "--cpu 1"], - [None, 2, None, None, None, "--cpu 2"], - [None, 3, None, None, None, None], - [None, 4, None, None, None, "--cpu 4"], - ["2G", 4, None, None, None, "--memory 2G --cpu 4"], - [None, None, 1800, None, None, "--timeout 1800"], - [None, None, None, 2, None, "--min-instances 2"], - [None, None, None, 2, 4, "--min-instances 2 --max-instances 4"], - [None, 2, None, None, 4, "--cpu 2 --max-instances 4"], - ], -) -def test_publish_cloudrun_memory_cpu( - mock_call, - mock_output, - mock_which, - memory, - cpu, - timeout, - min_instances, - max_instances, - expected_gcloud_args, - tmp_path_factory, -): - mock_output.return_value = "myproject" - mock_which.return_value = True - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - args = ["publish", "cloudrun", "test.db", "--service", "test"] - if memory: - args.extend(["--memory", memory]) - if cpu: - args.extend(["--cpu", str(cpu)]) - if timeout: - args.extend(["--timeout", str(timeout)]) - result = runner.invoke(cli.cli, args) - if expected_gcloud_args is None: - assert 2 == result.exit_code - return - assert 0 == result.exit_code - tag = f"us-docker.pkg.dev/{mock_output.return_value}/datasette/datasette-test" - expected_call = ( - "gcloud run deploy --allow-unauthenticated --platform=managed" - " --image {} test".format(tag) - ) - expected_build_call = f"gcloud builds submit --tag {tag}" - if memory: - expected_call += " --memory {}".format(memory) - if cpu: - expected_call += " --cpu {}".format(cpu) - if timeout: - expected_build_call += f" --timeout {timeout}" - # max_instances defaults to 1 - expected_call += " --max-instances 1" - mock_call.assert_has_calls( - [ - mock.call( - f"gcloud services enable artifactregistry.googleapis.com --project {mock_output.return_value} --quiet", - shell=True, - ), - mock.call( - f"gcloud artifacts repositories describe datasette --project {mock_output.return_value} --location us --quiet", - shell=True, - ), - mock.call(expected_build_call, shell=True), - mock.call( - expected_call, - shell=True, - ), - ] - ) - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.cloudrun.check_output") -@mock.patch("datasette.publish.cloudrun.check_call") -def test_publish_cloudrun_plugin_secrets( - mock_call, mock_output, mock_which, tmp_path_factory -): - mock_which.return_value = True - mock_output.return_value = "myproject" - - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - with open("metadata.yml", "w") as fp: - fp.write(textwrap.dedent(""" - title: Hello from metadata YAML - plugins: - datasette-auth-github: - foo: bar - """).strip()) - result = runner.invoke( - cli.cli, - [ - "publish", - "cloudrun", - "test.db", - "--metadata", - "metadata.yml", - "--service", - "datasette", - "--plugin-secret", - "datasette-auth-github", - "client_id", - "x-client-id", - "--show-files", - "--secret", - "x-secret", - ], - ) - assert result.exit_code == 0 - dockerfile = ( - result.output.split("==== Dockerfile ====\n")[1] - .split("\n====================\n")[0] - .strip() - ) - expected = textwrap.dedent( - r""" - FROM python:3.11.0-slim-bullseye - COPY . /app - WORKDIR /app - - ENV DATASETTE_AUTH_GITHUB_CLIENT_ID 'x-client-id' - ENV DATASETTE_SECRET 'x-secret' - RUN pip install -U datasette - RUN datasette inspect test.db --inspect-file inspect-data.json - ENV PORT 8001 - EXPOSE 8001 - CMD datasette serve --host 0.0.0.0 -i test.db --cors --inspect-file inspect-data.json --metadata metadata.json --setting force_https_urls on --port $PORT""" - ).strip() - assert expected == dockerfile - metadata = ( - result.output.split("=== metadata.json ===\n")[1] - .split("\n==== Dockerfile ====\n")[0] - .strip() - ) - assert { - "title": "Hello from metadata YAML", - "plugins": { - "datasette-auth-github": { - "client_id": {"$env": "DATASETTE_AUTH_GITHUB_CLIENT_ID"}, - "foo": "bar", - }, - }, - } == json.loads(metadata) - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.cloudrun.check_output") -@mock.patch("datasette.publish.cloudrun.check_call") -def test_publish_cloudrun_apt_get_install( - mock_call, mock_output, mock_which, tmp_path_factory -): - mock_which.return_value = True - mock_output.return_value = "myproject" - - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke( - cli.cli, - [ - "publish", - "cloudrun", - "test.db", - "--service", - "datasette", - "--show-files", - "--secret", - "x-secret", - "--apt-get-install", - "ripgrep", - "--spatialite", - ], - ) - assert result.exit_code == 0 - dockerfile = ( - result.output.split("==== Dockerfile ====\n")[1] - .split("\n====================\n")[0] - .strip() - ) - expected = textwrap.dedent(r""" - FROM python:3.11.0-slim-bullseye - COPY . /app - WORKDIR /app - - RUN apt-get update && \ - apt-get install -y ripgrep python3-dev gcc libsqlite3-mod-spatialite && \ - rm -rf /var/lib/apt/lists/* - - ENV DATASETTE_SECRET 'x-secret' - ENV SQLITE_EXTENSIONS '/usr/lib/x86_64-linux-gnu/mod_spatialite.so' - RUN pip install -U datasette - RUN datasette inspect test.db --inspect-file inspect-data.json - ENV PORT 8001 - EXPOSE 8001 - CMD datasette serve --host 0.0.0.0 -i test.db --cors --inspect-file inspect-data.json --setting force_https_urls on --port $PORT - """).strip() - assert expected == dockerfile - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.cloudrun.check_output") -@mock.patch("datasette.publish.cloudrun.check_call") -@pytest.mark.parametrize( - "extra_options,expected", - [ - ("", "--setting force_https_urls on"), - ( - "--setting base_url /foo", - "--setting base_url /foo --setting force_https_urls on", - ), - ("--setting force_https_urls off", "--setting force_https_urls off"), - ], -) -def test_publish_cloudrun_extra_options( - mock_call, mock_output, mock_which, extra_options, expected, tmp_path_factory -): - mock_which.return_value = True - mock_output.return_value = "myproject" - - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke( - cli.cli, - [ - "publish", - "cloudrun", - "test.db", - "--service", - "datasette", - "--show-files", - "--extra-options", - extra_options, - ], - ) - assert result.exit_code == 0 - dockerfile = ( - result.output.split("==== Dockerfile ====\n")[1] - .split("\n====================\n")[0] - .strip() - ) - last_line = dockerfile.split("\n")[-1] - extra_options = ( - last_line.split("--inspect-file inspect-data.json")[1] - .split("--port")[0] - .strip() - ) - assert extra_options == expected diff --git a/tests/test_publish_heroku.py b/tests/test_publish_heroku.py deleted file mode 100644 index cab83654..00000000 --- a/tests/test_publish_heroku.py +++ /dev/null @@ -1,183 +0,0 @@ -from click.testing import CliRunner -from datasette import cli -from unittest import mock -import os -import pathlib -import pytest - - -@pytest.mark.serial -@mock.patch("shutil.which") -def test_publish_heroku_requires_heroku(mock_which, tmp_path_factory): - mock_which.return_value = False - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke(cli.cli, ["publish", "heroku", "test.db"]) - assert result.exit_code == 1 - assert "Publishing to Heroku requires heroku" in result.output - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.heroku.check_output") -@mock.patch("datasette.publish.heroku.call") -def test_publish_heroku_installs_plugin( - mock_call, mock_check_output, mock_which, tmp_path_factory -): - mock_which.return_value = True - mock_check_output.side_effect = lambda s: {"['heroku', 'plugins']": b""}[repr(s)] - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("t.db", "w") as fp: - fp.write("data") - result = runner.invoke(cli.cli, ["publish", "heroku", "t.db"], input="y\n") - assert 0 != result.exit_code - mock_check_output.assert_has_calls( - [mock.call(["heroku", "plugins"]), mock.call(["heroku", "apps:list", "--json"])] - ) - mock_call.assert_has_calls( - [mock.call(["heroku", "plugins:install", "heroku-builds"])] - ) - - -@mock.patch("shutil.which") -def test_publish_heroku_invalid_database(mock_which): - mock_which.return_value = True - runner = CliRunner() - result = runner.invoke(cli.cli, ["publish", "heroku", "woop.db"]) - assert result.exit_code == 2 - assert "Path 'woop.db' does not exist" in result.output - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.heroku.check_output") -@mock.patch("datasette.publish.heroku.call") -def test_publish_heroku(mock_call, mock_check_output, mock_which, tmp_path_factory): - mock_which.return_value = True - mock_check_output.side_effect = lambda s: { - "['heroku', 'plugins']": b"heroku-builds", - "['heroku', 'apps:list', '--json']": b"[]", - "['heroku', 'apps:create', 'datasette', '--json']": b'{"name": "f"}', - }[repr(s)] - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke(cli.cli, ["publish", "heroku", "test.db", "--tar", "gtar"]) - assert 0 == result.exit_code, result.output - mock_call.assert_has_calls( - [ - mock.call( - [ - "heroku", - "builds:create", - "-a", - "f", - "--include-vcs-ignore", - "--tar", - "gtar", - ] - ), - ] - ) - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.heroku.check_output") -@mock.patch("datasette.publish.heroku.call") -def test_publish_heroku_plugin_secrets( - mock_call, mock_check_output, mock_which, tmp_path_factory -): - mock_which.return_value = True - mock_check_output.side_effect = lambda s: { - "['heroku', 'plugins']": b"heroku-builds", - "['heroku', 'apps:list', '--json']": b"[]", - "['heroku', 'apps:create', 'datasette', '--json']": b'{"name": "f"}', - }[repr(s)] - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - result = runner.invoke( - cli.cli, - [ - "publish", - "heroku", - "test.db", - "--plugin-secret", - "datasette-auth-github", - "client_id", - "x-client-id", - ], - ) - assert 0 == result.exit_code, result.output - mock_call.assert_has_calls( - [ - mock.call( - [ - "heroku", - "config:set", - "-a", - "f", - "DATASETTE_AUTH_GITHUB_CLIENT_ID=x-client-id", - ] - ), - mock.call(["heroku", "builds:create", "-a", "f", "--include-vcs-ignore"]), - ] - ) - - -@pytest.mark.serial -@mock.patch("shutil.which") -@mock.patch("datasette.publish.heroku.check_output") -@mock.patch("datasette.publish.heroku.call") -def test_publish_heroku_generate_dir( - mock_call, mock_check_output, mock_which, tmp_path_factory -): - mock_which.return_value = True - mock_check_output.side_effect = lambda s: { - "['heroku', 'plugins']": b"heroku-builds", - }[repr(s)] - runner = CliRunner() - os.chdir(tmp_path_factory.mktemp("runner")) - with open("test.db", "w") as fp: - fp.write("data") - output = str(tmp_path_factory.mktemp("generate_dir") / "output") - result = runner.invoke( - cli.cli, - [ - "publish", - "heroku", - "test.db", - "--generate-dir", - output, - ], - ) - assert result.exit_code == 0 - path = pathlib.Path(output) - assert path.exists() - file_names = {str(r.relative_to(path)) for r in path.glob("*")} - assert file_names == { - "requirements.txt", - "bin", - "runtime.txt", - "Procfile", - "test.db", - } - for name, expected in ( - ("requirements.txt", "datasette"), - ("runtime.txt", "python-3.11.0"), - ( - "Procfile", - ( - "web: datasette serve --host 0.0.0.0 -i test.db " - "--cors --port $PORT --inspect-file inspect-data.json" - ), - ), - ): - with open(path / name) as fp: - assert fp.read().strip() == expected diff --git a/tests/test_pytest_autoclose_plugin.py b/tests/test_pytest_autoclose_plugin.py deleted file mode 100644 index 3af1aace..00000000 --- a/tests/test_pytest_autoclose_plugin.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -Tests for datasette._pytest_plugin — the pytest plugin that auto-closes -Datasette instances constructed inside test bodies. - -These tests drive a real pytest session in a subprocess so the plugin -operates exactly as it would for a downstream consumer. -""" - -import subprocess -import sys -import textwrap -from pathlib import Path - -REPO_ROOT = Path(__file__).parent.parent - - -def _run_pytest(tmp_path: Path) -> subprocess.CompletedProcess: - return subprocess.run( - [sys.executable, "-m", "pytest", "-v", str(tmp_path)], - cwd=str(tmp_path), - capture_output=True, - text=True, - ) - - -def test_auto_close_of_instances_made_in_test_body(tmp_path): - # Two ordered tests: - # test_a makes a Datasette() and stashes a hard reference - # test_b asserts that the hard-reffed instance was closed by the plugin - (tmp_path / "test_sample.py").write_text(textwrap.dedent(""" - from datasette.app import Datasette - - _stash = {} - - def test_a(): - ds = Datasette(memory=True) - _stash["ds"] = ds - assert ds._closed is False - - def test_b(): - assert _stash["ds"]._closed is True - """)) - result = _run_pytest(tmp_path) - assert result.returncode == 0, result.stdout + result.stderr - - -def test_fixture_scoped_instance_is_not_closed(tmp_path): - # A module-scoped fixture instance must survive across tests in the module. - (tmp_path / "test_fixture.py").write_text(textwrap.dedent(""" - import pytest - from datasette.app import Datasette - - @pytest.fixture(scope="module") - def ds(): - return Datasette(memory=True) - - def test_first(ds): - assert ds._closed is False - - def test_second(ds): - # Still alive because the plugin only tracks instances - # constructed during pytest_runtest_call, not during fixture - # setup. - assert ds._closed is False - """)) - result = _run_pytest(tmp_path) - assert result.returncode == 0, result.stdout + result.stderr - - -def test_opt_out_via_ini(tmp_path): - # datasette_autoclose = false should leave instances untouched. - (tmp_path / "pytest.ini").write_text(textwrap.dedent(""" - [pytest] - datasette_autoclose = false - """).strip()) - (tmp_path / "test_optout.py").write_text(textwrap.dedent(""" - from datasette.app import Datasette - - _stash = {} - - def test_a(): - ds = Datasette(memory=True) - _stash["ds"] = ds - - def test_b(): - # Opt-out: plugin must not have closed it. - assert _stash["ds"]._closed is False - _stash["ds"].close() - """)) - result = _run_pytest(tmp_path) - assert result.returncode == 0, result.stdout + result.stderr diff --git a/tests/test_queries.py b/tests/test_queries.py deleted file mode 100644 index 59fab8c0..00000000 --- a/tests/test_queries.py +++ /dev/null @@ -1,1874 +0,0 @@ -import json - -import pytest - -from datasette.app import Datasette -from datasette.resources import DatabaseResource, QueryResource -from datasette.stored_queries import StoredQuery, StoredQueryPage -from datasette.utils.asgi import Forbidden - - -async def add_numbered_queries(ds, database, count): - for i in range(1, count + 1): - await ds.add_query( - database, - "demo_query_{:02d}".format(i), - "select {} as query_number".format(i), - title="Demo query {:02d}".format(i), - description="Seeded demo query number {:02d}".format(i), - source="user", - owner_id="root", - ) - - -@pytest.mark.asyncio -async def test_queries_internal_table_schema(): - ds = Datasette(memory=True) - await ds.invoke_startup() - internal_db = ds.get_internal_database() - - columns = [ - row["name"] - for row in ( - await internal_db.execute("select name from pragma_table_info('queries')") - ) - ] - - assert columns == [ - "database_name", - "name", - "sql", - "title", - "description", - "description_html", - "options", - "parameters", - "is_write", - "is_private", - "is_trusted", - "source", - "owner_id", - "created_at", - "updated_at", - ] - - -@pytest.mark.asyncio -async def test_add_get_and_remove_query(): - ds = Datasette(memory=True) - ds.add_memory_database("query_api", name="data") - await ds.invoke_startup() - - await ds.add_query( - "data", - "top_customers", - "select * from customers where region = :region", - title="Top customers", - description="Customers by region", - hide_sql=True, - fragment="chart", - parameters=["region"], - is_trusted=True, - source="user", - owner_id="alice", - ) - - options_row = ( - await ds.get_internal_database().execute( - """ - SELECT options FROM queries - WHERE database_name = ? AND name = ? - """, - ["data", "top_customers"], - ) - ).first() - assert json.loads(options_row["options"]) == { - "fragment": "chart", - "hide_sql": True, - } - - query = await ds.get_query("data", "top_customers") - assert query == StoredQuery( - database="data", - name="top_customers", - sql="select * from customers where region = :region", - title="Top customers", - description="Customers by region", - description_html=None, - hide_sql=True, - fragment="chart", - parameters=["region"], - is_write=False, - is_private=False, - is_trusted=True, - source="user", - owner_id="alice", - on_success_message=None, - on_success_message_sql=None, - on_success_redirect=None, - on_error_message=None, - on_error_redirect=None, - ) - - queries_page = await ds.list_queries("data", actor=None) - assert queries_page == StoredQueryPage( - queries=[query], - next=None, - has_more=False, - limit=50, - ) - - await ds.remove_query("data", "top_customers") - assert await ds.get_query("data", "top_customers") is None - queries_page = await ds.list_queries("data", actor=None) - assert queries_page.queries == [] - assert queries_page.next is None - - -@pytest.mark.asyncio -async def test_update_query_only_updates_provided_fields(): - ds = Datasette(memory=True) - ds.add_memory_database("query_api_update", name="data") - await ds.invoke_startup() - - await ds.add_query( - "data", - "redirect", - "select 1", - title="Original", - on_success_redirect="/original", - parameters=["one"], - ) - - options_row = ( - await ds.get_internal_database().execute( - """ - SELECT options FROM queries - WHERE database_name = ? AND name = ? - """, - ["data", "redirect"], - ) - ).first() - assert json.loads(options_row["options"]) == {"on_success_redirect": "/original"} - - await ds.update_query( - "data", - "redirect", - title="Updated", - parameters=[], - on_success_redirect=None, - ) - - query = await ds.get_query("data", "redirect") - assert query.title == "Updated" - assert query.parameters == [] - assert query.on_success_redirect is None - assert query.sql == "select 1" - assert query.is_private is False - assert query.is_trusted is False - options_row = ( - await ds.get_internal_database().execute( - """ - SELECT options FROM queries - WHERE database_name = ? AND name = ? - """, - ["data", "redirect"], - ) - ).first() - assert json.loads(options_row["options"]) == {} - - -@pytest.mark.asyncio -async def test_config_queries_imported_to_internal_table(): - ds = Datasette( - memory=True, - config={ - "databases": { - "data": { - "queries": { - "configured": { - "sql": "select :name as name", - "title": "Configured query", - "description_html": "

Configured HTML

", - "params": ["name"], - "on_success_message_sql": "select 'Hello ' || :name", - } - } - } - } - }, - ) - ds.add_memory_database("query_config", name="data") - await ds.invoke_startup() - - assert await ds.get_query("data", "configured") == StoredQuery( - database="data", - name="configured", - sql="select :name as name", - title="Configured query", - description=None, - description_html="

Configured HTML

", - hide_sql=False, - fragment=None, - parameters=["name"], - is_write=False, - is_private=False, - is_trusted=True, - source="config", - owner_id=None, - on_success_message=None, - on_success_message_sql="select 'Hello ' || :name", - on_success_redirect=None, - on_error_message=None, - on_error_redirect=None, - ) - - -@pytest.mark.asyncio -async def test_query_resources_come_from_internal_table(): - ds = Datasette(memory=True) - ds.add_memory_database("query_resources", name="data") - await ds.invoke_startup() - await ds.add_query("data", "internal_query", "select 1", source="user") - - page = await ds.allowed_resources("view-query", actor=None) - - assert [(r.parent, r.child) for r in page.resources] == [("data", "internal_query")] - - -@pytest.mark.asyncio -async def test_default_deny_blocks_view_query_even_for_trusted_query(): - ds = Datasette(memory=True, default_deny=True) - ds.add_memory_database("query_permissions", name="data") - await ds.invoke_startup() - await ds.add_query("data", "trusted", "select 1", is_trusted=True) - - assert not await ds.allowed( - action="view-query", - resource=QueryResource("data", "trusted"), - actor=None, - ) - - -@pytest.mark.asyncio -async def test_view_query_default_allow_still_respects_private_restriction(): - ds = Datasette(memory=True) - ds.add_memory_database("default_view_query_permissions", name="data") - await ds.invoke_startup() - await ds.add_query( - "data", - "private_report", - "select 1", - is_private=True, - source="user", - owner_id="alice", - ) - await ds.add_query( - "data", - "shared_report", - "select 2", - is_private=False, - source="user", - owner_id="alice", - ) - - assert await ds.allowed( - action="view-query", - resource=QueryResource("data", "shared_report"), - actor=None, - ) - assert await ds.allowed( - action="view-query", - resource=QueryResource("data", "private_report"), - actor={"id": "alice"}, - ) - assert not await ds.allowed( - action="view-query", - resource=QueryResource("data", "private_report"), - actor={"id": "bob"}, - ) - - -@pytest.mark.asyncio -async def test_private_query_restriction_blocks_broad_view_query_permission(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "view-query": {"id": "*"}, - } - } - } - }, - ) - ds.add_memory_database("private_query_permissions", name="data") - await ds.invoke_startup() - await ds.add_query( - "data", - "private_report", - "select 1", - is_private=True, - source="user", - owner_id="alice", - ) - await ds.add_query( - "data", - "shared_report", - "select 2", - is_private=False, - source="user", - owner_id="alice", - ) - - assert await ds.allowed( - action="view-query", - resource=QueryResource("data", "private_report"), - actor={"id": "alice"}, - ) - assert not await ds.allowed( - action="view-query", - resource=QueryResource("data", "private_report"), - actor={"id": "bob"}, - ) - assert await ds.allowed( - action="view-query", - resource=QueryResource("data", "shared_report"), - actor={"id": "bob"}, - ) - - -@pytest.mark.asyncio -async def test_config_query_restriction_does_not_override_private_internal_query(): - ds = Datasette(memory=True, default_deny=True) - ds.add_memory_database("private_query_with_config_name", name="data") - await ds.invoke_startup() - await ds.add_query( - "data", - "private_report", - "select 1", - is_private=True, - source="user", - owner_id="alice", - ) - ds.config = { - "databases": { - "data": { - "permissions": {"view-query": {"id": "*"}}, - "queries": {"private_report": {"sql": "select 2"}}, - } - } - } - - assert not await ds.allowed( - action="view-query", - resource=QueryResource("data", "private_report"), - actor={"id": "bob"}, - ) - - -@pytest.mark.asyncio -async def test_untrusted_shared_query_execution_requires_execute_sql(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "view-database": {"id": "viewer"}, - "view-query": {"id": "viewer"}, - } - } - } - }, - ) - ds.add_memory_database("untrusted_query_execution", name="data") - await ds.invoke_startup() - await ds.add_query( - "data", - "shared_report", - "select 1 as one", - is_private=False, - is_trusted=False, - source="user", - owner_id="alice", - ) - - denied_get = await ds.client.get("/data/shared_report.json", actor={"id": "viewer"}) - denied_post = await ds.client.post( - "/data/shared_report", - actor={"id": "viewer"}, - data={}, - ) - assert denied_get.status_code == 403 - assert denied_post.status_code == 403 - - ds.config["databases"]["data"]["permissions"]["execute-sql"] = {"id": "viewer"} - allowed = await ds.client.get("/data/shared_report.json", actor={"id": "viewer"}) - assert allowed.status_code == 200 - assert allowed.json()["rows"] == [{"one": 1}] - - -@pytest.mark.asyncio -async def test_config_queries_are_trusted_by_default_but_can_opt_out(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "view-query": {"id": "viewer"}, - }, - "queries": { - "trusted_report": {"sql": "select 1 as one"}, - "untrusted_report": { - "sql": "select 2 as two", - "is_trusted": False, - }, - }, - } - } - }, - ) - ds.add_memory_database("trusted_query_config", name="data") - await ds.invoke_startup() - - trusted = await ds.client.get("/data/trusted_report.json", actor={"id": "viewer"}) - untrusted = await ds.client.get( - "/data/untrusted_report.json", actor={"id": "viewer"} - ) - - assert trusted.status_code == 200 - assert trusted.json()["rows"] == [{"one": 1}] - assert untrusted.status_code == 403 - - -@pytest.mark.asyncio -async def test_database_page_query_preview_is_limited(): - ds = Datasette(memory=True) - ds.add_memory_database("query_preview", name="data") - await ds.invoke_startup() - await add_numbered_queries(ds, "data", 25) - - html_response = await ds.client.get("/data") - json_response = await ds.client.get("/data.json") - - assert html_response.status_code == 200 - assert "Demo query 05" in html_response.text - assert "Demo query 06" not in html_response.text - assert 'View 25 queries' in html_response.text - assert len(json_response.json()["queries"]) == 5 - assert json_response.json()["queries_more"] is True - assert json_response.json()["queries_count"] == 25 - - -@pytest.mark.asyncio -async def test_query_actions_are_registered(): - ds = Datasette() - await ds.invoke_startup() - - assert ds.get_action("execute-write-sql").resource_class is DatabaseResource - assert ds.get_action("store-query").resource_class is DatabaseResource - assert ds.get_action("update-query").resource_class is QueryResource - assert ds.get_action("delete-query").resource_class is QueryResource - - -@pytest.mark.asyncio -async def test_analyze_write_query_requires_table_permissions(): - ds = Datasette(memory=True, default_deny=True) - db = ds.add_memory_database("query_write_permissions", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - - actor = {"id": "writer"} - await ds.add_query( - "data", - "write_dog", - "insert into dogs (name) values (:name)", - is_write=True, - source="user", - owner_id="writer", - ) - - with pytest.raises(Forbidden): - await ds.ensure_query_write_permissions( - "data", - "insert into dogs (name) values (:name)", - actor=actor, - ) - - ds.config = { - "databases": { - "data": { - "tables": { - "dogs": { - "permissions": { - "insert-row": {"id": "writer"}, - } - } - } - } - } - } - - await ds.ensure_query_write_permissions( - "data", - "insert into dogs (name) values (:name)", - actor=actor, - ) - - -@pytest.mark.asyncio -async def test_analyze_write_query_rejects_writes_to_attached_databases(): - ds = Datasette(memory=True, default_deny=True) - db = ds.add_memory_database("query_attached_writes", name="data") - await db.execute_write("attach database ':memory:' as extra") - await db.execute_write("create table extra.cats (id integer primary key)") - await ds.invoke_startup() - - with pytest.raises(Forbidden): - await ds.ensure_query_write_permissions( - "data", - "insert into extra.cats (id) values (1)", - actor={"id": "writer"}, - ) - - -@pytest.mark.asyncio -async def test_query_store_api_creates_read_only_query(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - db = ds.add_memory_database("query_store_api", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - - response = await ds.client.post( - "/data/-/queries/store", - actor={"id": "root"}, - json={ - "query": { - "name": "by_name", - "sql": "select * from dogs where name = :name", - "title": "By name", - } - }, - ) - - assert response.status_code == 201 - data = response.json() - assert data["ok"] is True - assert data["query"]["name"] == "by_name" - assert data["query"]["parameters"] == ["name"] - assert data["query"]["is_write"] is False - assert data["query"]["source"] == "user" - assert data["query"]["owner_id"] == "root" - - -@pytest.mark.asyncio -async def test_query_list_and_definition_api(): - ds = Datasette(memory=True) - ds.root_enabled = True - ds.add_memory_database("query_list_api", name="data") - await ds.invoke_startup() - await add_numbered_queries(ds, "data", 12) - - list_response = await ds.client.get( - "/data/-/queries.json?_size=5", - actor={"id": "root"}, - ) - next_response = await ds.client.get( - "/data/-/queries.json?_size=5&_next={}".format(list_response.json()["next"]), - actor={"id": "root"}, - ) - definition_response = await ds.client.get( - "/data/demo_query_01/-/definition", - actor={"id": "root"}, - ) - - assert list_response.status_code == 200 - assert [query["name"] for query in list_response.json()["queries"]] == [ - "demo_query_01", - "demo_query_02", - "demo_query_03", - "demo_query_04", - "demo_query_05", - ] - assert list_response.json()["next"] - assert [query["name"] for query in next_response.json()["queries"]] == [ - "demo_query_06", - "demo_query_07", - "demo_query_08", - "demo_query_09", - "demo_query_10", - ] - assert definition_response.status_code == 200 - assert definition_response.json()["query"]["title"] == "Demo query 01" - - -@pytest.mark.asyncio -async def test_query_page_does_not_show_internal_source(): - ds = Datasette(memory=True) - ds.add_memory_database("query_page_source", name="data") - await ds.invoke_startup() - await ds.add_query( - "data", - "stored_report", - "select 1 as one", - title="Stored report", - source="user", - owner_id="root", - ) - - response = await ds.client.get("/data/stored_report", actor={"id": "root"}) - - assert response.status_code == 200 - assert "Stored report" in response.text - assert "Data source:" not in response.text - - -@pytest.mark.asyncio -async def test_query_list_search_filter_and_html(): - ds = Datasette(memory=True) - ds.root_enabled = True - ds.add_memory_database("query_list_html", name="data") - await ds.invoke_startup() - await add_numbered_queries(ds, "data", 3) - await ds.add_query( - "data", - "private_query", - "select 'private'", - title="Private query", - is_private=True, - source="user", - owner_id="root", - ) - await ds.add_query( - "data", - "trusted_query", - "select 'trusted'", - title="Trusted query", - is_trusted=True, - source="config", - ) - await ds.add_query( - "data", - "writable_query", - "insert into dogs (name) values (:name)", - title="Writable query", - is_write=True, - source="user", - owner_id="root", - ) - - html_response = await ds.client.get( - "/data/-/queries?q=02", - actor={"id": "root"}, - ) - flags_response = await ds.client.get( - "/data/-/queries", - actor={"id": "root"}, - ) - json_response = await ds.client.get( - "/data/-/queries.json?q=02", - actor={"id": "root"}, - ) - filtered_response = await ds.client.get( - "/data/-/queries.json?is_private=1", - actor={"id": "root"}, - ) - filtered_write_response = await ds.client.get( - "/data/-/queries?is_write=1", - actor={"id": "root"}, - ) - filtered_private_response = await ds.client.get( - "/data/-/queries?is_private=1", - actor={"id": "root"}, - ) - - assert html_response.status_code == 200 - assert "Demo query 02" in html_response.text - assert "Demo query 01" not in html_response.text - assert 'class="query-list-results"' in html_response.text - assert 'class="query-list-facets"' in html_response.text - assert 'type="radio"' not in html_response.text - assert "Only the owning actor can view this query." not in html_response.text - assert ( - "Execution skips the usual SQL and write permission checks" - not in html_response.text - ) - assert flags_response.status_code == 200 - assert '
' in flags_response.text - assert '' in flags_response.text - assert '' not in flags_response.text - assert 'class="query-list-owner">root' in flags_response.text - assert 'class="query-list-pill">Read-only' in flags_response.text - assert ( - 'class="query-list-pill query-list-pill-write">Writable' - in flags_response.text - ) - assert ( - 'class="query-list-pill query-list-pill-private">Private' - in flags_response.text - ) - assert ( - 'class="query-list-pill query-list-pill-trusted">Trusted' - in flags_response.text - ) - assert ( - 'href="/data/-/queries?is_write=0">Read-only5' - in flags_response.text - ) - assert ( - 'href="/data/-/queries?is_write=1">Writable1' - in flags_response.text - ) - assert ( - 'href="/data/-/queries?is_private=0">Not private5' - in flags_response.text - ) - assert ( - 'href="/data/-/queries?is_private=1">Private1' - in flags_response.text - ) - assert "Only the owning actor can view this query." in flags_response.text - assert ( - "Execution skips the usual SQL and write permission checks" - in flags_response.text - ) - assert json_response.json()["queries"][0]["name"] == "demo_query_02" - assert [query["name"] for query in filtered_response.json()["queries"]] == [ - "private_query" - ] - assert "Writable query" in filtered_write_response.text - assert "Demo query 01" not in filtered_write_response.text - assert ( - 'query-list-facet-link query-list-facet-link-active" href="/data/-/queries"' - in filtered_write_response.text - ) - assert ( - 'Read-only0' - not in filtered_write_response.text - ) - assert ( - 'href="/data/-/queries?is_write=1&is_private=0">Not private1' - in filtered_write_response.text - ) - assert ( - 'Private0' - not in filtered_write_response.text - ) - assert "Private query" in filtered_private_response.text - assert "Demo query 01" not in filtered_private_response.text - assert ( - 'href="/data/-/queries?is_private=1&is_write=0">Read-only1' - in filtered_private_response.text - ) - assert ( - 'Writable0' - not in filtered_private_response.text - ) - assert ( - 'Not private0' - not in filtered_private_response.text - ) - - -@pytest.mark.asyncio -async def test_query_list_html_defaults_to_twenty_and_shows_pagination(): - ds = Datasette(memory=True) - ds.root_enabled = True - ds.add_memory_database("query_list_html_pagination", name="data") - await ds.invoke_startup() - await add_numbered_queries(ds, "data", 25) - - response = await ds.client.get("/data/-/queries", actor={"id": "root"}) - json_response = await ds.client.get("/data/-/queries.json", actor={"id": "root"}) - - assert response.status_code == 200 - assert response.text.count('aria-label="Query pagination"') == 1 - assert "Demo query 20" in response.text - assert "Demo query 21" not in response.text - assert 'href="/data/-/queries?_next=' in response.text - assert len(json_response.json()["queries"]) == 25 - - -@pytest.mark.asyncio -async def test_global_query_list_api_and_html(): - ds = Datasette(memory=True) - ds.root_enabled = True - ds.add_memory_database("query_list_global_alpha", name="alpha") - ds.add_memory_database("query_list_global_beta", name="beta") - await ds.invoke_startup() - await ds.add_query( - "alpha", - "alpha_first", - "select 1", - title="Alpha first", - source="user", - owner_id="root", - ) - await ds.add_query( - "alpha", - "alpha_second", - "select 2", - title="Alpha second", - source="user", - owner_id="root", - ) - await ds.add_query( - "beta", - "beta_first", - "select 3", - title="Beta first", - source="user", - owner_id="root", - ) - - list_response = await ds.client.get( - "/-/queries.json?_size=2", - actor={"id": "root"}, - ) - next_response = await ds.client.get( - "/-/queries.json?_size=2&_next={}".format(list_response.json()["next"]), - actor={"id": "root"}, - ) - html_response = await ds.client.get( - "/-/queries?q=Beta", - actor={"id": "root"}, - ) - - assert list_response.status_code == 200 - assert [ - (query["database"], query["name"]) for query in list_response.json()["queries"] - ] == [ - ("alpha", "alpha_first"), - ("alpha", "alpha_second"), - ] - assert list_response.json()["next"] - assert [ - (query["database"], query["name"]) for query in next_response.json()["queries"] - ] == [ - ("beta", "beta_first"), - ] - assert html_response.status_code == 200 - assert '' in html_response.text - assert 'class="query-list-database" href="/beta">beta' in html_response.text - assert "Beta first" in html_response.text - assert "Alpha first" not in html_response.text - - -@pytest.mark.asyncio -async def test_query_store_api_rejects_is_trusted(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "view-database": {"id": "writer"}, - "execute-sql": {"id": "writer"}, - "store-query": {"id": "writer"}, - } - } - } - }, - ) - ds.add_memory_database("query_trusted_api", name="data") - await ds.invoke_startup() - - response = await ds.client.post( - "/data/-/queries/store", - actor={"id": "writer"}, - json={"query": {"name": "trusted", "sql": "select 1", "is_trusted": True}}, - ) - - assert response.status_code == 400 - assert response.json()["errors"] == ["Invalid keys: is_trusted"] - - -@pytest.mark.asyncio -async def test_query_store_rejects_config_only_fields(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - ds.add_memory_database("query_config_only_fields_api", name="data") - await ds.invoke_startup() - - response = await ds.client.post( - "/data/-/queries/store", - actor={"id": "root"}, - json={ - "query": { - "name": "unsafe", - "sql": "select 1", - "description_html": "", - "on_success_message_sql": "select 'secret'", - } - }, - ) - form_response = await ds.client.post( - "/data/-/queries/store", - actor={"id": "root"}, - data={ - "name": "unsafe_form", - "sql": "select 1", - "description_html": "", - }, - ) - - assert response.status_code == 400 - assert response.json()["errors"] == [ - "Invalid keys: description_html, on_success_message_sql" - ] - assert form_response.status_code == 400 - assert "Invalid keys: description_html" in form_response.text - assert await ds.get_query("data", "unsafe") is None - assert await ds.get_query("data", "unsafe_form") is None - - -@pytest.mark.asyncio -async def test_query_store_api_creates_writable_query(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - db = ds.add_memory_database("query_write_api", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - - response = await ds.client.post( - "/data/-/queries/store", - actor={"id": "root"}, - json={ - "query": { - "name": "insert_dog", - "sql": "insert into dogs (name) values (:name)", - } - }, - ) - - assert response.status_code == 201 - query = response.json()["query"] - assert query["is_write"] is True - assert query["is_private"] is True - assert query["is_trusted"] is False - assert query["parameters"] == ["name"] - - -@pytest.mark.asyncio -async def test_query_update_and_delete_api(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - ds.add_memory_database("query_update_api", name="data") - await ds.invoke_startup() - await ds.add_query( - "data", - "editable", - "select 1", - title="Original", - source="user", - owner_id="root", - ) - - update_response = await ds.client.post( - "/data/editable/-/update", - actor={"id": "root"}, - json={ - "update": { - "title": "Updated", - "description": "Fresh", - "on_success_redirect": None, - }, - "return": True, - }, - ) - - assert update_response.status_code == 200 - updated = update_response.json()["query"] - assert updated["title"] == "Updated" - assert updated["description"] == "Fresh" - assert updated["on_success_redirect"] is None - - delete_response = await ds.client.post( - "/data/editable/-/delete", - actor={"id": "root"}, - json={}, - ) - - assert delete_response.status_code == 200 - assert delete_response.json() == {"ok": True} - assert await ds.get_query("data", "editable") is None - - -@pytest.mark.asyncio -async def test_query_update_api_rejects_config_only_fields(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - db = ds.add_memory_database("query_update_config_only_fields", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - await ds.add_query( - "data", - "editable", - "insert into dogs (name) values (:name)", - is_write=True, - source="user", - owner_id="root", - ) - - response = await ds.client.post( - "/data/editable/-/update", - actor={"id": "root"}, - json={ - "update": { - "description_html": "", - "on_success_message_sql": "select 'secret'", - } - }, - ) - - assert response.status_code == 400 - assert response.json()["errors"] == [ - "Invalid keys: description_html, on_success_message_sql" - ] - query = await ds.get_query("data", "editable") - assert query.description_html is None - assert query.on_success_message_sql is None - - -@pytest.mark.asyncio -async def test_query_update_api_rejects_trusted_queries_but_internal_update_allowed(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "execute-sql": {"id": "editor"}, - "update-query": {"id": "editor"}, - }, - "queries": { - "trusted_report": { - "sql": "select 1 as one", - "title": "Original", - }, - }, - } - } - }, - ) - ds.add_memory_database("query_update_trusted_api", name="data") - await ds.invoke_startup() - - response = await ds.client.post( - "/data/trusted_report/-/update", - actor={"id": "editor"}, - json={"update": {"sql": "select 2 as two", "title": "Edited"}}, - ) - - assert response.status_code == 403 - assert response.json()["errors"] == [ - "Trusted queries cannot be updated using the API" - ] - query = await ds.get_query("data", "trusted_report") - assert query.is_trusted is True - assert query.sql == "select 1 as one" - assert query.title == "Original" - - await ds.update_query( - "data", - "trusted_report", - sql="select 3 as three", - title="Internal", - ) - query = await ds.get_query("data", "trusted_report") - assert query.is_trusted is True - assert query.sql == "select 3 as three" - assert query.title == "Internal" - - -@pytest.mark.asyncio -async def test_query_store_api_rejects_magic_parameters(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - ds.add_memory_database("query_magic_api", name="data") - await ds.invoke_startup() - - response = await ds.client.post( - "/data/-/queries/store", - actor={"id": "root"}, - json={"query": {"name": "magic", "sql": "select :_actor_id"}}, - ) - - assert response.status_code == 400 - assert response.json()["errors"] == ["Magic parameters are not allowed"] - - -@pytest.mark.asyncio -async def test_create_query_ui_and_arbitrary_sql_save_link(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - db = ds.add_memory_database("query_create_ui", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - - create_response = await ds.client.get( - "/data/-/queries/store?sql=select+*+from+dogs", - actor={"id": "root"}, - ) - write_create_response = await ds.client.get( - "/data/-/queries/store?sql=insert+into+dogs+(name)+values+('Cleo')", - actor={"id": "root"}, - ) - blank_create_response = await ds.client.get( - "/data/-/queries/store", - actor={"id": "root"}, - ) - old_insert_response = await ds.client.get( - "/data/-/queries/insert?sql=select+*+from+dogs", - actor={"id": "root"}, - ) - old_create_response = await ds.client.get( - "/data/-/queries/-/create?sql=select+*+from+dogs", - actor={"id": "root"}, - ) - query_response = await ds.client.get( - "/data/-/query?sql=select+*+from+dogs", - actor={"id": "root"}, - ) - - assert create_response.status_code == 200 - assert "Create query" in create_response.text - assert 'type="radio"' not in create_response.text - assert 'name="parameters"' not in create_response.text - assert 'id="query-parameters"' not in create_response.text - assert 'class="query-create-field"' in create_response.text - assert '' not in create_response.text - assert '' in create_response.text - assert '' in create_response.text - assert '/data/' in create_response.text - assert ( - '' - in create_response.text - ) - assert "function slugify(value)" in create_response.text - assert 'data-analyze-url="/data/-/queries/analyze"' in create_response.text - assert "setupSqlParameterRefresh" in create_response.text - assert "renderParameters: false" in create_response.text - assert "datasetteSqlAnalysis.renderAnalysis" in create_response.text - assert "data-query-create-submit" in create_response.text - assert "data-query-create-writable" not in create_response.text - assert "data-query-create-sql-type" not in create_response.text - assert "data-query-create-analysis-note" in create_response.text - assert "SQL type:" not in create_response.text - assert ( - 'This is a read-only query.' - in create_response.text - ) - assert "disabled> Writable" not in create_response.text - assert ( - "Queries marked private can only be seen by you, their creator." - in create_response.text - ) - assert create_response.text.index( - "This is a read-only query." - ) < create_response.text.index('') - assert "

Query operations

" in create_response.text - assert '
this …http…OwnerFlagsModeDatabase
' in create_response.text - assert '' in create_response.text - assert '' not in create_response.text - assert "" in create_response.text - assert ( - create_response.text.count( - '' - ) - == 2 - ) - assert create_response.text.index( - 'value="Save query"' - ) < create_response.text.index("

Query operations

") - assert blank_create_response.status_code == 200 - assert ( - '
Required permissionSourcereadn/a
' in response.text - assert '' in response.text - assert "" in response.text - assert "" in response.text - assert "" not in response.text - assert 'action="/data/-/execute-write"' in response.text - assert "insert into dogs (name) values ('Cleo')" in response.text - assert (await db.execute("select count(*) from dogs")).first()[0] == 0 - - empty_response = await ds.client.get( - "/data/-/execute-write", - actor={"id": "root"}, - ) - assert '' in empty_response.text - assert 'executeWriteSqlInput.value = "\\n\\n\\n";' in empty_response.text - assert "hidden>Save this query" in empty_response.text - - read_only_response = await ds.client.get( - "/data/-/execute-write?sql=select+*+from+dogs", - actor={"id": "root"}, - ) - assert ( - "Use /-/query for read-only SQL; this endpoint only executes writes" - in read_only_response.text - ) - assert "hidden>Save this query" in read_only_response.text - - -@pytest.mark.asyncio -async def test_execute_write_analyze_endpoint_uses_sql_only(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - db = ds.add_memory_database("execute_write_analyze", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - - response = await ds.client.get( - "/data/-/execute-write/analyze", - actor={"id": "root"}, - params={"sql": "insert into dogs (name) values (:name)"}, - ) - read_only_response = await ds.client.get( - "/data/-/execute-write/analyze", - actor={"id": "root"}, - params={"sql": "select * from dogs where name = :name"}, - ) - - assert response.status_code == 200 - data = response.json() - assert data["ok"] is True - assert data["parameters"] == ["name"] - assert data["analysis_error"] is None - assert data["execute_disabled"] is False - assert data["analysis_rows"] == [ - { - "operation": "insert", - "database": "data", - "table": "dogs", - "required_permission": "insert-row", - "source": None, - "allowed": True, - } - ] - assert "params" not in data - - assert read_only_response.status_code == 200 - read_only_data = read_only_response.json() - assert read_only_data["ok"] is False - assert read_only_data["parameters"] == ["name"] - assert read_only_data["analysis_error"] == ( - "Use /-/query for read-only SQL; this endpoint only executes writes" - ) - assert read_only_data["execute_disabled"] is True - - -@pytest.mark.asyncio -async def test_query_parameters_endpoint_uses_get_sql_only(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - db = ds.add_memory_database("query_parameters", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - - response = await ds.client.get( - "/data/-/query/parameters", - actor={"id": "root"}, - params={ - "sql": "select * from dogs where name = :name and id = :id", - }, - ) - permission_denied_response = await ds.client.get( - "/data/-/query/parameters", - actor={"id": "not-root"}, - params={"sql": "select * from dogs where name = :name"}, - ) - magic_parameter_response = await ds.client.get( - "/data/-/query/parameters", - actor={"id": "root"}, - params={"sql": "select :_actor_id"}, - ) - - assert response.status_code == 200 - assert response.json() == {"ok": True, "parameters": ["name", "id"]} - assert permission_denied_response.status_code == 403 - assert permission_denied_response.json()["errors"] == [ - "Permission denied: need execute-sql" - ] - assert magic_parameter_response.status_code == 400 - assert magic_parameter_response.json()["errors"] == [ - "Magic parameters are not allowed" - ] - - -@pytest.mark.asyncio -async def test_database_action_menu_links_to_execute_write_for_permitted_actor(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "view-database": { - "id": ["writer", "viewer"], - }, - "execute-write-sql": {"id": "writer"}, - } - } - } - }, - ) - ds.add_memory_database("execute_write_menu", name="data") - await ds.invoke_startup() - - anonymous_response = await ds.client.get("/data") - viewer_response = await ds.client.get("/data", actor={"id": "viewer"}) - writer_response = await ds.client.get("/data", actor={"id": "writer"}) - - assert anonymous_response.status_code == 403 - assert viewer_response.status_code == 200 - assert "Execute write SQL" not in viewer_response.text - assert writer_response.status_code == 200 - assert "Database actions" in writer_response.text - assert 'href="/data/-/execute-write"' in writer_response.text - assert "Execute write SQL" in writer_response.text - - -@pytest.mark.asyncio -async def test_database_action_menu_hides_execute_write_for_immutable_database(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "view-database": {"id": "writer"}, - "execute-write-sql": {"id": "writer"}, - } - } - } - }, - ) - db = ds.add_memory_database("execute_write_menu_immutable", name="data") - db.is_mutable = False - await ds.invoke_startup() - - response = await ds.client.get("/data", actor={"id": "writer"}) - - assert response.status_code == 200 - assert "Execute write SQL" not in response.text - assert 'href="/data/-/execute-write"' not in response.text - - -@pytest.mark.asyncio -async def test_execute_write_get_rejects_immutable_database(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - db = ds.add_memory_database("execute_write_get_immutable", name="data") - db.is_mutable = False - await ds.invoke_startup() - - response = await ds.client.get( - "/data/-/execute-write?sql=insert+into+dogs+(name)+values+('Cleo')", - actor={"id": "root"}, - ) - - assert response.status_code == 403 - assert response.json()["errors"] == [ - "Cannot execute write SQL because this database is immutable." - ] - - -@pytest.mark.asyncio -async def test_execute_write_post_requires_database_and_table_permissions(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "view-database": {"id": "writer"}, - "execute-write-sql": {"id": "writer"}, - } - } - } - }, - ) - db = ds.add_memory_database("execute_write_permissions", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - - no_database_permission = await ds.client.post( - "/data/-/execute-write", - actor={"id": "outsider"}, - json={ - "sql": "insert into dogs (name) values (:name)", - "params": {"name": "Cleo"}, - }, - ) - no_table_permission = await ds.client.post( - "/data/-/execute-write", - actor={"id": "writer"}, - json={ - "sql": "insert into dogs (name) values (:name)", - "params": {"name": "Cleo"}, - }, - ) - - assert no_database_permission.status_code == 403 - assert no_database_permission.json()["errors"] == [ - "Permission denied: need execute-write-sql" - ] - assert no_table_permission.status_code == 403 - assert no_table_permission.json()["errors"] == [ - "Permission denied: need insert-row on data/dogs" - ] - - ds.config = { - "databases": { - "data": { - "permissions": { - "view-database": {"id": "writer"}, - "execute-write-sql": {"id": "writer"}, - }, - "tables": { - "dogs": { - "permissions": { - "insert-row": {"id": "writer"}, - } - } - }, - } - } - } - allowed = await ds.client.post( - "/data/-/execute-write", - actor={"id": "writer"}, - json={ - "sql": "insert into dogs (name) values (:name)", - "params": {"name": "Cleo"}, - }, - ) - - assert allowed.status_code == 200 - assert allowed.json()["ok"] is True - assert allowed.json()["rowcount"] == 1 - assert allowed.json()["analysis"][0]["operation"] == "insert" - assert (await db.execute("select name from dogs")).first()[0] == "Cleo" - - -@pytest.mark.asyncio -async def test_execute_write_insert_links_to_inserted_row(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - db = ds.add_memory_database("execute_write_insert_link", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await db.execute_write("create table log (id integer primary key, message text)") - await db.execute_write("insert into log (message) values ('existing')") - await db.execute_write(""" - create trigger dogs_after_insert after insert on dogs begin - insert into log (message) values (new.name); - end - """) - await ds.invoke_startup() - - insert_response = await ds.client.post( - "/data/-/execute-write", - actor={"id": "root"}, - data={ - "sql": "insert into dogs (name) values (:name)", - "name": "Cleo", - }, - ) - update_response = await ds.client.post( - "/data/-/execute-write", - actor={"id": "root"}, - data={ - "sql": "update dogs set name = :name where id = :id", - "name": "Cleo 2", - "id": "1", - }, - ) - - assert insert_response.status_code == 200 - assert "Query executed, 1 row affected" in insert_response.text - assert 'View row' in insert_response.text - assert "/data/log/2" not in insert_response.text - assert update_response.status_code == 200 - assert "Query executed, 1 row affected" in update_response.text - assert "View row" not in update_response.text - - -@pytest.mark.asyncio -async def test_execute_write_post_rejects_read_only_sql(): - ds = Datasette(memory=True, default_deny=True) - ds.root_enabled = True - db = ds.add_memory_database("execute_write_read_only", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - - response = await ds.client.post( - "/data/-/execute-write", - actor={"id": "root"}, - json={"sql": "select * from dogs"}, - ) - - assert response.status_code == 400 - assert response.json()["errors"] == [ - "Use /-/query for read-only SQL; this endpoint only executes writes" - ] - - -@pytest.mark.asyncio -async def test_query_owner_gets_update_delete_and_writable_view_defaults(): - ds = Datasette(memory=True, default_deny=True) - ds.add_memory_database("query_owner_defaults", name="data") - await ds.invoke_startup() - await ds.add_query( - "data", - "insert_dog", - "insert into dogs (name) values (:name)", - is_write=True, - source="user", - owner_id="alice", - ) - - for action in ("view-query", "update-query", "delete-query"): - assert await ds.allowed( - action=action, - resource=QueryResource("data", "insert_dog"), - actor={"id": "alice"}, - ) - assert not await ds.allowed( - action=action, - resource=QueryResource("data", "insert_dog"), - actor={"id": "bob"}, - ) - - -@pytest.mark.asyncio -async def test_private_query_restricts_broad_update_delete_permissions(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "update-query": {"id": "bob"}, - "delete-query": {"id": "bob"}, - }, - }, - }, - }, - ) - ds.add_memory_database("query_broad_update_delete", name="data") - await ds.invoke_startup() - await ds.add_query( - "data", - "alice_private", - "select 1", - is_private=True, - source="user", - owner_id="alice", - ) - await ds.add_query( - "data", - "alice_public", - "select 2", - is_private=False, - source="user", - owner_id="alice", - ) - - for action in ("update-query", "delete-query"): - assert await ds.allowed( - action=action, - resource=QueryResource("data", "alice_private"), - actor={"id": "alice"}, - ) - assert not await ds.allowed( - action=action, - resource=QueryResource("data", "alice_private"), - actor={"id": "bob"}, - ) - assert await ds.allowed( - action=action, - resource=QueryResource("data", "alice_public"), - actor={"id": "bob"}, - ) - - private_update_response = await ds.client.post( - "/data/alice_private/-/update", - actor={"id": "bob"}, - json={"update": {"title": "Nope"}}, - ) - private_delete_response = await ds.client.post( - "/data/alice_private/-/delete", - actor={"id": "bob"}, - json={}, - ) - public_update_response = await ds.client.post( - "/data/alice_public/-/update", - actor={"id": "bob"}, - json={"update": {"title": "Bob can edit public queries"}}, - ) - public_delete_response = await ds.client.post( - "/data/alice_public/-/delete", - actor={"id": "bob"}, - json={}, - ) - - assert private_update_response.status_code == 403 - assert private_delete_response.status_code == 403 - assert public_update_response.status_code == 200 - assert public_delete_response.status_code == 200 - assert await ds.get_query("data", "alice_private") is not None - assert await ds.get_query("data", "alice_public") is None - - -@pytest.mark.asyncio -async def test_user_writable_query_execution_rechecks_table_permissions(): - ds = Datasette( - memory=True, - default_deny=True, - config={ - "databases": { - "data": { - "permissions": { - "view-database": {"id": ["alice", "bob"]}, - "execute-write-sql": {"id": ["alice", "bob"]}, - }, - "tables": { - "dogs": { - "permissions": { - "insert-row": {"id": "alice"}, - } - } - }, - } - } - }, - ) - db = ds.add_memory_database("query_write_execution", name="data") - await db.execute_write("create table dogs (id integer primary key, name text)") - await ds.invoke_startup() - await ds.add_query( - "data", - "insert_dog", - "insert into dogs (name) values (:name)", - is_write=True, - source="user", - owner_id="alice", - ) - await ds.add_query( - "data", - "insert_cat", - "insert into dogs (name) values (:name)", - is_write=True, - source="user", - owner_id="bob", - ) - - allowed_response = await ds.client.post( - "/data/insert_dog?_json=1", - actor={"id": "alice"}, - data={"name": "Cleo"}, - ) - denied_response = await ds.client.post( - "/data/insert_cat?_json=1", - actor={"id": "bob"}, - data={"name": "Milo"}, - ) - - assert allowed_response.status_code == 200 - assert allowed_response.json()["ok"] is True - assert denied_response.status_code == 403 - rows = (await db.execute("select name from dogs")).dicts() - assert rows == [{"name": "Cleo"}] diff --git a/tests/test_restriction_sql.py b/tests/test_restriction_sql.py deleted file mode 100644 index df6abd29..00000000 --- a/tests/test_restriction_sql.py +++ /dev/null @@ -1,315 +0,0 @@ -import pytest -from datasette.app import Datasette -from datasette.permissions import PermissionSQL -from datasette.resources import TableResource - - -@pytest.mark.asyncio -async def test_multiple_restriction_sources_intersect(): - """ - Test that when multiple plugins return restriction_sql, they are INTERSECTed. - - This tests the case where both actor _r restrictions AND a plugin - provide restriction_sql - both must pass for access to be granted. - """ - from datasette import hookimpl - - class RestrictivePlugin: - __name__ = "RestrictivePlugin" - - @hookimpl - def permission_resources_sql(self, datasette, actor, action): - # Plugin adds additional restriction: only db1_multi_intersect allowed - if action == "view-table": - return PermissionSQL( - restriction_sql="SELECT 'db1_multi_intersect' AS parent, NULL AS child", - params={}, - ) - return None - - plugin = RestrictivePlugin() - - ds = Datasette() - await ds.invoke_startup() - ds.pm.register(plugin, name="restrictive_plugin") - - try: - db1 = ds.add_memory_database("db1_multi_intersect") - db2 = ds.add_memory_database("db2_multi_intersect") - await db1.execute_write("CREATE TABLE t1 (id INTEGER)") - await db2.execute_write("CREATE TABLE t1 (id INTEGER)") - await ds._refresh_schemas() # Populate catalog tables - - # Actor has restrictions allowing both databases - # But plugin only allows db1_multi_intersect - # INTERSECT means only db1_multi_intersect/t1 should pass - actor = { - "id": "user", - "_r": {"d": {"db1_multi_intersect": ["vt"], "db2_multi_intersect": ["vt"]}}, - } - - page = await ds.allowed_resources("view-table", actor) - resources = {(r.parent, r.child) for r in page.resources} - - # Should only see db1_multi_intersect/t1 (intersection of actor restrictions and plugin restrictions) - assert ("db1_multi_intersect", "t1") in resources - assert ("db2_multi_intersect", "t1") not in resources - finally: - ds.pm.unregister(name="restrictive_plugin") - - -@pytest.mark.asyncio -async def test_restriction_sql_with_overlapping_databases_and_tables(): - """ - Test actor with both database-level and table-level restrictions for same database. - - When actor has: - - Database-level: db1_overlapping allowed (all tables) - - Table-level: db1_overlapping/t1 allowed - - Both entries are UNION'd (OR'ed) within the actor's restrictions. - Database-level restriction allows ALL tables, so table-level is redundant. - """ - ds = Datasette() - await ds.invoke_startup() - db = ds.add_memory_database("db1_overlapping") - await db.execute_write("CREATE TABLE t1 (id INTEGER)") - await db.execute_write("CREATE TABLE t2 (id INTEGER)") - await ds._refresh_schemas() - - # Actor has BOTH database-level (db1_overlapping all tables) AND table-level (db1_overlapping/t1 only) - actor = { - "id": "user", - "_r": { - "d": { - "db1_overlapping": ["vt"] - }, # Database-level: all tables in db1_overlapping - "r": { - "db1_overlapping": {"t1": ["vt"]} - }, # Table-level: only t1 in db1_overlapping - }, - } - - # Within actor restrictions, entries are UNION'd (OR'ed): - # - Database level allows: (db1_overlapping, NULL) → matches all tables via hierarchical matching - # - Table level allows: (db1_overlapping, t1) → redundant, already covered by database level - # Result: Both tables are allowed - page = await ds.allowed_resources("view-table", actor) - resources = {(r.parent, r.child) for r in page.resources} - - assert ("db1_overlapping", "t1") in resources - # Database-level restriction allows all tables, so t2 is also allowed - assert ("db1_overlapping", "t2") in resources - - -@pytest.mark.asyncio -async def test_restriction_sql_empty_allowlist_query(): - """ - Test the specific SQL query generated when action is not in allowlist. - - actor_restrictions_sql() returns "SELECT NULL AS parent, NULL AS child WHERE 0" - Verify this produces an empty result set. - """ - ds = Datasette() - await ds.invoke_startup() - db = ds.add_memory_database("db1_empty_allowlist") - await db.execute_write("CREATE TABLE t1 (id INTEGER)") - await ds._refresh_schemas() - - # Actor has restrictions but action not in allowlist - actor = {"id": "user", "_r": {"r": {"db1_empty_allowlist": {"t1": ["vt"]}}}} - - # Try to view-database (only view-table is in allowlist) - page = await ds.allowed_resources("view-database", actor) - - # Should be empty - assert len(page.resources) == 0 - - -@pytest.mark.asyncio -async def test_restriction_sql_with_pagination(): - """ - Test that restrictions work correctly with keyset pagination. - """ - ds = Datasette() - await ds.invoke_startup() - db = ds.add_memory_database("db1_pagination") - - # Create many tables - for i in range(10): - await db.execute_write(f"CREATE TABLE t{i:02d} (id INTEGER)") - await ds._refresh_schemas() - - # Actor restricted to only odd-numbered tables - restrictions = {"r": {"db1_pagination": {}}} - for i in range(10): - if i % 2 == 1: # Only odd tables - restrictions["r"]["db1_pagination"][f"t{i:02d}"] = ["vt"] - - actor = {"id": "user", "_r": restrictions} - - # Get first page with small limit - page1 = await ds.allowed_resources( - "view-table", actor, parent="db1_pagination", limit=2 - ) - assert len(page1.resources) == 2 - assert page1.next is not None - - # Get second page using next token - page2 = await ds.allowed_resources( - "view-table", actor, parent="db1_pagination", limit=2, next=page1.next - ) - assert len(page2.resources) == 2 - - # Should have no overlap - page1_ids = {r.child for r in page1.resources} - page2_ids = {r.child for r in page2.resources} - assert page1_ids.isdisjoint(page2_ids) - - # All should be odd-numbered tables - all_ids = page1_ids | page2_ids - for table_id in all_ids: - table_num = int(table_id[1:]) # Extract number from "t01", "t03", etc. - assert table_num % 2 == 1, f"Table {table_id} should be odd-numbered" - - -@pytest.mark.asyncio -async def test_also_requires_with_restrictions(): - """ - Test that also_requires actions properly respect restrictions. - - execute-sql requires view-database. With restrictions, both must pass. - """ - ds = Datasette() - await ds.invoke_startup() - ds.add_memory_database("db1_also_requires") - ds.add_memory_database("db2_also_requires") - await ds._refresh_schemas() - - # Actor restricted to only db1_also_requires for view-database - # execute-sql requires view-database, so should only work on db1_also_requires - actor = { - "id": "user", - "_r": { - "d": { - "db1_also_requires": ["vd", "es"], - "db2_also_requires": [ - "es" - ], # They have execute-sql but not view-database - } - }, - } - - # db1_also_requires should allow execute-sql - result = await ds.allowed( - action="execute-sql", - resource=TableResource("db1_also_requires", None), - actor=actor, - ) - assert result is True - - # db2_also_requires should not (they have execute-sql but not view-database) - result = await ds.allowed( - action="execute-sql", - resource=TableResource("db2_also_requires", None), - actor=actor, - ) - assert result is False - - -@pytest.mark.asyncio -async def test_restriction_abbreviations_and_full_names(): - """ - Test that both abbreviations and full action names work in restrictions. - """ - ds = Datasette() - await ds.invoke_startup() - db = ds.add_memory_database("db1_abbrev") - await db.execute_write("CREATE TABLE t1 (id INTEGER)") - await ds._refresh_schemas() - - # Test with abbreviation - actor_abbr = {"id": "user", "_r": {"r": {"db1_abbrev": {"t1": ["vt"]}}}} - result = await ds.allowed( - action="view-table", - resource=TableResource("db1_abbrev", "t1"), - actor=actor_abbr, - ) - assert result is True - - # Test with full name - actor_full = {"id": "user", "_r": {"r": {"db1_abbrev": {"t1": ["view-table"]}}}} - result = await ds.allowed( - action="view-table", - resource=TableResource("db1_abbrev", "t1"), - actor=actor_full, - ) - assert result is True - - # Test with mixed - actor_mixed = {"id": "user", "_r": {"d": {"db1_abbrev": ["view-database", "vt"]}}} - result = await ds.allowed( - action="view-table", - resource=TableResource("db1_abbrev", "t1"), - actor=actor_mixed, - ) - assert result is True - - -@pytest.mark.asyncio -async def test_permission_resources_sql_multiple_restriction_sources_intersect(): - """ - Test that when multiple plugins return restriction_sql, they are INTERSECTed. - - This tests the case where both actor _r restrictions AND a plugin - provide restriction_sql - both must pass for access to be granted. - """ - from datasette import hookimpl - - class RestrictivePlugin: - __name__ = "RestrictivePlugin" - - @hookimpl - def permission_resources_sql(self, datasette, actor, action): - # Plugin adds additional restriction: only db1_multi_restrictions allowed - if action == "view-table": - return PermissionSQL( - restriction_sql="SELECT 'db1_multi_restrictions' AS parent, NULL AS child", - params={}, - ) - return None - - plugin = RestrictivePlugin() - - ds = Datasette() - await ds.invoke_startup() - ds.pm.register(plugin, name="restrictive_plugin") - - try: - db1 = ds.add_memory_database("db1_multi_restrictions") - db2 = ds.add_memory_database("db2_multi_restrictions") - await db1.execute_write("CREATE TABLE t1 (id INTEGER)") - await db2.execute_write("CREATE TABLE t1 (id INTEGER)") - await ds._refresh_schemas() # Populate catalog tables - - # Actor has restrictions allowing both databases - # But plugin only allows db1 - # INTERSECT means only db1/t1 should pass - actor = { - "id": "user", - "_r": { - "d": { - "db1_multi_restrictions": ["vt"], - "db2_multi_restrictions": ["vt"], - } - }, - } - - page = await ds.allowed_resources("view-table", actor) - resources = {(r.parent, r.child) for r in page.resources} - - # Should only see db1/t1 (intersection of actor restrictions and plugin restrictions) - assert ("db1_multi_restrictions", "t1") in resources - assert ("db2_multi_restrictions", "t1") not in resources - finally: - ds.pm.unregister(name="restrictive_plugin") diff --git a/tests/test_routes.py b/tests/test_routes.py deleted file mode 100644 index 24c702fc..00000000 --- a/tests/test_routes.py +++ /dev/null @@ -1,109 +0,0 @@ -from datasette.app import Datasette, Database -from datasette.utils import resolve_routes -import pytest -import pytest_asyncio - - -@pytest.fixture(scope="session") -def routes(): - ds = Datasette() - return ds._routes() - - -@pytest.mark.parametrize( - "path,expected_name,expected_matches", - ( - ("/", "IndexView", {"format": None}), - ("/foo", "DatabaseView", {"format": None, "database": "foo"}), - ("/foo.csv", "DatabaseView", {"format": "csv", "database": "foo"}), - ("/foo.json", "DatabaseView", {"format": "json", "database": "foo"}), - ("/foo.humbug", "DatabaseView", {"format": "humbug", "database": "foo"}), - ( - "/foo/humbug", - "table_view", - {"database": "foo", "table": "humbug", "format": None}, - ), - ( - "/foo/humbug.json", - "table_view", - {"database": "foo", "table": "humbug", "format": "json"}, - ), - ( - "/foo/humbug.blah", - "table_view", - {"database": "foo", "table": "humbug", "format": "blah"}, - ), - ( - "/foo/humbug/1", - "RowView", - {"format": None, "database": "foo", "pks": "1", "table": "humbug"}, - ), - ( - "/foo/humbug/1.json", - "RowView", - {"format": "json", "database": "foo", "pks": "1", "table": "humbug"}, - ), - ), -) -def test_routes(routes, path, expected_name, expected_matches): - match, view = resolve_routes(routes, path) - if expected_name is None: - assert match is None - else: - assert ( - view.__name__ == expected_name or view.view_class.__name__ == expected_name - ) - assert match.groupdict() == expected_matches - - -@pytest_asyncio.fixture -async def ds_with_route(): - ds = Datasette() - await ds.invoke_startup() - ds.remove_database("_memory") - db = Database(ds, is_memory=True, memory_name="route-name-db") - ds.add_database(db, name="original-name", route="custom-route-name") - await db.execute_write_script(""" - create table if not exists t (id integer primary key); - insert or replace into t (id) values (1); - """) - return ds - - -@pytest.mark.asyncio -async def test_db_with_route_databases(ds_with_route): - response = await ds_with_route.client.get("/-/databases.json") - assert response.json()[0] == { - "name": "original-name", - "route": "custom-route-name", - "path": None, - "size": 0, - "is_mutable": True, - "is_memory": True, - "hash": None, - } - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_status", - ( - ("/", 200), - ("/original-name", 404), - ("/original-name/t", 404), - ("/original-name/t/1", 404), - ("/custom-route-name", 200), - ("/custom-route-name/-/query?sql=select+id+from+t", 200), - ("/custom-route-name/t", 200), - ("/custom-route-name/t/1", 200), - ), -) -async def test_db_with_route_that_does_not_match_name( - ds_with_route, path, expected_status -): - response = await ds_with_route.client.get(path) - assert response.status_code == expected_status - # There should be links to custom-route-name but none to original-name - if response.status_code == 200: - assert "/custom-route-name" in response.text - assert "/original-name" not in response.text diff --git a/tests/test_schema_endpoints.py b/tests/test_schema_endpoints.py deleted file mode 100644 index c95d8614..00000000 --- a/tests/test_schema_endpoints.py +++ /dev/null @@ -1,247 +0,0 @@ -import pytest -import pytest_asyncio -from datasette.app import Datasette - - -@pytest_asyncio.fixture(scope="module") -async def schema_ds(): - """Create a Datasette instance with test databases and permission config.""" - ds = Datasette( - config={ - "databases": { - "schema_private_db": {"allow": {"id": "root"}}, - } - } - ) - - # Create public database with multiple tables - public_db = ds.add_memory_database("schema_public_db") - await public_db.execute_write( - "CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT)" - ) - await public_db.execute_write( - "CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, title TEXT)" - ) - await public_db.execute_write( - "CREATE VIEW IF NOT EXISTS recent_posts AS SELECT * FROM posts ORDER BY id DESC" - ) - - # Create a database with restricted access (requires root permission) - private_db = ds.add_memory_database("schema_private_db") - await private_db.execute_write( - "CREATE TABLE IF NOT EXISTS secret_data (id INTEGER PRIMARY KEY, value TEXT)" - ) - - # Create an empty database - ds.add_memory_database("schema_empty_db") - - return ds - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "format_ext,expected_in_content", - [ - ("json", None), - ("md", ["# Schema for", "```sql"]), - ("", ["Schema for", "CREATE TABLE"]), - ], -) -async def test_database_schema_formats(schema_ds, format_ext, expected_in_content): - """Test /database/-/schema endpoint in different formats.""" - url = "/schema_public_db/-/schema" - if format_ext: - url += f".{format_ext}" - response = await schema_ds.client.get(url) - assert response.status_code == 200 - - if format_ext == "json": - data = response.json() - assert "database" in data - assert data["database"] == "schema_public_db" - assert "schema" in data - assert "CREATE TABLE users" in data["schema"] - else: - content = response.text - for expected in expected_in_content: - assert expected in content - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "format_ext,expected_in_content", - [ - ("json", None), - ("md", ["# Schema for", "```sql"]), - ("", ["Schema for all databases"]), - ], -) -async def test_instance_schema_formats(schema_ds, format_ext, expected_in_content): - """Test /-/schema endpoint in different formats.""" - url = "/-/schema" - if format_ext: - url += f".{format_ext}" - response = await schema_ds.client.get(url) - assert response.status_code == 200 - - if format_ext == "json": - data = response.json() - assert "schemas" in data - assert isinstance(data["schemas"], list) - db_names = [item["database"] for item in data["schemas"]] - # Should see schema_public_db and schema_empty_db, but not schema_private_db (anonymous user) - assert "schema_public_db" in db_names - assert "schema_empty_db" in db_names - assert "schema_private_db" not in db_names - # Check schemas are present - for item in data["schemas"]: - if item["database"] == "schema_public_db": - assert "CREATE TABLE users" in item["schema"] - else: - content = response.text - for expected in expected_in_content: - assert expected in content - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "format_ext,expected_in_content", - [ - ("json", None), - ("md", ["# Schema for", "```sql"]), - ("", ["Schema for users"]), - ], -) -async def test_table_schema_formats(schema_ds, format_ext, expected_in_content): - """Test /database/table/-/schema endpoint in different formats.""" - url = "/schema_public_db/users/-/schema" - if format_ext: - url += f".{format_ext}" - response = await schema_ds.client.get(url) - assert response.status_code == 200 - - if format_ext == "json": - data = response.json() - assert "database" in data - assert data["database"] == "schema_public_db" - assert "table" in data - assert data["table"] == "users" - assert "schema" in data - assert "CREATE TABLE users" in data["schema"] - else: - content = response.text - for expected in expected_in_content: - assert expected in content - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "url", - [ - "/schema_private_db/-/schema.json", - "/schema_private_db/secret_data/-/schema.json", - ], -) -async def test_schema_permission_enforcement(schema_ds, url): - """Test that permissions are enforced for schema endpoints.""" - # Anonymous user should get 403 - response = await schema_ds.client.get(url) - assert response.status_code == 403 - - # Authenticated user with permission should succeed - response = await schema_ds.client.get( - url, - actor={"id": "root"}, - ) - assert response.status_code == 200 - - -@pytest.mark.asyncio -async def test_instance_schema_respects_database_permissions(schema_ds): - """Test that /-/schema only shows databases the user can view.""" - # Anonymous user should only see public databases - response = await schema_ds.client.get("/-/schema.json") - assert response.status_code == 200 - data = response.json() - db_names = [item["database"] for item in data["schemas"]] - assert "schema_public_db" in db_names - assert "schema_empty_db" in db_names - assert "schema_private_db" not in db_names - - # Authenticated user should see all databases - response = await schema_ds.client.get( - "/-/schema.json", - actor={"id": "root"}, - ) - assert response.status_code == 200 - data = response.json() - db_names = [item["database"] for item in data["schemas"]] - assert "schema_public_db" in db_names - assert "schema_empty_db" in db_names - assert "schema_private_db" in db_names - - -@pytest.mark.asyncio -async def test_database_schema_with_multiple_tables(schema_ds): - """Test schema with multiple tables in a database.""" - response = await schema_ds.client.get("/schema_public_db/-/schema.json") - assert response.status_code == 200 - data = response.json() - schema = data["schema"] - - # All objects should be in the schema - assert "CREATE TABLE users" in schema - assert "CREATE TABLE posts" in schema - assert "CREATE VIEW recent_posts" in schema - - -@pytest.mark.asyncio -async def test_empty_database_schema(schema_ds): - """Test schema for an empty database.""" - response = await schema_ds.client.get("/schema_empty_db/-/schema.json") - assert response.status_code == 200 - data = response.json() - assert data["database"] == "schema_empty_db" - assert data["schema"] == "" - - -@pytest.mark.asyncio -async def test_database_not_exists(schema_ds): - """Test schema for a non-existent database returns 404.""" - # Test JSON format - response = await schema_ds.client.get("/nonexistent_db/-/schema.json") - assert response.status_code == 404 - data = response.json() - assert data["ok"] is False - assert "not found" in data["error"].lower() - - # Test HTML format (returns text) - response = await schema_ds.client.get("/nonexistent_db/-/schema") - assert response.status_code == 404 - assert "not found" in response.text.lower() - - # Test Markdown format (returns text) - response = await schema_ds.client.get("/nonexistent_db/-/schema.md") - assert response.status_code == 404 - assert "not found" in response.text.lower() - - -@pytest.mark.asyncio -async def test_table_not_exists(schema_ds): - """Test schema for a non-existent table returns 404.""" - # Test JSON format - response = await schema_ds.client.get("/schema_public_db/nonexistent/-/schema.json") - assert response.status_code == 404 - data = response.json() - assert data["ok"] is False - assert "not found" in data["error"].lower() - - # Test HTML format (returns text) - response = await schema_ds.client.get("/schema_public_db/nonexistent/-/schema") - assert response.status_code == 404 - assert "not found" in response.text.lower() - - # Test Markdown format (returns text) - response = await schema_ds.client.get("/schema_public_db/nonexistent/-/schema.md") - assert response.status_code == 404 - assert "not found" in response.text.lower() diff --git a/tests/test_search_tables.py b/tests/test_search_tables.py deleted file mode 100644 index ce774327..00000000 --- a/tests/test_search_tables.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Tests for special endpoints in datasette/views/special.py -""" - -import pytest -import pytest_asyncio -from datasette.app import Datasette - - -@pytest_asyncio.fixture -async def ds_with_tables(): - """Create a Datasette instance with some tables for searching.""" - ds = Datasette( - config={ - "databases": { - "content": { - "allow": {"id": "*"}, # Allow all authenticated users - "tables": { - "articles": { - "allow": {"id": "editor"}, # Only editor can view - }, - "comments": { - "allow": True, # Everyone can view - }, - }, - }, - "private": { - "allow": False, # Deny everyone - }, - } - } - ) - await ds.invoke_startup() - - # Add content database with some tables - content_db = ds.add_memory_database("search_tables_content", name="content") - await content_db.execute_write( - "CREATE TABLE IF NOT EXISTS articles (id INTEGER PRIMARY KEY, title TEXT)" - ) - await content_db.execute_write( - "CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, body TEXT)" - ) - await content_db.execute_write( - "CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT)" - ) - - # Add private database with a table - private_db = ds.add_memory_database("search_tables_private", name="private") - await private_db.execute_write( - "CREATE TABLE IF NOT EXISTS secrets (id INTEGER PRIMARY KEY, data TEXT)" - ) - - # Add another public database - public_db = ds.add_memory_database("search_tables_public", name="public") - await public_db.execute_write( - "CREATE TABLE IF NOT EXISTS articles (id INTEGER PRIMARY KEY, content TEXT)" - ) - await ds._refresh_schemas() - - return ds - - -# /-/jump.json table search tests -@pytest.mark.asyncio -async def test_tables_basic_search(ds_with_tables): - """Test basic table search functionality.""" - # Search for "articles" - should find it in both content and public databases - # but only return public.articles for anonymous user (content.articles requires auth) - response = await ds_with_tables.client.get("/-/jump.json?q=articles") - assert response.status_code == 200 - data = response.json() - - # Should only see public.articles (content.articles restricted to authenticated users) - assert "matches" in data - assert len(data["matches"]) == 1 - - match = data["matches"][0] - assert "url" in match - assert "name" in match - assert match["name"] == "public: articles" - assert "/public/articles" in match["url"] - - -@pytest.mark.asyncio -async def test_tables_search_with_auth(ds_with_tables): - """Test that authenticated users see more tables.""" - # Editor user should see content.articles - response = await ds_with_tables.client.get( - "/-/jump.json?q=articles", - actor={"id": "editor"}, - ) - assert response.status_code == 200 - data = response.json() - - # Should see both content.articles and public.articles - assert len(data["matches"]) == 2 - - names = {match["name"] for match in data["matches"]} - assert names == {"content: articles", "public: articles"} - - -@pytest.mark.asyncio -async def test_tables_search_partial_match(ds_with_tables): - """Test that search matches partial table names.""" - # Search for "com" should match "comments" - response = await ds_with_tables.client.get( - "/-/jump.json?q=com", - actor={"id": "user"}, - ) - assert response.status_code == 200 - data = response.json() - - assert len(data["matches"]) == 1 - assert data["matches"][0]["name"] == "content: comments" - - -@pytest.mark.asyncio -async def test_tables_search_respects_database_permissions(ds_with_tables): - """Test that tables from denied databases are not shown.""" - # Search for "secrets" which is in the private database - # Even authenticated users shouldn't see it because database is denied - response = await ds_with_tables.client.get( - "/-/jump.json?q=secrets", - actor={"id": "user"}, - ) - assert response.status_code == 200 - data = response.json() - - # Should not see secrets table from private database - assert len(data["matches"]) == 0 - - -@pytest.mark.asyncio -async def test_tables_search_respects_table_permissions(ds_with_tables): - """Test that tables with specific permissions are filtered correctly.""" - # Regular authenticated user searching for "users" - response = await ds_with_tables.client.get( - "/-/jump.json?q=users", - actor={"id": "regular"}, - ) - assert response.status_code == 200 - data = response.json() - - # Should see content.users (authenticated users can view content database) - assert len(data["matches"]) == 1 - assert data["matches"][0]["name"] == "content: users" - - -@pytest.mark.asyncio -async def test_tables_search_response_structure(ds_with_tables): - """Test that response has correct structure.""" - response = await ds_with_tables.client.get( - "/-/jump.json?q=users", - actor={"id": "user"}, - ) - assert response.status_code == 200 - data = response.json() - - assert "matches" in data - assert isinstance(data["matches"], list) - - if data["matches"]: - match = data["matches"][0] - assert "url" in match - assert "name" in match - assert isinstance(match["url"], str) - assert isinstance(match["name"], str) - # Name should be in format "database: table" - assert ": " in match["name"] diff --git a/tests/test_spatialite.py b/tests/test_spatialite.py deleted file mode 100644 index c07a30e8..00000000 --- a/tests/test_spatialite.py +++ /dev/null @@ -1,23 +0,0 @@ -from datasette.app import Datasette -from datasette.utils import find_spatialite, SpatialiteNotFound, SPATIALITE_FUNCTIONS -from .utils import has_load_extension -import pytest - - -def has_spatialite(): - try: - find_spatialite() - return True - except SpatialiteNotFound: - return False - - -@pytest.mark.asyncio -@pytest.mark.skipif(not has_spatialite(), reason="Requires SpatiaLite") -@pytest.mark.skipif(not has_load_extension(), reason="Requires enable_load_extension") -async def test_spatialite_version_info(): - ds = Datasette(sqlite_extensions=["spatialite"]) - response = await ds.client.get("/-/versions.json") - assert response.status_code == 200 - spatialite = response.json()["sqlite"]["extensions"]["spatialite"] - assert set(SPATIALITE_FUNCTIONS) == set(spatialite) diff --git a/tests/test_stored_queries.py b/tests/test_stored_queries.py deleted file mode 100644 index 2c648d5f..00000000 --- a/tests/test_stored_queries.py +++ /dev/null @@ -1,473 +0,0 @@ -from bs4 import BeautifulSoup as Soup -from asgiref.sync import async_to_sync -import json -import pytest -import re -from .fixtures import make_app_client - - -def update_query(client, name, **kwargs): - async_to_sync(client.ds.invoke_startup)() - async_to_sync(client.ds.update_query)("data", name, **kwargs) - - -@pytest.fixture -def stored_write_client(tmpdir): - template_dir = tmpdir / "stored_write_templates" - template_dir.mkdir() - (template_dir / "query-data-update_name.html").write_text( - """ - {% extends "query.html" %} - {% block content %}!!!CUSTOM_UPDATE_NAME_TEMPLATE!!!{{ super() }}{% endblock %} - """, - "utf-8", - ) - with make_app_client( - extra_databases={"data.db": "create table names (name text)"}, - template_dir=str(template_dir), - config={ - "databases": { - "data": { - "queries": { - "stored_read": {"sql": "select * from names"}, - "add_name": { - "sql": "insert into names (name) values (:name)", - "write": True, - "on_success_redirect": "/data/add_name?success", - }, - "add_name_specify_id": { - "sql": "insert into names (rowid, name) values (:rowid, :name)", - "on_success_message_sql": "select 'Name added: ' || :name || ' with rowid ' || :rowid", - "write": True, - "on_error_redirect": "/data/add_name_specify_id?error", - }, - "add_name_specify_id_with_error_in_on_success_message_sql": { - "sql": "insert into names (rowid, name) values (:rowid, :name)", - "on_success_message_sql": "select this is bad SQL", - "write": True, - }, - "delete_name": { - "sql": "delete from names where rowid = :rowid", - "write": True, - "on_success_message": "Name deleted", - "allow": {"id": "root"}, - }, - "update_name": { - "sql": "update names set name = :name where rowid = :rowid", - "params": ["rowid", "name", "extra"], - "write": True, - }, - } - } - } - }, - ) as client: - yield client - - -@pytest.fixture -def stored_write_immutable_client(): - with make_app_client( - is_immutable=True, - config={ - "databases": { - "fixtures": { - "queries": { - "add": { - "sql": "insert into sortable (text) values (:text)", - "write": True, - }, - } - } - } - }, - ) as client: - yield client - - -@pytest.mark.asyncio -async def test_stored_query_with_named_parameter(ds_client): - response = await ds_client.get( - "/fixtures/neighborhood_search.json?text=town&_shape=arrays" - ) - assert response.json()["rows"] == [ - ["Corktown", "Detroit", "MI"], - ["Downtown", "Los Angeles", "CA"], - ["Downtown", "Detroit", "MI"], - ["Greektown", "Detroit", "MI"], - ["Koreatown", "Los Angeles", "CA"], - ["Mexicantown", "Detroit", "MI"], - ] - - -def test_insert(stored_write_client): - response = stored_write_client.post( - "/data/add_name", - {"name": "Hello"}, - csrftoken_from=True, - cookies={"foo": "bar"}, - ) - messages = stored_write_client.ds.unsign( - response.cookies["ds_messages"], "messages" - ) - assert messages == [["Query executed, 1 row affected", 1]] - assert response.status == 302 - assert response.headers["Location"] == "/data/add_name?success" - - -def test_insert_blocked_cross_site(stored_write_client): - # A cross-site POST (browser-originated) must be blocked - response = stored_write_client.post( - "/data/add_name", - {"name": "Hello"}, - headers={"sec-fetch-site": "cross-site"}, - ) - assert 403 == response.status - - -def test_insert_no_cookies_no_csrf(stored_write_client): - response = stored_write_client.post("/data/add_name", {"name": "Hello"}) - assert 302 == response.status - assert "/data/add_name?success" == response.headers["Location"] - - -def test_custom_success_message(stored_write_client): - response = stored_write_client.post( - "/data/delete_name", - {"rowid": 1}, - cookies={"ds_actor": stored_write_client.actor_cookie({"id": "root"})}, - csrftoken_from=True, - ) - assert 302 == response.status - messages = stored_write_client.ds.unsign( - response.cookies["ds_messages"], "messages" - ) - assert [["Name deleted", 1]] == messages - - -def test_insert_error(stored_write_client): - stored_write_client.post("/data/add_name", {"name": "Hello"}, csrftoken_from=True) - response = stored_write_client.post( - "/data/add_name_specify_id", - {"rowid": 1, "name": "Should fail"}, - csrftoken_from=True, - ) - assert 302 == response.status - assert "/data/add_name_specify_id?error" == response.headers["Location"] - messages = stored_write_client.ds.unsign( - response.cookies["ds_messages"], "messages" - ) - assert [["UNIQUE constraint failed: names.rowid", 3]] == messages - # How about with a custom error message? - update_query(stored_write_client, "add_name_specify_id", on_error_message="ERROR") - response = stored_write_client.post( - "/data/add_name_specify_id", - {"rowid": 1, "name": "Should fail"}, - csrftoken_from=True, - ) - assert [["ERROR", 3]] == stored_write_client.ds.unsign( - response.cookies["ds_messages"], "messages" - ) - - -def test_on_success_message_sql(stored_write_client): - response = stored_write_client.post( - "/data/add_name_specify_id", - {"rowid": 5, "name": "Should be OK"}, - csrftoken_from=True, - ) - assert response.status == 302 - assert response.headers["Location"] == "/data/add_name_specify_id" - messages = stored_write_client.ds.unsign( - response.cookies["ds_messages"], "messages" - ) - assert messages == [["Name added: Should be OK with rowid 5", 1]] - - -def test_error_in_on_success_message_sql(stored_write_client): - response = stored_write_client.post( - "/data/add_name_specify_id_with_error_in_on_success_message_sql", - {"rowid": 1, "name": "Should fail"}, - csrftoken_from=True, - ) - messages = stored_write_client.ds.unsign( - response.cookies["ds_messages"], "messages" - ) - assert messages == [ - ["Error running on_success_message_sql: no such column: bad", 3] - ] - - -def test_custom_params(stored_write_client): - response = stored_write_client.get("/data/update_name?extra=foo") - assert ( - '' - in response.text - ) - - -def test_stored_query_pages_no_vary_header(stored_write_client): - # These pages no longer embed per-cookie CSRF tokens, so they must not - # set Vary: Cookie - they should be cacheable across users. - assert "vary" not in stored_write_client.get("/data").headers - assert "vary" not in stored_write_client.get("/data/update_name").headers - - -def test_json_post_body(stored_write_client): - response = stored_write_client.post( - "/data/add_name", - body=json.dumps({"name": ["Hello", "there"]}), - ) - assert 302 == response.status - assert "/data/add_name?success" == response.headers["Location"] - rows = stored_write_client.get("/data/names.json?_shape=array").json - assert rows == [{"rowid": 1, "name": "['Hello', 'there']"}] - - -@pytest.mark.parametrize( - "headers,body,querystring", - ( - (None, "name=NameGoesHere", "?_json=1"), - ({"Accept": "application/json"}, "name=NameGoesHere", None), - (None, "name=NameGoesHere&_json=1", None), - (None, '{"name": "NameGoesHere", "_json": 1}', None), - ), -) -def test_json_response(stored_write_client, headers, body, querystring): - response = stored_write_client.post( - "/data/add_name" + (querystring or ""), - body=body, - headers=headers, - ) - assert 200 == response.status - assert response.headers["content-type"] == "application/json; charset=utf-8" - assert response.json == { - "ok": True, - "message": "Query executed, 1 row affected", - "redirect": "/data/add_name?success", - } - rows = stored_write_client.get("/data/names.json?_shape=array").json - assert rows == [{"rowid": 1, "name": "NameGoesHere"}] - - -def test_stored_query_permissions_on_database_page(stored_write_client): - # Without auth shows the five public queries - anon_response = stored_write_client.get("/data.json") - query_names = {q["name"] for q in anon_response.json["queries"]} - assert query_names == { - "add_name_specify_id_with_error_in_on_success_message_sql", - "update_name", - "add_name_specify_id", - "stored_read", - "add_name", - } - assert anon_response.json["queries_more"] is False - - # With auth the database page preview shows the first five queries - response = stored_write_client.get( - "/data.json", - cookies={"ds_actor": stored_write_client.actor_cookie({"id": "root"})}, - ) - assert response.status == 200 - query_names_and_private = sorted( - [ - {"name": q["name"], "private": q["private"]} - for q in response.json["queries"] - ], - key=lambda q: q["name"], - ) - assert query_names_and_private == [ - {"name": "add_name", "private": False}, - {"name": "add_name_specify_id", "private": False}, - { - "name": "add_name_specify_id_with_error_in_on_success_message_sql", - "private": False, - }, - {"name": "delete_name", "private": True}, - {"name": "stored_read", "private": False}, - ] - assert response.json["queries_more"] is True - - # The full query list endpoint includes the remaining query - response = stored_write_client.get( - "/data/-/queries.json?_size=10", - cookies={"ds_actor": stored_write_client.actor_cookie({"id": "root"})}, - ) - assert response.status == 200 - query_names_and_private = sorted( - [ - {"name": q["name"], "private": q["private"]} - for q in response.json["queries"] - ], - key=lambda q: q["name"], - ) - assert query_names_and_private == [ - {"name": "add_name", "private": False}, - {"name": "add_name_specify_id", "private": False}, - { - "name": "add_name_specify_id_with_error_in_on_success_message_sql", - "private": False, - }, - {"name": "delete_name", "private": True}, - {"name": "stored_read", "private": False}, - {"name": "update_name", "private": False}, - ] - - -def test_stored_query_permissions(stored_write_client): - assert 403 == stored_write_client.get("/data/delete_name").status - assert 200 == stored_write_client.get("/data/update_name").status - cookies = {"ds_actor": stored_write_client.actor_cookie({"id": "root"})} - assert 200 == stored_write_client.get("/data/delete_name", cookies=cookies).status - assert 200 == stored_write_client.get("/data/update_name", cookies=cookies).status - - -@pytest.fixture(scope="session") -def magic_parameters_client(): - with make_app_client( - extra_databases={"data.db": "create table logs (line text)"}, - config={ - "databases": { - "data": { - "queries": { - "runme_post": {"sql": "", "write": True}, - "runme_get": {"sql": ""}, - } - } - } - }, - ) as client: - yield client - - -@pytest.mark.parametrize( - "magic_parameter,expected_re", - [ - ("_actor_id", "root"), - ("_header_host", "localhost"), - ("_header_not_a_thing", ""), - ("_cookie_foo", "bar"), - ("_now_epoch", r"^\d+$"), - ("_now_date_utc", r"^\d{4}-\d{2}-\d{2}$"), - ("_now_datetime_utc", r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$"), - ("_random_chars_1", r"^\w$"), - ("_random_chars_10", r"^\w{10}$"), - ], -) -def test_magic_parameters(magic_parameters_client, magic_parameter, expected_re): - update_query( - magic_parameters_client, - "runme_post", - sql=f"insert into logs (line) values (:{magic_parameter})", - ) - update_query( - magic_parameters_client, - "runme_get", - sql=f"select :{magic_parameter} as result", - ) - cookies = { - "ds_actor": magic_parameters_client.actor_cookie({"id": "root"}), - "foo": "bar", - } - # Test the GET version - get_response = magic_parameters_client.get( - "/data/runme_get.json?_shape=array", cookies=cookies - ) - get_actual = get_response.json[0]["result"] - assert re.match(expected_re, str(get_actual)) - # Test the form - form_response = magic_parameters_client.get("/data/runme_post") - soup = Soup(form_response.body, "html.parser") - # The magic parameter should not be represented as a form field - assert None is soup.find("input", {"name": magic_parameter}) - # Submit the form to create a log line - response = magic_parameters_client.post( - "/data/runme_post?_json=1", {}, csrftoken_from=True, cookies=cookies - ) - assert response.json == { - "ok": True, - "message": "Query executed, 1 row affected", - "redirect": None, - } - post_actual = magic_parameters_client.get( - "/data/logs.json?_sort_desc=rowid&_shape=array" - ).json[0]["line"] - assert re.match(expected_re, post_actual) - - -@pytest.mark.parametrize("use_csrf", [True, False]) -@pytest.mark.parametrize("return_json", [True, False]) -def test_magic_parameters_csrf_json(magic_parameters_client, use_csrf, return_json): - update_query( - magic_parameters_client, - "runme_post", - sql="insert into logs (line) values (:_header_host)", - ) - qs = "" - if return_json: - qs = "?_json=1" - response = magic_parameters_client.post( - f"/data/runme_post{qs}", - {}, - csrftoken_from=use_csrf or None, - ) - if return_json: - assert response.status == 200 - assert response.json["ok"], response.json - else: - assert response.status == 302 - messages = magic_parameters_client.ds.unsign( - response.cookies["ds_messages"], "messages" - ) - assert [["Query executed, 1 row affected", 1]] == messages - post_actual = magic_parameters_client.get( - "/data/logs.json?_sort_desc=rowid&_shape=array" - ).json[0]["line"] - assert post_actual == "localhost" - - -def test_magic_parameters_cannot_be_used_in_arbitrary_queries(magic_parameters_client): - response = magic_parameters_client.get( - "/data/-/query.json?sql=select+:_header_host&_shape=array" - ) - assert 400 == response.status - assert response.json["error"].startswith("You did not supply a value for binding") - - -def test_stored_write_custom_template(stored_write_client): - response = stored_write_client.get("/data/update_name") - assert response.status == 200 - assert "!!!CUSTOM_UPDATE_NAME_TEMPLATE!!!" in response.text - assert ( - "" - in response.text - ) - # And test for link rel=alternate while we're here: - assert ( - '' - in response.text - ) - assert ( - response.headers["link"] - == '; rel="alternate"; type="application/json+datasette"' - ) - - -def test_stored_write_query_disabled_for_immutable_database( - stored_write_immutable_client, -): - response = stored_write_immutable_client.get("/fixtures/add") - assert response.status == 200 - assert ( - "This query cannot be executed because the database is immutable." - in response.text - ) - assert '' in response.text - # Submitting form should get a forbidden error - response = stored_write_immutable_client.post( - "/fixtures/add", - {"text": "text"}, - csrftoken_from=True, - ) - assert response.status == 403 - assert "Database is immutable" in response.text diff --git a/tests/test_table_api.py b/tests/test_table_api.py deleted file mode 100644 index ceeb646d..00000000 --- a/tests/test_table_api.py +++ /dev/null @@ -1,1441 +0,0 @@ -from datasette.utils import detect_json1 -from datasette.utils.sqlite import sqlite_version -from datasette.fixtures import generate_compound_rows, generate_sortable_rows -from .fixtures import make_app_client -import json -import pytest -import urllib - - -@pytest.mark.asyncio -async def test_table_json(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.json?_extra=query") - assert response.status_code == 200 - data = response.json() - assert ( - data["query"]["sql"] - == "select id, content from simple_primary_key order by id limit 51" - ) - assert data["query"]["params"] == {} - assert data["rows"] == [ - {"id": 1, "content": "hello"}, - {"id": 2, "content": "world"}, - {"id": 3, "content": ""}, - {"id": 4, "content": "RENDER_CELL_DEMO"}, - {"id": 5, "content": "RENDER_CELL_ASYNC"}, - ] - - -@pytest.mark.asyncio -async def test_table_not_exists_json(ds_client): - assert (await ds_client.get("/fixtures/blah.json")).json() == { - "ok": False, - "error": "Table not found", - "status": 404, - "title": None, - } - - -@pytest.mark.asyncio -async def test_table_shape_arrays(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.json?_shape=arrays") - assert response.json()["rows"] == [ - [1, "hello"], - [2, "world"], - [3, ""], - [4, "RENDER_CELL_DEMO"], - [5, "RENDER_CELL_ASYNC"], - ] - - -@pytest.mark.asyncio -async def test_table_shape_arrayfirst(ds_client): - response = await ds_client.get( - "/fixtures/-/query.json?" - + urllib.parse.urlencode( - { - "sql": "select content from simple_primary_key order by id", - "_shape": "arrayfirst", - } - ) - ) - assert response.json() == [ - "hello", - "world", - "", - "RENDER_CELL_DEMO", - "RENDER_CELL_ASYNC", - ] - - -@pytest.mark.asyncio -async def test_table_shape_objects(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.json?_shape=objects") - assert response.json()["rows"] == [ - {"id": 1, "content": "hello"}, - {"id": 2, "content": "world"}, - {"id": 3, "content": ""}, - {"id": 4, "content": "RENDER_CELL_DEMO"}, - {"id": 5, "content": "RENDER_CELL_ASYNC"}, - ] - - -@pytest.mark.asyncio -async def test_table_shape_array(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.json?_shape=array") - assert response.json() == [ - {"id": 1, "content": "hello"}, - {"id": 2, "content": "world"}, - {"id": 3, "content": ""}, - {"id": 4, "content": "RENDER_CELL_DEMO"}, - {"id": 5, "content": "RENDER_CELL_ASYNC"}, - ] - - -@pytest.mark.asyncio -async def test_table_shape_array_nl(ds_client): - response = await ds_client.get( - "/fixtures/simple_primary_key.json?_shape=array&_nl=on" - ) - lines = response.text.split("\n") - results = [json.loads(line) for line in lines] - assert [ - {"id": 1, "content": "hello"}, - {"id": 2, "content": "world"}, - {"id": 3, "content": ""}, - {"id": 4, "content": "RENDER_CELL_DEMO"}, - {"id": 5, "content": "RENDER_CELL_ASYNC"}, - ] == results - - -@pytest.mark.asyncio -async def test_table_shape_invalid(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.json?_shape=invalid") - assert response.json() == { - "ok": False, - "error": "Invalid _shape: invalid", - "status": 400, - "title": None, - } - - -@pytest.mark.asyncio -async def test_table_shape_object(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key.json?_shape=object") - assert response.json() == { - "1": {"id": 1, "content": "hello"}, - "2": {"id": 2, "content": "world"}, - "3": {"id": 3, "content": ""}, - "4": {"id": 4, "content": "RENDER_CELL_DEMO"}, - "5": {"id": 5, "content": "RENDER_CELL_ASYNC"}, - } - - -@pytest.mark.asyncio -async def test_table_shape_object_compound_primary_key(ds_client): - response = await ds_client.get("/fixtures/compound_primary_key.json?_shape=object") - assert response.json() == { - "a,b": {"pk1": "a", "pk2": "b", "content": "c"}, - "a~2Fb,~2Ec-d": {"pk1": "a/b", "pk2": ".c-d", "content": "c"}, - "d,e": {"pk1": "d", "pk2": "e", "content": "RENDER_CELL_DEMO"}, - } - - -@pytest.mark.asyncio -async def test_table_with_slashes_in_name(ds_client): - response = await ds_client.get( - "/fixtures/table~2Fwith~2Fslashes~2Ecsv.json?_shape=objects" - ) - assert response.status_code == 200 - data = response.json() - assert data["rows"] == [{"pk": "3", "content": "hey"}] - - -@pytest.mark.asyncio -async def test_table_with_reserved_word_name(ds_client): - response = await ds_client.get("/fixtures/select.json?_shape=objects") - assert response.status_code == 200 - data = response.json() - assert data["rows"] == [ - { - "rowid": 1, - "group": "group", - "having": "having", - "and": "and", - "json": '{"href": "http://example.com/", "label":"Example"}', - } - ] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_rows,expected_pages", - [ - ("/fixtures/no_primary_key.json", 202, 5), - ("/fixtures/paginated_view.json", 202, 9), - ("/fixtures/no_primary_key.json?_size=25", 202, 9), - ("/fixtures/paginated_view.json?_size=50", 202, 5), - ("/fixtures/paginated_view.json?_size=max", 202, 3), - ("/fixtures/123_starts_with_digits.json", 0, 1), - # Ensure faceting doesn't break pagination: - ("/fixtures/compound_three_primary_keys.json?_facet=pk1", 1001, 21), - # Paginating while sorted by an expanded foreign key should work - ( - "/fixtures/roadside_attraction_characteristics.json?_size=2&_sort=attraction_id&_labels=on", - 5, - 3, - ), - ], -) -async def test_paginate_tables_and_views( - ds_client, path, expected_rows, expected_pages -): - fetched = [] - count = 0 - while path: - if "?" in path: - path += "&_extra=next_url" - else: - path += "?_extra=next_url" - response = await ds_client.get(path) - assert response.status_code == 200 - count += 1 - fetched.extend(response.json()["rows"]) - path = response.json()["next_url"] - if path: - assert urllib.parse.urlencode({"_next": response.json()["next"]}) in path - path = path.replace("http://localhost", "") - assert count < 30, "Possible infinite loop detected" - - assert expected_rows == len(fetched) - assert expected_pages == count - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_error", - [ - ("/fixtures/no_primary_key.json?_size=-4", "_size must be a positive integer"), - ("/fixtures/no_primary_key.json?_size=dog", "_size must be a positive integer"), - ("/fixtures/no_primary_key.json?_size=1001", "_size must be <= 100"), - ], -) -async def test_validate_page_size(ds_client, path, expected_error): - response = await ds_client.get(path) - assert expected_error == response.json()["error"] - assert response.status_code == 400 - - -@pytest.mark.asyncio -async def test_page_size_zero(ds_client): - """For _size=0 we return the counts, empty rows and no continuation token""" - response = await ds_client.get( - "/fixtures/no_primary_key.json?_size=0&_extra=count,next_url" - ) - assert response.status_code == 200 - assert [] == response.json()["rows"] - assert 202 == response.json()["count"] - assert None is response.json()["next"] - assert None is response.json()["next_url"] - - -@pytest.mark.asyncio -async def test_paginate_compound_keys(ds_client): - fetched = [] - path = "/fixtures/compound_three_primary_keys.json?_shape=objects&_extra=next_url" - page = 0 - while path: - page += 1 - response = await ds_client.get(path) - fetched.extend(response.json()["rows"]) - path = response.json()["next_url"] - if path: - path = path.replace("http://localhost", "") - assert page < 100 - assert 1001 == len(fetched) - assert 21 == page - # Should be correctly ordered - contents = [f["content"] for f in fetched] - expected = [r[3] for r in generate_compound_rows(1001)] - assert expected == contents - - -@pytest.mark.asyncio -async def test_paginate_compound_keys_with_extra_filters(ds_client): - fetched = [] - path = "/fixtures/compound_three_primary_keys.json?content__contains=d&_shape=objects&_extra=next_url" - page = 0 - while path: - page += 1 - assert page < 100 - response = await ds_client.get(path) - fetched.extend(response.json()["rows"]) - path = response.json()["next_url"] - if path: - path = path.replace("http://localhost", "") - assert 2 == page - expected = [r[3] for r in generate_compound_rows(1001) if "d" in r[3]] - assert expected == [f["content"] for f in fetched] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "query_string,sort_key,human_description_en", - [ - ("_sort=sortable", lambda row: row["sortable"], "sorted by sortable"), - ( - "_sort_desc=sortable", - lambda row: -row["sortable"], - "sorted by sortable descending", - ), - ( - "_sort=sortable_with_nulls", - lambda row: ( - 1 if row["sortable_with_nulls"] is not None else 0, - row["sortable_with_nulls"], - ), - "sorted by sortable_with_nulls", - ), - ( - "_sort_desc=sortable_with_nulls", - lambda row: ( - 1 if row["sortable_with_nulls"] is None else 0, - ( - -row["sortable_with_nulls"] - if row["sortable_with_nulls"] is not None - else 0 - ), - row["content"], - ), - "sorted by sortable_with_nulls descending", - ), - # text column contains '$null' - ensure it doesn't confuse pagination: - ("_sort=text", lambda row: row["text"], "sorted by text"), - # Still works if sort column removed using _col= - ("_sort=text&_col=content", lambda row: row["text"], "sorted by text"), - ], -) -async def test_sortable(ds_client, query_string, sort_key, human_description_en): - path = f"/fixtures/sortable.json?_shape=objects&_extra=human_description_en,next_url&{query_string}" - fetched = [] - page = 0 - while path: - page += 1 - assert page < 100 - response = await ds_client.get(path) - assert human_description_en == response.json()["human_description_en"] - fetched.extend(response.json()["rows"]) - path = response.json()["next_url"] - if path: - path = path.replace("http://localhost", "") - assert page == 5 - expected = list(generate_sortable_rows(201)) - expected.sort(key=sort_key) - assert [r["content"] for r in expected] == [r["content"] for r in fetched] - - -@pytest.mark.asyncio -async def test_sortable_and_filtered(ds_client): - path = ( - "/fixtures/sortable.json" - "?content__contains=d&_sort_desc=sortable&_shape=objects" - "&_extra=human_description_en,count" - ) - response = await ds_client.get(path) - fetched = response.json()["rows"] - assert ( - 'where content contains "d" sorted by sortable descending' - == response.json()["human_description_en"] - ) - expected = [row for row in generate_sortable_rows(201) if "d" in row["content"]] - assert len(expected) == response.json()["count"] - expected.sort(key=lambda row: -row["sortable"]) - assert [r["content"] for r in expected] == [r["content"] for r in fetched] - - -@pytest.mark.asyncio -async def test_sortable_argument_errors(ds_client): - response = await ds_client.get("/fixtures/sortable.json?_sort=badcolumn") - assert "Cannot sort table by badcolumn" == response.json()["error"] - response = await ds_client.get("/fixtures/sortable.json?_sort_desc=badcolumn2") - assert "Cannot sort table by badcolumn2" == response.json()["error"] - response = await ds_client.get( - "/fixtures/sortable.json?_sort=sortable_with_nulls&_sort_desc=sortable" - ) - assert ( - "Cannot use _sort and _sort_desc at the same time" == response.json()["error"] - ) - - -@pytest.mark.asyncio -async def test_sortable_columns_metadata(ds_client): - response = await ds_client.get("/fixtures/sortable.json?_sort=content") - assert "Cannot sort table by content" == response.json()["error"] - # no_primary_key has ALL sort options disabled - for column in ("content", "a", "b", "c"): - response = await ds_client.get(f"/fixtures/sortable.json?_sort={column}") - assert f"Cannot sort table by {column}" == response.json()["error"] - - -@pytest.mark.asyncio -@pytest.mark.xfail -@pytest.mark.parametrize( - "path,expected_rows", - [ - ( - "/fixtures/searchable.json?_shape=arrays&_search=dog", - [ - [1, "barry cat", "terry dog", "panther"], - [2, "terry dog", "sara weasel", "puma"], - ], - ), - ( - # Special keyword shouldn't break FTS query - "/fixtures/searchable.json?_shape=arrays&_search=AND", - [], - ), - ( - # Without _searchmode=raw this should return no results - "/fixtures/searchable.json?_shape=arrays&_search=te*+AND+do*", - [], - ), - ( - # _searchmode=raw - "/fixtures/searchable.json?_shape=arrays&_search=te*+AND+do*&_searchmode=raw", - [ - [1, "barry cat", "terry dog", "panther"], - [2, "terry dog", "sara weasel", "puma"], - ], - ), - ( - # _searchmode=raw combined with _search_COLUMN - "/fixtures/searchable.json?_shape=arrays&_search_text2=te*&_searchmode=raw", - [ - [1, "barry cat", "terry dog", "panther"], - ], - ), - ( - "/fixtures/searchable.json?_shape=arrays&_search=weasel", - [[2, "terry dog", "sara weasel", "puma"]], - ), - ( - "/fixtures/searchable.json?_shape=arrays&_search_text2=dog", - [[1, "barry cat", "terry dog", "panther"]], - ), - ( - "/fixtures/searchable.json?_shape=arrays&_search_name%20with%20.%20and%20spaces=panther", - [[1, "barry cat", "terry dog", "panther"]], - ), - ], -) -async def test_searchable(ds_client, path, expected_rows): - response = await ds_client.get(path) - assert expected_rows == response.json()["rows"] - - -_SEARCHMODE_RAW_RESULTS = [ - [1, "barry cat", "terry dog", "panther"], - [2, "terry dog", "sara weasel", "puma"], -] - - -@pytest.mark.parametrize( - "table_metadata,querystring,expected_rows", - [ - ( - {}, - "_search=te*+AND+do*", - [], - ), - ( - {"searchmode": "raw"}, - "_search=te*+AND+do*", - _SEARCHMODE_RAW_RESULTS, - ), - ( - {}, - "_search=te*+AND+do*&_searchmode=raw", - _SEARCHMODE_RAW_RESULTS, - ), - # Can be over-ridden with _searchmode=escaped - ( - {"searchmode": "raw"}, - "_search=te*+AND+do*&_searchmode=escaped", - [], - ), - ], -) -def test_searchmode(table_metadata, querystring, expected_rows): - with make_app_client( - metadata={"databases": {"fixtures": {"tables": {"searchable": table_metadata}}}} - ) as client: - response = client.get("/fixtures/searchable.json?_shape=arrays&" + querystring) - assert expected_rows == response.json["rows"] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_rows", - [ - ( - "/fixtures/searchable_view_configured_by_metadata.json?_shape=arrays&_search=weasel", - [[2, "terry dog", "sara weasel", "puma"]], - ), - # This should return all results because search is not configured: - ( - "/fixtures/searchable_view.json?_shape=arrays&_search=weasel", - [ - [1, "barry cat", "terry dog", "panther"], - [2, "terry dog", "sara weasel", "puma"], - ], - ), - ( - "/fixtures/searchable_view.json?_shape=arrays&_search=weasel&_fts_table=searchable_fts&_fts_pk=pk", - [[2, "terry dog", "sara weasel", "puma"]], - ), - ], -) -async def test_searchable_views(ds_client, path, expected_rows): - response = await ds_client.get(path) - assert response.json()["rows"] == expected_rows - - -@pytest.mark.asyncio -async def test_searchable_invalid_column(ds_client): - response = await ds_client.get("/fixtures/searchable.json?_search_invalid=x") - assert response.status_code == 400 - assert response.json() == { - "ok": False, - "error": "Cannot search by that column", - "status": 400, - "title": None, - } - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_rows", - [ - ( - "/fixtures/simple_primary_key.json?_shape=arrays&content=hello", - [[1, "hello"]], - ), - ( - "/fixtures/simple_primary_key.json?_shape=arrays&content__contains=o", - [ - [1, "hello"], - [2, "world"], - [4, "RENDER_CELL_DEMO"], - ], - ), - ( - "/fixtures/simple_primary_key.json?_shape=arrays&content__exact=", - [[3, ""]], - ), - ( - "/fixtures/simple_primary_key.json?_shape=arrays&content__not=world", - [ - [1, "hello"], - [3, ""], - [4, "RENDER_CELL_DEMO"], - [5, "RENDER_CELL_ASYNC"], - ], - ), - ], -) -async def test_table_filter_queries(ds_client, path, expected_rows): - response = await ds_client.get(path) - assert response.json()["rows"] == expected_rows - - -@pytest.mark.asyncio -async def test_table_filter_queries_multiple_of_same_type(ds_client): - response = await ds_client.get( - "/fixtures/simple_primary_key.json?_shape=arrays&content__not=world&content__not=hello" - ) - assert [ - [3, ""], - [4, "RENDER_CELL_DEMO"], - [5, "RENDER_CELL_ASYNC"], - ] == response.json()["rows"] - - -@pytest.mark.skipif(not detect_json1(), reason="Requires the SQLite json1 module") -@pytest.mark.asyncio -async def test_table_filter_json_arraycontains(ds_client): - response = await ds_client.get( - "/fixtures/facetable.json?_shape=arrays&tags__arraycontains=tag1" - ) - assert response.json()["rows"] == [ - [ - 1, - "2019-01-14 08:00:00", - 1, - 1, - "CA", - 1, - "Mission", - '["tag1", "tag2"]', - '[{"foo": "bar"}]', - "one", - "n1", - ], - [ - 2, - "2019-01-14 08:00:00", - 1, - 1, - "CA", - 1, - "Dogpatch", - '["tag1", "tag3"]', - "[]", - "two", - "n2", - ], - ] - - -@pytest.mark.skipif(not detect_json1(), reason="Requires the SQLite json1 module") -@pytest.mark.asyncio -async def test_table_filter_json_arraynotcontains(ds_client): - response = await ds_client.get( - "/fixtures/facetable.json?_shape=arrays&tags__arraynotcontains=tag3&tags__not=[]" - ) - assert response.json()["rows"] == [ - [ - 1, - "2019-01-14 08:00:00", - 1, - 1, - "CA", - 1, - "Mission", - '["tag1", "tag2"]', - '[{"foo": "bar"}]', - "one", - "n1", - ] - ] - - -@pytest.mark.asyncio -async def test_table_filter_extra_where(ds_client): - response = await ds_client.get( - "/fixtures/facetable.json?_shape=arrays&_where=_neighborhood='Dogpatch'" - ) - assert [ - [ - 2, - "2019-01-14 08:00:00", - 1, - 1, - "CA", - 1, - "Dogpatch", - '["tag1", "tag3"]', - "[]", - "two", - "n2", - ] - ] == response.json()["rows"] - - -@pytest.mark.asyncio -async def test_table_filter_extra_where_invalid(ds_client): - response = await ds_client.get( - "/fixtures/facetable.json?_where=_neighborhood=Dogpatch'" - ) - assert response.status_code == 400 - assert "Invalid SQL" == response.json()["title"] - - -def test_table_filter_extra_where_disabled_if_no_sql_allowed(): - with make_app_client(config={"allow_sql": {}}) as client: - response = client.get( - "/fixtures/facetable.json?_where=_neighborhood='Dogpatch'" - ) - assert response.status_code == 403 - assert "_where= is not allowed" == response.json["error"] - - -@pytest.mark.asyncio -async def test_table_through(ds_client): - # Just the museums: - response = await ds_client.get( - "/fixtures/roadside_attractions.json?_shape=arrays" - '&_through={"table":"roadside_attraction_characteristics","column":"characteristic_id","value":"1"}' - "&_extra=human_description_en" - ) - assert response.json()["rows"] == [ - [ - 3, - "Burlingame Museum of PEZ Memorabilia", - "214 California Drive, Burlingame, CA 94010", - None, - 37.5793, - -122.3442, - ], - [ - 4, - "Bigfoot Discovery Museum", - "5497 Highway 9, Felton, CA 95018", - "https://www.bigfootdiscoveryproject.com/", - 37.0414, - -122.0725, - ], - ] - - assert ( - response.json()["human_description_en"] - == 'where roadside_attraction_characteristics.characteristic_id = "1"' - ) - - -@pytest.mark.asyncio -async def test_max_returned_rows(ds_client): - response = await ds_client.get( - "/fixtures/-/query.json?sql=select+content+from+no_primary_key" - ) - data = response.json() - assert data["truncated"] - assert 100 == len(data["rows"]) - - -@pytest.mark.asyncio -async def test_view(ds_client): - response = await ds_client.get("/fixtures/simple_view.json?_shape=objects") - assert response.status_code == 200 - data = response.json() - assert data["rows"] == [ - {"upper_content": "HELLO", "content": "hello"}, - {"upper_content": "WORLD", "content": "world"}, - {"upper_content": "", "content": ""}, - {"upper_content": "RENDER_CELL_DEMO", "content": "RENDER_CELL_DEMO"}, - {"upper_content": "RENDER_CELL_ASYNC", "content": "RENDER_CELL_ASYNC"}, - ] - - -def test_page_size_matching_max_returned_rows( - app_client_returned_rows_matches_page_size, -): - fetched = [] - path = "/fixtures/no_primary_key.json?_extra=next_url" - while path: - response = app_client_returned_rows_matches_page_size.get(path) - fetched.extend(response.json["rows"]) - assert len(response.json["rows"]) in (2, 50) - path = response.json["next_url"] - if path: - path = path.replace("http://localhost", "") - assert len(fetched) == 202 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_facet_results", - [ - ( - "/fixtures/facetable.json?_facet=state&_facet=_city_id", - { - "state": { - "name": "state", - "hideable": True, - "type": "column", - "toggle_url": "/fixtures/facetable.json?_facet=_city_id", - "results": [ - { - "value": "CA", - "label": "CA", - "count": 10, - "toggle_url": "_facet=state&_facet=_city_id&state=CA", - "selected": False, - }, - { - "value": "MI", - "label": "MI", - "count": 4, - "toggle_url": "_facet=state&_facet=_city_id&state=MI", - "selected": False, - }, - { - "value": "MC", - "label": "MC", - "count": 1, - "toggle_url": "_facet=state&_facet=_city_id&state=MC", - "selected": False, - }, - ], - "truncated": False, - }, - "_city_id": { - "name": "_city_id", - "hideable": True, - "type": "column", - "toggle_url": "/fixtures/facetable.json?_facet=state", - "results": [ - { - "value": 1, - "label": "San Francisco", - "count": 6, - "toggle_url": "_facet=state&_facet=_city_id&_city_id__exact=1", - "selected": False, - }, - { - "value": 2, - "label": "Los Angeles", - "count": 4, - "toggle_url": "_facet=state&_facet=_city_id&_city_id__exact=2", - "selected": False, - }, - { - "value": 3, - "label": "Detroit", - "count": 4, - "toggle_url": "_facet=state&_facet=_city_id&_city_id__exact=3", - "selected": False, - }, - { - "value": 4, - "label": "Memnonia", - "count": 1, - "toggle_url": "_facet=state&_facet=_city_id&_city_id__exact=4", - "selected": False, - }, - ], - "truncated": False, - }, - }, - ), - ( - "/fixtures/facetable.json?_facet=state&_facet=_city_id&state=MI", - { - "state": { - "name": "state", - "hideable": True, - "type": "column", - "toggle_url": "/fixtures/facetable.json?_facet=_city_id&state=MI", - "results": [ - { - "value": "MI", - "label": "MI", - "count": 4, - "selected": True, - "toggle_url": "_facet=state&_facet=_city_id", - } - ], - "truncated": False, - }, - "_city_id": { - "name": "_city_id", - "hideable": True, - "type": "column", - "toggle_url": "/fixtures/facetable.json?_facet=state&state=MI", - "results": [ - { - "value": 3, - "label": "Detroit", - "count": 4, - "selected": False, - "toggle_url": "_facet=state&_facet=_city_id&state=MI&_city_id__exact=3", - } - ], - "truncated": False, - }, - }, - ), - ( - "/fixtures/facetable.json?_facet=planet_int", - { - "planet_int": { - "name": "planet_int", - "hideable": True, - "type": "column", - "toggle_url": "/fixtures/facetable.json", - "results": [ - { - "value": 1, - "label": 1, - "count": 14, - "selected": False, - "toggle_url": "_facet=planet_int&planet_int=1", - }, - { - "value": 2, - "label": 2, - "count": 1, - "selected": False, - "toggle_url": "_facet=planet_int&planet_int=2", - }, - ], - "truncated": False, - } - }, - ), - ( - # planet_int is an integer field: - "/fixtures/facetable.json?_facet=planet_int&planet_int=1", - { - "planet_int": { - "name": "planet_int", - "hideable": True, - "type": "column", - "toggle_url": "/fixtures/facetable.json?planet_int=1", - "results": [ - { - "value": 1, - "label": 1, - "count": 14, - "selected": True, - "toggle_url": "_facet=planet_int", - } - ], - "truncated": False, - } - }, - ), - ], -) -async def test_facets(ds_client, path, expected_facet_results): - response = await ds_client.get(path) - facet_results = response.json()["facet_results"] - # We only compare the querystring portion of the taggle_url - for facet_name, facet_info in facet_results["results"].items(): - assert facet_name == facet_info["name"] - assert False is facet_info["truncated"] - for facet_value in facet_info["results"]: - facet_value["toggle_url"] = facet_value["toggle_url"].split("?")[1] - assert expected_facet_results == facet_results["results"] - - -@pytest.mark.asyncio -@pytest.mark.skipif(not detect_json1(), reason="requires JSON1 extension") -async def test_facets_array(ds_client): - response = await ds_client.get("/fixtures/facetable.json?_facet_array=tags") - facet_results = response.json()["facet_results"] - assert facet_results["results"]["tags"]["results"] == [ - { - "value": "tag1", - "label": "tag1", - "count": 2, - "toggle_url": "http://localhost/fixtures/facetable.json?_facet_array=tags&tags__arraycontains=tag1", - "selected": False, - }, - { - "value": "tag2", - "label": "tag2", - "count": 1, - "toggle_url": "http://localhost/fixtures/facetable.json?_facet_array=tags&tags__arraycontains=tag2", - "selected": False, - }, - { - "value": "tag3", - "label": "tag3", - "count": 1, - "toggle_url": "http://localhost/fixtures/facetable.json?_facet_array=tags&tags__arraycontains=tag3", - "selected": False, - }, - ] - - -@pytest.mark.asyncio -async def test_suggested_facets(ds_client): - suggestions = [ - { - "name": suggestion["name"], - "querystring": suggestion["toggle_url"].split("?")[-1], - } - for suggestion in ( - await ds_client.get("/fixtures/facetable.json?_extra=suggested_facets") - ).json()["suggested_facets"] - ] - expected = [ - {"name": "created", "querystring": "_extra=suggested_facets&_facet=created"}, - { - "name": "planet_int", - "querystring": "_extra=suggested_facets&_facet=planet_int", - }, - {"name": "on_earth", "querystring": "_extra=suggested_facets&_facet=on_earth"}, - {"name": "state", "querystring": "_extra=suggested_facets&_facet=state"}, - {"name": "_city_id", "querystring": "_extra=suggested_facets&_facet=_city_id"}, - { - "name": "_neighborhood", - "querystring": "_extra=suggested_facets&_facet=_neighborhood", - }, - {"name": "tags", "querystring": "_extra=suggested_facets&_facet=tags"}, - { - "name": "complex_array", - "querystring": "_extra=suggested_facets&_facet=complex_array", - }, - { - "name": "created", - "querystring": "_extra=suggested_facets&_facet_date=created", - }, - ] - if detect_json1(): - expected.append( - {"name": "tags", "querystring": "_extra=suggested_facets&_facet_array=tags"} - ) - assert expected == suggestions - - -def test_allow_facet_off(): - with make_app_client(settings={"allow_facet": False}) as client: - assert ( - client.get( - "/fixtures/facetable.json?_facet=planet_int&_extra=suggested_facets" - ).status - == 400 - ) - data = client.get("/fixtures/facetable.json?_extra=suggested_facets").json - # Should not suggest any facets either: - assert [] == data["suggested_facets"] - - -def test_suggest_facets_off(): - with make_app_client(settings={"suggest_facets": False}) as client: - # Now suggested_facets should be [] - assert ( - [] - == client.get("/fixtures/facetable.json?_extra=suggested_facets").json[ - "suggested_facets" - ] - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("nofacet", (True, False)) -async def test_nofacet(ds_client, nofacet): - path = "/fixtures/facetable.json?_facet=state&_extra=suggested_facets" - if nofacet: - path += "&_nofacet=1" - response = await ds_client.get(path) - if nofacet: - assert response.json()["suggested_facets"] == [] - assert response.json()["facet_results"]["results"] == {} - else: - assert response.json()["suggested_facets"] != [] - assert response.json()["facet_results"]["results"] != {} - - -@pytest.mark.asyncio -@pytest.mark.parametrize("nosuggest", (True, False)) -async def test_nosuggest(ds_client, nosuggest): - path = "/fixtures/facetable.json?_facet=state&_extra=suggested_facets" - if nosuggest: - path += "&_nosuggest=1" - response = await ds_client.get(path) - if nosuggest: - assert response.json()["suggested_facets"] == [] - # But facets should still be returned: - assert response.json()["facet_results"] != {} - else: - assert response.json()["suggested_facets"] != [] - assert response.json()["facet_results"] != {} - - -@pytest.mark.asyncio -@pytest.mark.parametrize("nocount,expected_count", ((True, None), (False, 15))) -async def test_nocount(ds_client, nocount, expected_count): - path = "/fixtures/facetable.json?_extra=count" - if nocount: - path += "&_nocount=1" - response = await ds_client.get(path) - assert response.json()["count"] == expected_count - - -def test_nocount_nofacet_if_shape_is_object(app_client_with_trace): - response = app_client_with_trace.get( - "/fixtures/facetable.json?_trace=1&_shape=object" - ) - assert "count(*)" not in response.text - - -@pytest.mark.asyncio -async def test_expand_labels(ds_client): - response = await ds_client.get( - "/fixtures/facetable.json?_shape=object&_labels=1&_size=2" - "&_neighborhood__contains=c" - ) - assert response.json() == { - "2": { - "pk": 2, - "created": "2019-01-14 08:00:00", - "planet_int": 1, - "on_earth": 1, - "state": "CA", - "_city_id": {"value": 1, "label": "San Francisco"}, - "_neighborhood": "Dogpatch", - "tags": '["tag1", "tag3"]', - "complex_array": "[]", - "distinct_some_null": "two", - "n": "n2", - }, - "13": { - "pk": 13, - "created": "2019-01-17 08:00:00", - "planet_int": 1, - "on_earth": 1, - "state": "MI", - "_city_id": {"value": 3, "label": "Detroit"}, - "_neighborhood": "Corktown", - "tags": "[]", - "complex_array": "[]", - "distinct_some_null": None, - "n": None, - }, - } - - -@pytest.mark.asyncio -async def test_expand_label(ds_client): - response = await ds_client.get( - "/fixtures/foreign_key_references.json?_shape=object" - "&_label=foreign_key_with_label&_size=1" - ) - assert response.json() == { - "1": { - "pk": "1", - "foreign_key_with_label": {"value": 1, "label": "hello"}, - "foreign_key_with_blank_label": 3, - "foreign_key_with_no_label": "1", - "foreign_key_compound_pk1": "a", - "foreign_key_compound_pk2": "b", - } - } - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_cache_control", - [ - ("/fixtures/facetable.json", "max-age=5"), - ("/fixtures/facetable.json?_ttl=invalid", "max-age=5"), - ("/fixtures/facetable.json?_ttl=10", "max-age=10"), - ("/fixtures/facetable.json?_ttl=0", "no-cache"), - ], -) -async def test_ttl_parameter(ds_client, path, expected_cache_control): - response = await ds_client.get(path) - assert response.headers["Cache-Control"] == expected_cache_control - - -@pytest.mark.asyncio -async def test_infinity_returned_as_null(ds_client): - response = await ds_client.get("/fixtures/infinity.json?_shape=array") - assert response.json() == [ - {"rowid": 1, "value": None}, - {"rowid": 2, "value": None}, - {"rowid": 3, "value": 1.5}, - ] - - -@pytest.mark.asyncio -async def test_infinity_returned_as_invalid_json_if_requested(ds_client): - response = await ds_client.get( - "/fixtures/infinity.json?_shape=array&_json_infinity=1" - ) - assert response.json() == [ - {"rowid": 1, "value": float("inf")}, - {"rowid": 2, "value": float("-inf")}, - {"rowid": 3, "value": 1.5}, - ] - - -@pytest.mark.asyncio -async def test_custom_query_with_unicode_characters(ds_client): - # /fixtures/𝐜𝐢𝐭𝐢𝐞𝐬.json - response = await ds_client.get( - "/fixtures/~F0~9D~90~9C~F0~9D~90~A2~F0~9D~90~AD~F0~9D~90~A2~F0~9D~90~9E~F0~9D~90~AC.json?_shape=array" - ) - assert response.json() == [{"id": 1, "name": "San Francisco"}] - - -@pytest.mark.asyncio -async def test_null_and_compound_foreign_keys_are_not_expanded(ds_client): - response = await ds_client.get( - "/fixtures/foreign_key_references.json?_shape=array&_labels=on" - ) - assert response.json() == [ - { - "pk": "1", - "foreign_key_with_label": {"value": 1, "label": "hello"}, - "foreign_key_with_blank_label": {"value": 3, "label": ""}, - "foreign_key_with_no_label": {"value": "1", "label": "1"}, - "foreign_key_compound_pk1": "a", - "foreign_key_compound_pk2": "b", - }, - { - "pk": "2", - "foreign_key_with_label": None, - "foreign_key_with_blank_label": None, - "foreign_key_with_no_label": None, - "foreign_key_compound_pk1": None, - "foreign_key_compound_pk2": None, - }, - ] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_json,expected_text", - [ - ( - "/fixtures/binary_data.json?_shape=array", - [ - {"rowid": 1, "data": {"$base64": True, "encoded": "FRwCx60F/g=="}}, - {"rowid": 2, "data": {"$base64": True, "encoded": "FRwDx60F/g=="}}, - {"rowid": 3, "data": None}, - ], - None, - ), - ( - "/fixtures/binary_data.json?_shape=array&_nl=on", - None, - ( - '{"rowid": 1, "data": {"$base64": true, "encoded": "FRwCx60F/g=="}}\n' - '{"rowid": 2, "data": {"$base64": true, "encoded": "FRwDx60F/g=="}}\n' - '{"rowid": 3, "data": null}' - ), - ), - ], -) -async def test_binary_data_in_json(ds_client, path, expected_json, expected_text): - response = await ds_client.get(path) - if expected_json: - assert response.json() == expected_json - else: - assert response.text == expected_text - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "qs", - [ - "", - "?_shape=arrays", - "?_shape=arrayfirst", - "?_shape=object", - "?_shape=objects", - "?_shape=array", - "?_shape=array&_nl=on", - ], -) -async def test_paginate_using_link_header(ds_client, qs): - path = f"/fixtures/compound_three_primary_keys.json{qs}" - num_pages = 0 - while path: - response = await ds_client.get(path) - assert response.status_code == 200 - num_pages += 1 - link = response.headers.get("link") - if link: - assert link.startswith("<") - assert link.endswith('>; rel="next"') - path = link[1:].split(">")[0] - path = path.replace("http://localhost", "") - else: - path = None - assert num_pages == 21 - - -@pytest.mark.skipif( - sqlite_version() < (3, 31, 0), - reason="generated columns were added in SQLite 3.31.0", -) -def test_generated_columns_are_visible_in_datasette(): - with make_app_client(extra_databases={"generated.db": """ - CREATE TABLE generated_columns ( - body TEXT, - id INT GENERATED ALWAYS AS (json_extract(body, '$.number')) STORED, - consideration INT GENERATED ALWAYS AS (json_extract(body, '$.string')) STORED - ); - INSERT INTO generated_columns (body) VALUES ( - '{"number": 1, "string": "This is a string"}' - );"""}) as client: - response = client.get("/generated/generated_columns.json?_shape=array") - assert response.json == [ - { - "rowid": 1, - "body": '{"number": 1, "string": "This is a string"}', - "id": 1, - "consideration": "This is a string", - } - ] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_columns", - ( - ("/fixtures/facetable.json?_col=created", ["pk", "created"]), - ( - "/fixtures/facetable.json?_nocol=created", - [ - "pk", - "planet_int", - "on_earth", - "state", - "_city_id", - "_neighborhood", - "tags", - "complex_array", - "distinct_some_null", - "n", - ], - ), - ( - "/fixtures/facetable.json?_col=state&_col=created", - ["pk", "state", "created"], - ), - ( - "/fixtures/facetable.json?_col=state&_col=state", - ["pk", "state"], - ), - ( - "/fixtures/facetable.json?_col=state&_col=created&_nocol=created", - ["pk", "state"], - ), - ( - # Ensure faceting doesn't break, https://github.com/simonw/datasette/issues/1345 - "/fixtures/facetable.json?_nocol=state&_facet=state", - [ - "pk", - "created", - "planet_int", - "on_earth", - "_city_id", - "_neighborhood", - "tags", - "complex_array", - "distinct_some_null", - "n", - ], - ), - ( - "/fixtures/simple_view.json?_nocol=content", - ["upper_content"], - ), - ("/fixtures/simple_view.json?_col=content", ["content"]), - ), -) -async def test_col_nocol(ds_client, path, expected_columns): - response = await ds_client.get(path + "&_extra=columns") - assert response.status_code == 200 - columns = response.json()["columns"] - assert columns == expected_columns - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_error", - ( - ("/fixtures/facetable.json?_col=bad", "_col=bad - invalid columns"), - ("/fixtures/facetable.json?_nocol=bad", "_nocol=bad - invalid columns"), - ("/fixtures/facetable.json?_nocol=pk", "_nocol=pk - invalid columns"), - ("/fixtures/simple_view.json?_col=bad", "_col=bad - invalid columns"), - ), -) -async def test_col_nocol_errors(ds_client, path, expected_error): - response = await ds_client.get(path) - assert response.status_code == 400 - assert response.json()["error"] == expected_error - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "extra,expected_json", - ( - ( - "columns", - { - "ok": True, - "next": None, - "columns": ["id", "content", "content2"], - "rows": [{"id": "1", "content": "hey", "content2": "world"}], - "truncated": False, - }, - ), - ( - "count", - { - "ok": True, - "next": None, - "rows": [{"id": "1", "content": "hey", "content2": "world"}], - "truncated": False, - "count": 1, - }, - ), - ), -) -async def test_table_extras(ds_client, extra, expected_json): - response = await ds_client.get( - "/fixtures/primary_key_multiple_columns.json?_extra=" + extra - ) - assert response.status_code == 200 - assert response.json() == expected_json - - -@pytest.mark.asyncio -async def test_extra_render_cell(): - """Test that _extra=render_cell returns rendered HTML from render_cell plugin hook""" - from datasette import hookimpl - from datasette.app import Datasette - - class TestRenderCellPlugin: - __name__ = "TestRenderCellPlugin" - - @hookimpl - def render_cell(self, value, column, table, database): - # Only modify cells in our test table - if table == "test_render" and column == "name": - return f"{value}" - return None - - ds = Datasette(memory=True) - await ds.invoke_startup() - db = ds.add_memory_database("test_table_render") - await db.execute_write( - "create table test_render (id integer primary key, name text)" - ) - await db.execute_write("insert into test_render values (1, 'Alice')") - await db.execute_write("insert into test_render values (2, 'Bob')") - - # Register our test plugin - ds.pm.register(TestRenderCellPlugin(), name="TestRenderCellPlugin") - - try: - # Request with _extra=render_cell - response = await ds.client.get( - "/test_table_render/test_render.json?_extra=render_cell" - ) - assert response.status_code == 200 - data = response.json() - - # Verify the response structure - assert "render_cell" in data - assert "rows" in data - - # render_cell should be a list of rows, each row being a dict of column -> rendered HTML - # Only columns modified by plugins are included (sparse output) - render_cell = data["render_cell"] - assert len(render_cell) == 2 - - # First row: id=1, name='Alice' - # The 'name' column should be rendered by our plugin as Alice - assert render_cell[0]["name"] == "Alice" - # The 'id' column is not included since no plugin modified it - assert "id" not in render_cell[0] - - # Second row: id=2, name='Bob' - assert render_cell[1]["name"] == "Bob" - assert "id" not in render_cell[1] - - # The regular rows should still contain raw values - assert data["rows"] == [ - {"id": 1, "name": "Alice"}, - {"id": 2, "name": "Bob"}, - ] - - finally: - ds.pm.unregister(name="TestRenderCellPlugin") diff --git a/tests/test_table_html.py b/tests/test_table_html.py deleted file mode 100644 index 86b9a4eb..00000000 --- a/tests/test_table_html.py +++ /dev/null @@ -1,1403 +0,0 @@ -from datasette.app import Datasette -from bs4 import BeautifulSoup as Soup -from .fixtures import make_app_client -import pathlib -import pytest -import urllib.parse -from .utils import inner_html - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_definition_sql", - [ - ( - "/fixtures/facet_cities", - """ -CREATE TABLE facet_cities ( - id integer primary key, - name text -); - """.strip(), - ), - ( - "/fixtures/compound_three_primary_keys", - """ -CREATE TABLE compound_three_primary_keys ( - pk1 varchar(30), - pk2 varchar(30), - pk3 varchar(30), - content text, - PRIMARY KEY (pk1, pk2, pk3) -); -CREATE INDEX idx_compound_three_primary_keys_content ON compound_three_primary_keys(content); - """.strip(), - ), - ], -) -async def test_table_definition_sql(path, expected_definition_sql, ds_client): - response = await ds_client.get(path) - pre = Soup(response.text, "html.parser").select_one("pre.wrapped-sql") - assert expected_definition_sql == pre.string - - -def test_table_cell_truncation(): - with make_app_client(settings={"truncate_cells_html": 5}) as client: - response = client.get("/fixtures/facetable") - assert response.status == 200 - table = Soup(response.body, "html.parser").find("table") - assert table["class"] == ["rows-and-columns"] - assert [ - "Missi…", - "Dogpa…", - "SOMA", - "Tende…", - "Berna…", - "Hayes…", - "Holly…", - "Downt…", - "Los F…", - "Korea…", - "Downt…", - "Greek…", - "Corkt…", - "Mexic…", - "Arcad…", - ] == [ - td.string - for td in table.find_all("td", {"class": "col-neighborhood-b352a7"}) - ] - # URLs should be truncated too - response2 = client.get("/fixtures/roadside_attractions") - assert response2.status == 200 - table = Soup(response2.body, "html.parser").find("table") - tds = table.find_all("td", {"class": "col-url"}) - assert [str(td) for td in tds] == [ - '', - '', - '', - '', - ] - - -@pytest.mark.asyncio -async def test_add_filter_redirects(ds_client): - filter_args = urllib.parse.urlencode( - {"_filter_column": "content", "_filter_op": "startswith", "_filter_value": "x"} - ) - path_base = "/fixtures/simple_primary_key" - path = path_base + "?" + filter_args - response = await ds_client.get(path) - assert response.status_code == 302 - assert response.headers["Location"].endswith("?content__startswith=x") - - # Adding a redirect to an existing query string: - path = path_base + "?foo=bar&" + filter_args - response = await ds_client.get(path) - assert response.status_code == 302 - assert response.headers["Location"].endswith("?foo=bar&content__startswith=x") - - # Test that op with a __x suffix overrides the filter value - path = ( - path_base - + "?" - + urllib.parse.urlencode( - { - "_filter_column": "content", - "_filter_op": "isnull__5", - "_filter_value": "x", - } - ) - ) - response = await ds_client.get(path) - assert response.status_code == 302 - assert response.headers["Location"].endswith("?content__isnull=5") - - -@pytest.mark.asyncio -async def test_existing_filter_redirects(ds_client): - filter_args = { - "_filter_column_1": "name", - "_filter_op_1": "contains", - "_filter_value_1": "hello", - "_filter_column_2": "age", - "_filter_op_2": "gte", - "_filter_value_2": "22", - "_filter_column_3": "age", - "_filter_op_3": "lt", - "_filter_value_3": "30", - "_filter_column_4": "name", - "_filter_op_4": "contains", - "_filter_value_4": "world", - } - path_base = "/fixtures/simple_primary_key" - path = path_base + "?" + urllib.parse.urlencode(filter_args) - response = await ds_client.get(path) - assert response.status_code == 302 - assert_querystring_equal( - "name__contains=hello&age__gte=22&age__lt=30&name__contains=world", - response.headers["Location"].split("?")[1], - ) - - # Setting _filter_column_3 to empty string should remove *_3 entirely - filter_args["_filter_column_3"] = "" - path = path_base + "?" + urllib.parse.urlencode(filter_args) - response = await ds_client.get(path) - assert response.status_code == 302 - assert_querystring_equal( - "name__contains=hello&age__gte=22&name__contains=world", - response.headers["Location"].split("?")[1], - ) - - # ?_filter_op=exact should be removed if unaccompanied by _fiter_column - response = await ds_client.get(path_base + "?_filter_op=exact") - assert response.status_code == 302 - assert "?" not in response.headers["Location"] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "qs,expected_hidden", - ( - # Things that should be reflected in hidden form fields: - ("_facet=_neighborhood", {"_facet": "_neighborhood"}), - ("_where=1+=+1&_col=_city_id", {"_where": "1 = 1", "_col": "_city_id"}), - # Things that should NOT be reflected in hidden form fields: - ( - "_facet=_neighborhood&_neighborhood__exact=Downtown", - {"_facet": "_neighborhood"}, - ), - ("_facet=_neighborhood&_city_id__gt=1", {"_facet": "_neighborhood"}), - ), -) -async def test_reflected_hidden_form_fields(ds_client, qs, expected_hidden): - # https://github.com/simonw/datasette/issues/1527 - response = await ds_client.get("/fixtures/facetable?{}".format(qs)) - # In this case we should NOT have a hidden _neighborhood__exact=Downtown field - form = Soup(response.text, "html.parser").find("form") - hidden_inputs = { - input["name"]: input["value"] for input in form.select("input[type=hidden]") - } - assert hidden_inputs == expected_hidden - - -@pytest.mark.asyncio -async def test_empty_search_parameter_gets_removed(ds_client): - path_base = "/fixtures/simple_primary_key" - path = ( - path_base - + "?" - + urllib.parse.urlencode( - { - "_search": "", - "_filter_column": "name", - "_filter_op": "exact", - "_filter_value": "chidi", - } - ) - ) - response = await ds_client.get(path) - assert response.status_code == 302 - assert response.headers["Location"].endswith("?name__exact=chidi") - - -@pytest.mark.asyncio -async def test_searchable_view_persists_fts_table(ds_client): - # The search form should persist ?_fts_table as a hidden field - response = await ds_client.get( - "/fixtures/searchable_view?_fts_table=searchable_fts&_fts_pk=pk" - ) - inputs = Soup(response.text, "html.parser").find("form").find_all("input") - hiddens = [i for i in inputs if i["type"] == "hidden"] - assert [("_fts_table", "searchable_fts"), ("_fts_pk", "pk")] == [ - (hidden["name"], hidden["value"]) for hidden in hiddens - ] - - -@pytest.mark.asyncio -async def test_sort_by_desc_redirects(ds_client): - path_base = "/fixtures/sortable" - path = ( - path_base - + "?" - + urllib.parse.urlencode({"_sort": "sortable", "_sort_by_desc": "1"}) - ) - response = await ds_client.get(path) - assert response.status_code == 302 - assert response.headers["Location"].endswith("?_sort_desc=sortable") - - -@pytest.mark.asyncio -async def test_sort_links(ds_client): - response = await ds_client.get("/fixtures/sortable?_sort=sortable") - assert response.status_code == 200 - ths = Soup(response.text, "html.parser").find_all("th") - attrs_and_link_attrs = [ - { - "attrs": th.attrs, - "a_href": (th.find("a")["href"] if th.find("a") else None), - } - for th in ths - ] - assert attrs_and_link_attrs == [ - { - "attrs": { - "class": ["col-Link"], - "scope": "col", - "data-column": "Link", - "data-column-type": "", - "data-column-not-null": "0", - "data-is-pk": "0", - "data-is-link-column": "1", - }, - "a_href": None, - }, - { - "attrs": { - "class": ["col-pk1"], - "scope": "col", - "data-column": "pk1", - "data-column-type": "varchar(30)", - "data-column-not-null": "0", - "data-is-pk": "1", - }, - "a_href": None, - }, - { - "attrs": { - "class": ["col-pk2"], - "scope": "col", - "data-column": "pk2", - "data-column-type": "varchar(30)", - "data-column-not-null": "0", - "data-is-pk": "1", - }, - "a_href": None, - }, - { - "attrs": { - "class": ["col-content"], - "scope": "col", - "data-column": "content", - "data-column-type": "text", - "data-column-not-null": "0", - "data-is-pk": "0", - }, - "a_href": None, - }, - { - "attrs": { - "class": ["col-sortable"], - "scope": "col", - "data-column": "sortable", - "data-column-type": "integer", - "data-column-not-null": "0", - "data-is-pk": "0", - }, - "a_href": "/fixtures/sortable?_sort_desc=sortable", - }, - { - "attrs": { - "class": ["col-sortable_with_nulls"], - "scope": "col", - "data-column": "sortable_with_nulls", - "data-column-type": "real", - "data-column-not-null": "0", - "data-is-pk": "0", - }, - "a_href": "/fixtures/sortable?_sort=sortable_with_nulls", - }, - { - "attrs": { - "class": ["col-sortable_with_nulls_2"], - "scope": "col", - "data-column": "sortable_with_nulls_2", - "data-column-type": "real", - "data-column-not-null": "0", - "data-is-pk": "0", - }, - "a_href": "/fixtures/sortable?_sort=sortable_with_nulls_2", - }, - { - "attrs": { - "class": ["col-text"], - "scope": "col", - "data-column": "text", - "data-column-type": "text", - "data-column-not-null": "0", - "data-is-pk": "0", - }, - "a_href": "/fixtures/sortable?_sort=text", - }, - ] - - -@pytest.mark.asyncio -async def test_facet_display(ds_client): - response = await ds_client.get( - "/fixtures/facetable?_facet=planet_int&_facet=_city_id&_facet=on_earth" - ) - assert response.status_code == 200 - soup = Soup(response.text, "html.parser") - divs = soup.find("div", {"class": "facet-results"}).find_all("div") - actual = [] - for div in divs: - actual.append( - { - "name": div.find("strong").text.split()[0], - "items": [ - { - "name": a.text, - "qs": a["href"].split("?")[-1], - "count": int(str(a.parent).split("")[1].split("<")[0]), - } - for a in div.find("ul").find_all("a") - ], - } - ) - assert actual == [ - { - "name": "_city_id", - "items": [ - { - "name": "San Francisco", - "qs": "_facet=planet_int&_facet=_city_id&_facet=on_earth&_city_id__exact=1", - "count": 6, - }, - { - "name": "Los Angeles", - "qs": "_facet=planet_int&_facet=_city_id&_facet=on_earth&_city_id__exact=2", - "count": 4, - }, - { - "name": "Detroit", - "qs": "_facet=planet_int&_facet=_city_id&_facet=on_earth&_city_id__exact=3", - "count": 4, - }, - { - "name": "Memnonia", - "qs": "_facet=planet_int&_facet=_city_id&_facet=on_earth&_city_id__exact=4", - "count": 1, - }, - ], - }, - { - "name": "planet_int", - "items": [ - { - "name": "1", - "qs": "_facet=planet_int&_facet=_city_id&_facet=on_earth&planet_int=1", - "count": 14, - }, - { - "name": "2", - "qs": "_facet=planet_int&_facet=_city_id&_facet=on_earth&planet_int=2", - "count": 1, - }, - ], - }, - { - "name": "on_earth", - "items": [ - { - "name": "1", - "qs": "_facet=planet_int&_facet=_city_id&_facet=on_earth&on_earth=1", - "count": 14, - }, - { - "name": "0", - "qs": "_facet=planet_int&_facet=_city_id&_facet=on_earth&on_earth=0", - "count": 1, - }, - ], - }, - ] - - -@pytest.mark.asyncio -async def test_facets_persist_through_filter_form(ds_client): - response = await ds_client.get( - "/fixtures/facetable?_facet=planet_int&_facet=_city_id&_facet_array=tags" - ) - assert response.status_code == 200 - inputs = Soup(response.text, "html.parser").find("form").find_all("input") - hiddens = [i for i in inputs if i["type"] == "hidden"] - assert [(hidden["name"], hidden["value"]) for hidden in hiddens] == [ - ("_facet", "planet_int"), - ("_facet", "_city_id"), - ("_facet_array", "tags"), - ] - - -@pytest.mark.asyncio -async def test_next_does_not_persist_in_hidden_field(ds_client): - response = await ds_client.get("/fixtures/searchable?_size=1&_next=1") - assert response.status_code == 200 - inputs = Soup(response.text, "html.parser").find("form").find_all("input") - hiddens = [i for i in inputs if i["type"] == "hidden"] - assert [(hidden["name"], hidden["value"]) for hidden in hiddens] == [ - ("_size", "1"), - ] - - -@pytest.mark.asyncio -async def test_table_html_simple_primary_key(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key?_size=3") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - assert table["class"] == ["rows-and-columns"] - ths = table.find_all("th") - assert "id\xa0▼" == ths[0].find("a").string.strip() - for expected_col, th in zip(("content",), ths[1:]): - a = th.find("a") - assert expected_col == a.string - assert a["href"].endswith(f"/simple_primary_key?_size=3&_sort={expected_col}") - assert ["nofollow"] == a["rel"] - assert [ - [ - '', - '', - ], - [ - '', - '', - ], - [ - '', - '', - ], - ] == [[str(td) for td in tr.select("td")] for tr in table.select("tbody tr")] - - -@pytest.mark.asyncio -async def test_table_csv_json_export_interface(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key?id__gt=2") - assert response.status_code == 200 - # The links at the top of the page - links = ( - Soup(response.text, "html.parser") - .find("p", {"class": "export-links"}) - .find_all("a") - ) - actual = [link["href"] for link in links] - expected = [ - "/fixtures/simple_primary_key.json?id__gt=2", - "/fixtures/simple_primary_key.testall?id__gt=2", - "/fixtures/simple_primary_key.testnone?id__gt=2", - "/fixtures/simple_primary_key.testresponse?id__gt=2", - "/fixtures/simple_primary_key.csv?id__gt=2&_size=max", - "#export", - ] - assert expected == actual - # And the advanced export box at the bottom: - div = Soup(response.text, "html.parser").find("div", {"class": "advanced-export"}) - json_links = [a["href"] for a in div.find("p").find_all("a")] - assert [ - "/fixtures/simple_primary_key.json?id__gt=2", - "/fixtures/simple_primary_key.json?id__gt=2&_shape=array", - "/fixtures/simple_primary_key.json?id__gt=2&_shape=array&_nl=on", - "/fixtures/simple_primary_key.json?id__gt=2&_shape=object", - ] == json_links - # And the CSV form - form = div.find("form") - assert form["action"].endswith("/simple_primary_key.csv") - inputs = [str(input) for input in form.find_all("input")] - assert [ - '', - '', - '', - '', - ] == inputs - - -@pytest.mark.asyncio -async def test_csv_json_export_links_include_labels_if_foreign_keys(ds_client): - response = await ds_client.get("/fixtures/facetable") - assert response.status_code == 200 - links = ( - Soup(response.text, "html.parser") - .find("p", {"class": "export-links"}) - .find_all("a") - ) - actual = [link["href"] for link in links] - expected = [ - "/fixtures/facetable.json?_labels=on", - "/fixtures/facetable.testall?_labels=on", - "/fixtures/facetable.testnone?_labels=on", - "/fixtures/facetable.testresponse?_labels=on", - "/fixtures/facetable.csv?_labels=on&_size=max", - "#export", - ] - assert expected == actual - - -@pytest.mark.asyncio -async def test_table_not_exists(ds_client): - assert "Table not found" in (await ds_client.get("/fixtures/blah")).text - - -@pytest.mark.asyncio -async def test_table_html_no_primary_key(ds_client): - response = await ds_client.get("/fixtures/no_primary_key") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - # We have disabled sorting for this table using metadata.json - assert ["content", "a", "b", "c"] == [ - th.string.strip() for th in table.select("thead th")[2:] - ] - expected = [ - [ - ''.format( - i, i - ), - f'', - f'', - f'', - f'', - f'', - ] - for i in range(1, 51) - ] - assert expected == [ - [str(td) for td in tr.select("td")] for tr in table.select("tbody tr") - ] - - -@pytest.mark.asyncio -async def test_rowid_sortable_no_primary_key(ds_client): - response = await ds_client.get("/fixtures/no_primary_key") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - assert table["class"] == ["rows-and-columns"] - ths = table.find_all("th") - assert "rowid\xa0▼" == ths[1].find("a").string.strip() - - -@pytest.mark.asyncio -async def test_table_html_compound_primary_key(ds_client): - response = await ds_client.get("/fixtures/compound_primary_key") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - ths = table.find_all("th") - assert "Link" == ths[0].string.strip() - for expected_col, th in zip(("pk1", "pk2", "content"), ths[1:]): - a = th.find("a") - assert expected_col == a.string - assert th["class"] == [f"col-{expected_col}"] - assert a["href"].endswith(f"/compound_primary_key?_sort={expected_col}") - expected = [ - [ - '', - '', - '', - '', - ], - [ - '', - '', - '', - '', - ], - [ - '', - '', - '', - '', - ], - ] - assert [ - [str(td) for td in tr.select("td")] for tr in table.select("tbody tr") - ] == expected - - -@pytest.mark.asyncio -async def test_table_html_foreign_key_links(ds_client): - response = await ds_client.get("/fixtures/foreign_key_references") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - actual = [[str(td) for td in tr.select("td")] for tr in table.select("tbody tr")] - assert actual == [ - [ - '', - '', - '', - '', - '', - '', - ], - [ - '', - '', - '', - '', - '', - '', - ], - ] - - -@pytest.mark.asyncio -async def test_table_html_foreign_key_facets(ds_client): - response = await ds_client.get( - "/fixtures/foreign_key_references?_facet=foreign_key_with_blank_label" - ) - assert response.status_code == 200 - assert ( - '
  • - 1
  • ' - ) in response.text - - -@pytest.mark.asyncio -async def test_table_html_disable_foreign_key_links_with_labels(ds_client): - response = await ds_client.get( - "/fixtures/foreign_key_references?_labels=off&_size=1" - ) - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - actual = [[str(td) for td in tr.select("td")] for tr in table.select("tbody tr")] - assert actual == [ - [ - '
    ', - '', - '', - '', - '', - '', - ] - ] - - -@pytest.mark.asyncio -async def test_table_html_foreign_key_custom_label_column(ds_client): - response = await ds_client.get("/fixtures/custom_foreign_key_label") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - expected = [ - [ - '', - '', - ] - ] - assert expected == [ - [str(td) for td in tr.select("td")] for tr in table.select("tbody tr") - ] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_column_options", - [ - ("/fixtures/infinity", ["- column -", "rowid", "value"]), - ( - "/fixtures/primary_key_multiple_columns", - ["- column -", "id", "content", "content2"], - ), - ("/fixtures/compound_primary_key", ["- column -", "pk1", "pk2", "content"]), - ], -) -async def test_table_html_filter_form_column_options( - path, expected_column_options, ds_client -): - response = await ds_client.get(path) - assert response.status_code == 200 - form = Soup(response.text, "html.parser").find("form") - column_options = [ - o.attrs.get("value") or o.string - for o in form.select("select[name=_filter_column] option") - ] - assert expected_column_options == column_options - - -@pytest.mark.asyncio -async def test_table_html_filter_form_still_shows_nocol_columns(ds_client): - # https://github.com/simonw/datasette/issues/1503 - response = await ds_client.get("/fixtures/sortable?_nocol=sortable") - assert response.status_code == 200 - form = Soup(response.text, "html.parser").find("form") - assert [ - o.string - for o in form.select("select[name='_filter_column']")[0].select("option") - ] == [ - "- column -", - "pk1", - "pk2", - "content", - "sortable_with_nulls", - "sortable_with_nulls_2", - "text", - # Moved to the end because it is no longer returned by the query: - "sortable", - ] - - -@pytest.mark.asyncio -async def test_column_chooser_present(ds_client): - response = await ds_client.get("/fixtures/facetable") - assert response.status_code == 200 - soup = Soup(response.text, "html.parser") - # Web component should be present - chooser = soup.find("column-chooser") - assert chooser is not None - # Script block should contain column data as JSON - - scripts = soup.find_all("script") - chooser_script = [s for s in scripts if "_columnChooserData" in (s.string or "")] - assert len(chooser_script) == 1 - script_text = chooser_script[0].string - # Extract the JSON data - assert "allColumns" in script_text - assert "selectedColumns" in script_text - assert "primaryKeys" in script_text - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path", ["/fixtures/facetable", "/fixtures/123_starts_with_digits"] -) -async def test_mobile_column_actions_present(ds_client, path): - response = await ds_client.get(path) - assert response.status_code == 200 - soup = Soup(response.text, "html.parser") - button = soup.select_one("button.column-actions-mobile.small-screen-only") - assert button is not None - assert button.text.strip() == "Column actions" - assert button.find("svg") is not None - assert any( - "mobile-column-actions.js" in (script.get("src") or "") - for script in soup.find_all("script") - ) - # mobile-column-actions.js builds its dialog from ', - '', - '', - ], - [ - '', - '', - '', - ], - ] - assert expected == [ - [str(td) for td in tr.select("td")] for tr in table.select("tbody tr") - ] - - -@pytest.mark.asyncio -async def test_view_html(ds_client): - response = await ds_client.get("/fixtures/simple_view?_size=3") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - ths = table.select("thead th") - assert 2 == len(ths) - assert ths[0].find("a") is not None - assert ths[0].find("a")["href"].endswith("/simple_view?_size=3&_sort=content") - assert ths[0].find("a").string.strip() == "content" - assert ths[1].find("a") is None - assert ths[1].string.strip() == "upper_content" - expected = [ - [ - '', - '', - ], - [ - '', - '', - ], - [ - '', - '', - ], - ] - assert expected == [ - [str(td) for td in tr.select("td")] for tr in table.select("tbody tr") - ] - - -@pytest.mark.asyncio -async def test_table_metadata(ds_client): - response = await ds_client.get("/fixtures/simple_primary_key") - assert response.status_code == 200 - soup = Soup(response.text, "html.parser") - # Page title should be custom and should be HTML escaped - assert "This <em>HTML</em> is escaped" == inner_html(soup.find("h1")) - # Description should be custom and NOT escaped (we used description_html) - assert "Simple primary key" == inner_html( - soup.find("div", {"class": "metadata-description"}) - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,has_object,has_stream,has_expand", - [ - ("/fixtures/no_primary_key", False, True, False), - ("/fixtures/complex_foreign_keys", True, False, True), - ], -) -async def test_advanced_export_box(ds_client, path, has_object, has_stream, has_expand): - response = await ds_client.get(path) - assert response.status_code == 200 - soup = Soup(response.text, "html.parser") - # JSON shape options - expected_json_shapes = ["default", "array", "newline-delimited"] - if has_object: - expected_json_shapes.append("object") - div = soup.find("div", {"class": "advanced-export"}) - assert expected_json_shapes == [a.text for a in div.find("p").find_all("a")] - # "stream all rows" option - if has_stream: - assert "stream all rows" in str(div) - # "expand labels" option - if has_expand: - assert "expand labels" in str(div) - - -@pytest.mark.asyncio -async def test_extra_where_clauses(ds_client): - response = await ds_client.get( - "/fixtures/facetable?_where=_neighborhood='Dogpatch'&_where=_city_id=1" - ) - soup = Soup(response.text, "html.parser") - div = soup.select(".extra-wheres")[0] - assert "2 extra where clauses" == div.find("h3").text - hrefs = [a["href"] for a in div.find_all("a")] - assert [ - "/fixtures/facetable?_where=_city_id%3D1", - "/fixtures/facetable?_where=_neighborhood%3D%27Dogpatch%27", - ] == hrefs - # These should also be persisted as hidden fields - inputs = soup.find("form").find_all("input") - hiddens = [i for i in inputs if i["type"] == "hidden"] - assert [("_where", "_neighborhood='Dogpatch'"), ("_where", "_city_id=1")] == [ - (hidden["name"], hidden["value"]) for hidden in hiddens - ] - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_hidden", - [ - ("/fixtures/facetable?_size=10", [("_size", "10")]), - ( - "/fixtures/facetable?_size=10&_ignore=1&_ignore=2", - [ - ("_size", "10"), - ("_ignore", "1"), - ("_ignore", "2"), - ], - ), - ], -) -async def test_other_hidden_form_fields(ds_client, path, expected_hidden): - response = await ds_client.get(path) - soup = Soup(response.text, "html.parser") - inputs = soup.find("form").find_all("input") - hiddens = [i for i in inputs if i["type"] == "hidden"] - assert [(hidden["name"], hidden["value"]) for hidden in hiddens] == expected_hidden - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected_hidden", - [ - ("/fixtures/searchable?_search=terry", []), - ("/fixtures/searchable?_sort=text2", []), - ("/fixtures/searchable?_sort_desc=text2", []), - ("/fixtures/searchable?_sort=text2&_where=1", [("_where", "1")]), - ], -) -async def test_search_and_sort_fields_not_duplicated(ds_client, path, expected_hidden): - # https://github.com/simonw/datasette/issues/1214 - response = await ds_client.get(path) - soup = Soup(response.text, "html.parser") - inputs = soup.find("form").find_all("input") - hiddens = [i for i in inputs if i["type"] == "hidden"] - assert [(hidden["name"], hidden["value"]) for hidden in hiddens] == expected_hidden - - -@pytest.mark.asyncio -async def test_binary_data_display_in_table(ds_client): - response = await ds_client.get("/fixtures/binary_data") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - expected_tds = [ - [ - '', - '', - '', - ], - [ - '', - '', - '', - ], - [ - '', - '', - '', - ], - ] - assert expected_tds == [ - [str(td) for td in tr.select("td")] for tr in table.select("tbody tr") - ] - - -def test_custom_table_include(): - with make_app_client( - template_dir=str(pathlib.Path(__file__).parent / "test_templates") - ) as client: - response = client.get("/fixtures/complex_foreign_keys") - assert response.status == 200 - assert ( - '
    ' - '1 - 2 - hello 1' - "
    " - ) == str(Soup(response.text, "html.parser").select_one("div.custom-table-row")) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("json", (True, False)) -@pytest.mark.parametrize( - "params,error", - ( - ("?_sort=bad", "Cannot sort table by bad"), - ("?_sort_desc=bad", "Cannot sort table by bad"), - ( - "?_sort=state&_sort_desc=state", - "Cannot use _sort and _sort_desc at the same time", - ), - ), -) -async def test_sort_errors(ds_client, json, params, error): - path = "/fixtures/facetable{}{}".format( - ".json" if json else "", - params, - ) - response = await ds_client.get(path) - assert response.status_code == 400 - if json: - assert response.json() == { - "ok": False, - "error": error, - "status": 400, - "title": None, - } - else: - assert error in response.text - - -@pytest.mark.asyncio -async def test_metadata_sort(ds_client): - response = await ds_client.get("/fixtures/facet_cities") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - assert table["class"] == ["rows-and-columns"] - ths = table.find_all("th") - assert ["id", "name\xa0▼"] == [th.find("a").string.strip() for th in ths] - rows = [[str(td) for td in tr.select("td")] for tr in table.select("tbody tr")] - expected = [ - [ - '
    ', - '', - ], - [ - '', - '', - ], - [ - '', - '', - ], - [ - '', - '', - ], - ] - assert expected == rows - # Make sure you can reverse that sort order - response = await ds_client.get("/fixtures/facet_cities?_sort_desc=name") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - rows = [[str(td) for td in tr.select("td")] for tr in table.select("tbody tr")] - assert list(reversed(expected)) == rows - - -@pytest.mark.asyncio -async def test_metadata_sort_desc(ds_client): - response = await ds_client.get("/fixtures/attraction_characteristic") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - assert table["class"] == ["rows-and-columns"] - ths = table.find_all("th") - assert ["pk\xa0▲", "name"] == [th.find("a").string.strip() for th in ths] - rows = [[str(td) for td in tr.select("td")] for tr in table.select("tbody tr")] - expected = [ - [ - '', - '', - ], - [ - '', - '', - ], - ] - assert expected == rows - # Make sure you can reverse that sort order - response = await ds_client.get("/fixtures/attraction_characteristic?_sort=pk") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - rows = [[str(td) for td in tr.select("td")] for tr in table.select("tbody tr")] - assert list(reversed(expected)) == rows - - -@pytest.mark.parametrize( - "max_returned_rows,path,expected_num_facets,expected_ellipses,expected_ellipses_url", - ( - ( - 5, - # Default should show 2 facets - "/fixtures/facetable?_facet=_neighborhood", - 2, - True, - "/fixtures/facetable?_facet=_neighborhood&_facet_size=max", - ), - # _facet_size above max_returned_rows should show max_returned_rows (5) - ( - 5, - "/fixtures/facetable?_facet=_neighborhood&_facet_size=50", - 5, - True, - "/fixtures/facetable?_facet=_neighborhood&_facet_size=max", - ), - # If max_returned_rows is high enough, should return all - ( - 20, - "/fixtures/facetable?_facet=_neighborhood&_facet_size=max", - 14, - False, - None, - ), - # If num facets > max_returned_rows, show ... without a link - # _facet_size above max_returned_rows should show max_returned_rows (5) - ( - 5, - "/fixtures/facetable?_facet=_neighborhood&_facet_size=max", - 5, - True, - None, - ), - ), -) -def test_facet_more_links( - max_returned_rows, - path, - expected_num_facets, - expected_ellipses, - expected_ellipses_url, -): - with make_app_client( - settings={"max_returned_rows": max_returned_rows, "default_facet_size": 2} - ) as client: - response = client.get(path) - soup = Soup(response.body, "html.parser") - lis = soup.select("#facet-neighborhood-b352a7 ul li:not(.facet-truncated)") - facet_truncated = soup.select_one(".facet-truncated") - assert len(lis) == expected_num_facets - if not expected_ellipses: - assert facet_truncated is None - else: - if expected_ellipses_url: - assert facet_truncated.find("a")["href"] == expected_ellipses_url - else: - assert facet_truncated.find("a") is None - - -def test_unavailable_table_does_not_break_sort_relationships(): - # https://github.com/simonw/datasette/issues/1305 - with make_app_client( - config={ - "databases": { - "fixtures": {"tables": {"foreign_key_references": {"allow": False}}} - } - } - ) as client: - response = client.get("/?_sort=relationships") - assert response.status == 200 - - -@pytest.mark.asyncio -async def test_column_metadata(ds_client): - response = await ds_client.get("/fixtures/roadside_attractions") - soup = Soup(response.text, "html.parser") - dl = soup.find("dl") - assert [(dt.text, dt.next_sibling.text) for dt in dl.find_all("dt")] == [ - ("address", "The street address for the attraction"), - ("name", "The name of the attraction"), - ] - assert ( - soup.select("th[data-column=name]")[0]["data-column-description"] - == "The name of the attraction" - ) - assert ( - soup.select("th[data-column=address]")[0]["data-column-description"] - == "The street address for the attraction" - ) - - -def test_facet_total(): - # https://github.com/simonw/datasette/issues/1423 - # https://github.com/simonw/datasette/issues/1556 - with make_app_client(settings={"max_returned_rows": 100}) as client: - path = "/fixtures/sortable?_facet=content&_facet=pk1" - response = client.get(path) - assert response.status == 200 - fragments = ( - '>30', - '8', - ) - for fragment in fragments: - assert fragment in response.text - - -@pytest.mark.asyncio -async def test_sort_rowid_with_next(ds_client): - # https://github.com/simonw/datasette/issues/1470 - response = await ds_client.get("/fixtures/binary_data?_size=1&_next=1&_sort=rowid") - assert response.status_code == 200 - - -def assert_querystring_equal(expected, actual): - assert sorted(expected.split("&")) == sorted(actual.split("&")) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path,expected", - ( - ( - "/fixtures/facetable", - "fixtures: facetable: 15 rows", - ), - ( - "/fixtures/facetable?on_earth__exact=1", - "fixtures: facetable: 14 rows where on_earth = 1", - ), - ), -) -async def test_table_page_title(ds_client, path, expected): - response = await ds_client.get(path) - title = Soup(response.text, "html.parser").find("title").text - assert title == expected - - -@pytest.mark.asyncio -async def test_table_post_method_not_allowed(ds_client): - response = await ds_client.post("/fixtures/facetable") - assert response.status_code == 405 - assert "Method not allowed" in response.text - - -@pytest.mark.parametrize("allow_facet", (True, False)) -def test_allow_facet_off(allow_facet): - with make_app_client(settings={"allow_facet": allow_facet}) as client: - response = client.get("/fixtures/facetable") - expected = "DATASETTE_ALLOW_FACET = {};".format( - "true" if allow_facet else "false" - ) - assert expected in response.text - if allow_facet: - assert "Suggested facets" in response.text - else: - assert "Suggested facets" not in response.text - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "size,title,length_bytes", - ( - (2000, ' title="2.0 KB"', "2,000"), - (20000, ' title="19.5 KB"', "20,000"), - (20, "", "20"), - ), -) -async def test_format_of_binary_links(size, title, length_bytes): - ds = Datasette() - db_name = "binary-links-{}".format(size) - db = ds.add_memory_database(db_name) - sql = "select zeroblob({}) as blob".format(size) - await db.execute_write("create table blobs as {}".format(sql)) - response = await ds.client.get("/{}/blobs".format(db_name)) - assert response.status_code == 200 - expected = "{}><Binary: {} bytes>".format(title, length_bytes) - assert expected in response.text - # And test with arbitrary SQL query too - sql_response = await ds.client.get( - "{}/-/query".format(db_name), params={"sql": sql} - ) - assert sql_response.status_code == 200 - assert expected in sql_response.text - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "config", - ( - # Blocked at table level - { - "databases": { - "foreign_key_labels": { - "tables": { - # Table a is only visible to root - "a": {"allow": {"id": "root"}}, - } - } - } - }, - # Blocked at database level - { - "databases": { - "foreign_key_labels": { - # Only root can view this database - "allow": {"id": "root"}, - "tables": { - # But table b is visible to everyone - "b": {"allow": True}, - }, - } - } - }, - # Blocked at the instance level - { - "allow": {"id": "root"}, - "databases": { - "foreign_key_labels": { - "tables": { - # Table b is visible to everyone - "b": {"allow": True}, - } - } - }, - }, - ), -) -async def test_foreign_key_labels_obey_permissions(config): - ds = Datasette(config=config) - db = ds.add_memory_database("foreign_key_labels") - await db.execute_write( - "create table if not exists a(id integer primary key, name text)" - ) - await db.execute_write("insert or replace into a (id, name) values (1, 'hello')") - await db.execute_write( - "create table if not exists b(id integer primary key, name text, a_id integer references a(id))" - ) - await db.execute_write( - "insert or replace into b (id, name, a_id) values (1, 'world', 1)" - ) - # Anonymous user can see table b but not table a - await ds.client.get("/foreign_key_labels.json") - anon_a = await ds.client.get("/foreign_key_labels/a.json?_labels=on") - assert anon_a.status_code == 403 - anon_b = await ds.client.get("/foreign_key_labels/b.json?_labels=on") - assert anon_b.status_code == 200 - # root user can see both - cookies = {"ds_actor": ds.sign({"a": {"id": "root"}}, "actor")} - root_a = await ds.client.get( - "/foreign_key_labels/a.json?_labels=on", cookies=cookies - ) - assert root_a.status_code == 200 - root_b = await ds.client.get( - "/foreign_key_labels/b.json?_labels=on", cookies=cookies - ) - assert root_b.status_code == 200 - # Labels should have been expanded for root - assert root_b.json() == { - "ok": True, - "next": None, - "rows": [{"id": 1, "name": "world", "a_id": {"value": 1, "label": "hello"}}], - "truncated": False, - } - # But not for anon - assert anon_b.json() == { - "ok": True, - "next": None, - "rows": [{"id": 1, "name": "world", "a_id": 1}], - "truncated": False, - } - - -def test_foreign_keys_special_character_in_database_name(app_client_with_dot): - # https://github.com/simonw/datasette/pull/2476 - response = app_client_with_dot.get("/fixtures~2Edot/complex_foreign_keys") - assert 'world' in response.text diff --git a/tests/test_templates/_table.html b/tests/test_templates/_table.html deleted file mode 100644 index 14f635a8..00000000 --- a/tests/test_templates/_table.html +++ /dev/null @@ -1,3 +0,0 @@ -{% for row in display_rows %} -
    {{ row["f1"] }} - {{ row["f2"] }} - {{ row.display("f3") }}
    -{% endfor %} diff --git a/tests/test_templates/pages/202.html b/tests/test_templates/pages/202.html deleted file mode 100644 index 43a313b2..00000000 --- a/tests/test_templates/pages/202.html +++ /dev/null @@ -1 +0,0 @@ -{{ custom_status(202) }}202! \ No newline at end of file diff --git a/tests/test_templates/pages/about.html b/tests/test_templates/pages/about.html deleted file mode 100644 index 11d78862..00000000 --- a/tests/test_templates/pages/about.html +++ /dev/null @@ -1 +0,0 @@ -ABOUT! view_name:{{ view_name }} \ No newline at end of file diff --git a/tests/test_templates/pages/atom.html b/tests/test_templates/pages/atom.html deleted file mode 100644 index 1c7faafd..00000000 --- a/tests/test_templates/pages/atom.html +++ /dev/null @@ -1 +0,0 @@ -{{ custom_header("content-type", "application/xml") }} \ No newline at end of file diff --git a/tests/test_templates/pages/headers.html b/tests/test_templates/pages/headers.html deleted file mode 100644 index 8a59d4aa..00000000 --- a/tests/test_templates/pages/headers.html +++ /dev/null @@ -1 +0,0 @@ -{{ custom_header("x-this-is-foo", "foo") }}FOO{{ custom_header("x-this-is-bar", "bar") }}BAR \ No newline at end of file diff --git a/tests/test_templates/pages/nested/nest.html b/tests/test_templates/pages/nested/nest.html deleted file mode 100644 index 5510f99e..00000000 --- a/tests/test_templates/pages/nested/nest.html +++ /dev/null @@ -1 +0,0 @@ -Nest! \ No newline at end of file diff --git a/tests/test_templates/pages/redirect.html b/tests/test_templates/pages/redirect.html deleted file mode 100644 index 36a71554..00000000 --- a/tests/test_templates/pages/redirect.html +++ /dev/null @@ -1 +0,0 @@ -{{ custom_redirect("/example") }} \ No newline at end of file diff --git a/tests/test_templates/pages/redirect2.html b/tests/test_templates/pages/redirect2.html deleted file mode 100644 index b7ae092a..00000000 --- a/tests/test_templates/pages/redirect2.html +++ /dev/null @@ -1 +0,0 @@ -{{ custom_redirect("/example", 301) }} \ No newline at end of file diff --git a/tests/test_templates/pages/request.html b/tests/test_templates/pages/request.html deleted file mode 100644 index aa8e0b62..00000000 --- a/tests/test_templates/pages/request.html +++ /dev/null @@ -1 +0,0 @@ -path:{{ request.path }} \ No newline at end of file diff --git a/tests/test_templates/pages/route_{name}.html b/tests/test_templates/pages/route_{name}.html deleted file mode 100644 index 42bd1e04..00000000 --- a/tests/test_templates/pages/route_{name}.html +++ /dev/null @@ -1,2 +0,0 @@ -{% if name == "OhNo" %}{{ raise_404("Oh no") }}{% endif %} -

    Hello from {{ name }}

    \ No newline at end of file diff --git a/tests/test_templates/pages/topic_{topic}.html b/tests/test_templates/pages/topic_{topic}.html deleted file mode 100644 index f07b6b07..00000000 --- a/tests/test_templates/pages/topic_{topic}.html +++ /dev/null @@ -1 +0,0 @@ -Topic page for {{ topic }} \ No newline at end of file diff --git a/tests/test_templates/pages/topic_{topic}/{slug}.html b/tests/test_templates/pages/topic_{topic}/{slug}.html deleted file mode 100644 index cbe5344f..00000000 --- a/tests/test_templates/pages/topic_{topic}/{slug}.html +++ /dev/null @@ -1 +0,0 @@ -Slug: {{ slug }}, Topic: {{ topic }} \ No newline at end of file diff --git a/tests/test_templates/show_json.html b/tests/test_templates/show_json.html deleted file mode 100644 index cff04fb4..00000000 --- a/tests/test_templates/show_json.html +++ /dev/null @@ -1,9 +0,0 @@ -{% extends "base.html" %} - -{% block content %} -{{ super() }} -Test data for extra_template_vars: -
    {{ extra_template_vars|safe }}
    -
    {{ extra_template_vars_from_awaitable|safe }}
    -
    {{ query_database("select sqlite_version();") }}
    -{% endblock %} diff --git a/tests/test_token_handler.py b/tests/test_token_handler.py deleted file mode 100644 index 5c87f577..00000000 --- a/tests/test_token_handler.py +++ /dev/null @@ -1,338 +0,0 @@ -""" -Tests for the register_token_handler plugin hook. -""" - -from datasette.app import Datasette -from datasette.hookspecs import hookimpl -from datasette.plugins import pm -from datasette.tokens import TokenHandler, TokenRestrictions, SignedTokenHandler -import pytest - - -@pytest.fixture -def datasette(): - return Datasette() - - -@pytest.mark.asyncio -async def test_default_signed_handler_registered(datasette): - """The default SignedTokenHandler should be registered automatically.""" - handlers = datasette._token_handlers() - assert len(handlers) >= 1 - assert any(isinstance(h, SignedTokenHandler) for h in handlers) - assert any(h.name == "signed" for h in handlers) - - -@pytest.mark.asyncio -async def test_create_token_default(datasette): - """create_token() with handler='signed' should create a signed token.""" - token = await datasette.create_token("test_actor", handler="signed") - assert token.startswith("dstok_") - - -@pytest.mark.asyncio -async def test_create_token_with_restrictions(datasette): - """create_token() should handle restriction parameters.""" - token = await datasette.create_token( - "test_actor", - handler="signed", - expires_after=3600, - restrictions=TokenRestrictions().allow_all("view-instance"), - ) - assert token.startswith("dstok_") - # Verify the token contains the expected data - decoded = datasette.unsign(token[len("dstok_") :], namespace="token") - assert decoded["a"] == "test_actor" - assert decoded["d"] == 3600 - assert "_r" in decoded - assert "a" in decoded["_r"] - - -@pytest.mark.asyncio -async def test_verify_token_default(datasette): - """verify_token() should verify signed tokens.""" - token = await datasette.create_token("test_actor", handler="signed") - actor = await datasette.verify_token(token) - assert actor is not None - assert actor["id"] == "test_actor" - assert actor["token"] == "dstok" - - -@pytest.mark.asyncio -async def test_verify_token_unknown_returns_none(datasette): - """verify_token() should return None for unrecognized tokens.""" - result = await datasette.verify_token("unknown_token_format_xyz") - assert result is None - - -@pytest.mark.asyncio -async def test_verify_token_bad_signature_returns_none(datasette): - """verify_token() should return None for tokens with bad signatures.""" - result = await datasette.verify_token("dstok_tampered_data_here") - assert result is None - - -@pytest.mark.asyncio -async def test_create_token_with_named_handler(datasette): - """create_token(handler='signed') should select the signed handler.""" - token = await datasette.create_token("test_actor", handler="signed") - assert token.startswith("dstok_") - - -@pytest.mark.asyncio -async def test_create_token_unknown_handler_raises(datasette): - """create_token(handler='nonexistent') should raise ValueError.""" - with pytest.raises(ValueError, match="Token handler 'nonexistent' not found"): - await datasette.create_token("test_actor", handler="nonexistent") - - -@pytest.mark.asyncio -async def test_custom_token_handler(datasette): - """A custom token handler should be usable for both create and verify.""" - - class CustomHandler(TokenHandler): - name = "custom" - - async def create_token(self, datasette, actor_id, **kwargs): - return f"custom_{actor_id}" - - async def verify_token(self, datasette, token): - if token.startswith("custom_"): - return {"id": token[len("custom_") :], "token": "custom"} - return None - - class Plugin: - __name__ = "CustomTokenPlugin" - - @staticmethod - @hookimpl - def register_token_handler(datasette): - return CustomHandler() - - pm.register(Plugin(), name="test_custom_handler") - try: - handlers = datasette._token_handlers() - assert any(h.name == "custom" for h in handlers) - - # Create with custom handler - token = await datasette.create_token("alice", handler="custom") - assert token == "custom_alice" - - # Verify custom token - actor = await datasette.verify_token("custom_alice") - assert actor is not None - assert actor["id"] == "alice" - assert actor["token"] == "custom" - - # Signed tokens should still work - signed_token = await datasette.create_token("bob", handler="signed") - assert signed_token.startswith("dstok_") - actor = await datasette.verify_token(signed_token) - assert actor["id"] == "bob" - finally: - pm.unregister(name="test_custom_handler") - - -@pytest.mark.asyncio -async def test_verify_token_tries_all_handlers(datasette): - """verify_token() should try each handler until one matches.""" - - class HandlerA(TokenHandler): - name = "handler_a" - - async def create_token(self, datasette, actor_id, **kwargs): - return f"a_{actor_id}" - - async def verify_token(self, datasette, token): - if token.startswith("a_"): - return {"id": token[2:], "token": "handler_a"} - return None - - class HandlerB(TokenHandler): - name = "handler_b" - - async def create_token(self, datasette, actor_id, **kwargs): - return f"b_{actor_id}" - - async def verify_token(self, datasette, token): - if token.startswith("b_"): - return {"id": token[2:], "token": "handler_b"} - return None - - class PluginA: - __name__ = "PluginA" - - @staticmethod - @hookimpl - def register_token_handler(datasette): - return HandlerA() - - class PluginB: - __name__ = "PluginB" - - @staticmethod - @hookimpl - def register_token_handler(datasette): - return HandlerB() - - pm.register(PluginA(), name="test_handler_a") - pm.register(PluginB(), name="test_handler_b") - try: - # Both handler tokens should verify - actor_a = await datasette.verify_token("a_alice") - assert actor_a is not None - assert actor_a["id"] == "alice" - assert actor_a["token"] == "handler_a" - - actor_b = await datasette.verify_token("b_bob") - assert actor_b is not None - assert actor_b["id"] == "bob" - assert actor_b["token"] == "handler_b" - - # Unknown token should return None - assert await datasette.verify_token("c_charlie") is None - finally: - pm.unregister(name="test_handler_a") - pm.unregister(name="test_handler_b") - - -@pytest.mark.asyncio -async def test_token_handler_via_http(datasette): - """Default signed tokens should work through HTTP auth.""" - token = await datasette.create_token("http_user", handler="signed") - response = await datasette.client.get( - "/-/actor.json", - headers={"Authorization": f"Bearer {token}"}, - ) - assert response.status_code == 200 - actor = response.json()["actor"] - assert actor["id"] == "http_user" - assert actor["token"] == "dstok" - - -@pytest.mark.asyncio -async def test_custom_handler_via_http(datasette): - """Custom handler tokens should work through HTTP auth.""" - - class CustomHandler(TokenHandler): - name = "custom_http" - - async def create_token(self, datasette, actor_id, **kwargs): - return f"chttp_{actor_id}" - - async def verify_token(self, datasette, token): - if token.startswith("chttp_"): - return {"id": token[len("chttp_") :], "token": "custom_http"} - return None - - class Plugin: - __name__ = "CustomHTTPPlugin" - - @staticmethod - @hookimpl - def register_token_handler(datasette): - return CustomHandler() - - pm.register(Plugin(), name="test_custom_http") - try: - token = await datasette.create_token("web_user", handler="custom_http") - response = await datasette.client.get( - "/-/actor.json", - headers={"Authorization": f"Bearer {token}"}, - ) - assert response.status_code == 200 - actor = response.json()["actor"] - assert actor["id"] == "web_user" - assert actor["token"] == "custom_http" - finally: - pm.unregister(name="test_custom_http") - - -@pytest.mark.asyncio -async def test_token_handler_base_class_raises(): - """TokenHandler base class methods should raise NotImplementedError.""" - handler = TokenHandler() - ds = Datasette() - with pytest.raises(NotImplementedError): - await handler.create_token(ds, "test") - with pytest.raises(NotImplementedError): - await handler.verify_token(ds, "test") - - -@pytest.mark.asyncio -async def test_restrictions_round_trip(datasette): - """Tokens with database/resource restrictions should round-trip correctly.""" - restrictions = ( - TokenRestrictions() - .allow_all("view-instance") - .allow_database("docs", "view-query") - .allow_resource("docs", "attachments", "insert-row") - ) - token = await datasette.create_token( - "test_actor", handler="signed", restrictions=restrictions - ) - actor = await datasette.verify_token(token) - assert actor is not None - assert actor["id"] == "test_actor" - assert actor["_r"]["a"] == ["view-instance"] - assert actor["_r"]["d"] == {"docs": ["view-query"]} - assert actor["_r"]["r"] == {"docs": {"attachments": ["insert-row"]}} - - -@pytest.mark.asyncio -async def test_expires_after_round_trip(datasette): - """Tokens with expires_after should include token_expires in the actor.""" - token = await datasette.create_token( - "test_actor", handler="signed", expires_after=3600 - ) - actor = await datasette.verify_token(token) - assert actor is not None - assert actor["id"] == "test_actor" - assert "token_expires" in actor - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "build_restrictions,expected", - [ - (lambda r: r, None), - (lambda r: r.allow_all("view-instance"), {"a": ["vi"]}), - ( - lambda r: r.allow_database("docs", "view-query"), - {"d": {"docs": ["vq"]}}, - ), - ( - lambda r: r.allow_resource("docs", "attachments", "insert-row"), - {"r": {"docs": {"attachments": ["ir"]}}}, - ), - ( - lambda r: r.allow_all("view-instance") - .allow_database("docs", "view-query") - .allow_resource("docs", "attachments", "insert-row"), - { - "a": ["vi"], - "d": {"docs": ["vq"]}, - "r": {"docs": {"attachments": ["ir"]}}, - }, - ), - ( - lambda r: r.allow_all("not-a-real-action"), - {"a": ["not-a-real-action"]}, - ), - ], - ids=["empty", "all", "database", "resource", "combined", "unknown_action"], -) -async def test_token_restrictions_abbreviated(datasette, build_restrictions, expected): - await datasette.invoke_startup() - restrictions = build_restrictions(TokenRestrictions()) - assert restrictions.abbreviated(datasette) == expected - - -@pytest.mark.asyncio -async def test_signed_tokens_disabled(): - """create_token and verify_token should fail/skip when signed tokens are disabled.""" - ds = Datasette(settings={"allow_signed_tokens": False}) - with pytest.raises(ValueError, match="Signed tokens are not enabled"): - await ds.create_token("test_actor", handler="signed") - # verify_token should return None rather than raising - assert await ds.verify_token("dstok_anything") is None diff --git a/tests/test_tracer.py b/tests/test_tracer.py deleted file mode 100644 index 6cc80fc4..00000000 --- a/tests/test_tracer.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytest -from .fixtures import make_app_client - - -@pytest.mark.parametrize("trace_debug", (True, False)) -def test_trace(trace_debug): - with make_app_client(settings={"trace_debug": trace_debug}) as client: - response = client.get("/fixtures/simple_primary_key.json?_trace=1") - assert response.status == 200 - - data = response.json - if not trace_debug: - assert "_trace" not in data - return - - assert "_trace" in data - trace_info = data["_trace"] - assert isinstance(trace_info["request_duration_ms"], float) - assert isinstance(trace_info["sum_trace_duration_ms"], float) - assert isinstance(trace_info["num_traces"], int) - assert isinstance(trace_info["traces"], list) - traces = trace_info["traces"] - assert len(traces) == trace_info["num_traces"] - for trace in traces: - assert isinstance(trace["type"], str) - assert isinstance(trace["start"], float) - assert isinstance(trace["end"], float) - assert trace["duration_ms"] == (trace["end"] - trace["start"]) * 1000 - assert isinstance(trace["traceback"], list) - assert isinstance(trace["database"], str) - assert isinstance(trace["sql"], str) - assert isinstance(trace.get("params"), (list, dict, None.__class__)) - - sqls = [trace["sql"] for trace in traces if "sql" in trace] - # There should be SQL statements from request handling in the trace. - # Note: CREATE TABLE, INSERT OR REPLACE, executescript, and executemany - # are not expected here because internal tables are now created and - # populated during invoke_startup(), before the request is traced. - assert any(sql.startswith("select ") for sql in sqls), "No select statements traced" - - -def test_trace_silently_fails_for_large_page(): - # Max HTML size is 256KB - with make_app_client(settings={"trace_debug": True}) as client: - # Small response should have trace - small_response = client.get("/fixtures/simple_primary_key.json?_trace=1") - assert small_response.status == 200 - assert "_trace" in small_response.json - - # Big response should not - big_response = client.get( - "/fixtures/-/query.json", - params={"_trace": 1, "sql": "select zeroblob(1024 * 256)"}, - ) - assert big_response.status == 200 - assert "_trace" not in big_response.json - - -def test_trace_query_errors(): - with make_app_client(settings={"trace_debug": True}) as client: - response = client.get( - "/fixtures/-/query.json", - params={"_trace": 1, "sql": "select * from non_existent_table"}, - ) - assert response.status == 400 - - data = response.json - assert "_trace" in data - trace_info = data["_trace"] - assert trace_info["traces"][-1]["error"] == "no such table: non_existent_table" - - -def test_trace_parallel_queries(): - with make_app_client(settings={"trace_debug": True}) as client: - response = client.get("/parallel-queries?_trace=1") - assert response.status == 200 - - data = response.json - assert data["one"] == 1 - assert data["two"] == 2 - trace_info = data["_trace"] - traces = [trace for trace in trace_info["traces"] if "sql" in trace] - one, two = traces - # "two" should have started before "one" ended - assert two["start"] < one["end"] diff --git a/tests/test_utils.py b/tests/test_utils.py index 3fcb623e..85ee2d84 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,189 +2,194 @@ Tests for various datasette helper functions. """ -from datasette.app import Datasette from datasette import utils -from datasette.utils.asgi import Request -from datasette.utils.sqlite import sqlite3 import json import os -import pathlib import pytest +import sqlite3 import tempfile from unittest.mock import patch -@pytest.mark.parametrize( - "path,expected", - [ - ("foo", ["foo"]), - ("foo,bar", ["foo", "bar"]), - ("123,433,112", ["123", "433", "112"]), - ("123~2C433,112", ["123,433", "112"]), - ("123~2F433~2F112", ["123/433/112"]), - ], -) +@pytest.mark.parametrize('path,expected', [ + ('foo', ['foo']), + ('foo,bar', ['foo', 'bar']), + ('123,433,112', ['123', '433', '112']), + ('123%2C433,112', ['123,433', '112']), + ('123%2F433%2F112', ['123/433/112']), +]) def test_urlsafe_components(path, expected): assert expected == utils.urlsafe_components(path) -@pytest.mark.parametrize( - "path,added_args,expected", - [ - ("/foo", {"bar": 1}, "/foo?bar=1"), - ("/foo?bar=1", {"baz": 2}, "/foo?bar=1&baz=2"), - ("/foo?bar=1&bar=2", {"baz": 3}, "/foo?bar=1&bar=2&baz=3"), - ("/foo?bar=1", {"bar": None}, "/foo"), - # Test order is preserved - ( - "/?_facet=prim_state&_facet=area_name", - (("prim_state", "GA"),), - "/?_facet=prim_state&_facet=area_name&prim_state=GA", - ), - ( - "/?_facet=state&_facet=city&state=MI", - (("city", "Detroit"),), - "/?_facet=state&_facet=city&state=MI&city=Detroit", - ), - ( - "/?_facet=state&_facet=city", - (("_facet", "planet_int"),), - "/?_facet=state&_facet=city&_facet=planet_int", - ), - ], -) +@pytest.mark.parametrize('path,added_args,expected', [ + ('/foo', {'bar': 1}, '/foo?bar=1'), + ('/foo?bar=1', {'baz': 2}, '/foo?bar=1&baz=2'), + ('/foo?bar=1&bar=2', {'baz': 3}, '/foo?bar=1&bar=2&baz=3'), + ('/foo?bar=1', {'bar': None}, '/foo'), + # Test order is preserved + ('/?_facet=prim_state&_facet=area_name', ( + ('prim_state', 'GA'), + ), '/?_facet=prim_state&_facet=area_name&prim_state=GA'), + ('/?_facet=state&_facet=city&state=MI', ( + ('city', 'Detroit'), + ), '/?_facet=state&_facet=city&state=MI&city=Detroit'), + ('/?_facet=state&_facet=city', ( + ('_facet', 'planet_int'), + ), '/?_facet=state&_facet=city&_facet=planet_int'), +]) def test_path_with_added_args(path, added_args, expected): - request = Request.fake(path) - actual = utils.path_with_added_args(request, added_args) + try: + path, qsbits = path.split('?', 1) + except ValueError: + qsbits = '' + qs = utils.Querystring(path, qsbits) + actual = utils.path_with_added_args(qs, added_args) assert expected == actual -@pytest.mark.parametrize( - "path,args,expected", - [ - ("/foo?bar=1", {"bar"}, "/foo"), - ("/foo?bar=1&baz=2", {"bar"}, "/foo?baz=2"), - ("/foo?bar=1&bar=2&bar=3", {"bar": "2"}, "/foo?bar=1&bar=3"), - ], -) +@pytest.mark.parametrize('path,args,expected', [ + ('/foo?bar=1', {'bar'}, '/foo'), + ('/foo?bar=1&baz=2', {'bar'}, '/foo?baz=2'), + ('/foo?bar=1&bar=2&bar=3', {'bar': '2'}, '/foo?bar=1&bar=3'), +]) def test_path_with_removed_args(path, args, expected): - request = Request.fake(path) - actual = utils.path_with_removed_args(request, args) - assert expected == actual - # Run the test again but this time use the path= argument - request = Request.fake("/") - actual = utils.path_with_removed_args(request, args, path=path) + try: + path, qsbits = path.split('?', 1) + except ValueError: + qsbits = '' + qs = utils.Querystring(path, qsbits) + actual = utils.path_with_removed_args(qs, args) assert expected == actual -@pytest.mark.parametrize( - "path,args,expected", - [ - ("/foo?bar=1", {"bar": 2}, "/foo?bar=2"), - ("/foo?bar=1&baz=2", {"bar": None}, "/foo?baz=2"), - ], -) +@pytest.mark.parametrize('path,args,expected', [ + ('/foo?bar=1', {'bar': 2}, '/foo?bar=2'), + ('/foo?bar=1&baz=2', {'bar': None}, '/foo?baz=2'), +]) def test_path_with_replaced_args(path, args, expected): - request = Request.fake(path) - actual = utils.path_with_replaced_args(request, args) + try: + path, qsbits = path.split('?', 1) + except ValueError: + qsbits = '' + qs = utils.Querystring(path, qsbits) + actual = utils.path_with_replaced_args(qs, args) assert expected == actual -@pytest.mark.parametrize( - "row,pks,expected_path", - [ - ({"A": "foo", "B": "bar"}, ["A", "B"], "foo,bar"), - ({"A": "f,o", "B": "bar"}, ["A", "B"], "f~2Co,bar"), - ({"A": 123}, ["A"], "123"), - ( - utils.CustomRow( - ["searchable_id", "tag"], - [ - ("searchable_id", {"value": 1, "label": "1"}), - ("tag", {"value": "feline", "label": "feline"}), - ], - ), - ["searchable_id", "tag"], - "1,feline", - ), - ], -) +@pytest.mark.parametrize('row,pks,expected_path', [ + ({'A': 'foo', 'B': 'bar'}, ['A', 'B'], 'foo,bar'), + ({'A': 'f,o', 'B': 'bar'}, ['A', 'B'], 'f%2Co,bar'), + ({'A': 123}, ['A'], '123'), +]) def test_path_from_row_pks(row, pks, expected_path): actual_path = utils.path_from_row_pks(row, pks, False) assert expected_path == actual_path -@pytest.mark.parametrize( - "obj,expected", - [ - ( - { - "Description": "Soft drinks", - "Picture": b"\x15\x1c\x02\xc7\xad\x05\xfe", - "CategoryID": 1, - }, - """ +@pytest.mark.parametrize('obj,expected', [ + ({ + 'Description': 'Soft drinks', + 'Picture': b"\x15\x1c\x02\xc7\xad\x05\xfe", + 'CategoryID': 1, + }, """ {"CategoryID": 1, "Description": "Soft drinks", "Picture": {"$base64": true, "encoded": "FRwCx60F/g=="}} - """.strip(), - ) - ], -) + """.strip()), +]) def test_custom_json_encoder(obj, expected): - actual = json.dumps(obj, cls=utils.CustomJSONEncoder, sort_keys=True) + actual = json.dumps( + obj, + cls=utils.CustomJSONEncoder, + sort_keys=True + ) assert expected == actual -@pytest.mark.parametrize( - "bad_sql", - [ - "update blah;", - "-- sql comment to skip\nupdate blah;", - "update blah set some_column='# Hello there\n\n* This is a list\n* of items\n--\n[And a link](https://github.com/simonw/datasette-render-markdown).'\nas demo_markdown", - "PRAGMA case_sensitive_like = true", - "SELECT * FROM pragma_not_on_allow_list('idx52')", - "/* This comment is not valid. select 1", - "/**/\nupdate foo set bar = 1\n/* test */ select 1", - ], -) +@pytest.mark.parametrize('args,expected_where,expected_params', [ + ( + { + 'name_english__contains': 'foo', + }, + ['"name_english" like :p0'], + ['%foo%'] + ), + ( + { + 'foo': 'bar', + 'bar__contains': 'baz', + }, + ['"bar" like :p0', '"foo" = :p1'], + ['%baz%', 'bar'] + ), + ( + { + 'foo__startswith': 'bar', + 'bar__endswith': 'baz', + }, + ['"bar" like :p0', '"foo" like :p1'], + ['%baz', 'bar%'] + ), + ( + { + 'foo__lt': '1', + 'bar__gt': '2', + 'baz__gte': '3', + 'bax__lte': '4', + }, + ['"bar" > :p0', '"bax" <= :p1', '"baz" >= :p2', '"foo" < :p3'], + [2, 4, 3, 1] + ), + ( + { + 'foo__like': '2%2', + 'zax__glob': '3*', + }, + ['"foo" like :p0', '"zax" glob :p1'], + ['2%2', '3*'] + ), + ( + { + 'foo__isnull': '1', + 'baz__isnull': '1', + 'bar__gt': '10' + }, + ['"bar" > :p0', '"baz" is null', '"foo" is null'], + [10] + ), +]) +def test_build_where(args, expected_where, expected_params): + f = utils.Filters(sorted(args.items())) + sql_bits, actual_params = f.build_where_clauses() + assert expected_where == sql_bits + assert { + 'p{}'.format(i): param + for i, param in enumerate(expected_params) + } == actual_params + + +@pytest.mark.parametrize('bad_sql', [ + 'update blah;', + 'PRAGMA case_sensitive_like = true' + "SELECT * FROM pragma_index_info('idx52')", +]) def test_validate_sql_select_bad(bad_sql): with pytest.raises(utils.InvalidSql): utils.validate_sql_select(bad_sql) -@pytest.mark.parametrize( - "good_sql", - [ - "select count(*) from airports", - "select foo from bar", - "--sql comment to skip\nselect foo from bar", - "select '# Hello there\n\n* This is a list\n* of items\n--\n[And a link](https://github.com/simonw/datasette-render-markdown).'\nas demo_markdown", - "select 1 + 1", - "explain select 1 + 1", - "explain\nselect 1 + 1", - "explain query plan select 1 + 1", - "explain query plan\nselect 1 + 1", - "SELECT\nblah FROM foo", - "WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;", - "explain WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;", - "explain query plan WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;", - "SELECT * FROM pragma_index_info('idx52')", - "select * from pragma_table_xinfo('table')", - # Various types of comment - "-- comment\nselect 1", - "-- one line\n -- two line\nselect 1", - " /* comment */\nselect 1", - " /* comment */select 1", - "/* comment */\n -- another\n /* one more */ select 1", - "/* This comment \n has multiple lines */\nselect 1", - ], -) +@pytest.mark.parametrize('good_sql', [ + 'select count(*) from airports', + 'select foo from bar', + 'select 1 + 1', + 'SELECT\nblah FROM foo', + 'WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt LIMIT 10) SELECT x FROM cnt;' +]) def test_validate_sql_select_good(good_sql): utils.validate_sql_select(good_sql) -@pytest.mark.parametrize("open_quote,close_quote", [('"', '"'), ("[", "]")]) -def test_detect_fts(open_quote, close_quote): - sql = """ +def test_detect_fts(): + sql = ''' CREATE TABLE "Dumb_Table" ( "TreeID" INTEGER, "qSpecies" TEXT @@ -199,58 +204,36 @@ def test_detect_fts(open_quote, close_quote): "qCaretaker" TEXT ); CREATE VIEW Test_View AS SELECT * FROM Dumb_Table; - CREATE VIRTUAL TABLE {open}Street_Tree_List_fts{close} USING FTS4 ("qAddress", "qCaretaker", "qSpecies", content={open}Street_Tree_List{close}); + CREATE VIRTUAL TABLE "Street_Tree_List_fts" USING FTS4 ("qAddress", "qCaretaker", "qSpecies", content="Street_Tree_List"); CREATE VIRTUAL TABLE r USING rtree(a, b, c); - """.format(open=open_quote, close=close_quote) - conn = utils.sqlite3.connect(":memory:") + ''' + conn = sqlite3.connect(':memory:') conn.executescript(sql) - assert None is utils.detect_fts(conn, "Dumb_Table") - assert None is utils.detect_fts(conn, "Test_View") - assert None is utils.detect_fts(conn, "r") - assert "Street_Tree_List_fts" == utils.detect_fts(conn, "Street_Tree_List") - conn.close() + assert None is utils.detect_fts(conn, 'Dumb_Table') + assert None is utils.detect_fts(conn, 'Test_View') + assert None is utils.detect_fts(conn, 'r') + assert 'Street_Tree_List_fts' == utils.detect_fts(conn, 'Street_Tree_List') -@pytest.mark.parametrize("table", ("regular", "has'single quote")) -def test_detect_fts_different_table_names(table): - sql = """ - CREATE TABLE [{table}] ( - "TreeID" INTEGER, - "qSpecies" TEXT - ); - CREATE VIRTUAL TABLE [{table}_fts] USING FTS4 ("qSpecies", content="{table}"); - """.format(table=table) - conn = utils.sqlite3.connect(":memory:") - conn.executescript(sql) - assert "{table}_fts".format(table=table) == utils.detect_fts(conn, table) - conn.close() - - -@pytest.mark.parametrize( - "url,expected", - [ - ("http://www.google.com/", True), - ("https://example.com/", True), - ("www.google.com", False), - ("http://www.google.com/ is a search engine", False), - ], -) +@pytest.mark.parametrize('url,expected', [ + ('http://www.google.com/', True), + ('https://example.com/', True), + ('www.google.com', False), + ('http://www.google.com/ is a search engine', False), +]) def test_is_url(url, expected): assert expected == utils.is_url(url) -@pytest.mark.parametrize( - "s,expected", - [ - ("simple", "simple"), - ("MixedCase", "MixedCase"), - ("-no-leading-hyphens", "no-leading-hyphens-65bea6"), - ("_no-leading-underscores", "no-leading-underscores-b921bc"), - ("no spaces", "no-spaces-7088d7"), - ("-", "336d5e"), - ("no $ characters", "no--characters-59e024"), - ], -) +@pytest.mark.parametrize('s,expected', [ + ('simple', 'simple'), + ('MixedCase', 'MixedCase'), + ('-no-leading-hyphens', 'no-leading-hyphens-65bea6'), + ('_no-leading-underscores', 'no-leading-underscores-b921bc'), + ('no spaces', 'no-spaces-7088d7'), + ('-', '336d5e'), + ('no $ characters', 'no--characters-59e024'), +]) def test_to_css_class(s, expected): assert expected == utils.to_css_class(s) @@ -258,12 +241,11 @@ def test_to_css_class(s, expected): def test_temporary_docker_directory_uses_hard_link(): with tempfile.TemporaryDirectory() as td: os.chdir(td) - with open("hello", "w") as fp: - fp.write("world") + open('hello', 'w').write('world') # Default usage of this should use symlink with utils.temporary_docker_directory( - files=["hello"], - name="t", + files=['hello'], + name='t', metadata=None, extra_options=None, branch=None, @@ -272,28 +254,24 @@ def test_temporary_docker_directory_uses_hard_link(): static=[], install=[], spatialite=False, - version_note=None, - secret="secret", ) as temp_docker: - hello = os.path.join(temp_docker, "hello") - with open(hello) as fp: - assert "world" == fp.read() + hello = os.path.join(temp_docker, 'hello') + assert 'world' == open(hello).read() # It should be a hard link assert 2 == os.stat(hello).st_nlink -@patch("os.link") +@patch('os.link') def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link): # Copy instead if os.link raises OSError (normally due to different device) mock_link.side_effect = OSError with tempfile.TemporaryDirectory() as td: os.chdir(td) - with open("hello", "w") as fp: - fp.write("world") + open('hello', 'w').write('world') # Default usage of this should use symlink with utils.temporary_docker_directory( - files=["hello"], - name="t", + files=['hello'], + name='t', metadata=None, extra_options=None, branch=None, @@ -302,66 +280,51 @@ def test_temporary_docker_directory_uses_copy_if_hard_link_fails(mock_link): static=[], install=[], spatialite=False, - version_note=None, - secret=None, ) as temp_docker: - hello = os.path.join(temp_docker, "hello") - with open(hello) as fp: - assert "world" == fp.read() + hello = os.path.join(temp_docker, 'hello') + assert 'world' == open(hello).read() # It should be a copy, not a hard link assert 1 == os.stat(hello).st_nlink -def test_temporary_docker_directory_quotes_args(): - with tempfile.TemporaryDirectory() as td: - os.chdir(td) - with open("hello", "w") as fp: - fp.write("world") - with utils.temporary_docker_directory( - files=["hello"], - name="t", - metadata=None, - extra_options="--$HOME", - branch=None, - template_dir=None, - plugins_dir=None, - static=[], - install=[], - spatialite=False, - version_note="$PWD", - secret="secret", - ) as temp_docker: - df = os.path.join(temp_docker, "Dockerfile") - with open(df) as fp: - df_contents = fp.read() - assert "'$PWD'" in df_contents - assert "'--$HOME'" in df_contents - assert "ENV DATASETTE_SECRET 'secret'" in df_contents - - def test_compound_keys_after_sql(): - assert "((a > :p0))" == utils.compound_keys_after_sql(["a"]) - assert """ + assert '((a > :p0))' == utils.compound_keys_after_sql(['a']) + assert ''' ((a > :p0) or (a = :p0 and b > :p1)) - """.strip() == utils.compound_keys_after_sql(["a", "b"]) - assert """ + '''.strip() == utils.compound_keys_after_sql(['a', 'b']) + assert ''' ((a > :p0) or (a = :p0 and b > :p1) or (a = :p0 and b = :p1 and c > :p2)) - """.strip() == utils.compound_keys_after_sql(["a", "b", "c"]) + '''.strip() == utils.compound_keys_after_sql(['a', 'b', 'c']) -def test_table_columns(): - conn = sqlite3.connect(":memory:") - conn.executescript(""" - create table places (id integer primary key, name text, bob integer) - """) - assert ["id", "name", "bob"] == utils.table_columns(conn, "places") - conn.close() +def table_exists(table): + return table == "exists.csv" + + +@pytest.mark.parametrize( + "table_and_format,expected_table,expected_format", + [ + ("blah", "blah", None), + ("blah.csv", "blah", "csv"), + ("blah.json", "blah", "json"), + ("blah.baz", "blah.baz", None), + ("exists.csv", "exists.csv", None), + ], +) +def test_resolve_table_and_format( + table_and_format, expected_table, expected_format +): + actual_table, actual_format = utils.resolve_table_and_format( + table_and_format, table_exists + ) + assert expected_table == actual_table + assert expected_format == actual_format @pytest.mark.parametrize( @@ -371,7 +334,9 @@ def test_table_columns(): ("/foo?sql=select+1", "json", {}, "/foo.json?sql=select+1"), ("/foo/bar", "json", {}, "/foo/bar.json"), ("/foo/bar", "csv", {}, "/foo/bar.csv"), + ("/foo/bar.csv", "json", {}, "/foo/bar.csv?_format=json"), ("/foo/bar", "csv", {"_dl": 1}, "/foo/bar.csv?_dl=1"), + ("/foo/b.csv", "json", {"_dl": 1}, "/foo/b.csv?_dl=1&_format=json"), ( "/sf-trees/Street_Tree_List?_search=cherry&_size=1000", "csv", @@ -381,375 +346,10 @@ def test_table_columns(): ], ) def test_path_with_format(path, format, extra_qs, expected): - request = Request.fake(path) - actual = utils.path_with_format(request=request, format=format, extra_qs=extra_qs) + try: + path, qsbits = path.split('?', 1) + except ValueError: + qsbits = '' + qs = utils.Querystring(path, qsbits) + actual = utils.path_with_format(qs, format, extra_qs) assert expected == actual - - -@pytest.mark.parametrize( - "bytes,expected", - [ - (120, "120 bytes"), - (1024, "1.0 KB"), - (1024 * 1024, "1.0 MB"), - (1024 * 1024 * 1024, "1.0 GB"), - (1024 * 1024 * 1024 * 1.3, "1.3 GB"), - (1024 * 1024 * 1024 * 1024, "1.0 TB"), - ], -) -def test_format_bytes(bytes, expected): - assert expected == utils.format_bytes(bytes) - - -@pytest.mark.parametrize( - "query,expected", - [ - ("dog", '"dog"'), - ("cat,", '"cat,"'), - ("cat dog", '"cat" "dog"'), - # If a phrase is already double quoted, leave it so - ('"cat dog"', '"cat dog"'), - ('"cat dog" fish', '"cat dog" "fish"'), - # Sensibly handle unbalanced double quotes - ('cat"', '"cat"'), - ('"cat dog" "fish', '"cat dog" "fish"'), - ], -) -def test_escape_fts(query, expected): - assert expected == utils.escape_fts(query) - - -@pytest.mark.parametrize( - "input,expected", - [ - ("dog", "dog"), - ('dateutil_parse("1/2/2020")', r"dateutil_parse(\0000221/2/2020\000022)"), - ("this\r\nand\r\nthat", r"this\00000Aand\00000Athat"), - ], -) -def test_escape_css_string(input, expected): - assert expected == utils.escape_css_string(input) - - -def test_check_connection_spatialite_raises(): - path = str(pathlib.Path(__file__).parent / "spatialite.db") - conn = sqlite3.connect(path) - with pytest.raises(utils.SpatialiteConnectionProblem): - utils.check_connection(conn) - conn.close() - - -def test_check_connection_passes(): - conn = sqlite3.connect(":memory:") - utils.check_connection(conn) - conn.close() - - -def test_call_with_supported_arguments(): - def foo(a, b): - return f"{a}+{b}" - - assert "1+2" == utils.call_with_supported_arguments(foo, a=1, b=2) - assert "1+2" == utils.call_with_supported_arguments(foo, a=1, b=2, c=3) - - with pytest.raises(TypeError): - utils.call_with_supported_arguments(foo, a=1) - - -@pytest.mark.parametrize( - "data,should_raise", - [ - ([["foo", "bar"], ["foo", "baz"]], False), - ([("foo", "bar"), ("foo", "baz")], False), - ((["foo", "bar"], ["foo", "baz"]), False), - ([["foo", "bar"], ["foo", "baz", "bax"]], True), - ({"foo": ["bar", "baz"]}, False), - ({"foo": ("bar", "baz")}, False), - ({"foo": "bar"}, True), - ], -) -def test_multi_params(data, should_raise): - if should_raise: - with pytest.raises(AssertionError): - utils.MultiParams(data) - return - p1 = utils.MultiParams(data) - assert "bar" == p1["foo"] - assert ["bar", "baz"] == list(p1.getlist("foo")) - - -@pytest.mark.parametrize( - "actor,allow,expected", - [ - # Default is to allow: - (None, None, True), - # {} means deny-all: - (None, {}, False), - ({"id": "root"}, {}, False), - # true means allow-all - ({"id": "root"}, True, True), - (None, True, True), - # false means deny-all - ({"id": "root"}, False, False), - (None, False, False), - # Special case for "unauthenticated": true - (None, {"unauthenticated": True}, True), - (None, {"unauthenticated": False}, False), - # Match on just one property: - (None, {"id": "root"}, False), - ({"id": "root"}, None, True), - ({"id": "simon", "staff": True}, {"staff": True}, True), - ({"id": "simon", "staff": False}, {"staff": True}, False), - # Special "*" value for any key: - ({"id": "root"}, {"id": "*"}, True), - ({}, {"id": "*"}, False), - ({"name": "root"}, {"id": "*"}, False), - # Supports single strings or list of values: - ({"id": "root"}, {"id": "bob"}, False), - ({"id": "root"}, {"id": ["bob"]}, False), - ({"id": "root"}, {"id": "root"}, True), - ({"id": "root"}, {"id": ["root"]}, True), - # Any matching role will work: - ({"id": "garry", "roles": ["staff", "dev"]}, {"roles": ["staff"]}, True), - ({"id": "garry", "roles": ["staff", "dev"]}, {"roles": ["dev"]}, True), - ({"id": "garry", "roles": ["staff", "dev"]}, {"roles": ["otter"]}, False), - ({"id": "garry", "roles": ["staff", "dev"]}, {"roles": ["dev", "otter"]}, True), - ({"id": "garry", "roles": []}, {"roles": ["staff"]}, False), - ({"id": "garry"}, {"roles": ["staff"]}, False), - # Any single matching key works: - ({"id": "root"}, {"bot_id": "my-bot", "id": ["root"]}, True), - ], -) -def test_actor_matches_allow(actor, allow, expected): - assert expected == utils.actor_matches_allow(actor, allow) - - -@pytest.mark.parametrize( - "config,expected", - [ - ({"foo": "bar"}, {"foo": "bar"}), - ({"$env": "FOO"}, "x"), - ({"k": {"$env": "FOO"}}, {"k": "x"}), - ([{"k": {"$env": "FOO"}}, {"z": {"$env": "FOO"}}], [{"k": "x"}, {"z": "x"}]), - ({"k": [{"in_a_list": {"$env": "FOO"}}]}, {"k": [{"in_a_list": "x"}]}), - ], -) -def test_resolve_env_secrets(config, expected): - assert expected == utils.resolve_env_secrets(config, {"FOO": "x"}) - - -@pytest.mark.parametrize( - "actor,expected", - [ - ({"id": "blah"}, "blah"), - ({"id": "blah", "login": "l"}, "l"), - ({"id": "blah", "login": "l"}, "l"), - ({"id": "blah", "login": "l", "username": "u"}, "u"), - ({"login": "l", "name": "n"}, "n"), - ( - {"id": "blah", "login": "l", "username": "u", "name": "n", "display": "d"}, - "d", - ), - ({"weird": "shape"}, "{'weird': 'shape'}"), - ], -) -def test_display_actor(actor, expected): - assert expected == utils.display_actor(actor) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "dbs,expected_path", - [ - (["one_table"], "/one/one"), - (["two_tables"], "/two"), - (["one_table", "two_tables"], "/"), - ], -) -async def test_initial_path_for_datasette(tmp_path_factory, dbs, expected_path): - db_dir = tmp_path_factory.mktemp("dbs") - one_table = str(db_dir / "one.db") - conn1 = sqlite3.connect(one_table) - conn1.execute("create table one (id integer primary key)") - conn1.close() - two_tables = str(db_dir / "two.db") - conn2 = sqlite3.connect(two_tables) - conn2.execute("create table two (id integer primary key)") - conn2.execute("create table three (id integer primary key)") - conn2.close() - datasette = Datasette( - [{"one_table": one_table, "two_tables": two_tables}[db] for db in dbs] - ) - path = await utils.initial_path_for_datasette(datasette) - assert path == expected_path - - -@pytest.mark.parametrize( - "content,expected", - ( - ("title: Hello", {"title": "Hello"}), - ('{"title": "Hello"}', {"title": "Hello"}), - ("{{ this }} is {{ bad }}", None), - ), -) -def test_parse_metadata(content, expected): - if expected is None: - with pytest.raises(utils.BadMetadataError): - utils.parse_metadata(content) - else: - assert utils.parse_metadata(content) == expected - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "sql,expected", - ( - ("select 1", []), - ("select 1 + :one", ["one"]), - ("select 1 + :one + :two", ["one", "two"]), - ("select 'bob' || '0:00' || :cat", ["cat"]), - ("select this is invalid :one, :two, :three", ["one", "two", "three"]), - ), -) -@pytest.mark.parametrize("use_async_version", (False, True)) -async def test_named_parameters(sql, expected, use_async_version): - ds = Datasette([], memory=True) - db = ds.get_database("_memory") - if use_async_version: - params = await utils.derive_named_parameters(db, sql) - else: - params = utils.named_parameters(sql) - assert params == expected - - -@pytest.mark.parametrize( - "original,expected", - ( - ("abc", "abc"), - ("/foo/bar", "~2Ffoo~2Fbar"), - ("/-/bar", "~2F-~2Fbar"), - ("-/db-/table.csv", "-~2Fdb-~2Ftable~2Ecsv"), - (r"%~-/", "~25~7E-~2F"), - ("~25~7E~2D~2F", "~7E25~7E7E~7E2D~7E2F"), - ("with space", "with+space"), - ), -) -def test_tilde_encoding(original, expected): - actual = utils.tilde_encode(original) - assert actual == expected - # And test round-trip - assert original == utils.tilde_decode(actual) - - -@pytest.mark.parametrize( - "url,length,expected", - ( - ("https://example.com/", 5, "http…"), - ("https://example.com/foo/bar", 15, "https://exampl…"), - ("https://example.com/foo/bar/baz.jpg", 30, "https://example.com/foo/ba….jpg"), - # Extensions longer than 4 characters are not treated specially: - ("https://example.com/foo/bar/baz.jpeg2", 30, "https://example.com/foo/bar/b…"), - ( - "https://example.com/foo/bar/baz.jpeg2", - None, - "https://example.com/foo/bar/baz.jpeg2", - ), - ), -) -def test_truncate_url(url, length, expected): - actual = utils.truncate_url(url, length) - assert actual == expected - - -@pytest.mark.parametrize( - "pairs,expected", - ( - # Simple nested objects - ([("a", "b")], {"a": "b"}), - ([("a.b", "c")], {"a": {"b": "c"}}), - # JSON literals - ([("a.b", "true")], {"a": {"b": True}}), - ([("a.b", "false")], {"a": {"b": False}}), - ([("a.b", "null")], {"a": {"b": None}}), - ([("a.b", "1")], {"a": {"b": 1}}), - ([("a.b", "1.1")], {"a": {"b": 1.1}}), - # Nested JSON literals - ([("a.b", '{"foo": "bar"}')], {"a": {"b": {"foo": "bar"}}}), - ([("a.b", "[1, 2, 3]")], {"a": {"b": [1, 2, 3]}}), - # JSON strings are preserved - ([("a.b", '"true"')], {"a": {"b": "true"}}), - ([("a.b", '"[1, 2, 3]"')], {"a": {"b": "[1, 2, 3]"}}), - # Later keys over-ride the previous - ( - [ - ("a", "b"), - ("a.b", "c"), - ], - {"a": {"b": "c"}}, - ), - ( - [ - ("settings.trace_debug", "true"), - ("plugins.datasette-ripgrep.path", "/etc"), - ("settings.trace_debug", "false"), - ], - { - "settings": { - "trace_debug": False, - }, - "plugins": { - "datasette-ripgrep": { - "path": "/etc", - } - }, - }, - ), - ), -) -def test_pairs_to_nested_config(pairs, expected): - actual = utils.pairs_to_nested_config(pairs) - assert actual == expected - - -@pytest.mark.asyncio -async def test_calculate_etag(tmp_path): - path = tmp_path / "test.txt" - path.write_text("hello") - etag = '"5d41402abc4b2a76b9719d911017c592"' - assert etag == await utils.calculate_etag(path) - assert utils._etag_cache[path] == etag - utils._etag_cache[path] = "hash" - assert "hash" == await utils.calculate_etag(path) - utils._etag_cache.clear() - - -@pytest.mark.parametrize( - "dict1,dict2,expected", - [ - # Basic update - ({"a": 1, "b": 2}, {"b": 3, "c": 4}, {"a": 1, "b": 3, "c": 4}), - # Nested dictionary update - ( - {"a": 1, "b": {"x": 10, "y": 20}}, - {"b": {"y": 30, "z": 40}}, - {"a": 1, "b": {"x": 10, "y": 30, "z": 40}}, - ), - # Deep nested update - ( - {"a": {"b": {"c": 1}}}, - {"a": {"b": {"d": 2}}}, - {"a": {"b": {"c": 1, "d": 2}}}, - ), - # Update with mixed types - ( - {"a": 1, "b": {"x": 10}}, - {"b": {"y": 20}, "c": [1, 2, 3]}, - {"a": 1, "b": {"x": 10, "y": 20}, "c": [1, 2, 3]}, - ), - ], -) -def test_deep_dict_update(dict1, dict2, expected): - result = utils.deep_dict_update(dict1, dict2) - assert result == expected - # Check that the original dict1 was modified - assert dict1 == expected diff --git a/tests/test_utils_check_callable.py b/tests/test_utils_check_callable.py deleted file mode 100644 index 4f72f9ff..00000000 --- a/tests/test_utils_check_callable.py +++ /dev/null @@ -1,46 +0,0 @@ -from datasette.utils.check_callable import check_callable -import pytest - - -class AsyncClass: - async def __call__(self): - pass - - -class NotAsyncClass: - def __call__(self): - pass - - -class ClassNoCall: - pass - - -async def async_func(): - pass - - -def non_async_func(): - pass - - -@pytest.mark.parametrize( - "obj,expected_is_callable,expected_is_async_callable", - ( - (async_func, True, True), - (non_async_func, True, False), - (AsyncClass(), True, True), - (NotAsyncClass(), True, False), - (ClassNoCall(), False, False), - (AsyncClass, True, False), - (NotAsyncClass, True, False), - (ClassNoCall, True, False), - ("", False, False), - (1, False, False), - (str, True, False), - ), -) -def test_check_callable(obj, expected_is_callable, expected_is_async_callable): - status = check_callable(obj) - assert status.is_callable == expected_is_callable - assert status.is_async_callable == expected_is_async_callable diff --git a/tests/test_utils_permissions.py b/tests/test_utils_permissions.py deleted file mode 100644 index bc3599c2..00000000 --- a/tests/test_utils_permissions.py +++ /dev/null @@ -1,610 +0,0 @@ -import pytest -from datasette.app import Datasette -from datasette.permissions import PermissionSQL -from datasette.utils.permissions import resolve_permissions_from_catalog -from typing import Callable, List - - -@pytest.fixture -def db(): - ds = Datasette() - import tempfile - from datasette.database import Database - - path = tempfile.mktemp(suffix="demo.db") - db = ds.add_database(Database(ds, path=path)) - return db - - -NO_RULES_SQL = ( - "SELECT NULL AS parent, NULL AS child, NULL AS allow, NULL AS reason WHERE 0" -) - - -def plugin_allow_all_for_user(user: str) -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - return PermissionSQL( - """ - SELECT NULL AS parent, NULL AS child, 1 AS allow, - 'global allow for ' || :allow_all_user || ' on ' || :allow_all_action AS reason - WHERE :actor_id = :allow_all_user - """, - {"allow_all_user": user, "allow_all_action": action}, - ) - - return provider - - -def plugin_deny_specific_table( - user: str, parent: str, child: str -) -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - return PermissionSQL( - """ - SELECT :deny_specific_table_parent AS parent, :deny_specific_table_child AS child, 0 AS allow, - 'deny ' || :deny_specific_table_parent || '/' || :deny_specific_table_child || ' for ' || :deny_specific_table_user || ' on ' || :deny_specific_table_action AS reason - WHERE :actor_id = :deny_specific_table_user - """, - { - "deny_specific_table_parent": parent, - "deny_specific_table_child": child, - "deny_specific_table_user": user, - "deny_specific_table_action": action, - }, - ) - - return provider - - -def plugin_org_policy_deny_parent(parent: str) -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - return PermissionSQL( - """ - SELECT :org_policy_parent_deny_parent AS parent, NULL AS child, 0 AS allow, - 'org policy: parent ' || :org_policy_parent_deny_parent || ' denied on ' || :org_policy_parent_deny_action AS reason - """, - { - "org_policy_parent_deny_parent": parent, - "org_policy_parent_deny_action": action, - }, - ) - - return provider - - -def plugin_allow_parent_for_user( - user: str, parent: str -) -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - return PermissionSQL( - """ - SELECT :allow_parent_parent AS parent, NULL AS child, 1 AS allow, - 'allow full parent for ' || :allow_parent_user || ' on ' || :allow_parent_action AS reason - WHERE :actor_id = :allow_parent_user - """, - { - "allow_parent_parent": parent, - "allow_parent_user": user, - "allow_parent_action": action, - }, - ) - - return provider - - -def plugin_child_allow_for_user( - user: str, parent: str, child: str -) -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - return PermissionSQL( - """ - SELECT :allow_child_parent AS parent, :allow_child_child AS child, 1 AS allow, - 'allow child for ' || :allow_child_user || ' on ' || :allow_child_action AS reason - WHERE :actor_id = :allow_child_user - """, - { - "allow_child_parent": parent, - "allow_child_child": child, - "allow_child_user": user, - "allow_child_action": action, - }, - ) - - return provider - - -def plugin_root_deny_for_all() -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - return PermissionSQL( - """ - SELECT NULL AS parent, NULL AS child, 0 AS allow, 'root deny for all on ' || :root_deny_action AS reason - """, - {"root_deny_action": action}, - ) - - return provider - - -def plugin_conflicting_same_child_rules( - user: str, parent: str, child: str -) -> List[Callable[[str], PermissionSQL]]: - def allow_provider(action: str) -> PermissionSQL: - return PermissionSQL( - """ - SELECT :conflict_child_allow_parent AS parent, :conflict_child_allow_child AS child, 1 AS allow, - 'team grant at child for ' || :conflict_child_allow_user || ' on ' || :conflict_child_allow_action AS reason - WHERE :actor_id = :conflict_child_allow_user - """, - { - "conflict_child_allow_parent": parent, - "conflict_child_allow_child": child, - "conflict_child_allow_user": user, - "conflict_child_allow_action": action, - }, - ) - - def deny_provider(action: str) -> PermissionSQL: - return PermissionSQL( - """ - SELECT :conflict_child_deny_parent AS parent, :conflict_child_deny_child AS child, 0 AS allow, - 'exception deny at child for ' || :conflict_child_deny_user || ' on ' || :conflict_child_deny_action AS reason - WHERE :actor_id = :conflict_child_deny_user - """, - { - "conflict_child_deny_parent": parent, - "conflict_child_deny_child": child, - "conflict_child_deny_user": user, - "conflict_child_deny_action": action, - }, - ) - - return [allow_provider, deny_provider] - - -def plugin_allow_all_for_action( - user: str, allowed_action: str -) -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - if action != allowed_action: - return PermissionSQL(NO_RULES_SQL) - # Sanitize parameter names by replacing hyphens with underscores - param_prefix = action.replace("-", "_") - return PermissionSQL( - f""" - SELECT NULL AS parent, NULL AS child, 1 AS allow, - 'global allow for ' || :{param_prefix}_user || ' on ' || :{param_prefix}_action AS reason - WHERE :actor_id = :{param_prefix}_user - """, - {f"{param_prefix}_user": user, f"{param_prefix}_action": action}, - ) - - return provider - - -VIEW_TABLE = "view-table" - - -# ---------- Catalog DDL (from your schema) ---------- -CATALOG_DDL = """ -CREATE TABLE IF NOT EXISTS catalog_databases ( - database_name TEXT PRIMARY KEY, - path TEXT, - is_memory INTEGER, - schema_version INTEGER -); -CREATE TABLE IF NOT EXISTS catalog_tables ( - database_name TEXT, - table_name TEXT, - rootpage INTEGER, - sql TEXT, - PRIMARY KEY (database_name, table_name), - FOREIGN KEY (database_name) REFERENCES catalog_databases(database_name) -); -""" - -PARENTS = ["accounting", "hr", "analytics"] -SPECIALS = {"accounting": ["sales"], "analytics": ["secret"], "hr": []} - -TABLE_CANDIDATES_SQL = ( - "SELECT database_name AS parent, table_name AS child FROM catalog_tables" -) -PARENT_CANDIDATES_SQL = ( - "SELECT database_name AS parent, NULL AS child FROM catalog_databases" -) - - -# ---------- Helpers ---------- -async def seed_catalog(db, per_parent: int = 10) -> None: - await db.execute_write_script(CATALOG_DDL) - # databases - db_rows = [(p, f"/{p}.db", 0, 1) for p in PARENTS] - await db.execute_write_many( - "INSERT OR REPLACE INTO catalog_databases(database_name, path, is_memory, schema_version) VALUES (?,?,?,?)", - db_rows, - ) - - # tables - def tables_for(parent: str, n: int): - base = [f"table{i:02d}" for i in range(1, n + 1)] - for s in SPECIALS.get(parent, []): - if s not in base: - base[0] = s - return base - - table_rows = [] - for p in PARENTS: - for t in tables_for(p, per_parent): - table_rows.append((p, t, 0, f"CREATE TABLE {t} (id INTEGER PRIMARY KEY)")) - await db.execute_write_many( - "INSERT OR REPLACE INTO catalog_tables(database_name, table_name, rootpage, sql) VALUES (?,?,?,?)", - table_rows, - ) - - -def res_allowed(rows, parent=None): - return sorted( - r["resource"] - for r in rows - if r["allow"] == 1 and (parent is None or r["parent"] == parent) - ) - - -def res_denied(rows, parent=None): - return sorted( - r["resource"] - for r in rows - if r["allow"] == 0 and (parent is None or r["parent"] == parent) - ) - - -# ---------- Tests ---------- -@pytest.mark.asyncio -async def test_alice_global_allow_with_specific_denies_catalog(db): - await seed_catalog(db) - plugins = [ - plugin_allow_all_for_user("alice"), - plugin_deny_specific_table("alice", "accounting", "sales"), - plugin_org_policy_deny_parent("hr"), - ] - rows = await resolve_permissions_from_catalog( - db, - {"id": "alice"}, - plugins, - VIEW_TABLE, - TABLE_CANDIDATES_SQL, - implicit_deny=True, - ) - # Alice can see everything except accounting/sales and hr/* - assert "/accounting/sales" in res_denied(rows) - for r in rows: - if r["parent"] == "hr": - assert r["allow"] == 0 - elif r["resource"] == "/accounting/sales": - assert r["allow"] == 0 - else: - assert r["allow"] == 1 - - -@pytest.mark.asyncio -async def test_carol_parent_allow_but_child_conflict_deny_wins_catalog(db): - await seed_catalog(db) - plugins = [ - plugin_org_policy_deny_parent("hr"), - plugin_allow_parent_for_user("carol", "analytics"), - *plugin_conflicting_same_child_rules("carol", "analytics", "secret"), - ] - rows = await resolve_permissions_from_catalog( - db, - {"id": "carol"}, - plugins, - VIEW_TABLE, - TABLE_CANDIDATES_SQL, - implicit_deny=True, - ) - allowed_analytics = res_allowed(rows, parent="analytics") - denied_analytics = res_denied(rows, parent="analytics") - - assert "/analytics/secret" in denied_analytics - # 10 analytics children total, 1 denied - assert len(allowed_analytics) == 9 - - -@pytest.mark.asyncio -async def test_specificity_child_allow_overrides_parent_deny_catalog(db): - await seed_catalog(db) - plugins = [ - plugin_allow_all_for_user("alice"), - plugin_org_policy_deny_parent("analytics"), # parent-level deny - plugin_child_allow_for_user( - "alice", "analytics", "table02" - ), # child allow beats parent deny - ] - rows = await resolve_permissions_from_catalog( - db, - {"id": "alice"}, - plugins, - VIEW_TABLE, - TABLE_CANDIDATES_SQL, - implicit_deny=True, - ) - - # table02 allowed, other analytics tables denied - assert any(r["resource"] == "/analytics/table02" and r["allow"] == 1 for r in rows) - assert all( - (r["parent"] != "analytics" or r["child"] == "table02" or r["allow"] == 0) - for r in rows - ) - - -@pytest.mark.asyncio -async def test_root_deny_all_but_parent_allow_rescues_specific_parent_catalog(db): - await seed_catalog(db) - plugins = [ - plugin_root_deny_for_all(), # root deny - plugin_allow_parent_for_user( - "bob", "accounting" - ), # parent allow (more specific) - ] - rows = await resolve_permissions_from_catalog( - db, {"id": "bob"}, plugins, VIEW_TABLE, TABLE_CANDIDATES_SQL, implicit_deny=True - ) - for r in rows: - if r["parent"] == "accounting": - assert r["allow"] == 1 - else: - assert r["allow"] == 0 - - -@pytest.mark.asyncio -async def test_parent_scoped_candidates(db): - await seed_catalog(db) - plugins = [ - plugin_org_policy_deny_parent("hr"), - plugin_allow_parent_for_user("carol", "analytics"), - ] - rows = await resolve_permissions_from_catalog( - db, - {"id": "carol"}, - plugins, - VIEW_TABLE, - PARENT_CANDIDATES_SQL, - implicit_deny=True, - ) - d = {r["resource"]: r["allow"] for r in rows} - assert d["/analytics"] == 1 - assert d["/hr"] == 0 - - -@pytest.mark.asyncio -async def test_implicit_deny_behavior(db): - await seed_catalog(db) - plugins = [] # no rules at all - - # implicit_deny=True -> everything denied with reason 'implicit deny' - rows = await resolve_permissions_from_catalog( - db, - {"id": "erin"}, - plugins, - VIEW_TABLE, - TABLE_CANDIDATES_SQL, - implicit_deny=True, - ) - assert all(r["allow"] == 0 and r["reason"] == "implicit deny" for r in rows) - - # implicit_deny=False -> no winner => allow is None, reason is None - rows2 = await resolve_permissions_from_catalog( - db, - {"id": "erin"}, - plugins, - VIEW_TABLE, - TABLE_CANDIDATES_SQL, - implicit_deny=False, - ) - assert all(r["allow"] is None and r["reason"] is None for r in rows2) - - -@pytest.mark.asyncio -async def test_candidate_filters_via_params(db): - await seed_catalog(db) - # Add some metadata to test filtering - # Mark 'hr' as is_memory=1 and increment analytics schema_version - await db.execute_write( - "UPDATE catalog_databases SET is_memory=1 WHERE database_name='hr'" - ) - await db.execute_write( - "UPDATE catalog_databases SET schema_version=2 WHERE database_name='analytics'" - ) - - # Candidate SQL that filters by db metadata via params - candidate_sql = """ - SELECT t.database_name AS parent, t.table_name AS child - FROM catalog_tables t - JOIN catalog_databases d ON d.database_name = t.database_name - WHERE (:exclude_memory = 1 AND d.is_memory = 1) IS NOT 1 - AND (:min_schema_version IS NULL OR d.schema_version >= :min_schema_version) - """ - - plugins = [ - plugin_root_deny_for_all(), - plugin_allow_parent_for_user( - "dev", "analytics" - ), # analytics rescued if included by candidates - ] - - # Case 1: exclude memory dbs, require schema_version >= 2 -> only analytics appear, and thus are allowed - rows = await resolve_permissions_from_catalog( - db, - {"id": "dev"}, - plugins, - VIEW_TABLE, - candidate_sql, - candidate_params={"exclude_memory": 1, "min_schema_version": 2}, - implicit_deny=True, - ) - assert rows and all(r["parent"] == "analytics" for r in rows) - assert all(r["allow"] == 1 for r in rows) - - # Case 2: include memory dbs, min_schema_version = None -> accounting/hr/analytics appear, - # but root deny wins except where specifically allowed (none except analytics parent allow doesn’t apply to table depth if candidate includes children; still fine—policy is explicit). - rows2 = await resolve_permissions_from_catalog( - db, - {"id": "dev"}, - plugins, - VIEW_TABLE, - candidate_sql, - candidate_params={"exclude_memory": 0, "min_schema_version": None}, - implicit_deny=True, - ) - assert any(r["parent"] == "accounting" for r in rows2) - assert any(r["parent"] == "hr" for r in rows2) - # For table-scoped candidates, the parent-level allow does not override root deny unless you have child-level rules - assert all(r["allow"] in (0, 1) for r in rows2) - - -@pytest.mark.asyncio -async def test_action_specific_rules(db): - await seed_catalog(db) - plugins = [plugin_allow_all_for_action("dana", VIEW_TABLE)] - - view_rows = await resolve_permissions_from_catalog( - db, - {"id": "dana"}, - plugins, - VIEW_TABLE, - TABLE_CANDIDATES_SQL, - implicit_deny=True, - ) - assert view_rows and all(r["allow"] == 1 for r in view_rows) - assert all(r["action"] == VIEW_TABLE for r in view_rows) - - insert_rows = await resolve_permissions_from_catalog( - db, - {"id": "dana"}, - plugins, - "insert-row", - TABLE_CANDIDATES_SQL, - implicit_deny=True, - ) - assert insert_rows and all(r["allow"] == 0 for r in insert_rows) - assert all(r["reason"] == "implicit deny" for r in insert_rows) - assert all(r["action"] == "insert-row" for r in insert_rows) - - -@pytest.mark.asyncio -async def test_actor_actor_id_action_parameters_available(db): - """Test that :actor (JSON), :actor_id, and :action are all available in SQL""" - await seed_catalog(db) - - def plugin_using_all_parameters() -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - return PermissionSQL(""" - SELECT NULL AS parent, NULL AS child, 1 AS allow, - 'Actor ID: ' || COALESCE(:actor_id, 'null') || - ', Actor JSON: ' || COALESCE(:actor, 'null') || - ', Action: ' || :action AS reason - WHERE :actor_id = 'test_user' AND :action = 'view-table' - AND json_extract(:actor, '$.role') = 'admin' - """) - - return provider - - plugins = [plugin_using_all_parameters()] - - # Test with full actor dict - rows = await resolve_permissions_from_catalog( - db, - {"id": "test_user", "role": "admin"}, - plugins, - "view-table", - TABLE_CANDIDATES_SQL, - implicit_deny=True, - ) - - # Should have allowed rows with reason containing all the info - allowed = [r for r in rows if r["allow"] == 1] - assert len(allowed) > 0 - - # Check that the reason string contains evidence of all parameters - reason = allowed[0]["reason"] - assert "test_user" in reason - assert "view-table" in reason - # The :actor parameter should be the JSON string - assert "Actor JSON:" in reason - - -@pytest.mark.asyncio -async def test_multiple_plugins_with_own_parameters(db): - """ - Test that multiple plugins can use their own parameter names without conflict. - - This verifies that the parameter naming convention works: plugins prefix their - parameters (e.g., :plugin1_pattern, :plugin2_message) and both sets of parameters - are successfully bound in the SQL queries. - """ - await seed_catalog(db) - - def plugin_one() -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - if action != "view-table": - return PermissionSQL("plugin_one", "SELECT NULL WHERE 0", {}) - return PermissionSQL( - """ - SELECT database_name AS parent, table_name AS child, - 1 AS allow, 'Plugin one used param: ' || :plugin1_param AS reason - FROM catalog_tables - WHERE database_name = 'accounting' - """, - { - "plugin1_param": "value1", - }, - ) - - return provider - - def plugin_two() -> Callable[[str], PermissionSQL]: - def provider(action: str) -> PermissionSQL: - if action != "view-table": - return PermissionSQL("plugin_two", "SELECT NULL WHERE 0", {}) - return PermissionSQL( - """ - SELECT database_name AS parent, table_name AS child, - 1 AS allow, 'Plugin two used param: ' || :plugin2_param AS reason - FROM catalog_tables - WHERE database_name = 'hr' - """, - { - "plugin2_param": "value2", - }, - ) - - return provider - - plugins = [plugin_one(), plugin_two()] - - rows = await resolve_permissions_from_catalog( - db, - {"id": "test_user"}, - plugins, - "view-table", - TABLE_CANDIDATES_SQL, - implicit_deny=False, - ) - - # Both plugins should contribute results with their parameters successfully bound - plugin_one_rows = [ - r for r in rows if r.get("reason") and "Plugin one" in r["reason"] - ] - plugin_two_rows = [ - r for r in rows if r.get("reason") and "Plugin two" in r["reason"] - ] - - assert len(plugin_one_rows) > 0, "Plugin one should contribute rules" - assert len(plugin_two_rows) > 0, "Plugin two should contribute rules" - - # Verify each plugin's parameters were successfully bound in the SQL - assert any( - "value1" in r.get("reason", "") for r in plugin_one_rows - ), "Plugin one's :plugin1_param should be bound" - assert any( - "value2" in r.get("reason", "") for r in plugin_two_rows - ), "Plugin two's :plugin2_param should be bound" diff --git a/tests/test_utils_sql_analysis.py b/tests/test_utils_sql_analysis.py deleted file mode 100644 index 5730cd0d..00000000 --- a/tests/test_utils_sql_analysis.py +++ /dev/null @@ -1,188 +0,0 @@ -import pytest - -from datasette.utils.sqlite import sqlite3 -from datasette.utils.sql_analysis import analyze_sql_tables - - -@pytest.fixture -def conn(): - conn = sqlite3.connect(":memory:") - conn.executescript(""" - create table dogs (id integer primary key, name text, age integer); - create table cats (id integer primary key, name text); - create table log (message text); - create view dog_names as select id, name from dogs; - create trigger dogs_after_insert after insert on dogs begin - update cats set name = new.name where id = new.id; - insert into log (message) values (new.name); - end; - create trigger dog_names_instead_of_update instead of update on dog_names begin - update dogs set name = new.name where id = old.id; - end; - """) - try: - yield conn - finally: - conn.close() - - -def as_tuples(analysis): - return [ - ( - access.operation, - access.database, - access.sqlite_schema, - access.table, - access.columns, - access.source, - ) - for access in analysis.table_accesses - ] - - -def test_analyze_select_tables(conn): - analysis = analyze_sql_tables( - conn, - "select dogs.name, cats.name from dogs join cats on dogs.id = cats.id where dogs.age > ?", - (2,), - database_name="data", - ) - - assert set(as_tuples(analysis)) == { - ("read", "data", "main", "cats", ("id", "name"), None), - ("read", "data", "main", "dogs", ("age", "id", "name"), None), - } - - -def test_analyze_uses_sqlite_schema_as_default_database(conn): - analysis = analyze_sql_tables(conn, "select name from dogs") - - assert set(as_tuples(analysis)) == { - ("read", "main", "main", "dogs", ("name",), None), - } - - -def test_analyze_insert_tables(conn): - analysis = analyze_sql_tables( - conn, - "insert into dogs (name, age) values (:name, :age)", - {"name": "Cleo", "age": 4}, - database_name="data", - ) - - assert set(as_tuples(analysis)) == { - ("insert", "data", "main", "dogs", (), None), - ("read", "data", "main", "dogs", ("id", "name"), "dogs_after_insert"), - ("update", "data", "main", "cats", ("name",), "dogs_after_insert"), - ("read", "data", "main", "cats", ("id",), "dogs_after_insert"), - ("insert", "data", "main", "log", (), "dogs_after_insert"), - } - - -def test_analyze_update_tables(conn): - analysis = analyze_sql_tables( - conn, - "update dogs set age = age + 1 where name = ?", - ("Cleo",), - database_name="data", - ) - - assert set(as_tuples(analysis)) == { - ("update", "data", "main", "dogs", ("age",), None), - ("read", "data", "main", "dogs", ("age", "name"), None), - } - - -def test_analyze_delete_tables(conn): - analysis = analyze_sql_tables( - conn, - "delete from dogs where name = ?", - ("Cleo",), - database_name="data", - ) - - assert set(as_tuples(analysis)) == { - ("delete", "data", "main", "dogs", (), None), - ("read", "data", "main", "dogs", ("name",), None), - } - - -def test_analyze_insert_select_with_cte(conn): - analysis = analyze_sql_tables( - conn, - """ - with old_dogs as ( - select name from dogs where age > :age - ) - insert into cats (name) - select name from old_dogs - """, - {"age": 10}, - database_name="data", - ) - - assert set(as_tuples(analysis)) == { - ("insert", "data", "main", "cats", (), None), - ("read", "data", "main", "dogs", ("age", "name"), "old_dogs"), - } - - -def test_analyze_view_with_instead_of_trigger(conn): - analysis = analyze_sql_tables( - conn, - "update dog_names set name = :name where id = :id", - {"name": "Zelda", "id": 1}, - database_name="data", - ) - - assert set(as_tuples(analysis)) == { - ("update", "data", "main", "dog_names", ("name",), None), - ("read", "data", "main", "dogs", ("id", "name"), "dog_names"), - ("read", "data", "main", "dog_names", ("id", "name"), "dog_names"), - ( - "read", - "data", - "main", - "dog_names", - ("id", "name"), - "dog_names_instead_of_update", - ), - ("update", "data", "main", "dogs", ("name",), "dog_names_instead_of_update"), - ("read", "data", "main", "dogs", ("id",), "dog_names_instead_of_update"), - } - - -def test_analyze_attached_database_tables(conn): - conn.execute("attach database ':memory:' as extra") - conn.execute("create table extra.people (id integer primary key, name text)") - - analysis = analyze_sql_tables( - conn, - "insert into extra.people (name) select name from dogs", - database_name="data", - schema_to_database={"extra": "extra_db"}, - ) - - assert set(as_tuples(analysis)) == { - ("insert", "extra_db", "extra", "people", (), None), - ("read", "data", "main", "dogs", ("name",), None), - } - - -def test_analyze_clears_authorizer_on_error(): - class FakeConnection: - def __init__(self): - self.authorizers = [] - - def set_authorizer(self, authorizer): - self.authorizers.append(authorizer) - - def execute(self, sql, params): - raise sqlite3.OperationalError("bad SQL") - - conn = FakeConnection() - - with pytest.raises(sqlite3.OperationalError): - analyze_sql_tables(conn, "bad SQL") - - assert conn.authorizers[-1] is None diff --git a/tests/test_write_wrapper.py b/tests/test_write_wrapper.py deleted file mode 100644 index 88ce5520..00000000 --- a/tests/test_write_wrapper.py +++ /dev/null @@ -1,768 +0,0 @@ -""" -Tests for the write_wrapper plugin hook. -""" - -import asyncio -from dataclasses import dataclass -from datasette.app import Datasette -from datasette.events import Event -from datasette.hookspecs import hookimpl -from datasette.plugins import pm -import pytest -import sqlite3 -import time - - -@dataclass -class DummyEvent(Event): - name = "dummy" - message: str - - -@pytest.fixture -def datasette(tmp_path): - db_path = str(tmp_path / "test.db") - ds = Datasette([db_path]) - return ds - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "use_execute_write", - (False, True), - ids=["execute_write_fn", "execute_write"], -) -async def test_write_wrapper_before_and_after(datasette, use_execute_write): - """Test that code before and after yield both execute.""" - log = [] - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - def wrapper(conn): - log.append("before") - yield - log.append("after") - - return wrapper - - pm.register(Plugin(), name="test_before_after") - try: - db = datasette.get_database("test") - if use_execute_write: - await db.execute_write( - "create table if not exists t (id integer primary key)" - ) - else: - await db.execute_write_fn( - lambda conn: conn.execute( - "create table if not exists t (id integer primary key)" - ) - ) - assert log == ["before", "after"] - finally: - pm.unregister(name="test_before_after") - - -@pytest.mark.asyncio -async def test_write_wrapper_receives_result_via_yield(datasette): - """Test that the result of fn(conn) is sent back through yield.""" - captured = {} - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - def wrapper(conn): - result = yield - captured["result"] = result - - return wrapper - - pm.register(Plugin(), name="test_result") - try: - db = datasette.get_database("test") - await db.execute_write_fn( - lambda conn: conn.execute( - "create table if not exists t2 (id integer primary key)" - ) - ) - assert "result" in captured - # Should be a sqlite3 Cursor - assert captured["result"] is not None - finally: - pm.unregister(name="test_result") - - -@pytest.mark.asyncio -async def test_write_wrapper_exception_thrown_into_generator(datasette): - """Test that exceptions from fn(conn) are thrown into the generator.""" - caught = {} - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - def wrapper(conn): - try: - yield - except Exception as e: - caught["error"] = e - - return wrapper - - pm.register(Plugin(), name="test_exception") - try: - db = datasette.get_database("test") - with pytest.raises(Exception, match="deliberate"): - await db.execute_write_fn( - lambda conn: (_ for _ in ()).throw(Exception("deliberate")) - ) - assert "error" in caught - assert str(caught["error"]) == "deliberate" - finally: - pm.unregister(name="test_exception") - - -@pytest.mark.asyncio -async def test_write_wrapper_conn_is_usable(datasette): - """Test that the conn passed to the wrapper can execute SQL.""" - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - def wrapper(conn): - conn.execute("create table if not exists hook_log (msg text)") - conn.execute("insert into hook_log values ('before')") - yield - conn.execute("insert into hook_log values ('after')") - - return wrapper - - pm.register(Plugin(), name="test_conn") - try: - db = datasette.get_database("test") - await db.execute_write_fn( - lambda conn: conn.execute( - "create table if not exists t3 (id integer primary key)" - ) - ) - result = await db.execute("select msg from hook_log order by rowid") - messages = [row[0] for row in result.rows] - assert messages == ["before", "after"] - finally: - pm.unregister(name="test_conn") - - -@pytest.mark.asyncio -async def test_write_wrapper_multiple_plugins_nest(datasette): - """Test that multiple write_wrapper plugins nest correctly.""" - log = [] - - class PluginA: - __name__ = "PluginA" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - def wrapper(conn): - log.append("A-before") - yield - log.append("A-after") - - return wrapper - - class PluginB: - __name__ = "PluginB" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - def wrapper(conn): - log.append("B-before") - yield - log.append("B-after") - - return wrapper - - pm.register(PluginA(), name="PluginA") - pm.register(PluginB(), name="PluginB") - try: - db = datasette.get_database("test") - await db.execute_write_fn( - lambda conn: conn.execute( - "create table if not exists t4 (id integer primary key)" - ) - ) - assert set(log) == {"A-before", "A-after", "B-before", "B-after"} - # Verify proper nesting: each plugin's before/after should be - # symmetric around the write - a_before = log.index("A-before") - a_after = log.index("A-after") - b_before = log.index("B-before") - b_after = log.index("B-after") - if a_before < b_before: - assert a_after > b_after, "A is outer so A-after should come after B-after" - else: - assert b_after > a_after, "B is outer so B-after should come after A-after" - finally: - pm.unregister(name="PluginA") - pm.unregister(name="PluginB") - - -@pytest.mark.asyncio -async def test_write_wrapper_return_none_skips(datasette): - """Test that returning None from write_wrapper means no wrapping.""" - log = [] - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - log.append("hook-called") - return None - - pm.register(Plugin(), name="test_skip") - try: - db = datasette.get_database("test") - await db.execute_write_fn( - lambda conn: conn.execute( - "create table if not exists t5 (id integer primary key)" - ) - ) - assert log == ["hook-called"] - finally: - pm.unregister(name="test_skip") - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "request_value,transaction_value,expected_request,expected_transaction", - ( - ("fake-request", True, "fake-request", True), - (None, True, None, True), - (None, False, None, False), - ), - ids=["with-request", "request-none-by-default", "transaction-false"], -) -async def test_write_wrapper_hook_parameters( - datasette, - request_value, - transaction_value, - expected_request, - expected_transaction, -): - """Test that request and transaction parameters are passed through.""" - captured = {} - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - captured["request"] = request - captured["database"] = database - captured["transaction"] = transaction - - pm.register(Plugin(), name="test_params") - try: - db = datasette.get_database("test") - kwargs = {"transaction": transaction_value} - if request_value is not None: - kwargs["request"] = request_value - await db.execute_write_fn( - lambda conn: conn.execute( - "create table if not exists t6 (id integer primary key)" - ), - **kwargs, - ) - assert captured["request"] == expected_request - assert captured["database"] == "test" - assert captured["transaction"] == expected_transaction - finally: - pm.unregister(name="test_params") - - -@pytest.mark.asyncio -async def test_write_wrapper_via_api(tmp_path): - """Test that write_wrapper fires for API write operations.""" - log = [] - - db_path = str(tmp_path / "test.db") - ds = Datasette([db_path], pdb=False) - ds.root_enabled = True - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - if database != "test": - return None - - def wrapper(conn): - log.append("before") - yield - log.append("after") - - return wrapper - - pm.register(Plugin(), name="test_api") - try: - db = ds.get_database("test") - await db.execute_write( - "create table if not exists api_test (id integer primary key, name text)" - ) - log.clear() - - token = "dstok_{}".format( - ds.sign( - {"a": "root", "token": "dstok", "t": int(time.time())}, - namespace="token", - ) - ) - response = await ds.client.post( - "/test/api_test/-/insert", - json={"row": {"name": "test"}, "return": True}, - headers={ - "Authorization": "Bearer {}".format(token), - "Content-Type": "application/json", - }, - ) - assert response.status_code == 201, response.json() - assert log == ["before", "after"] - finally: - pm.unregister(name="test_api") - - -@pytest.mark.asyncio -async def test_write_wrapper_change_group_pattern(datasette): - """Test the motivating use case: activating a change group around a write.""" - db = datasette.get_database("test") - - await db.execute_write( - "create table if not exists groups (id integer primary key, current integer)" - ) - await db.execute_write( - "create table if not exists data (id integer primary key, value text)" - ) - await db.execute_write("insert into groups (id, current) values (1, null)") - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - if request and getattr(request, "group_id", None): - group_id = request.group_id - - def wrapper(conn): - conn.execute( - "update groups set current = 1 where id = ?", [group_id] - ) - yield - conn.execute("update groups set current = null where current = 1") - - return wrapper - - pm.register(Plugin(), name="test_change_group") - try: - - class FakeRequest: - group_id = 1 - - await db.execute_write_fn( - lambda conn: conn.execute("insert into data (value) values ('test')"), - request=FakeRequest(), - ) - - result = await db.execute("select current from groups where id = 1") - assert result.rows[0][0] is None - finally: - pm.unregister(name="test_change_group") - - -WRITE_ACTIONS = ( - sqlite3.SQLITE_INSERT, - sqlite3.SQLITE_UPDATE, - sqlite3.SQLITE_DELETE, -) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "actor,table,should_deny", - ( - (None, "protected_table", True), - ({"id": "regular"}, "protected_table", True), - ({"id": "admin"}, "protected_table", False), - (None, "other_table", False), - ({"id": "regular"}, "other_table", False), - ), - ids=[ - "no-actor-protected", - "regular-user-protected", - "admin-protected", - "no-actor-other", - "regular-user-other", - ], -) -async def test_write_wrapper_set_authorizer(datasette, actor, table, should_deny): - """Test the docs example that uses set_authorizer to block writes to a protected table.""" - db = datasette.get_database("test") - await db.execute_write( - "create table if not exists protected_table (id integer primary key, value text)" - ) - await db.execute_write( - "create table if not exists other_table (id integer primary key, value text)" - ) - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - actor = None - if request: - actor = request.actor - if actor and actor.get("id") == "admin": - return None - - def wrapper(conn): - def authorizer(action, arg1, arg2, db_name, trigger): - if action in WRITE_ACTIONS and arg1 == "protected_table": - return sqlite3.SQLITE_DENY - return sqlite3.SQLITE_OK - - conn.set_authorizer(authorizer) - try: - yield - finally: - conn.set_authorizer(lambda *args: sqlite3.SQLITE_OK) - - return wrapper - - class FakeRequest: - def __init__(self, actor): - self.actor = actor - - pm.register(Plugin(), name="test_set_authorizer") - try: - request = FakeRequest(actor) - if should_deny: - with pytest.raises(Exception): - await db.execute_write_fn( - lambda conn: conn.execute( - f"insert into {table} (value) values ('test')" - ), - request=request, - ) - else: - await db.execute_write_fn( - lambda conn: conn.execute( - f"insert into {table} (value) values ('test')" - ), - request=request, - ) - result = await db.execute( - f"select value from {table} order by rowid desc limit 1" - ) - assert result.rows[0][0] == "test" - finally: - pm.unregister(name="test_set_authorizer") - - -# --- Tests for track_event callback --- - - -@pytest.fixture -def ds_with_event_tracking(tmp_path): - """Datasette instance that records tracked events and registers DummyEvent.""" - db_path = str(tmp_path / "test.db") - ds = Datasette([db_path]) - ds._tracked_events = [] - # Set event_classes directly to avoid needing invoke_startup - ds.event_classes = (DummyEvent,) - - async def recording_track_event(event): - ds._tracked_events.append(event) - - ds.track_event = recording_track_event - - yield ds - ds.close() - - -@pytest.mark.asyncio -async def test_track_event_in_write_fn(ds_with_event_tracking): - """fn(conn, track_event) can queue events that are dispatched after commit.""" - ds = ds_with_event_tracking - db = ds.get_database("test") - - def my_write(conn, track_event): - conn.execute("create table if not exists te1 (id integer primary key)") - track_event(DummyEvent(actor=None, message="hello")) - - await db.execute_write_fn(my_write) - assert len(ds._tracked_events) == 1 - assert ds._tracked_events[0].message == "hello" - - -@pytest.mark.asyncio -async def test_track_event_discarded_on_exception(ds_with_event_tracking): - """Events are discarded if the write fn raises an exception.""" - ds = ds_with_event_tracking - db = ds.get_database("test") - - def my_write(conn, track_event): - track_event(DummyEvent(actor=None, message="should not fire")) - raise ValueError("deliberate error") - - with pytest.raises(ValueError, match="deliberate"): - await db.execute_write_fn(my_write) - assert len(ds._tracked_events) == 0 - - -@pytest.mark.asyncio -async def test_track_event_existing_fn_signature_still_works(ds_with_event_tracking): - """Existing fn(conn) signatures continue to work without track_event.""" - ds = ds_with_event_tracking - db = ds.get_database("test") - - await db.execute_write_fn( - lambda conn: conn.execute( - "create table if not exists te2 (id integer primary key)" - ) - ) - # No events, no errors - assert len(ds._tracked_events) == 0 - - -@pytest.mark.asyncio -async def test_track_event_in_write_wrapper(ds_with_event_tracking): - """write_wrapper generator with (conn, track_event) can queue events.""" - ds = ds_with_event_tracking - db = ds.get_database("test") - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - def wrapper(conn, track_event): - track_event(DummyEvent(actor=None, message="from wrapper before")) - yield - track_event(DummyEvent(actor=None, message="from wrapper after")) - - return wrapper - - pm.register(Plugin(), name="test_track_wrapper") - try: - await db.execute_write_fn( - lambda conn: conn.execute( - "create table if not exists te3 (id integer primary key)" - ) - ) - assert len(ds._tracked_events) == 2 - assert ds._tracked_events[0].message == "from wrapper before" - assert ds._tracked_events[1].message == "from wrapper after" - finally: - pm.unregister(name="test_track_wrapper") - - -@pytest.mark.asyncio -async def test_track_event_shared_between_fn_and_wrapper(ds_with_event_tracking): - """Both fn and wrapper can queue events, all dispatched in order.""" - ds = ds_with_event_tracking - db = ds.get_database("test") - - class Plugin: - __name__ = "Plugin" - - @staticmethod - @hookimpl - def write_wrapper(datasette, database, request, transaction): - def wrapper(conn, track_event): - track_event(DummyEvent(actor=None, message="wrapper-before")) - yield - track_event(DummyEvent(actor=None, message="wrapper-after")) - - return wrapper - - pm.register(Plugin(), name="test_track_shared") - try: - - def my_write(conn, track_event): - conn.execute("create table if not exists te4 (id integer primary key)") - track_event(DummyEvent(actor=None, message="from-fn")) - - await db.execute_write_fn(my_write) - messages = [e.message for e in ds._tracked_events] - assert messages == ["wrapper-before", "from-fn", "wrapper-after"] - finally: - pm.unregister(name="test_track_shared") - - -@pytest.mark.asyncio -async def test_track_event_with_block_false(ds_with_event_tracking): - """Events are dispatched even when block=False (non-blocking writes).""" - ds = ds_with_event_tracking - db = ds.get_database("test") - - def my_write(conn, track_event): - conn.execute("create table if not exists te5 (id integer primary key)") - track_event(DummyEvent(actor=None, message="non-blocking")) - - task_id = await db.execute_write_fn(my_write, block=False) - assert task_id is not None - - # Give the background task time to complete - for _ in range(50): - if ds._tracked_events: - break - await asyncio.sleep(0.01) - - assert len(ds._tracked_events) == 1 - assert ds._tracked_events[0].message == "non-blocking" - - -@pytest.mark.asyncio -async def test_track_event_with_block_false_discarded_on_exception( - ds_with_event_tracking, -): - """Events queued by a non-blocking write are discarded if the write fails.""" - ds = ds_with_event_tracking - db = ds.get_database("test") - - def my_write(conn, track_event): - track_event(DummyEvent(actor=None, message="should not fire")) - raise ValueError("deliberate error") - - task_id = await db.execute_write_fn(my_write, block=False) - assert task_id is not None - - # A following blocking write proves the failed non-blocking task has - # completed; one more loop turn lets its event-dispatch task observe the - # exception and exit. - await db.execute_write_fn(lambda conn: conn.execute("select 1")) - await asyncio.sleep(0) - - assert ds._tracked_events == [] - - -# --- Tests for RenameTableEvent detection --- - - -@pytest.fixture -def ds_for_rename(tmp_path): - """Datasette instance that records tracked events for rename detection tests.""" - from datasette.events import RenameTableEvent - - db_path = str(tmp_path / "test.db") - ds = Datasette([db_path]) - ds._tracked_events = [] - ds.event_classes = (RenameTableEvent,) - - async def recording_track_event(event): - ds._tracked_events.append(event) - - ds.track_event = recording_track_event - return ds - - -@pytest.mark.asyncio -async def test_rename_table_fires_event(ds_for_rename): - """Renaming a table via ALTER TABLE fires a RenameTableEvent.""" - from datasette.events import RenameTableEvent - - ds = ds_for_rename - db = ds.get_database("test") - - await db.execute_write("create table old_name (id integer primary key)") - - def rename(conn): - conn.execute("alter table old_name rename to new_name") - - await db.execute_write_fn(rename) - - rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] - assert len(rename_events) == 1 - assert rename_events[0].old_table == "old_name" - assert rename_events[0].new_table == "new_name" - assert rename_events[0].database == "test" - - -@pytest.mark.asyncio -async def test_no_rename_event_for_regular_writes(ds_for_rename): - """Regular writes (CREATE, INSERT) do not fire RenameTableEvent.""" - from datasette.events import RenameTableEvent - - ds = ds_for_rename - db = ds.get_database("test") - - await db.execute_write("create table t (id integer primary key)") - await db.execute_write_fn(lambda conn: conn.execute("insert into t values (1)")) - - rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] - assert len(rename_events) == 0 - - -@pytest.mark.asyncio -async def test_no_rename_event_on_rollback(ds_for_rename): - """RenameTableEvent is not fired if the write raises an exception.""" - from datasette.events import RenameTableEvent - - ds = ds_for_rename - db = ds.get_database("test") - - await db.execute_write("create table rollback_test (id integer primary key)") - - def rename_then_fail(conn): - conn.execute("alter table rollback_test rename to renamed") - raise ValueError("deliberate error") - - with pytest.raises(ValueError, match="deliberate"): - await db.execute_write_fn(rename_then_fail) - - rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] - assert len(rename_events) == 0 - - -@pytest.mark.asyncio -async def test_multiple_renames_in_one_write(ds_for_rename): - """Multiple renames in a single write fire multiple RenameTableEvents.""" - from datasette.events import RenameTableEvent - - ds = ds_for_rename - db = ds.get_database("test") - - await db.execute_write("create table alpha (id integer primary key)") - await db.execute_write("create table beta (id integer primary key)") - - def rename_both(conn): - conn.execute("alter table alpha rename to alpha2") - conn.execute("alter table beta rename to beta2") - - await db.execute_write_fn(rename_both) - - rename_events = [e for e in ds._tracked_events if isinstance(e, RenameTableEvent)] - assert len(rename_events) == 2 - names = {(e.old_table, e.new_table) for e in rename_events} - assert names == {("alpha", "alpha2"), ("beta", "beta2")} diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 808feea7..00000000 --- a/tests/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from datasette.utils.sqlite import sqlite3 - - -def last_event(datasette): - events = getattr(datasette, "_tracked_events", []) - return events[-1] if events else None - - -def assert_footer_links(soup): - footer_links = soup.find("footer").find_all("a") - assert 4 == len(footer_links) - datasette_link, license_link, source_link, about_link = footer_links - assert "Datasette" == datasette_link.text.strip() - assert "tests/fixtures.py" == source_link.text.strip() - assert "Apache License 2.0" == license_link.text.strip() - assert "About Datasette" == about_link.text.strip() - assert "https://datasette.io/" == datasette_link["href"] - assert ( - "https://github.com/simonw/datasette/blob/main/tests/fixtures.py" - == source_link["href"] - ) - assert ( - "https://github.com/simonw/datasette/blob/main/LICENSE" == license_link["href"] - ) - assert "https://github.com/simonw/datasette" == about_link["href"] - - -def inner_html(soup): - html = str(soup) - # This includes the parent tag - so remove that - inner_html = html.split(">", 1)[1].rsplit("<", 1)[0] - return inner_html.strip() - - -def has_load_extension(): - conn = sqlite3.connect(":memory:") - result = hasattr(conn, "enable_load_extension") - conn.close() - return result - - -def cookie_was_deleted(response, cookie): - return any( - h - for h in response.headers.get_list("set-cookie") - if h.startswith(f'{cookie}="";') - ) diff --git a/versioneer.py b/versioneer.py new file mode 100644 index 00000000..64fea1c8 --- /dev/null +++ b/versioneer.py @@ -0,0 +1,1822 @@ + +# Version: 0.18 + +"""The Versioneer - like a rocketeer, but for versions. + +The Versioneer +============== + +* like a rocketeer, but for versions! +* https://github.com/warner/python-versioneer +* Brian Warner +* License: Public Domain +* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy +* [![Latest Version] +(https://pypip.in/version/versioneer/badge.svg?style=flat) +](https://pypi.python.org/pypi/versioneer/) +* [![Build Status] +(https://travis-ci.org/warner/python-versioneer.png?branch=master) +](https://travis-ci.org/warner/python-versioneer) + +This is a tool for managing a recorded version number in distutils-based +python projects. The goal is to remove the tedious and error-prone "update +the embedded version string" step from your release process. Making a new +release should be as easy as recording a new tag in your version-control +system, and maybe making new tarballs. + + +## Quick Install + +* `pip install versioneer` to somewhere to your $PATH +* add a `[versioneer]` section to your setup.cfg (see below) +* run `versioneer install` in your source tree, commit the results + +## Version Identifiers + +Source trees come from a variety of places: + +* a version-control system checkout (mostly used by developers) +* a nightly tarball, produced by build automation +* a snapshot tarball, produced by a web-based VCS browser, like github's + "tarball from tag" feature +* a release tarball, produced by "setup.py sdist", distributed through PyPI + +Within each source tree, the version identifier (either a string or a number, +this tool is format-agnostic) can come from a variety of places: + +* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows + about recent "tags" and an absolute revision-id +* the name of the directory into which the tarball was unpacked +* an expanded VCS keyword ($Id$, etc) +* a `_version.py` created by some earlier build step + +For released software, the version identifier is closely related to a VCS +tag. Some projects use tag names that include more than just the version +string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool +needs to strip the tag prefix to extract the version identifier. For +unreleased software (between tags), the version identifier should provide +enough information to help developers recreate the same tree, while also +giving them an idea of roughly how old the tree is (after version 1.2, before +version 1.3). Many VCS systems can report a description that captures this, +for example `git describe --tags --dirty --always` reports things like +"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the +0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has +uncommitted changes. + +The version identifier is used for multiple purposes: + +* to allow the module to self-identify its version: `myproject.__version__` +* to choose a name and prefix for a 'setup.py sdist' tarball + +## Theory of Operation + +Versioneer works by adding a special `_version.py` file into your source +tree, where your `__init__.py` can import it. This `_version.py` knows how to +dynamically ask the VCS tool for version information at import time. + +`_version.py` also contains `$Revision$` markers, and the installation +process marks `_version.py` to have this marker rewritten with a tag name +during the `git archive` command. As a result, generated tarballs will +contain enough information to get the proper version. + +To allow `setup.py` to compute a version too, a `versioneer.py` is added to +the top level of your source tree, next to `setup.py` and the `setup.cfg` +that configures it. This overrides several distutils/setuptools commands to +compute the version when invoked, and changes `setup.py build` and `setup.py +sdist` to replace `_version.py` with a small static file that contains just +the generated version data. + +## Installation + +See [INSTALL.md](./INSTALL.md) for detailed installation instructions. + +## Version-String Flavors + +Code which uses Versioneer can learn about its version string at runtime by +importing `_version` from your main `__init__.py` file and running the +`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can +import the top-level `versioneer.py` and run `get_versions()`. + +Both functions return a dictionary with different flavors of version +information: + +* `['version']`: A condensed version string, rendered using the selected + style. This is the most commonly used value for the project's version + string. The default "pep440" style yields strings like `0.11`, + `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section + below for alternative styles. + +* `['full-revisionid']`: detailed revision identifier. For Git, this is the + full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". + +* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the + commit date in ISO 8601 format. This will be None if the date is not + available. + +* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that + this is only accurate if run in a VCS checkout, otherwise it is likely to + be False or None + +* `['error']`: if the version string could not be computed, this will be set + to a string describing the problem, otherwise it will be None. It may be + useful to throw an exception in setup.py if this is set, to avoid e.g. + creating tarballs with a version string of "unknown". + +Some variants are more useful than others. Including `full-revisionid` in a +bug report should allow developers to reconstruct the exact code being tested +(or indicate the presence of local changes that should be shared with the +developers). `version` is suitable for display in an "about" box or a CLI +`--version` output: it can be easily compared against release notes and lists +of bugs fixed in various releases. + +The installer adds the following text to your `__init__.py` to place a basic +version in `YOURPROJECT.__version__`: + + from ._version import get_versions + __version__ = get_versions()['version'] + del get_versions + +## Styles + +The setup.cfg `style=` configuration controls how the VCS information is +rendered into a version string. + +The default style, "pep440", produces a PEP440-compliant string, equal to the +un-prefixed tag name for actual releases, and containing an additional "local +version" section with more detail for in-between builds. For Git, this is +TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags +--dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the +tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and +that this commit is two revisions ("+2") beyond the "0.11" tag. For released +software (exactly equal to a known tag), the identifier will only contain the +stripped tag, e.g. "0.11". + +Other styles are available. See [details.md](details.md) in the Versioneer +source tree for descriptions. + +## Debugging + +Versioneer tries to avoid fatal errors: if something goes wrong, it will tend +to return a version of "0+unknown". To investigate the problem, run `setup.py +version`, which will run the version-lookup code in a verbose mode, and will +display the full contents of `get_versions()` (including the `error` string, +which may help identify what went wrong). + +## Known Limitations + +Some situations are known to cause problems for Versioneer. This details the +most significant ones. More can be found on Github +[issues page](https://github.com/warner/python-versioneer/issues). + +### Subprojects + +Versioneer has limited support for source trees in which `setup.py` is not in +the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are +two common reasons why `setup.py` might not be in the root: + +* Source trees which contain multiple subprojects, such as + [Buildbot](https://github.com/buildbot/buildbot), which contains both + "master" and "slave" subprojects, each with their own `setup.py`, + `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI + distributions (and upload multiple independently-installable tarballs). +* Source trees whose main purpose is to contain a C library, but which also + provide bindings to Python (and perhaps other langauges) in subdirectories. + +Versioneer will look for `.git` in parent directories, and most operations +should get the right version string. However `pip` and `setuptools` have bugs +and implementation details which frequently cause `pip install .` from a +subproject directory to fail to find a correct version string (so it usually +defaults to `0+unknown`). + +`pip install --editable .` should work correctly. `setup.py install` might +work too. + +Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in +some later version. + +[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking +this issue. The discussion in +[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the +issue from the Versioneer side in more detail. +[pip PR#3176](https://github.com/pypa/pip/pull/3176) and +[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve +pip to let Versioneer work correctly. + +Versioneer-0.16 and earlier only looked for a `.git` directory next to the +`setup.cfg`, so subprojects were completely unsupported with those releases. + +### Editable installs with setuptools <= 18.5 + +`setup.py develop` and `pip install --editable .` allow you to install a +project into a virtualenv once, then continue editing the source code (and +test) without re-installing after every change. + +"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a +convenient way to specify executable scripts that should be installed along +with the python package. + +These both work as expected when using modern setuptools. When using +setuptools-18.5 or earlier, however, certain operations will cause +`pkg_resources.DistributionNotFound` errors when running the entrypoint +script, which must be resolved by re-installing the package. This happens +when the install happens with one version, then the egg_info data is +regenerated while a different version is checked out. Many setup.py commands +cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into +a different virtualenv), so this can be surprising. + +[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes +this one, but upgrading to a newer version of setuptools should probably +resolve it. + +### Unicode version strings + +While Versioneer works (and is continually tested) with both Python 2 and +Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. +Newer releases probably generate unicode version strings on py2. It's not +clear that this is wrong, but it may be surprising for applications when then +write these strings to a network connection or include them in bytes-oriented +APIs like cryptographic checksums. + +[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates +this question. + + +## Updating Versioneer + +To upgrade your project to a new release of Versioneer, do the following: + +* install the new Versioneer (`pip install -U versioneer` or equivalent) +* edit `setup.cfg`, if necessary, to include any new configuration settings + indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install` in your source tree, to replace + `SRC/_version.py` +* commit any changed files + +## Future Directions + +This tool is designed to make it easily extended to other version-control +systems: all VCS-specific components are in separate directories like +src/git/ . The top-level `versioneer.py` script is assembled from these +components by running make-versioneer.py . In the future, make-versioneer.py +will take a VCS name as an argument, and will construct a version of +`versioneer.py` that is specific to the given VCS. It might also take the +configuration arguments that are currently provided manually during +installation by editing setup.py . Alternatively, it might go the other +direction and include code from all supported VCS systems, reducing the +number of intermediate scripts. + + +## License + +To make Versioneer easier to embed, all its code is dedicated to the public +domain. The `_version.py` that it creates is also in the public domain. +Specifically, both are released under the Creative Commons "Public Domain +Dedication" license (CC0-1.0), as described in +https://creativecommons.org/publicdomain/zero/1.0/ . + +""" + +from __future__ import print_function +try: + import configparser +except ImportError: + import ConfigParser as configparser +import errno +import json +import os +import re +import subprocess +import sys + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_root(): + """Get the project root directory. + + We require that all commands are run from the project root, i.e. the + directory that contains setup.py, setup.cfg, and versioneer.py . + """ + root = os.path.realpath(os.path.abspath(os.getcwd())) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + # allow 'python path/to/setup.py COMMAND' + root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + err = ("Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND').") + raise VersioneerBadRootError(err) + try: + # Certain runtime workflows (setup.py install/develop in a setuptools + # tree) execute all dependencies in a single python process, so + # "versioneer" may be imported multiple times, and python's shared + # module-import table will cache the first one. So we can't use + # os.path.dirname(__file__), as that will find whichever + # versioneer.py was first imported, even in later projects. + me = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(me)[0]) + vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) + if me_dir != vsr_dir: + print("Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(me), versioneer_py)) + except NameError: + pass + return root + + +def get_config_from_root(root): + """Read the project setup.cfg file to determine Versioneer config.""" + # This might raise EnvironmentError (if setup.cfg is missing), or + # configparser.NoSectionError (if it lacks a [versioneer] section), or + # configparser.NoOptionError (if it lacks "VCS="). See the docstring at + # the top of versioneer.py for instructions on writing your setup.cfg . + setup_cfg = os.path.join(root, "setup.cfg") + parser = configparser.SafeConfigParser() + with open(setup_cfg, "r") as f: + parser.readfp(f) + VCS = parser.get("versioneer", "VCS") # mandatory + + def get(parser, name): + if parser.has_option("versioneer", name): + return parser.get("versioneer", name) + return None + cfg = VersioneerConfig() + cfg.VCS = VCS + cfg.style = get(parser, "style") or "" + cfg.versionfile_source = get(parser, "versionfile_source") + cfg.versionfile_build = get(parser, "versionfile_build") + cfg.tag_prefix = get(parser, "tag_prefix") + if cfg.tag_prefix in ("''", '""'): + cfg.tag_prefix = "" + cfg.parentdir_prefix = get(parser, "parentdir_prefix") + cfg.verbose = get(parser, "verbose") + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +# these dictionaries contain VCS-specific tools +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = p.communicate()[0].strip() + if sys.version_info[0] >= 3: + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, p.returncode + return stdout, p.returncode + + +LONG_VERSION_PY['git'] = ''' +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.18 (https://github.com/warner/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" + git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" + git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "%(STYLE)s" + cfg.tag_prefix = "%(TAG_PREFIX)s" + cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" + cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %%s" %% dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %%s" %% (commands,)) + return None, None + stdout = p.communicate()[0].strip() + if sys.version_info[0] >= 3: + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %%s (error)" %% dispcmd) + print("stdout was %%s" %% stdout) + return None, p.returncode + return stdout, p.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for i in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + else: + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %%s but none started with prefix %%s" %% + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + f = open(versionfile_abs, "r") + for line in f.readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + f.close() + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") + if date is not None: + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %%d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = set([r for r in refs if re.search(r'\d', r)]) + if verbose: + print("discarding '%%s', no digits" %% ",".join(refs - tags)) + if verbose: + print("likely tags: %%s" %% ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %%s" %% r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %%s not under git control" %% root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%%s*" %% tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%%s'" + %% describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%%s' doesn't start with prefix '%%s'" + print(fmt %% (full_tag, tag_prefix)) + pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" + %% (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], + cwd=root)[0].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%%d" %% pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%%d" %% pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%%s'" %% style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for i in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} +''' + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + f = open(versionfile_abs, "r") + for line in f.readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + f.close() + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") + if date is not None: + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = set([r for r in refs if re.search(r'\d', r)]) + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], + cwd=root)[0].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def do_vcs_install(manifest_in, versionfile_source, ipy): + """Git-specific installation logic for Versioneer. + + For Git, this means creating/changing .gitattributes to mark _version.py + for export-subst keyword substitution. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + files = [manifest_in, versionfile_source] + if ipy: + files.append(ipy) + try: + me = __file__ + if me.endswith(".pyc") or me.endswith(".pyo"): + me = os.path.splitext(me)[0] + ".py" + versioneer_file = os.path.relpath(me) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) + present = False + try: + f = open(".gitattributes", "r") + for line in f.readlines(): + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + f.close() + except EnvironmentError: + pass + if not present: + f = open(".gitattributes", "a+") + f.write("%s export-subst\n" % versionfile_source) + f.close() + files.append(".gitattributes") + run_command(GITS, ["add", "--"] + files) + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for i in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + else: + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +SHORT_VERSION_PY = """ +# This file was generated by 'versioneer.py' (0.18) from +# revision-control system data, or from the parent directory name of an +# unpacked source archive. Distribution tarballs contain a pre-generated copy +# of this file. + +import json + +version_json = ''' +%s +''' # END VERSION_JSON + + +def get_versions(): + return json.loads(version_json) +""" + + +def versions_from_file(filename): + """Try to determine the version from _version.py if present.""" + try: + with open(filename) as f: + contents = f.read() + except EnvironmentError: + raise NotThisMethod("unable to read _version.py") + mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) + if not mo: + mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) + if not mo: + raise NotThisMethod("no version_json in _version.py") + return json.loads(mo.group(1)) + + +def write_to_version_file(filename, versions): + """Write the given version number to the given _version.py file.""" + os.unlink(filename) + contents = json.dumps(versions, sort_keys=True, + indent=1, separators=(",", ": ")) + with open(filename, "w") as f: + f.write(SHORT_VERSION_PY % contents) + + print("set %s to '%s'" % (filename, versions["version"])) + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +class VersioneerBadRootError(Exception): + """The project root directory is unknown or missing key files.""" + + +def get_versions(verbose=False): + """Get the project version from whatever source is available. + + Returns dict with two keys: 'version' and 'full'. + """ + if "versioneer" in sys.modules: + # see the discussion in cmdclass.py:get_cmdclass() + del sys.modules["versioneer"] + + root = get_root() + cfg = get_config_from_root(root) + + assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" + handlers = HANDLERS.get(cfg.VCS) + assert handlers, "unrecognized VCS '%s'" % cfg.VCS + verbose = verbose or cfg.verbose + assert cfg.versionfile_source is not None, \ + "please set versioneer.versionfile_source" + assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" + + versionfile_abs = os.path.join(root, cfg.versionfile_source) + + # extract version from first of: _version.py, VCS command (e.g. 'git + # describe'), parentdir. This is meant to work for developers using a + # source checkout, for users of a tarball created by 'setup.py sdist', + # and for users of a tarball/zipball created by 'git archive' or github's + # download-from-tag feature or the equivalent in other VCSes. + + get_keywords_f = handlers.get("get_keywords") + from_keywords_f = handlers.get("keywords") + if get_keywords_f and from_keywords_f: + try: + keywords = get_keywords_f(versionfile_abs) + ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) + if verbose: + print("got version from expanded keyword %s" % ver) + return ver + except NotThisMethod: + pass + + try: + ver = versions_from_file(versionfile_abs) + if verbose: + print("got version from file %s %s" % (versionfile_abs, ver)) + return ver + except NotThisMethod: + pass + + from_vcs_f = handlers.get("pieces_from_vcs") + if from_vcs_f: + try: + pieces = from_vcs_f(cfg.tag_prefix, root, verbose) + ver = render(pieces, cfg.style) + if verbose: + print("got version from VCS %s" % ver) + return ver + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + if verbose: + print("got version from parentdir %s" % ver) + return ver + except NotThisMethod: + pass + + if verbose: + print("unable to compute version") + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, "error": "unable to compute version", + "date": None} + + +def get_version(): + """Get the short version string for this project.""" + return get_versions()["version"] + + +def get_cmdclass(): + """Get the custom setuptools/distutils subclasses used by Versioneer.""" + if "versioneer" in sys.modules: + del sys.modules["versioneer"] + # this fixes the "python setup.py develop" case (also 'install' and + # 'easy_install .'), in which subdependencies of the main project are + # built (using setup.py bdist_egg) in the same python process. Assume + # a main project A and a dependency B, which use different versions + # of Versioneer. A's setup.py imports A's Versioneer, leaving it in + # sys.modules by the time B's setup.py is executed, causing B to run + # with the wrong versioneer. Setuptools wraps the sub-dep builds in a + # sandbox that restores sys.modules to it's pre-build state, so the + # parent is protected against the child's "import versioneer". By + # removing ourselves from sys.modules here, before the child build + # happens, we protect the child from the parent's versioneer too. + # Also see https://github.com/warner/python-versioneer/issues/52 + + cmds = {} + + # we add "version" to both distutils and setuptools + from distutils.core import Command + + class cmd_version(Command): + description = "report generated version string" + user_options = [] + boolean_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + vers = get_versions(verbose=True) + print("Version: %s" % vers["version"]) + print(" full-revisionid: %s" % vers.get("full-revisionid")) + print(" dirty: %s" % vers.get("dirty")) + print(" date: %s" % vers.get("date")) + if vers["error"]: + print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version + + # we override "build_py" in both distutils and setuptools + # + # most invocation pathways end up running build_py: + # distutils/build -> build_py + # distutils/install -> distutils/build ->.. + # setuptools/bdist_wheel -> distutils/install ->.. + # setuptools/bdist_egg -> distutils/install_lib -> build_py + # setuptools/install -> bdist_egg ->.. + # setuptools/develop -> ? + # pip install: + # copies source tree to a tempdir before running egg_info/etc + # if .git isn't copied too, 'git describe' will fail + # then does setup.py bdist_wheel, or sometimes setup.py install + # setup.py egg_info -> ? + + # we override different "build_py" commands for both environments + if "setuptools" in sys.modules: + from setuptools.command.build_py import build_py as _build_py + else: + from distutils.command.build_py import build_py as _build_py + + class cmd_build_py(_build_py): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_py.run(self) + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if cfg.versionfile_build: + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py + + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string + # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. + # setup(console=[{ + # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION + # "product_version": versioneer.get_version(), + # ... + + class cmd_build_exe(_build_exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _build_exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + cmds["build_exe"] = cmd_build_exe + del cmds["build_py"] + + if 'py2exe' in sys.modules: # py2exe enabled? + try: + from py2exe.distutils_buildexe import py2exe as _py2exe # py3 + except ImportError: + from py2exe.build_exe import py2exe as _py2exe # py2 + + class cmd_py2exe(_py2exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _py2exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + cmds["py2exe"] = cmd_py2exe + + # we override different "sdist" commands for both environments + if "setuptools" in sys.modules: + from setuptools.command.sdist import sdist as _sdist + else: + from distutils.command.sdist import sdist as _sdist + + class cmd_sdist(_sdist): + def run(self): + versions = get_versions() + self._versioneer_generated_versions = versions + # unless we update this, the command will keep using the old + # version + self.distribution.metadata.version = versions["version"] + return _sdist.run(self) + + def make_release_tree(self, base_dir, files): + root = get_root() + cfg = get_config_from_root(root) + _sdist.make_release_tree(self, base_dir, files) + # now locate _version.py in the new base_dir directory + # (remembering that it may be a hardlink) and replace it with an + # updated value + target_versionfile = os.path.join(base_dir, cfg.versionfile_source) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, + self._versioneer_generated_versions) + cmds["sdist"] = cmd_sdist + + return cmds + + +CONFIG_ERROR = """ +setup.cfg is missing the necessary Versioneer configuration. You need +a section like: + + [versioneer] + VCS = git + style = pep440 + versionfile_source = src/myproject/_version.py + versionfile_build = myproject/_version.py + tag_prefix = + parentdir_prefix = myproject- + +You will also need to edit your setup.py to use the results: + + import versioneer + setup(version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), ...) + +Please read the docstring in ./versioneer.py for configuration instructions, +edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. +""" + +SAMPLE_CONFIG = """ +# See the docstring in versioneer.py for instructions. Note that you must +# re-run 'versioneer.py setup' after changing this section, and commit the +# resulting files. + +[versioneer] +#VCS = git +#style = pep440 +#versionfile_source = +#versionfile_build = +#tag_prefix = +#parentdir_prefix = + +""" + +INIT_PY_SNIPPET = """ +from ._version import get_versions +__version__ = get_versions()['version'] +del get_versions +""" + + +def do_setup(): + """Main VCS-independent setup function for installing Versioneer.""" + root = get_root() + try: + cfg = get_config_from_root(root) + except (EnvironmentError, configparser.NoSectionError, + configparser.NoOptionError) as e: + if isinstance(e, (EnvironmentError, configparser.NoSectionError)): + print("Adding sample versioneer config to setup.cfg", + file=sys.stderr) + with open(os.path.join(root, "setup.cfg"), "a") as f: + f.write(SAMPLE_CONFIG) + print(CONFIG_ERROR, file=sys.stderr) + return 1 + + print(" creating %s" % cfg.versionfile_source) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), + "__init__.py") + if os.path.exists(ipy): + try: + with open(ipy, "r") as f: + old = f.read() + except EnvironmentError: + old = "" + if INIT_PY_SNIPPET not in old: + print(" appending to %s" % ipy) + with open(ipy, "a") as f: + f.write(INIT_PY_SNIPPET) + else: + print(" %s unmodified" % ipy) + else: + print(" %s doesn't exist, ok" % ipy) + ipy = None + + # Make sure both the top-level "versioneer.py" and versionfile_source + # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so + # they'll be copied into source distributions. Pip won't be able to + # install the package without this. + manifest_in = os.path.join(root, "MANIFEST.in") + simple_includes = set() + try: + with open(manifest_in, "r") as f: + for line in f: + if line.startswith("include "): + for include in line.split()[1:]: + simple_includes.add(include) + except EnvironmentError: + pass + # That doesn't cover everything MANIFEST.in can do + # (http://docs.python.org/2/distutils/sourcedist.html#commands), so + # it might give some false negatives. Appending redundant 'include' + # lines is safe, though. + if "versioneer.py" not in simple_includes: + print(" appending 'versioneer.py' to MANIFEST.in") + with open(manifest_in, "a") as f: + f.write("include versioneer.py\n") + else: + print(" 'versioneer.py' already in MANIFEST.in") + if cfg.versionfile_source not in simple_includes: + print(" appending versionfile_source ('%s') to MANIFEST.in" % + cfg.versionfile_source) + with open(manifest_in, "a") as f: + f.write("include %s\n" % cfg.versionfile_source) + else: + print(" versionfile_source already in MANIFEST.in") + + # Make VCS-specific changes. For git, this means creating/changing + # .gitattributes to mark _version.py for export-subst keyword + # substitution. + do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + return 0 + + +def scan_setup_py(): + """Validate the contents of setup.py against Versioneer's expectations.""" + found = set() + setters = False + errors = 0 + with open("setup.py", "r") as f: + for line in f.readlines(): + if "import versioneer" in line: + found.add("import") + if "versioneer.get_cmdclass()" in line: + found.add("cmdclass") + if "versioneer.get_version()" in line: + found.add("get_version") + if "versioneer.VCS" in line: + setters = True + if "versioneer.versionfile_source" in line: + setters = True + if len(found) != 3: + print("") + print("Your setup.py appears to be missing some important items") + print("(but I might be wrong). Please make sure it has something") + print("roughly like the following:") + print("") + print(" import versioneer") + print(" setup( version=versioneer.get_version(),") + print(" cmdclass=versioneer.get_cmdclass(), ...)") + print("") + errors += 1 + if setters: + print("You should remove lines like 'versioneer.VCS = ' and") + print("'versioneer.versionfile_source = ' . This configuration") + print("now lives in setup.cfg, and should be removed from setup.py") + print("") + errors += 1 + return errors + + +if __name__ == "__main__": + cmd = sys.argv[1] + if cmd == "setup": + errors = do_setup() + errors += scan_setup_py() + if errors: + sys.exit(1)
    Required permissioninsertupdatereadhttp…http…\xa0http…1hello2world3\xa0{i}{i}a{i}b{i}c{i}abca/b.c-dcde{"row": {"pk1": "d", "pk2": "e", "content": "RENDER_CELL_DEMO"}, "column": "content", "table": "compound_primary_key", "database": "fixtures", "pks": ["pk1", "pk2"], "config": {"depth": "database"}}1hello\xa01-\xa031ab2\xa0\xa0\xa0\xa0\xa01131ab1world2\xa01 elements, - # so the thead must render even when the table has no rows. - ths = soup.select("table.rows-and-columns thead th[data-column]") - assert len(ths) >= 1 - - -@pytest.mark.asyncio -async def test_zero_row_table_renders_thead(ds_client): - response = await ds_client.get("/fixtures/123_starts_with_digits") - assert response.status_code == 200 - soup = Soup(response.text, "html.parser") - table = soup.select_one("table.rows-and-columns") - assert table is not None - column_names = [ - th.get("data-column") for th in table.select("thead th[data-column]") - ] - assert "content" in column_names - assert table.select_one("tbody tr") is None - assert soup.select_one("p.zero-results") is not None - - -@pytest.mark.asyncio -async def test_column_chooser_data_reflects_col_filtering(ds_client): - response = await ds_client.get("/fixtures/facetable?_col=state&_col=created") - assert response.status_code == 200 - import json - import re - - soup = Soup(response.text, "html.parser") - chooser = soup.find("column-chooser") - assert chooser is not None - scripts = soup.find_all("script") - chooser_script = [s for s in scripts if "_columnChooserData" in (s.string or "")] - script_text = chooser_script[0].string - # Parse the JSON object from the script - match = re.search( - r"window\._columnChooserData\s*=\s*({.*?});", script_text, re.DOTALL - ) - data = json.loads(match.group(1)) - # All non-PK columns should still be listed in allColumns - assert "state" in data["allColumns"] - assert "created" in data["allColumns"] - assert "planet_int" in data["allColumns"] - # Only state and created should be in selectedColumns (plus pk) - non_pk_selected = [ - c for c in data["selectedColumns"] if c not in data["primaryKeys"] - ] - assert "state" in non_pk_selected - assert "created" in non_pk_selected - assert "planet_int" not in non_pk_selected - - -@pytest.mark.asyncio -async def test_column_chooser_shown_for_views(ds_client): - response = await ds_client.get("/fixtures/simple_view") - assert response.status_code == 200 - soup = Soup(response.text, "html.parser") - chooser = soup.find("column-chooser") - assert chooser is not None - scripts = soup.find_all("script") - chooser_script = [s for s in scripts if "_columnChooserData" in (s.string or "")] - assert len(chooser_script) == 1 - - -@pytest.mark.asyncio -async def test_compound_primary_key_with_foreign_key_references(ds_client): - # e.g. a many-to-many table with a compound primary key on the two columns - response = await ds_client.get("/fixtures/searchable_tags") - assert response.status_code == 200 - table = Soup(response.text, "html.parser").find("table") - expected = [ - [ - '1\xa01feline2\xa02caninehelloHELLOworldWORLD\xa0\xa01<Binary:\xa07\xa0bytes>2<Binary:\xa07\xa0bytes>3\xa03Detroit2Los Angeles4Memnonia1San Francisco2Paranormal1Museum