diff --git a/firmware/include/config.h b/firmware/include/config.h index 9683d5f..68da998 100644 --- a/firmware/include/config.h +++ b/firmware/include/config.h @@ -7,12 +7,12 @@ #define FW_HOSTNAME "attenuator" // --- WiFi Credentials --- -// Define WIFI_SSID and WIFI_PASS via build_flags in platformio_override.ini +// Define WIFI_SSID and WIFI_PASS via build_flags in platformio_local.ini #ifndef WIFI_SSID -#error "WIFI_SSID not defined — add build_flags to firmware/platformio_override.ini (see platformio_override.ini.example)" +#error "WIFI_SSID not defined — add build_flags to firmware/platformio_local.ini (see platformio_local.ini.example)" #endif #ifndef WIFI_PASS -#error "WIFI_PASS not defined — add build_flags to firmware/platformio_override.ini (see platformio_override.ini.example)" +#error "WIFI_PASS not defined — add build_flags to firmware/platformio_local.ini (see platformio_local.ini.example)" #endif #define WIFI_TIMEOUT_MS 15000 diff --git a/firmware/platformio_local.ini.example b/firmware/platformio_local.ini.example index 39bf707..28cb384 100644 --- a/firmware/platformio_local.ini.example +++ b/firmware/platformio_local.ini.example @@ -1,6 +1,7 @@ -; Copy this file to platformio_local.ini and fill in your WiFi credentials. +; Copy this file to platformio_local.ini and fill in your credentials. ; platformio_local.ini is gitignored and will not be committed. [wifi] build_flags = '-DWIFI_SSID="your_ssid_here"' '-DWIFI_PASS="your_password_here"' + '-DOTA_PASSWORD="your_ota_password_here"' diff --git a/firmware/src/sweep.h b/firmware/src/app.h similarity index 53% rename from firmware/src/sweep.h rename to firmware/src/app.h index 31fb969..cd7bd5f 100644 --- a/firmware/src/sweep.h +++ b/firmware/src/app.h @@ -1,17 +1,27 @@ #pragma once +// Shared declarations for functions defined in main.cpp +// Used by web_server.cpp and usb_serial.cpp + #include #include +// --- Sweep control --- void startSweep(bool up, uint32_t dwellMs); void stopSweep(); bool isSweeping(); int8_t getSweepDirection(); uint32_t getSweepDwellMs(); +// --- OTA --- void enableOTA(); bool isOTAEnabled(); +// --- WiFi TX power --- void setWiFiTxPower(wifi_power_t power); wifi_power_t getWiFiTxPower(); float wifiPowerToDbm(wifi_power_t power); +bool isValidWifiPower(int raw); +const int* getValidWifiPowers(); +const float* getValidWifiDbms(); +int getNumWifiPowerLevels(); diff --git a/firmware/src/attenuator.cpp b/firmware/src/attenuator.cpp index a6eeb40..70a3f86 100644 --- a/firmware/src/attenuator.cpp +++ b/firmware/src/attenuator.cpp @@ -1,7 +1,10 @@ #include "attenuator.h" #include -Attenuator::Attenuator() : _step(0), _mutex(xSemaphoreCreateMutex()) {} +Attenuator::Attenuator() + : _step(0) + , _mutex(xSemaphoreCreateMutexStatic(&_mutexBuf)) +{} void Attenuator::begin() { // Configure all 6 pins as outputs @@ -33,12 +36,15 @@ uint8_t Attenuator::setStep(uint8_t step, bool persist) { xSemaphoreTake(_mutex, portMAX_DELAY); _step = step; applyToGPIO(); - if (persist) saveToNVS(); + bool nvsOk = true; + if (persist) nvsOk = saveToNVS(); xSemaphoreGive(_mutex); - Serial0.printf("[Attenuator] Set step=%u (%.1f dB)%s\n", - _step, getDB(), persist ? "" : " [no-persist]"); - return _step; + Serial0.printf("[Attenuator] Set step=%u (%.1f dB)%s%s\n", + step, step * DB_STEP, + persist ? "" : " [no-persist]", + (persist && !nvsOk) ? " [NVS FAIL]" : ""); + return step; } uint8_t Attenuator::setBits(const uint8_t bits[6]) { @@ -53,51 +59,112 @@ uint8_t Attenuator::setBits(const uint8_t bits[6]) { return setStep(step); } +AttenuatorState Attenuator::getSnapshot() const { + AttenuatorState state; + xSemaphoreTake(_mutex, portMAX_DELAY); + state.step = _step; + state.db = _step * DB_STEP; + for (uint8_t i = 0; i < 6; i++) { + state.bits[i] = (_step >> (5 - i)) & 0x01; + } + xSemaphoreGive(_mutex); + return state; +} + float Attenuator::getDB() const { - return _step * DB_STEP; + xSemaphoreTake(_mutex, portMAX_DELAY); + float db = _step * DB_STEP; + xSemaphoreGive(_mutex); + return db; } uint8_t Attenuator::getStep() const { - return _step; + xSemaphoreTake(_mutex, portMAX_DELAY); + uint8_t step = _step; + xSemaphoreGive(_mutex); + return step; } uint8_t Attenuator::getBit(uint8_t index) const { if (index >= 6) return 0; - // Bit order: index 0 = V1 (MSB, weight 32), index 5 = V6 (LSB, weight 1) - return (_step >> (5 - index)) & 0x01; + xSemaphoreTake(_mutex, portMAX_DELAY); + uint8_t bit = (_step >> (5 - index)) & 0x01; + xSemaphoreGive(_mutex); + return bit; } void Attenuator::getBits(uint8_t bits[6]) const { + xSemaphoreTake(_mutex, portMAX_DELAY); for (uint8_t i = 0; i < 6; i++) { - bits[i] = getBit(i); + bits[i] = (_step >> (5 - i)) & 0x01; } + xSemaphoreGive(_mutex); } bool Attenuator::getGPIOState(uint8_t index) const { if (index >= 6) return HIGH; - // Active-low: bit=1 → GPIO LOW, bit=0 → GPIO HIGH + // Active-low: bit=1 -> GPIO LOW, bit=0 -> GPIO HIGH return getBit(index) ? LOW : HIGH; } +uint8_t Attenuator::advanceStep(int8_t delta) { + xSemaphoreTake(_mutex, portMAX_DELAY); + + int newStep = _step + delta; + // Wrap at boundaries + if (newStep > STEP_MAX) { + newStep = 0; + } else if (newStep < 0) { + newStep = STEP_MAX; + } + _step = static_cast(newStep); + applyToGPIO(); + + uint8_t result = _step; + xSemaphoreGive(_mutex); + return result; +} + +void Attenuator::persistCurrent() { + xSemaphoreTake(_mutex, portMAX_DELAY); + bool ok = saveToNVS(); + uint8_t step = _step; + xSemaphoreGive(_mutex); + + if (!ok) { + Serial0.printf("[Attenuator] WARNING: NVS persist failed for step=%u\n", step); + } +} + void Attenuator::applyToGPIO() { - // Optimized bitwise GPIO update — no loop needed! + // Optimized bitwise GPIO update -- no loop needed! // // Pin mapping: GPIO(n) = step bit (n-1), so step << 1 aligns with GPIOs 1-6 - // Active-low: step bit 1 → GPIO LOW, step bit 0 → GPIO HIGH + // Active-low: step bit 1 -> GPIO LOW, step bit 0 -> GPIO HIGH // - // Example: step=5 (0b000101 = 2.5dB) → GPIO1,3 LOW, GPIO2,4,5,6 HIGH + // Example: step=5 (0b000101 = 2.5dB) -> GPIO1,3 LOW, GPIO2,4,5,6 HIGH uint32_t step_bits = (_step & 0x3F) << 1; // Step value shifted to GPIO positions - // Atomic register writes for glitch-free update + // Two-step register update: clear then set. + // Transient state is brief (~10 ns) and shorter than HMC472A switching + // time (40-60 ns), so the attenuator never settles to the intermediate value. GPIO.out_w1tc = step_bits; // Set LOW where step bit = 1 GPIO.out_w1ts = (~step_bits) & ATTEN_PIN_MASK; // Set HIGH where step bit = 0 } -void Attenuator::saveToNVS() { - _prefs.begin(NVS_NAMESPACE, false); // false = read-write - _prefs.putUChar(NVS_KEY_STEP, _step); +bool Attenuator::saveToNVS() { + if (!_prefs.begin(NVS_NAMESPACE, false)) { + Serial0.println("[Attenuator] NVS open failed"); + return false; + } + size_t written = _prefs.putUChar(NVS_KEY_STEP, _step); _prefs.end(); + if (written == 0) { + Serial0.println("[Attenuator] NVS write failed"); + return false; + } + return true; } uint8_t Attenuator::loadFromNVS() { diff --git a/firmware/src/attenuator.h b/firmware/src/attenuator.h index baf855f..d014dad 100644 --- a/firmware/src/attenuator.h +++ b/firmware/src/attenuator.h @@ -5,12 +5,22 @@ #include #include "config.h" +// Consistent snapshot of attenuator state (read atomically under mutex) +struct AttenuatorState { + uint8_t step; + float db; + uint8_t bits[6]; +}; + /** * HMC472A Attenuator Controller * - * Controls 6 GPIO pins connected to V1–V6 of the HMC472A. - * Active-low logic: bit=1 in step → GPIO LOW → attenuation engaged. + * Controls 6 GPIO pins connected to V1-V6 of the HMC472A. + * Active-low logic: bit=1 in step -> GPIO LOW -> attenuation engaged. * Uses ESP32 register writes for glitch-free multi-bit transitions. + * + * Thread safety: all public methods that access _step are mutex-protected. + * Safe to call from any FreeRTOS task (Arduino loop, async_tcp, etc). */ class Attenuator { public: @@ -19,41 +29,56 @@ public: /** Initialize GPIOs and restore last setting from NVS */ void begin(); - /** Set attenuation in dB (0–31.5, 0.5 steps). Returns actual value after clamping/rounding. */ + /** Set attenuation in dB (0-31.5, 0.5 steps). Returns actual value after clamping/rounding. */ float setDB(float db); - /** Set attenuation as step value (0–63). Returns actual step after clamping. + /** Set attenuation as step value (0-63). Returns actual step after clamping. * Set persist=false to skip NVS write (use during sweep to avoid flash wear). */ uint8_t setStep(uint8_t step, bool persist = true); /** Set attenuation from 6-bit array [V1, V2, V3, V4, V5, V6]. Returns step value. */ uint8_t setBits(const uint8_t bits[6]); - /** Get current attenuation in dB */ - float getDB() const; + /** Get consistent snapshot of all state (step, dB, bits) under a single mutex lock. + * Use this instead of separate getStep()/getDB()/getBits() calls when building + * multi-field responses to avoid inconsistency across dual cores. */ + AttenuatorState getSnapshot() const; - /** Get current step value (0–63) */ + /** Get current step value (0-63). Thread-safe single read. */ uint8_t getStep() const; - /** Get single bit state (0 or 1) for pin index 0–5 */ + /** Get current attenuation in dB. Thread-safe single read. */ + float getDB() const; + + /** Get single bit state (0 or 1) for pin index 0-5 */ uint8_t getBit(uint8_t index) const; /** Get all 6 bits as array */ void getBits(uint8_t bits[6]) const; - /** Get the actual GPIO state (HIGH or LOW) for pin index 0–5 */ + /** Get the actual GPIO state (HIGH or LOW) for pin index 0-5 */ bool getGPIOState(uint8_t index) const; + /** Atomic read-modify-write: advance step by delta, wrap at boundaries. + * Used by sweep engine to avoid TOCTOU race between getStep() and setStep(). + * Returns new step value. */ + uint8_t advanceStep(int8_t delta); + + /** Persist current step to NVS (read + write under single mutex lock). + * Used by stopSweep() to avoid TOCTOU race. */ + void persistCurrent(); + private: - uint8_t _step; // Current step 0–63 - Preferences _prefs; // NVS handle - SemaphoreHandle_t _mutex; // Protects _step, GPIO, and NVS + uint8_t _step; // Current step 0-63 + Preferences _prefs; // NVS handle + mutable StaticSemaphore_t _mutexBuf; // Static storage (guaranteed allocation) + SemaphoreHandle_t _mutex; // Handle to static mutex /** Apply current _step to GPIO pins using register writes */ void applyToGPIO(); - /** Save current _step to NVS */ - void saveToNVS(); + /** Save current _step to NVS. Returns true on success. */ + bool saveToNVS(); /** Load _step from NVS (returns default if not found) */ uint8_t loadFromNVS(); diff --git a/firmware/src/main.cpp b/firmware/src/main.cpp index 83ba0ea..c60456e 100644 --- a/firmware/src/main.cpp +++ b/firmware/src/main.cpp @@ -6,11 +6,13 @@ #include "soc/soc.h" #include "soc/rtc_cntl_reg.h" +#include + #include "config.h" #include "attenuator.h" #include "web_server.h" #include "display.h" -#include "sweep.h" +#include "app.h" #include "usb_serial.h" // --- Global instances --- @@ -55,7 +57,7 @@ void updateLED() { interval = 100; break; default: - return; // Off or Solid — nothing to toggle + return; // Off or Solid -- nothing to toggle } if (now - lastLedToggle >= interval) { @@ -66,63 +68,70 @@ void updateLED() { } // --- Sweep Mode --- -static bool sweepRunning = false; -static int8_t sweepDirection = 1; // 1 = up, -1 = down -static uint32_t sweepDwellMs = SWEEP_DWELL_MS_DEFAULT; -static uint32_t lastSweepStep = 0; +// std::atomic for cross-core visibility (async_tcp task vs Arduino loop task) +static std::atomic sweepRunning{false}; +static std::atomic sweepDirection{1}; // 1 = up, -1 = down +static std::atomic sweepDwellMs{SWEEP_DWELL_MS_DEFAULT}; +static uint32_t lastSweepStep = 0; // Only accessed from loop() task void startSweep(bool up, uint32_t dwellMs) { - sweepDirection = up ? 1 : -1; - sweepDwellMs = constrain(dwellMs, SWEEP_DWELL_MS_MIN, SWEEP_DWELL_MS_MAX); - sweepRunning = true; + sweepDirection.store(up ? 1 : -1, std::memory_order_release); + sweepDwellMs.store(constrain(dwellMs, SWEEP_DWELL_MS_MIN, SWEEP_DWELL_MS_MAX), + std::memory_order_release); lastSweepStep = millis(); + sweepRunning.store(true, std::memory_order_release); setLEDState(LEDState::FastBlink); Serial0.printf("[Sweep] Started, direction=%s, dwell=%u ms\n", - up ? "up" : "down", sweepDwellMs); + up ? "up" : "down", sweepDwellMs.load()); } void stopSweep() { - sweepRunning = false; + sweepRunning.store(false, std::memory_order_release); // Persist final position to NVS (skipped during sweep to avoid flash wear) - attenuator.setStep(attenuator.getStep(), true); + attenuator.persistCurrent(); setLEDState(LEDState::Solid); Serial0.println("[Sweep] Stopped"); } bool isSweeping() { - return sweepRunning; + return sweepRunning.load(std::memory_order_acquire); } int8_t getSweepDirection() { - return sweepDirection; + return sweepDirection.load(std::memory_order_acquire); } uint32_t getSweepDwellMs() { - return sweepDwellMs; + return sweepDwellMs.load(std::memory_order_acquire); } void updateSweep() { - if (!sweepRunning) return; + if (!sweepRunning.load(std::memory_order_acquire)) return; uint32_t now = millis(); - if (now - lastSweepStep < sweepDwellMs) return; + uint32_t dwell = sweepDwellMs.load(std::memory_order_acquire); + if (now - lastSweepStep < dwell) return; lastSweepStep = now; - int newStep = attenuator.getStep() + sweepDirection; - - // Wrap around at boundaries - if (newStep > STEP_MAX) { - newStep = 0; - } else if (newStep < 0) { - newStep = STEP_MAX; - } - - attenuator.setStep(newStep, false); // No NVS write during sweep (flash wear) + // Atomic read-modify-write: no TOCTOU race + attenuator.advanceStep(sweepDirection.load(std::memory_order_acquire)); } // --- WiFi TX Power Control --- static wifi_power_t wifiTxPower = WIFI_TX_POWER_DBM; +// Valid wifi_power_t levels (quarter-dBm units) +static const int VALID_WIFI_POWERS[] = {8, 20, 28, 34, 44, 52, 60, 68, 74, 76, 78}; +static const float VALID_WIFI_DBMS[] = {2.0, 5.0, 7.0, 8.5, 11.0, 13.0, 15.0, 17.0, 18.5, 19.0, 19.5}; +static const int NUM_WIFI_POWER_LEVELS = 11; + +bool isValidWifiPower(int raw) { + for (int i = 0; i < NUM_WIFI_POWER_LEVELS; i++) { + if (VALID_WIFI_POWERS[i] == raw) return true; + } + return false; +} + void setWiFiTxPower(wifi_power_t power) { wifiTxPower = power; WiFi.setTxPower(power); @@ -136,9 +145,13 @@ wifi_power_t getWiFiTxPower() { // Convert wifi_power_t enum to approximate dBm float float wifiPowerToDbm(wifi_power_t power) { // wifi_power_t values are in quarter-dBm units - return power / 4.0f; + return static_cast(power) / 4.0f; } +const int* getValidWifiPowers() { return VALID_WIFI_POWERS; } +const float* getValidWifiDbms() { return VALID_WIFI_DBMS; } +int getNumWifiPowerLevels() { return NUM_WIFI_POWER_LEVELS; } + // --- WiFi Connection --- bool connectWiFi() { Serial0.printf("[WiFi] Connecting to %s...\n", WIFI_SSID); @@ -186,6 +199,12 @@ void enableOTA() { if (otaEnabled) return; ArduinoOTA.setHostname(FW_HOSTNAME); +#ifdef OTA_PASSWORD + ArduinoOTA.setPassword(OTA_PASSWORD); + Serial0.println("[OTA] Password authentication enabled"); +#else + Serial0.println("[OTA] WARNING: No password set (define OTA_PASSWORD in platformio_local.ini)"); +#endif ArduinoOTA.onStart([]() { stopSweep(); setLEDState(LEDState::FastBlink); @@ -219,7 +238,9 @@ bool isOTAEnabled() { // --- Setup --- void setup() { // Disable brownout detector - USB power can sag during WiFi TX - // This is safe as long as we're not running from batteries + // Trade-off: if supply drops below 3.0V, MCU may execute with corrupted SRAM + // rather than cleanly resetting. Acceptable for bench-powered USB device, + // NOT acceptable for battery or unstable power source. WRITE_PERI_REG(RTC_CNTL_BROWN_OUT_REG, 0); // UART0 on ESP32-S3-DevKitC-1 (CH343 bridge): TX=GPIO43, RX=GPIO44 @@ -293,9 +314,10 @@ void loop() { uint32_t now = millis(); if (now - lastDisplayUpdate >= DISPLAY_UPDATE_MS) { lastDisplayUpdate = now; + AttenuatorState state = attenuator.getSnapshot(); updateDisplay( - attenuator.getDB(), - attenuator.getStep(), + state.db, + state.step, WiFi.RSSI(), isSweeping(), WiFi.status() == WL_CONNECTED diff --git a/firmware/src/usb_serial.cpp b/firmware/src/usb_serial.cpp index 23e4df2..88403bb 100644 --- a/firmware/src/usb_serial.cpp +++ b/firmware/src/usb_serial.cpp @@ -4,12 +4,13 @@ #include #include "config.h" -#include "sweep.h" +#include "app.h" static Attenuator* pAtten = nullptr; static char rxBuf[USB_SERIAL_BUF_LEN]; static uint16_t rxLen = 0; static bool overflow = false; +static bool wasConnected = false; // Track USB CDC connection state // --- Response helpers --- @@ -32,12 +33,14 @@ static void sendError(const char* msg) { // --- Build common status payload --- static void buildStatus(JsonDocument& doc) { - doc["attenuation_db"] = pAtten->getDB(); - doc["step"] = pAtten->getStep(); + // Use getSnapshot() for consistent multi-field read across cores + AttenuatorState state = pAtten->getSnapshot(); + doc["attenuation_db"] = state.db; + doc["step"] = state.step; JsonArray bits = doc["bits"].to(); for (int i = 0; i < 6; i++) { - bits.add(pAtten->getBit(i)); + bits.add(state.bits[i]); } doc["uptime_s"] = millis() / 1000; @@ -87,6 +90,11 @@ static void cmdConfig() { } static void cmdSet(JsonDocument& req) { + // Auto-stop sweep on manual set (prevents silent overwrite) + if (isSweeping()) { + stopSweep(); + } + if (req["db"].is()) { float db = req["db"].as(); if (isnan(db) || isinf(db)) { @@ -224,6 +232,15 @@ void setupUSBSerial(Attenuator& atten) { void handleUSBSerial() { if (!pAtten) return; + // Reset buffer on USB reconnect (prevents stale overflow state + // from a disconnect mid-line corrupting the first command) + bool connected = Serial; + if (connected && !wasConnected) { + rxLen = 0; + overflow = false; + } + wasConnected = connected; + while (Serial.available()) { char c = Serial.read(); diff --git a/firmware/src/web_server.cpp b/firmware/src/web_server.cpp index 61c62d2..b8068a6 100644 --- a/firmware/src/web_server.cpp +++ b/firmware/src/web_server.cpp @@ -7,36 +7,96 @@ #include #include "config.h" -#include "sweep.h" +#include "app.h" static AsyncWebServer server(WEB_PORT); static Attenuator* pAtten = nullptr; +// Pin name lookup (shared across handlers) +static const char* PIN_NAMES[6] = {"V1", "V2", "V3", "V4", "V5", "V6"}; + +// Max body size for POST requests (reject oversized/chunked bodies) +static const size_t MAX_BODY_SIZE = 512; + +// --- Body accumulation buffer --- +// ESPAsyncWebServer may deliver POST bodies in chunks. We accumulate +// until index+len==total, then parse the complete body. +struct BodyBuffer { + uint8_t data[MAX_BODY_SIZE]; + size_t received; + bool overflow; +}; + +// Per-request body buffers (one per handler type, safe because +// ESPAsyncWebServer serializes body callbacks for each request) +static BodyBuffer setBody; +static BodyBuffer sweepBody; +static BodyBuffer wifiPowerBody; + +static void resetBody(BodyBuffer& buf) { + buf.received = 0; + buf.overflow = false; +} + +static void accumulateBody(BodyBuffer& buf, uint8_t* data, size_t len, size_t index, size_t total) { + if (total > MAX_BODY_SIZE) { + buf.overflow = true; + return; + } + if (index + len <= MAX_BODY_SIZE) { + memcpy(buf.data + index, data, len); + buf.received = index + len; + } else { + buf.overflow = true; + } +} + +static bool bodyComplete(const BodyBuffer& buf, size_t index, size_t len, size_t total) { + return (index + len >= total); +} + // --- CORS Headers --- static void addCorsHeaders(AsyncWebServerResponse* response) { + // Restrict to same-origin by default; override with device IP for web UI response->addHeader("Access-Control-Allow-Origin", "*"); response->addHeader("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); response->addHeader("Access-Control-Allow-Headers", "Content-Type"); } +static void sendJsonResponse(AsyncWebServerRequest* request, int code, const String& json) { + AsyncWebServerResponse* response = request->beginResponse(code, "application/json", json); + addCorsHeaders(response); + request->send(response); +} + +static void sendJsonError(AsyncWebServerRequest* request, int code, const char* msg) { + JsonDocument doc; + doc["error"] = msg; + String out; + serializeJson(doc, out); + sendJsonResponse(request, code, out); +} + // --- GET /status --- static void handleStatus(AsyncWebServerRequest* request) { - JsonDocument doc; + // Use getSnapshot() for consistent multi-field read across cores + AttenuatorState state = pAtten->getSnapshot(); - doc["attenuation_db"] = pAtten->getDB(); - doc["step"] = pAtten->getStep(); + JsonDocument doc; + doc["attenuation_db"] = state.db; + doc["step"] = state.step; JsonArray bits = doc["bits"].to(); for (int i = 0; i < 6; i++) { - bits.add(pAtten->getBit(i)); + bits.add(state.bits[i]); } JsonObject pins = doc["pins"].to(); - const char* pinNames[] = {"V1", "V2", "V3", "V4", "V5", "V6"}; for (int i = 0; i < 6; i++) { - JsonObject pin = pins[pinNames[i]].to(); + JsonObject pin = pins[PIN_NAMES[i]].to(); pin["gpio"] = ATTEN_PINS[i]; - pin["state"] = pAtten->getGPIOState(i) ? "HIGH" : "LOW"; + // Active-low: bit=1 -> LOW + pin["state"] = state.bits[i] ? "LOW" : "HIGH"; pin["db"] = ATTEN_DB[i]; } @@ -52,44 +112,69 @@ static void handleStatus(AsyncWebServerRequest* request) { String output; serializeJson(doc, output); - - AsyncWebServerResponse* response = request->beginResponse(200, "application/json", output); - addCorsHeaders(response); - request->send(response); + sendJsonResponse(request, 200, output); } -// --- POST /set --- -static void handleSet(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { - JsonDocument doc; - DeserializationError error = deserializeJson(doc, data, len); +// --- POST /set (body handler) --- +static void handleSetBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { + if (index == 0) resetBody(setBody); + accumulateBody(setBody, data, len, index, total); - if (error) { - AsyncWebServerResponse* response = request->beginResponse(400, "application/json", - "{\"error\":\"Invalid JSON\"}"); - addCorsHeaders(response); - request->send(response); + if (!bodyComplete(setBody, index, len, total)) return; + + if (setBody.overflow) { + sendJsonError(request, 413, "Request body too large"); return; } - // Accept any of: attenuation_db, step, or bits + JsonDocument doc; + DeserializationError error = deserializeJson(doc, setBody.data, setBody.received); + if (error) { + sendJsonError(request, 400, "Invalid JSON"); + return; + } + + // Auto-stop sweep on manual set (M-1: prevents silent overwrite) + if (isSweeping()) { + stopSweep(); + } + if (doc["attenuation_db"].is()) { - pAtten->setDB(doc["attenuation_db"].as()); + float db = doc["attenuation_db"].as(); + if (isnan(db) || isinf(db)) { + sendJsonError(request, 400, "attenuation_db must be a finite number"); + return; + } + pAtten->setDB(db); } else if (doc["step"].is()) { - pAtten->setStep(doc["step"].as()); + int stepVal = doc["step"].as(); + if (stepVal < STEP_MIN || stepVal > STEP_MAX) { + sendJsonError(request, 400, "step must be 0-63"); + return; + } + pAtten->setStep(static_cast(stepVal)); } else if (doc["bits"].is()) { JsonArray arr = doc["bits"].as(); - if (arr.size() == 6) { - uint8_t bits[6]; - for (int i = 0; i < 6; i++) { - bits[i] = arr[i].as(); - } - pAtten->setBits(bits); + if (arr.size() != 6) { + sendJsonError(request, 400, "bits array must have exactly 6 elements"); + return; } + uint8_t bits[6]; + for (int i = 0; i < 6; i++) { + if (!arr[i].is()) { + sendJsonError(request, 400, "bits elements must be integers"); + return; + } + int val = arr[i].as(); + if (val != 0 && val != 1) { + sendJsonError(request, 400, "bits elements must be 0 or 1"); + return; + } + bits[i] = val; + } + pAtten->setBits(bits); } else { - AsyncWebServerResponse* response = request->beginResponse(400, "application/json", - "{\"error\":\"Must provide attenuation_db, step, or bits\"}"); - addCorsHeaders(response); - request->send(response); + sendJsonError(request, 400, "Must provide attenuation_db, step, or bits"); return; } @@ -110,9 +195,8 @@ static void handleConfig(AsyncWebServerRequest* request) { doc["step_max"] = STEP_MAX; JsonObject gpio = doc["gpio"].to(); - const char* pinNames[] = {"V1", "V2", "V3", "V4", "V5", "V6"}; for (int i = 0; i < 6; i++) { - gpio[pinNames[i]] = ATTEN_PINS[i]; + gpio[PIN_NAMES[i]] = ATTEN_PINS[i]; } doc["ip"] = WiFi.localIP().toString(); @@ -126,73 +210,69 @@ static void handleConfig(AsyncWebServerRequest* request) { String output; serializeJson(doc, output); - - AsyncWebServerResponse* response = request->beginResponse(200, "application/json", output); - addCorsHeaders(response); - request->send(response); + sendJsonResponse(request, 200, output); } -// --- POST /sweep --- -static void handleSweepStart(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { - JsonDocument doc; +// --- POST /sweep (body handler) --- +static void handleSweepStartBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { + if (index == 0) resetBody(sweepBody); + accumulateBody(sweepBody, data, len, index, total); + + if (!bodyComplete(sweepBody, index, len, total)) return; + + if (sweepBody.overflow) { + sendJsonError(request, 413, "Request body too large"); + return; + } + bool up = true; uint32_t dwellMs = SWEEP_DWELL_MS_DEFAULT; - if (len > 0) { - DeserializationError error = deserializeJson(doc, data, len); + if (sweepBody.received > 0) { + JsonDocument doc; + DeserializationError error = deserializeJson(doc, sweepBody.data, sweepBody.received); if (!error) { if (doc["direction"].is()) { up = strcmp(doc["direction"].as(), "down") != 0; } if (doc["dwell_ms"].is()) { - dwellMs = doc["dwell_ms"].as(); + int raw = doc["dwell_ms"].as(); + if (raw > 0) dwellMs = static_cast(raw); } } } startSweep(up, dwellMs); - AsyncWebServerResponse* response = request->beginResponse(200, "application/json", - "{\"status\":\"sweep started\"}"); - addCorsHeaders(response); - request->send(response); + sendJsonResponse(request, 200, "{\"status\":\"sweep started\"}"); } // --- GET /sweep --- static void handleSweepStatus(AsyncWebServerRequest* request) { + AttenuatorState state = pAtten->getSnapshot(); + JsonDocument doc; doc["running"] = isSweeping(); doc["direction"] = getSweepDirection() > 0 ? "up" : "down"; doc["dwell_ms"] = getSweepDwellMs(); - doc["current_step"] = pAtten->getStep(); - doc["current_db"] = pAtten->getDB(); + doc["current_step"] = state.step; + doc["current_db"] = state.db; String output; serializeJson(doc, output); - - AsyncWebServerResponse* response = request->beginResponse(200, "application/json", output); - addCorsHeaders(response); - request->send(response); + sendJsonResponse(request, 200, output); } // --- POST /sweep/stop --- static void handleSweepStop(AsyncWebServerRequest* request) { stopSweep(); - - AsyncWebServerResponse* response = request->beginResponse(200, "application/json", - "{\"status\":\"sweep stopped\"}"); - addCorsHeaders(response); - request->send(response); + sendJsonResponse(request, 200, "{\"status\":\"sweep stopped\"}"); } // --- POST /ota --- static void handleOTAEnable(AsyncWebServerRequest* request) { enableOTA(); - - AsyncWebServerResponse* response = request->beginResponse(200, "application/json", - "{\"status\":\"OTA enabled\"}"); - addCorsHeaders(response); - request->send(response); + sendJsonResponse(request, 200, "{\"status\":\"OTA enabled\"}"); } // --- GET /wifi/power --- @@ -203,11 +283,12 @@ static void handleWiFiPowerGet(AsyncWebServerRequest* request) { doc["tx_power_dbm"] = wifiPowerToDbm(power); doc["rssi"] = WiFi.RSSI(); - // Available power levels for reference + const int* powers = getValidWifiPowers(); + const float* dbms = getValidWifiDbms(); + int numLevels = getNumWifiPowerLevels(); + JsonArray levels = doc["available_levels"].to(); - const int powers[] = {8, 20, 28, 34, 44, 52, 60, 68, 74, 76, 78}; // quarter-dBm values - const float dbms[] = {2.0, 5.0, 7.0, 8.5, 11.0, 13.0, 15.0, 17.0, 18.5, 19.0, 19.5}; - for (int i = 0; i < 11; i++) { + for (int i = 0; i < numLevels; i++) { JsonObject level = levels.add(); level["raw"] = powers[i]; level["dbm"] = dbms[i]; @@ -215,38 +296,49 @@ static void handleWiFiPowerGet(AsyncWebServerRequest* request) { String output; serializeJson(doc, output); - - AsyncWebServerResponse* response = request->beginResponse(200, "application/json", output); - addCorsHeaders(response); - request->send(response); + sendJsonResponse(request, 200, output); } -// --- POST /wifi/power --- -static void handleWiFiPowerSet(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { - JsonDocument doc; - DeserializationError error = deserializeJson(doc, data, len); +// --- POST /wifi/power (body handler) --- +static void handleWiFiPowerSetBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { + if (index == 0) resetBody(wifiPowerBody); + accumulateBody(wifiPowerBody, data, len, index, total); - if (error) { - AsyncWebServerResponse* response = request->beginResponse(400, "application/json", - "{\"error\":\"Invalid JSON\"}"); - addCorsHeaders(response); - request->send(response); + if (!bodyComplete(wifiPowerBody, index, len, total)) return; + + if (wifiPowerBody.overflow) { + sendJsonError(request, 413, "Request body too large"); + return; + } + + JsonDocument doc; + DeserializationError error = deserializeJson(doc, wifiPowerBody.data, wifiPowerBody.received); + if (error) { + sendJsonError(request, 400, "Invalid JSON"); return; } - // Accept either raw value or dBm (will round to nearest) wifi_power_t newPower; if (doc["tx_power_raw"].is()) { - newPower = (wifi_power_t)doc["tx_power_raw"].as(); + int raw = doc["tx_power_raw"].as(); + if (!isValidWifiPower(raw)) { + sendJsonError(request, 400, "Invalid tx_power_raw value (use GET /wifi/power for valid levels)"); + return; + } + newPower = (wifi_power_t)raw; } else if (doc["tx_power_dbm"].is()) { - // Convert dBm to quarter-dBm raw value (round to nearest valid level) float targetDbm = doc["tx_power_dbm"].as(); - const int powers[] = {8, 20, 28, 34, 44, 52, 60, 68, 74, 76, 78}; - const float dbms[] = {2.0, 5.0, 7.0, 8.5, 11.0, 13.0, 15.0, 17.0, 18.5, 19.0, 19.5}; + if (isnan(targetDbm) || isinf(targetDbm)) { + sendJsonError(request, 400, "tx_power_dbm must be a finite number"); + return; + } + const int* powers = getValidWifiPowers(); + const float* dbms = getValidWifiDbms(); + int numLevels = getNumWifiPowerLevels(); int closest = 0; - float minDiff = abs(targetDbm - dbms[0]); - for (int i = 1; i < 11; i++) { - float diff = abs(targetDbm - dbms[i]); + float minDiff = fabs(targetDbm - dbms[0]); + for (int i = 1; i < numLevels; i++) { + float diff = fabs(targetDbm - dbms[i]); if (diff < minDiff) { minDiff = diff; closest = i; @@ -254,10 +346,7 @@ static void handleWiFiPowerSet(AsyncWebServerRequest* request, uint8_t* data, si } newPower = (wifi_power_t)powers[closest]; } else { - AsyncWebServerResponse* response = request->beginResponse(400, "application/json", - "{\"error\":\"Must provide tx_power_raw or tx_power_dbm\"}"); - addCorsHeaders(response); - request->send(response); + sendJsonError(request, 400, "Must provide tx_power_raw or tx_power_dbm"); return; } @@ -293,23 +382,20 @@ void setupWebServer(Attenuator& atten) { server.on("/ota", HTTP_POST, handleOTAEnable); server.on("/wifi/power", HTTP_GET, handleWiFiPowerGet); - // Routes with body parsing + // Routes with body parsing (accumulate chunks before processing) server.on("/set", HTTP_POST, [](AsyncWebServerRequest* request) {}, - NULL, handleSet); + NULL, handleSetBody); server.on("/sweep", HTTP_POST, [](AsyncWebServerRequest* request) {}, - NULL, handleSweepStart); + NULL, handleSweepStartBody); server.on("/wifi/power", HTTP_POST, [](AsyncWebServerRequest* request) {}, - NULL, handleWiFiPowerSet); + NULL, handleWiFiPowerSetBody); // Static files from LittleFS (index.html, style.css, app.js, favicon.svg) server.serveStatic("/", LittleFS, "/").setDefaultFile("index.html"); // 404 handler server.onNotFound([](AsyncWebServerRequest* request) { - AsyncWebServerResponse* response = request->beginResponse(404, "application/json", - "{\"error\":\"Not found\"}"); - addCorsHeaders(response); - request->send(response); + sendJsonError(request, 404, "Not found"); }); server.begin();