diff --git a/pyproject.toml b/pyproject.toml index aaa1d71..e1602d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,9 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["src/mcarchive_org"] +[tool.pytest.ini_options] +asyncio_mode = "auto" + [tool.ruff] line-length = 100 target-version = "py310" diff --git a/src/mcarchive_org/client.py b/src/mcarchive_org/client.py index 9c8ec75..a822cdf 100644 --- a/src/mcarchive_org/client.py +++ b/src/mcarchive_org/client.py @@ -2,29 +2,104 @@ from __future__ import annotations +import errno import hashlib +import os +import re from collections.abc import AsyncIterator from pathlib import Path from typing import Any +from urllib.parse import quote import httpx ARCHIVE_BASE = "https://archive.org" DEFAULT_UA = "mcarchive-org/2026.04.21 (+https://archive.org/developers/)" +# Per-chunk read timeout (60s) means a stalled stream is caught between chunks. +# Don't relax this without thinking about hung TCP connections. DEFAULT_TIMEOUT = httpx.Timeout(30.0, read=60.0) +# Archive.org documents identifiers as [A-Za-z0-9._-], starting with alnum, max 100 chars. +_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,99}$") +_MAX_FILENAME = 512 + class ArchiveError(RuntimeError): """Raised when archive.org returns an error payload or unexpected status.""" -class ArchiveClient: - """Async client for the three archive.org endpoints we care about. +# ---------- input validators (defense-in-depth) ---------- - - advancedsearch.php : small Solr-style queries (<= ~10,000 rows paginated) - - services/search/v1/scrape : bulk cursor-based iteration (count >= 100) + +def validate_identifier(identifier: str) -> str: + """Reject identifiers that don't match archive.org's documented grammar. + + Both archive.org and the local filesystem trust this string — '/' or '..' + in an identifier is a path-traversal vector. + """ + if not isinstance(identifier, str) or not _IDENTIFIER_RE.match(identifier): + raise ValueError( + f"invalid archive.org identifier: {identifier!r} " + f"(must match {_IDENTIFIER_RE.pattern})" + ) + return identifier + + +def validate_filename(filename: str) -> str: + """Reject filenames that could escape a download root or attack the FS. + + archive.org files[].name CAN contain forward slashes for subdirectory files + (e.g. 'cover/back.jpg') — that's allowed. What's rejected: '..' components, + absolute paths, NUL bytes, Windows drive letters, and excessive length. + """ + if not isinstance(filename, str) or not filename: + raise ValueError(f"filename must be a non-empty string, got {filename!r}") + if len(filename) > _MAX_FILENAME: + raise ValueError(f"filename exceeds {_MAX_FILENAME} chars") + if "\x00" in filename: + raise ValueError(f"filename contains NUL byte: {filename!r}") + if filename.startswith(("/", "\\")): + raise ValueError(f"filename must not be absolute: {filename!r}") + if len(filename) >= 2 and filename[1] == ":": + raise ValueError(f"filename must not be a Windows drive path: {filename!r}") + if any(part == ".." for part in filename.replace("\\", "/").split("/")): + raise ValueError(f"filename must not contain '..' components: {filename!r}") + return filename + + +# ---------- safe filesystem open ---------- + + +def _safe_open_for_write(dest: Path, append: bool): + """Open dest for writing, refusing to follow a symlink at the leaf. + + Defense against the symlink-substitution race: even if our path-confinement + check passes, a symlink at `dest` could redirect the write. O_NOFOLLOW tells + the kernel to fail the open instead. + """ + flags = os.O_WRONLY | os.O_CREAT + flags |= os.O_APPEND if append else os.O_TRUNC + nofollow = getattr(os, "O_NOFOLLOW", 0) # not present on Windows + flags |= nofollow + try: + fd = os.open(dest, flags, 0o644) + except OSError as e: + if nofollow and e.errno == errno.ELOOP: + raise ArchiveError(f"refusing to write through symlink at {dest}") from e + raise + return os.fdopen(fd, "ab" if append else "wb") + + +# ---------- client ---------- + + +class ArchiveClient: + """Async client for the archive.org endpoints we wrap. + + - advancedsearch.php : Solr-style queries (<= ~10,000 rows paginated) + - services/search/v1/scrape : bulk cursor pagination (count >= 100) - metadata/{id} : full item manifest including files[] - - download/{id}/{file} : byte stream with Range support + - download/{id}/{file} : byte stream with HTTP Range support """ def __init__( @@ -32,13 +107,17 @@ class ArchiveClient: base_url: str = ARCHIVE_BASE, user_agent: str = DEFAULT_UA, timeout: httpx.Timeout | float = DEFAULT_TIMEOUT, + transport: httpx.AsyncBaseTransport | None = None, ) -> None: self._base = base_url.rstrip("/") - self._client = httpx.AsyncClient( - headers={"User-Agent": user_agent, "Accept": "application/json"}, - timeout=timeout, - follow_redirects=True, - ) + kwargs: dict[str, Any] = { + "headers": {"User-Agent": user_agent, "Accept": "application/json"}, + "timeout": timeout, + "follow_redirects": True, + } + if transport is not None: + kwargs["transport"] = transport + self._client = httpx.AsyncClient(**kwargs) async def aclose(self) -> None: await self._client.aclose() @@ -49,6 +128,25 @@ class ArchiveClient: async def __aexit__(self, *exc: object) -> None: await self.aclose() + # ---------- internal: error-surfacing fetch ---------- + + async def _fetch_json(self, url: str, params: Any = None) -> Any: + """GET + JSON decode with archive.org-friendly error messages. + + Wraps raise_for_status so that 4xx/5xx responses include a body preview + in the exception — invaluable for an LLM trying to fix a bad query. + """ + r = await self._client.get(url, params=params) + if r.is_error: + body = r.text[:500] if r.content else "" + raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip()) + try: + return r.json() + except ValueError as e: + raise ArchiveError( + f"invalid JSON from {r.url}: {r.text[:200]!r}" + ) from e + # ---------- search ---------- async def search( @@ -71,9 +169,7 @@ class ArchiveClient: for s in sort or []: params.append(("sort[]", s)) - r = await self._client.get(f"{self._base}/advancedsearch.php", params=params) - r.raise_for_status() - data = r.json() + data = await self._fetch_json(f"{self._base}/advancedsearch.php", params=params) resp = data.get("response", {}) return { "num_found": resp.get("numFound", 0), @@ -103,39 +199,41 @@ class ArchiveClient: if cursor: params["cursor"] = cursor - r = await self._client.get(f"{self._base}/services/search/v1/scrape", params=params) - r.raise_for_status() - data = r.json() - if "error" in data: + data = await self._fetch_json(f"{self._base}/services/search/v1/scrape", params=params) + if isinstance(data, dict) and "error" in data: raise ArchiveError(f"{data.get('errorType', 'ScrapeError')}: {data['error']}") return data # keys: items, count, total, cursor (if more pages) # ---------- metadata ---------- async def metadata(self, identifier: str) -> dict[str, Any]: - """Full metadata blob for an item.""" - r = await self._client.get(f"{self._base}/metadata/{identifier}") - r.raise_for_status() - data = r.json() + """Full metadata blob for an item. Empty {} from archive.org → not found.""" + validate_identifier(identifier) + data = await self._fetch_json(f"{self._base}/metadata/{identifier}") if not data: - raise ArchiveError(f"item not found: {identifier}") + raise ArchiveError(f"item not found or unavailable: {identifier}") return data async def files(self, identifier: str) -> list[dict[str, Any]]: """Just the files[] slice — smaller payload when that's all you want.""" - r = await self._client.get(f"{self._base}/metadata/{identifier}/files") - r.raise_for_status() - data = r.json() - if isinstance(data, dict) and "result" in data: - return data["result"] + validate_identifier(identifier) + data = await self._fetch_json(f"{self._base}/metadata/{identifier}/files") + if isinstance(data, dict): + if "error" in data: + raise ArchiveError(f"archive.org error for {identifier}: {data['error']}") + if "result" in data: + return data["result"] if isinstance(data, list): return data - raise ArchiveError(f"unexpected files response for {identifier}") + raise ArchiveError(f"unexpected files response shape for {identifier}: {type(data).__name__}") # ---------- download ---------- def download_url(self, identifier: str, filename: str) -> str: - return f"{self._base}/download/{identifier}/{filename}" + """Build the canonical download URL. Filename is URL-encoded but '/' preserved.""" + validate_identifier(identifier) + validate_filename(filename) + return f"{self._base}/download/{identifier}/{quote(filename, safe='/')}" async def stream_file( self, @@ -143,13 +241,40 @@ class ArchiveClient: filename: str, resume_from: int = 0, ) -> AsyncIterator[bytes]: - """Async byte iterator — caller is responsible for writing to disk.""" + """Async byte iterator. If resume_from > 0, requires a 206 response. + + Raises ArchiveError BEFORE yielding any bytes if: + - the server returns a 4xx/5xx + - resume was requested but the server returned 200 (Range ignored) + - the Content-Range start byte doesn't match resume_from + """ + validate_identifier(identifier) + validate_filename(filename) + headers = {} if resume_from > 0: headers["Range"] = f"bytes={resume_from}-" url = self.download_url(identifier, filename) + async with self._client.stream("GET", url, headers=headers) as r: - r.raise_for_status() + if r.is_error: + body = (await r.aread())[:500].decode("utf-8", errors="replace") + raise ArchiveError(f"HTTP {r.status_code} from {r.url}: {body}".strip()) + if resume_from > 0: + if r.status_code != 206: + raise ArchiveError( + f"server ignored Range request (got HTTP {r.status_code}); " + f"local file may be stale — retry download_file with overwrite=True" + ) + # Verify the byte range starts where we expect. archive.org's CDN + # is normally well-behaved here, but trust-but-verify. + cr = r.headers.get("Content-Range", "") + m = re.match(r"bytes\s+(\d+)-", cr) + if m and int(m.group(1)) != resume_from: + raise ArchiveError( + f"Content-Range start {m.group(1)} != resume_from {resume_from}; " + f"refusing to corrupt {filename}" + ) async for chunk in r.aiter_bytes(chunk_size=1 << 16): yield chunk @@ -161,20 +286,34 @@ class ArchiveClient: verify_md5: str | None = None, chunk_cb=None, ) -> dict[str, Any]: - """Download with resume support. Returns stats + md5 verification result.""" + """Download with resume support. Returns stats + md5 verification result. + + Caller is responsible for ensuring `dest` is confined to a safe directory + (use mcarchive_org.server's path validation, or do your own). This method + adds defense-in-depth via O_NOFOLLOW but does not validate path confinement. + """ + validate_identifier(identifier) + validate_filename(filename) + dest.parent.mkdir(parents=True, exist_ok=True) - resume_from = dest.stat().st_size if dest.exists() else 0 + # If parent is a symlink to elsewhere, refuse — our caller's confinement + # check should already have caught this, but redundant defense is cheap. + if dest.parent.is_symlink(): + raise ArchiveError(f"refusing to write into symlinked directory: {dest.parent}") + + resume_from = dest.stat().st_size if dest.exists() and not dest.is_symlink() else 0 + # is_symlink() detection: if dest is a symlink, treat as fresh (open with + # O_NOFOLLOW will refuse anyway, so we won't actually corrupt the target). hasher = hashlib.md5() if verify_md5 else None if hasher and resume_from: - # re-hash existing bytes so the final digest is correct with dest.open("rb") as f: while chunk := f.read(1 << 16): hasher.update(chunk) bytes_written = resume_from - mode = "ab" if resume_from else "wb" - with dest.open(mode) as f: + f = _safe_open_for_write(dest, append=resume_from > 0) + try: async for chunk in self.stream_file(identifier, filename, resume_from=resume_from): f.write(chunk) bytes_written += len(chunk) @@ -182,10 +321,12 @@ class ArchiveClient: hasher.update(chunk) if chunk_cb: chunk_cb(bytes_written) + finally: + f.close() - result = { + result: dict[str, Any] = { "path": str(dest), - "bytes": bytes_written, + "bytes_written": bytes_written, "resumed_from": resume_from, } if verify_md5 and hasher: diff --git a/src/mcarchive_org/server.py b/src/mcarchive_org/server.py index e7839e3..ceb33df 100644 --- a/src/mcarchive_org/server.py +++ b/src/mcarchive_org/server.py @@ -6,16 +6,25 @@ import fnmatch import os from pathlib import Path from typing import Annotated, Any +from urllib.parse import quote from fastmcp import FastMCP from pydantic import Field from mcarchive_org import __version__ -from mcarchive_org.client import ArchiveClient +from mcarchive_org.client import ( + ArchiveClient, + ArchiveError, + validate_filename, + validate_identifier, +) -DEFAULT_DOWNLOAD_ROOT = Path( - os.environ.get("MCARCHIVE_DOWNLOAD_ROOT", Path.cwd() / "downloads") -).expanduser() + +def _resolve_download_root() -> Path: + """Resolve the download root lazily so env-var changes after import are honored.""" + return Path( + os.environ.get("MCARCHIVE_DOWNLOAD_ROOT", str(Path.cwd() / "downloads")) + ).expanduser() mcp = FastMCP( name="mcarchive-org", @@ -42,18 +51,29 @@ def _human_size(n: int | str | None) -> str: return f"{x:.1f} PB" +def _parse_size(raw: Any) -> int | None: + """Best-effort int parse of archive.org size fields. None if unparseable.""" + if raw is None: + return None + try: + return int(str(raw).strip()) + except (TypeError, ValueError): + return None + + def _enrich_file(identifier: str, f: dict[str, Any]) -> dict[str, Any]: name = f.get("name", "") + size = _parse_size(f.get("size")) return { "name": name, "format": f.get("format"), - "size": int(f["size"]) if f.get("size") and str(f["size"]).isdigit() else None, - "size_human": _human_size(f.get("size")), + "size": size, + "size_human": _human_size(size if size is not None else f.get("size")), "md5": f.get("md5"), "sha1": f.get("sha1"), "mtime": f.get("mtime"), "source": f.get("source"), - "download_url": f"https://archive.org/download/{identifier}/{name}", + "download_url": f"https://archive.org/download/{identifier}/{quote(name, safe='/')}", } @@ -63,6 +83,36 @@ def _matches(name: str, format_: str | None, name_glob: str | None, formats: lis return not (formats and (format_ or "").lower() not in {f.lower() for f in formats}) +def _confine_dest(identifier: str, filename: str, dest_dir: str | None) -> Path: + """Construct + verify the download destination path. + + Validates inputs, resolves the path, and asserts it lives inside the allowed + download root. Raises ValueError on any escape attempt — never returns an + unsafe path. + """ + validate_identifier(identifier) + validate_filename(filename) + + download_root = _resolve_download_root().resolve() + if dest_dir: + target_dir = Path(dest_dir).expanduser().resolve() + else: + target_dir = (download_root / identifier).resolve() + + target_dir.mkdir(parents=True, exist_ok=True) + if target_dir.is_symlink(): + raise ValueError(f"download directory must not be a symlink: {target_dir}") + + # Build dest, then resolve and confirm containment. Path.resolve() collapses + # '..' even on non-existent leaves, so escape attempts surface here. + dest = (target_dir / filename).resolve() + if not dest.is_relative_to(target_dir): + raise ValueError( + f"refusing destination outside target dir: {dest} not under {target_dir}" + ) + return dest + + # ---------- tools ---------- @@ -132,6 +182,13 @@ async def get_item_metadata( By default omits the (potentially huge) files[] array — call list_files for that. """ + return await _fetch_item_metadata(identifier, include_files=include_files) + + +async def _fetch_item_metadata(identifier: str, include_files: bool) -> dict[str, Any]: + """Shared metadata-fetching logic. Used by both the tool and the MCP resource + so neither has to depend on the other's `.fn` attribute.""" + validate_identifier(identifier) async with ArchiveClient() as c: data = await c.metadata(identifier) @@ -177,6 +234,7 @@ async def list_files( Each entry includes a ready-to-use `download_url`. """ + validate_identifier(identifier) async with ArchiveClient() as c: files = await c.files(identifier) @@ -199,8 +257,10 @@ def get_file_url( filename: Annotated[str, Field(description="Exact filename as shown in list_files.")], ) -> dict[str, str]: """Build the canonical download URL for a file without fetching anything.""" + validate_identifier(identifier) + validate_filename(filename) return { - "url": f"https://archive.org/download/{identifier}/{filename}", + "url": f"https://archive.org/download/{identifier}/{quote(filename, safe='/')}", "item_url": f"https://archive.org/details/{identifier}", } @@ -222,18 +282,29 @@ async def download_file( Field(description="If false and file exists, resume the download (Range request)."), ] = False, ) -> dict[str, Any]: - """Download a file to disk. Supports resume via HTTP Range when overwrite=false.""" - target_dir = Path(dest_dir).expanduser() if dest_dir else (DEFAULT_DOWNLOAD_ROOT / identifier) - dest = target_dir / filename - if overwrite and dest.exists(): + """Download a file to disk. Supports resume via HTTP Range when overwrite=false. + + The destination is path-confined to either `dest_dir` (when given) or + $MCARCHIVE_DOWNLOAD_ROOT/{identifier}. Filenames containing '..', absolute + paths, or NUL bytes are rejected before any FS or network I/O. + """ + dest = _confine_dest(identifier, filename, dest_dir) + if overwrite and dest.exists() and not dest.is_symlink(): + dest.unlink() + elif overwrite and dest.is_symlink(): + # Don't follow the symlink to delete the target; remove the link itself. dest.unlink() - async with ArchiveClient() as c: - result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5) + try: + async with ArchiveClient() as c: + result = await c.download_to_file(identifier, filename, dest, verify_md5=verify_md5) + except ArchiveError as e: + # Re-raise with the destination context so the caller can act on it. + raise ArchiveError(f"{e} (dest={dest})") from e result["identifier"] = identifier result["filename"] = filename - result["size_human"] = _human_size(result.get("bytes")) + result["size_human"] = _human_size(result.get("bytes_written")) return result @@ -243,7 +314,7 @@ async def download_file( @mcp.resource("archive://item/{identifier}") async def item_resource(identifier: str) -> dict[str, Any]: """Expose item metadata as a readable MCP resource.""" - return await get_item_metadata.fn(identifier=identifier, include_files=False) # type: ignore[attr-defined] + return await _fetch_item_metadata(identifier, include_files=False) # ---------- entry point ---------- diff --git a/tests/test_client.py b/tests/test_client.py index d6cb21e..37d713a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,7 +12,7 @@ import pytest from mcarchive_org.client import ArchiveClient -pytestmark = [pytest.mark.asyncio, pytest.mark.network] +pytestmark = pytest.mark.network async def test_search_nasa_item(): @@ -41,7 +41,7 @@ async def test_download_small_file(tmp_path: Path): result = await c.download_to_file( "nasa", small["name"], dest, verify_md5=small.get("md5") ) - assert result["bytes"] > 0 + assert result["bytes_written"] > 0 if small.get("md5"): assert result["md5_ok"] is True diff --git a/tests/test_client_mocked.py b/tests/test_client_mocked.py new file mode 100644 index 0000000..0957353 --- /dev/null +++ b/tests/test_client_mocked.py @@ -0,0 +1,242 @@ +"""Failure-mode regression tests using httpx.MockTransport (no network). + +Each test pins down one of the Hamilton review findings (C1/C2/C3/H4 etc.) so +future refactors can't silently regress safety. +""" + +from __future__ import annotations + +import hashlib + +import httpx +import pytest + +from mcarchive_org.client import ( + ArchiveClient, + ArchiveError, + validate_filename, + validate_identifier, +) +from mcarchive_org.server import _confine_dest + + +def _client_with(handler) -> ArchiveClient: + """Build an ArchiveClient backed by a MockTransport handler.""" + return ArchiveClient(transport=httpx.MockTransport(handler)) + + +# ---------- C1: identifier + filename validation ---------- + + +@pytest.mark.parametrize("bad", ["", "../etc", "foo/bar", "has space", "a" * 200]) +def test_invalid_identifier_rejected(bad): + with pytest.raises(ValueError, match=r"invalid archive\.org identifier"): + validate_identifier(bad) + + +@pytest.mark.parametrize( + "bad", + [ + "../escape.txt", + "/etc/passwd", + "C:\\windows.txt", + "with\x00null.bin", + "foo/../bar.mp3", + "foo\\..\\bar.mp3", + "", + ], +) +def test_invalid_filename_rejected(bad): + with pytest.raises(ValueError): + validate_filename(bad) + + +@pytest.mark.parametrize( + "ok", + ["song.mp3", "cover/back.jpg", "subdir/file with space.txt", "a.b.c.d"], +) +def test_legitimate_filenames_accepted(ok): + assert validate_filename(ok) == ok + + +def test_confine_dest_blocks_traversal(tmp_path, monkeypatch): + monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path)) + # validate_filename catches '..' before _confine_dest's path-resolution check, + # so this raises ValueError from the validator — both layers in agreement. + with pytest.raises(ValueError): + _confine_dest("nasa", "../escape.txt", dest_dir=None) + + +def test_confine_dest_legit_filename_lands_in_root(tmp_path, monkeypatch): + monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path)) + dest = _confine_dest("nasa", "globe.jpg", dest_dir=None) + assert dest.is_relative_to(tmp_path) + assert dest.name == "globe.jpg" + + +# ---------- C2: symlink refusal ---------- + + +async def test_download_refuses_symlink_at_dest(tmp_path): + target = tmp_path / "real.bin" + target.write_bytes(b"original-content") + + link = tmp_path / "evil.bin" + link.symlink_to(target) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=b"new-content-that-should-not-overwrite") + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="symlink"): + await c.download_to_file("nasa", "evil.bin", link) + + # Symlink target must be unchanged. + assert target.read_bytes() == b"original-content" + + +# ---------- C3: Range-ignored detection ---------- + + +async def test_resume_with_200_response_raises_before_writing(tmp_path): + """If the server returns 200 instead of 206 on a Range request, we must not + append to the existing file — that path corrupts data silently.""" + dest = tmp_path / "partial.bin" + dest.write_bytes(b"X" * 100) # pretend we have a partial download + + def handler(req: httpx.Request) -> httpx.Response: + # Server ignores Range header and returns the full body with 200 + assert req.headers.get("Range") == "bytes=100-" + return httpx.Response(200, content=b"FULL_FILE_BODY") + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="ignored Range"): + await c.download_to_file("nasa", "partial.bin", dest) + + # File must be unchanged — corruption avoided. + assert dest.read_bytes() == b"X" * 100 + + +async def test_resume_with_correct_206_succeeds(tmp_path): + full_body = b"0123456789ABCDEF" * 16 # 256 bytes + dest = tmp_path / "resume.bin" + dest.write_bytes(full_body[:64]) # we already have first 64 bytes + + def handler(req: httpx.Request) -> httpx.Response: + assert req.headers.get("Range") == "bytes=64-" + return httpx.Response( + 206, + content=full_body[64:], + headers={"Content-Range": f"bytes 64-{len(full_body)-1}/{len(full_body)}"}, + ) + + expected_md5 = hashlib.md5(full_body).hexdigest() + async with _client_with(handler) as c: + result = await c.download_to_file( + "nasa", "resume.bin", dest, verify_md5=expected_md5 + ) + + assert result["bytes_written"] == len(full_body) + assert result["resumed_from"] == 64 + assert result["md5_ok"] is True + assert dest.read_bytes() == full_body + + +async def test_resume_with_wrong_content_range_start_raises(tmp_path): + dest = tmp_path / "off.bin" + dest.write_bytes(b"X" * 100) + + def handler(req: httpx.Request) -> httpx.Response: + # Server returns 206 but with WRONG starting offset + return httpx.Response( + 206, + content=b"junk", + headers={"Content-Range": "bytes 50-99/100"}, + ) + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="Content-Range start"): + await c.download_to_file("nasa", "off.bin", dest) + + assert dest.read_bytes() == b"X" * 100 # unchanged + + +# ---------- H4: error body surfacing ---------- + + +async def test_search_400_includes_response_body(): + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(400, text='{"error":"bad query syntax"}') + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="bad query syntax"): + await c.search(query="INVALID:::") + + +async def test_metadata_404_includes_status(): + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(404, text="not found") + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="HTTP 404"): + await c.metadata("nasa") + + +async def test_metadata_empty_dict_means_not_found(): + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={}) + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="not found or unavailable"): + await c.metadata("nasa") + + +async def test_files_returns_error_payload_as_archive_error(): + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={"error": "item is dark"}) + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="item is dark"): + await c.files("nasa") + + +async def test_scrape_error_payload_surfaced(): + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, json={"error": "count too small", "errorType": "RangeException"} + ) + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match=r"RangeException.*count too small"): + await c.scrape(query="identifier:nasa", count=100) + + +async def test_invalid_json_response_surfaced(): + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text="not json") + + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="invalid JSON"): + await c.metadata("nasa") + + +# ---------- happy path ---------- + + +async def test_fresh_download_writes_full_body(tmp_path): + body = b"hello world" * 100 + dest = tmp_path / "new.bin" + + def handler(req: httpx.Request) -> httpx.Response: + assert "Range" not in req.headers + return httpx.Response(200, content=body) + + async with _client_with(handler) as c: + result = await c.download_to_file( + "nasa", "new.bin", dest, verify_md5=hashlib.md5(body).hexdigest() + ) + + assert result["bytes_written"] == len(body) + assert result["resumed_from"] == 0 + assert result["md5_ok"] is True + assert dest.read_bytes() == body