diff --git a/src/mcarchive_org/client.py b/src/mcarchive_org/client.py index 6aa8b6e..febfbde 100644 --- a/src/mcarchive_org/client.py +++ b/src/mcarchive_org/client.py @@ -329,11 +329,21 @@ class ArchiveClient: verify_md5: str | None = None, chunk_cb=None, ) -> dict[str, Any]: - """Download with resume support. Returns stats + md5 verification result. + """Download with resume support, atomically promoted on success. - Caller is responsible for ensuring `dest` is confined to a safe directory - (use mcarchive_org.server's path validation, or do your own). This method - adds defense-in-depth via O_NOFOLLOW but does not validate path confinement. + Writes go to a `.part` staging file. On success, that staging file + is atomically renamed to `dest` (POSIX rename is observed all-or-nothing). + On any failure, the `.part` file remains and a follow-up call resumes + from it — meaning the user's directory only ever contains complete files + or `.part` files (which clearly signal incomplete state). + + - `dest` already complete (no .part) and overwrite implied false → + early-return with `already_complete: True` (still verifies MD5 if asked). + - `.part` exists → resume from its size, append, then promote on success. + - Neither exists → fresh download to `.part`, promote on success. + + Caller is responsible for path confinement; this method adds O_NOFOLLOW + defense-in-depth but does not enforce a download root. """ validate_identifier(identifier) validate_filename(filename) @@ -343,19 +353,50 @@ class ArchiveClient: # check should already have caught this, but redundant defense is cheap. if dest.parent.is_symlink(): raise ArchiveError(f"refusing to write into symlinked directory: {dest.parent}") + # Refuse to "download over" a symlink at dest. Even though the staging + # rename would replace (not follow) the symlink, predictable refusal is + # better than silent symlink removal — that surprises the user and could + # destroy a symlink farm they intentionally set up. + if dest.is_symlink(): + raise ArchiveError(f"refusing to write through symlink at {dest}") - resume_from = dest.stat().st_size if dest.exists() and not dest.is_symlink() else 0 - # is_symlink() detection: if dest is a symlink, treat as fresh (open with - # O_NOFOLLOW will refuse anyway, so we won't actually corrupt the target). + part = dest.with_name(dest.name + ".part") + # Same protection on the staging path. + if part.is_symlink(): + raise ArchiveError(f"refusing to write through symlink at {part}") + + # Short-circuit: dest already exists complete and no .part to resume. + if dest.exists() and not dest.is_symlink() and not part.exists(): + size = dest.stat().st_size + result: dict[str, Any] = { + "path": str(dest), + "bytes_written": size, + "resumed_from": size, + "already_complete": True, + } + if verify_md5: + # Verify the existing file matches the expected hash. + hasher = hashlib.md5() + with dest.open("rb") as fh: + while chunk := fh.read(1 << 16): + hasher.update(chunk) + actual = hasher.hexdigest() + result["md5_actual"] = actual + result["md5_expected"] = verify_md5 + result["md5_ok"] = actual.lower() == verify_md5.lower() + return result + + # Resume from .part if present; otherwise fresh. + resume_from = part.stat().st_size if part.exists() and not part.is_symlink() else 0 hasher = hashlib.md5() if verify_md5 else None if hasher and resume_from: - with dest.open("rb") as f: - while chunk := f.read(1 << 16): + with part.open("rb") as fh: + while chunk := fh.read(1 << 16): hasher.update(chunk) bytes_written = resume_from - f = _safe_open_for_write(dest, append=resume_from > 0) + f = _safe_open_for_write(part, append=resume_from > 0) try: async for chunk in self.stream_file(identifier, filename, resume_from=resume_from): f.write(chunk) @@ -365,22 +406,24 @@ class ArchiveClient: if chunk_cb: chunk_cb(bytes_written) except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadTimeout) as e: - # H1: surface partial-state context so the caller can decide whether - # to resume or restart. The bytes already on disk are valid (we only - # write whole chunks), so a follow-up call with overwrite=False will - # resume cleanly from `bytes_written`. + # H1: surface partial-state context. The .part file stays on disk + # so a follow-up call with overwrite=False resumes from bytes_written. raise ArchiveError( f"download interrupted after {bytes_written - resume_from} new bytes " - f"({bytes_written} total on disk, resumed from {resume_from}). " - f"File at {dest} — call download_file again to resume. Cause: {e!r}" + f"({bytes_written} total in {part.name}, resumed from {resume_from}). " + f"Call download_file again to resume. Cause: {e!r}" ) from e finally: f.close() - result: dict[str, Any] = { + # Atomic promotion — only after the stream completed cleanly. + os.replace(part, dest) + + result = { "path": str(dest), "bytes_written": bytes_written, "resumed_from": resume_from, + "already_complete": False, } if verify_md5 and hasher: actual = hasher.hexdigest() diff --git a/src/mcarchive_org/server.py b/src/mcarchive_org/server.py index f6a3e56..eb01619 100644 --- a/src/mcarchive_org/server.py +++ b/src/mcarchive_org/server.py @@ -342,13 +342,18 @@ async def download_file( so two parallel tool invocations can't race on the same destination file. """ dest = _confine_dest(identifier, filename, dest_dir) + part = dest.with_name(dest.name + ".part") lock = _download_lock_for(identifier, filename) async with lock: - if overwrite and (dest.exists() or dest.is_symlink()): - # is_symlink() before exists() — a dangling symlink reports exists()=False - # but we still want to remove the link itself rather than follow it. - dest.unlink() + if overwrite: + # Remove both the final file AND any leftover .part — we want to start + # truly fresh, not resume an old partial. is_symlink check first since + # a dangling symlink reports exists()=False. + if dest.exists() or dest.is_symlink(): + dest.unlink() + if part.exists() or part.is_symlink(): + part.unlink() try: c = await get_shared_client() @@ -363,6 +368,76 @@ async def download_file( return result +# ---------- runtime configuration ---------- + +# Paths we refuse to use as a download root no matter what — these are system +# directories where writing junk is genuinely harmful. The user's MCP client +# can usually re-launch with MCARCHIVE_DOWNLOAD_ROOT pointing anywhere they +# want at startup; this guard is just for the LLM-driven set_download_root tool. +_FORBIDDEN_ROOTS = frozenset({ + "/", "/etc", "/usr", "/bin", "/sbin", "/var", "/sys", "/proc", "/dev", "/boot", "/root", +}) + + +def _check_root_safety(p: Path) -> None: + s = str(p) + if s in _FORBIDDEN_ROOTS: + raise ValueError(f"refusing to use system directory as download root: {s}") + for forbidden in _FORBIDDEN_ROOTS: + if forbidden != "/" and s.startswith(forbidden + "/"): + raise ValueError(f"refusing to use system directory as download root: {s}") + + +@mcp.tool +def get_download_root() -> dict[str, Any]: + """Report the directory where download_file writes by default. + + Useful at the start of a session to confirm where files will land. The + `source` field tells you whether the value came from the MCARCHIVE_DOWNLOAD_ROOT + env var or from the built-in default of `./downloads` under the server's CWD. + """ + raw_env = os.environ.get("MCARCHIVE_DOWNLOAD_ROOT") + root = _resolve_download_root().resolve() + return { + "download_root": str(root), + "exists": root.exists(), + "writable": os.access(root, os.W_OK) if root.exists() else None, + "source": "MCARCHIVE_DOWNLOAD_ROOT env var" if raw_env else "default (./downloads under server CWD)", + "raw_env_value": raw_env, + } + + +@mcp.tool +def set_download_root( + path: Annotated[ + str, + Field(description="New download root path. '~' is expanded; the directory is created if missing."), + ], +) -> dict[str, Any]: + """Change the download root for the rest of this MCP server session. + + Useful when running as a stdio MCP server, since you can't otherwise + re-export environment variables to a running child process. The change + persists until the server process exits or set_download_root is called + again. System directories (/etc, /usr, /var, /sys, /proc, /dev, /boot, + /root, /bin, /sbin, /) are refused. + """ + expanded = Path(path).expanduser().resolve() + _check_root_safety(expanded) + + previous = _resolve_download_root().resolve() + expanded.mkdir(parents=True, exist_ok=True) + if not os.access(expanded, os.W_OK): + raise ValueError(f"download root not writable: {expanded}") + + os.environ["MCARCHIVE_DOWNLOAD_ROOT"] = str(expanded) + return { + "download_root": str(expanded), + "previous": str(previous), + "changed": str(expanded) != str(previous), + } + + # ---------- resources ---------- diff --git a/tests/test_client_mocked.py b/tests/test_client_mocked.py index 7a64769..abf1fd7 100644 --- a/tests/test_client_mocked.py +++ b/tests/test_client_mocked.py @@ -102,7 +102,8 @@ async def test_resume_with_200_response_raises_before_writing(tmp_path): """If the server returns 200 instead of 206 on a Range request, we must not append to the existing file — that path corrupts data silently.""" dest = tmp_path / "partial.bin" - dest.write_bytes(b"X" * 100) # pretend we have a partial download + part = tmp_path / "partial.bin.part" # staging file holds resume state + part.write_bytes(b"X" * 100) def handler(req: httpx.Request) -> httpx.Response: # Server ignores Range header and returns the full body with 200 @@ -113,14 +114,16 @@ async def test_resume_with_200_response_raises_before_writing(tmp_path): with pytest.raises(ArchiveError, match="ignored Range"): await c.download_to_file("nasa", "partial.bin", dest) - # File must be unchanged — corruption avoided. - assert dest.read_bytes() == b"X" * 100 + # The .part file must be unchanged and dest must not exist — corruption avoided. + assert part.read_bytes() == b"X" * 100 + assert not dest.exists() async def test_resume_with_correct_206_succeeds(tmp_path): full_body = b"0123456789ABCDEF" * 16 # 256 bytes dest = tmp_path / "resume.bin" - dest.write_bytes(full_body[:64]) # we already have first 64 bytes + part = tmp_path / "resume.bin.part" + part.write_bytes(full_body[:64]) # we already have first 64 bytes in staging def handler(req: httpx.Request) -> httpx.Response: assert req.headers.get("Range") == "bytes=64-" @@ -139,12 +142,15 @@ async def test_resume_with_correct_206_succeeds(tmp_path): assert result["bytes_written"] == len(full_body) assert result["resumed_from"] == 64 assert result["md5_ok"] is True + # On success the .part is atomically renamed to dest. assert dest.read_bytes() == full_body + assert not part.exists() async def test_resume_with_wrong_content_range_start_raises(tmp_path): dest = tmp_path / "off.bin" - dest.write_bytes(b"X" * 100) + part = tmp_path / "off.bin.part" + part.write_bytes(b"X" * 100) def handler(req: httpx.Request) -> httpx.Response: # Server returns 206 but with WRONG starting offset @@ -158,7 +164,9 @@ async def test_resume_with_wrong_content_range_start_raises(tmp_path): with pytest.raises(ArchiveError, match="Content-Range start"): await c.download_to_file("nasa", "off.bin", dest) - assert dest.read_bytes() == b"X" * 100 # unchanged + # .part unchanged, dest never created. + assert part.read_bytes() == b"X" * 100 + assert not dest.exists() # ---------- H4: error body surfacing ---------- @@ -328,6 +336,7 @@ async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path): return httpx.Response(200, content=evil_body()) dest = tmp_path / "interrupted.bin" + part = tmp_path / "interrupted.bin.part" async with _client_with(handler) as c: with pytest.raises(ArchiveError) as exc_info: await c.download_to_file("nasa", "interrupted.bin", dest) @@ -335,8 +344,9 @@ async def test_stream_abort_raises_archive_error_with_byte_count(tmp_path): msg = str(exc_info.value) assert "interrupted after" in msg assert "ReadError" in msg - # Partial bytes ARE on disk — at least the first delivered chunk. - on_disk = dest.read_bytes() + # Partial bytes go to .part, NOT dest. dest stays absent until success. + assert not dest.exists() + on_disk = part.read_bytes() assert len(on_disk) > 0 assert on_disk == chunk_payload[: len(on_disk)] @@ -360,4 +370,82 @@ async def test_fresh_download_writes_full_body(tmp_path): assert result["bytes_written"] == len(body) assert result["resumed_from"] == 0 assert result["md5_ok"] is True + assert result["already_complete"] is False + # The atomic-rename pattern leaves no .part artifact after success. assert dest.read_bytes() == body + assert not (tmp_path / "new.bin.part").exists() + + +# ---------- Atomic .part staging ---------- + + +async def test_failed_download_leaves_no_dest_file(tmp_path): + """A failed fresh download must NOT leave the final dest file as zero bytes — + it should leave only the .part staging file (or nothing if no bytes arrived).""" + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(500, text="upstream cdn miss") + + dest = tmp_path / "shouldfail.bin" + async with _client_with(handler) as c: + with pytest.raises(ArchiveError, match="HTTP 500"): + await c.download_to_file("nasa", "shouldfail.bin", dest) + + # Critical: dest must NOT exist as an empty file misleading the user. + assert not dest.exists() + + +async def test_already_complete_short_circuits_without_network(tmp_path): + """If dest exists and no .part, a follow-up download must not hit the + network — the file is already complete.""" + dest = tmp_path / "done.bin" + dest.write_bytes(b"already-here") + + calls = {"n": 0} + + def handler(req: httpx.Request) -> httpx.Response: + calls["n"] += 1 + return httpx.Response(500, text="should never fire") + + async with _client_with(handler) as c: + result = await c.download_to_file("nasa", "done.bin", dest) + + assert calls["n"] == 0 # no network at all + assert result["already_complete"] is True + assert result["bytes_written"] == len(b"already-here") + assert dest.read_bytes() == b"already-here" + + +async def test_already_complete_verifies_md5_against_existing_file(tmp_path): + """If verify_md5 is passed and dest is complete, we re-hash to confirm.""" + body = b"on-disk-content" + dest = tmp_path / "done.bin" + dest.write_bytes(body) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(500, text="should never fire") + + async with _client_with(handler) as c: + result = await c.download_to_file( + "nasa", "done.bin", dest, verify_md5=hashlib.md5(body).hexdigest() + ) + + assert result["already_complete"] is True + assert result["md5_ok"] is True + + +async def test_already_complete_md5_mismatch_caught(tmp_path): + """If the existing file's MD5 doesn't match expected, surface md5_ok=False.""" + dest = tmp_path / "wrong.bin" + dest.write_bytes(b"actual-content") + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(500, text="should never fire") + + async with _client_with(handler) as c: + result = await c.download_to_file( + "nasa", "wrong.bin", dest, verify_md5="0" * 32 + ) + assert result["already_complete"] is True + assert result["md5_ok"] is False + assert result["md5_expected"] == "0" * 32 diff --git a/tests/test_server_mocked.py b/tests/test_server_mocked.py index 4c48328..8ef1d06 100644 --- a/tests/test_server_mocked.py +++ b/tests/test_server_mocked.py @@ -10,6 +10,7 @@ These exercise the MCP tool functions directly and verify: from __future__ import annotations import asyncio +import os from contextlib import asynccontextmanager import httpx @@ -21,8 +22,10 @@ from mcarchive_org.server import ( _enrich_doc, _normalize_collection, download_file, + get_download_root, get_item_metadata, search_items, + set_download_root, ) @@ -170,6 +173,75 @@ async def test_concurrent_downloads_same_file_are_serialized(tmp_path, monkeypat assert state["max_active"] == 1 +# ---------- runtime download root management ---------- + + +def test_get_download_root_reports_env_value(tmp_path, monkeypatch): + monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path)) + info = get_download_root() + assert info["download_root"] == str(tmp_path.resolve()) + assert info["source"] == "MCARCHIVE_DOWNLOAD_ROOT env var" + assert info["raw_env_value"] == str(tmp_path) + + +def test_get_download_root_reports_default_when_no_env(monkeypatch): + monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False) + info = get_download_root() + assert info["source"] == "default (./downloads under server CWD)" + assert info["raw_env_value"] is None + + +def test_set_download_root_changes_env_and_creates_dir(tmp_path, monkeypatch): + monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False) + target = tmp_path / "new" / "spot" + assert not target.exists() + + info = set_download_root(path=str(target)) + + assert info["download_root"] == str(target.resolve()) + assert info["changed"] is True + assert target.exists() and target.is_dir() + assert os.environ["MCARCHIVE_DOWNLOAD_ROOT"] == str(target.resolve()) + + +def test_set_download_root_expands_tilde(tmp_path, monkeypatch): + monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False) + monkeypatch.setenv("HOME", str(tmp_path)) + + info = set_download_root(path="~/dl") + + assert info["download_root"] == str((tmp_path / "dl").resolve()) + assert (tmp_path / "dl").exists() + + +@pytest.mark.parametrize("forbidden", ["/etc", "/usr/local", "/var/log", "/", "/sys"]) +def test_set_download_root_refuses_system_dirs(forbidden): + with pytest.raises(ValueError, match="system directory"): + set_download_root(path=forbidden) + + +async def test_set_download_root_takes_effect_for_next_download(tmp_path, monkeypatch): + """The lazy-resolved root means a runtime change is honored by download_file + on the very next call without restarting.""" + monkeypatch.delenv("MCARCHIVE_DOWNLOAD_ROOT", raising=False) + set_download_root(path=str(tmp_path / "first")) + + def handler(req): + return httpx.Response(200, content=b"data") + + async with swap_shared_client(handler): + await download_file(identifier="nasa", filename="a.bin", overwrite=True) + # Now move the root to a different directory mid-session. + set_download_root(path=str(tmp_path / "second")) + await download_file(identifier="nasa", filename="b.bin", overwrite=True) + + assert (tmp_path / "first" / "nasa" / "a.bin").exists() + assert (tmp_path / "second" / "nasa" / "b.bin").exists() + + +# ---------- M2 (continued): cross-file parallelism ---------- + + async def test_concurrent_downloads_different_files_run_in_parallel(tmp_path, monkeypatch): """Different filenames get different locks — they should run concurrently.""" monkeypatch.setenv("MCARCHIVE_DOWNLOAD_ROOT", str(tmp_path))