Harden firmware for dual-core concurrency and input validation

Address safety review findings for the dual-interface (WiFi + USB serial)
architecture running on the ESP32-S3's two Xtensa LX7 cores:

- Protect sweep state with std::atomic (acquire/release ordering)
- Add Attenuator::getSnapshot() for consistent multi-field reads
- Add advanceStep()/persistCurrent() to eliminate TOCTOU races
- Switch to StaticSemaphore_t (compile-time mutex, can't fail)
- Accumulate web server POST bodies before parsing (chunked TCP fix)
- Backport USB serial input validation to web server handlers
- Auto-stop sweep on manual set (prevents silent overwrite)
- Validate WiFi TX power against known-good levels
- Add OTA password authentication support
- Check NVS write return values, log failures
- Reset USB serial buffer on reconnect (stale overflow fix)
- Rename sweep.h to app.h (declares more than sweep functions)
This commit is contained in:
Ryan Malloy 2026-02-18 18:43:08 -07:00
parent fee8d9c1f9
commit 4e19882d32
8 changed files with 402 additions and 174 deletions

View File

@ -7,12 +7,12 @@
#define FW_HOSTNAME "attenuator" #define FW_HOSTNAME "attenuator"
// --- WiFi Credentials --- // --- WiFi Credentials ---
// Define WIFI_SSID and WIFI_PASS via build_flags in platformio_override.ini // Define WIFI_SSID and WIFI_PASS via build_flags in platformio_local.ini
#ifndef WIFI_SSID #ifndef WIFI_SSID
#error "WIFI_SSID not defined — add build_flags to firmware/platformio_override.ini (see platformio_override.ini.example)" #error "WIFI_SSID not defined — add build_flags to firmware/platformio_local.ini (see platformio_local.ini.example)"
#endif #endif
#ifndef WIFI_PASS #ifndef WIFI_PASS
#error "WIFI_PASS not defined — add build_flags to firmware/platformio_override.ini (see platformio_override.ini.example)" #error "WIFI_PASS not defined — add build_flags to firmware/platformio_local.ini (see platformio_local.ini.example)"
#endif #endif
#define WIFI_TIMEOUT_MS 15000 #define WIFI_TIMEOUT_MS 15000

View File

@ -1,6 +1,7 @@
; Copy this file to platformio_local.ini and fill in your WiFi credentials. ; Copy this file to platformio_local.ini and fill in your credentials.
; platformio_local.ini is gitignored and will not be committed. ; platformio_local.ini is gitignored and will not be committed.
[wifi] [wifi]
build_flags = build_flags =
'-DWIFI_SSID="your_ssid_here"' '-DWIFI_SSID="your_ssid_here"'
'-DWIFI_PASS="your_password_here"' '-DWIFI_PASS="your_password_here"'
'-DOTA_PASSWORD="your_ota_password_here"'

View File

@ -1,17 +1,27 @@
#pragma once #pragma once
// Shared declarations for functions defined in main.cpp
// Used by web_server.cpp and usb_serial.cpp
#include <Arduino.h> #include <Arduino.h>
#include <WiFi.h> #include <WiFi.h>
// --- Sweep control ---
void startSweep(bool up, uint32_t dwellMs); void startSweep(bool up, uint32_t dwellMs);
void stopSweep(); void stopSweep();
bool isSweeping(); bool isSweeping();
int8_t getSweepDirection(); int8_t getSweepDirection();
uint32_t getSweepDwellMs(); uint32_t getSweepDwellMs();
// --- OTA ---
void enableOTA(); void enableOTA();
bool isOTAEnabled(); bool isOTAEnabled();
// --- WiFi TX power ---
void setWiFiTxPower(wifi_power_t power); void setWiFiTxPower(wifi_power_t power);
wifi_power_t getWiFiTxPower(); wifi_power_t getWiFiTxPower();
float wifiPowerToDbm(wifi_power_t power); float wifiPowerToDbm(wifi_power_t power);
bool isValidWifiPower(int raw);
const int* getValidWifiPowers();
const float* getValidWifiDbms();
int getNumWifiPowerLevels();

View File

@ -1,7 +1,10 @@
#include "attenuator.h" #include "attenuator.h"
#include <soc/gpio_struct.h> #include <soc/gpio_struct.h>
Attenuator::Attenuator() : _step(0), _mutex(xSemaphoreCreateMutex()) {} Attenuator::Attenuator()
: _step(0)
, _mutex(xSemaphoreCreateMutexStatic(&_mutexBuf))
{}
void Attenuator::begin() { void Attenuator::begin() {
// Configure all 6 pins as outputs // Configure all 6 pins as outputs
@ -33,12 +36,15 @@ uint8_t Attenuator::setStep(uint8_t step, bool persist) {
xSemaphoreTake(_mutex, portMAX_DELAY); xSemaphoreTake(_mutex, portMAX_DELAY);
_step = step; _step = step;
applyToGPIO(); applyToGPIO();
if (persist) saveToNVS(); bool nvsOk = true;
if (persist) nvsOk = saveToNVS();
xSemaphoreGive(_mutex); xSemaphoreGive(_mutex);
Serial0.printf("[Attenuator] Set step=%u (%.1f dB)%s\n", Serial0.printf("[Attenuator] Set step=%u (%.1f dB)%s%s\n",
_step, getDB(), persist ? "" : " [no-persist]"); step, step * DB_STEP,
return _step; persist ? "" : " [no-persist]",
(persist && !nvsOk) ? " [NVS FAIL]" : "");
return step;
} }
uint8_t Attenuator::setBits(const uint8_t bits[6]) { uint8_t Attenuator::setBits(const uint8_t bits[6]) {
@ -53,51 +59,112 @@ uint8_t Attenuator::setBits(const uint8_t bits[6]) {
return setStep(step); return setStep(step);
} }
AttenuatorState Attenuator::getSnapshot() const {
AttenuatorState state;
xSemaphoreTake(_mutex, portMAX_DELAY);
state.step = _step;
state.db = _step * DB_STEP;
for (uint8_t i = 0; i < 6; i++) {
state.bits[i] = (_step >> (5 - i)) & 0x01;
}
xSemaphoreGive(_mutex);
return state;
}
float Attenuator::getDB() const { float Attenuator::getDB() const {
return _step * DB_STEP; xSemaphoreTake(_mutex, portMAX_DELAY);
float db = _step * DB_STEP;
xSemaphoreGive(_mutex);
return db;
} }
uint8_t Attenuator::getStep() const { uint8_t Attenuator::getStep() const {
return _step; xSemaphoreTake(_mutex, portMAX_DELAY);
uint8_t step = _step;
xSemaphoreGive(_mutex);
return step;
} }
uint8_t Attenuator::getBit(uint8_t index) const { uint8_t Attenuator::getBit(uint8_t index) const {
if (index >= 6) return 0; if (index >= 6) return 0;
// Bit order: index 0 = V1 (MSB, weight 32), index 5 = V6 (LSB, weight 1) xSemaphoreTake(_mutex, portMAX_DELAY);
return (_step >> (5 - index)) & 0x01; uint8_t bit = (_step >> (5 - index)) & 0x01;
xSemaphoreGive(_mutex);
return bit;
} }
void Attenuator::getBits(uint8_t bits[6]) const { void Attenuator::getBits(uint8_t bits[6]) const {
xSemaphoreTake(_mutex, portMAX_DELAY);
for (uint8_t i = 0; i < 6; i++) { for (uint8_t i = 0; i < 6; i++) {
bits[i] = getBit(i); bits[i] = (_step >> (5 - i)) & 0x01;
} }
xSemaphoreGive(_mutex);
} }
bool Attenuator::getGPIOState(uint8_t index) const { bool Attenuator::getGPIOState(uint8_t index) const {
if (index >= 6) return HIGH; if (index >= 6) return HIGH;
// Active-low: bit=1 → GPIO LOW, bit=0 → GPIO HIGH // Active-low: bit=1 -> GPIO LOW, bit=0 -> GPIO HIGH
return getBit(index) ? LOW : HIGH; return getBit(index) ? LOW : HIGH;
} }
uint8_t Attenuator::advanceStep(int8_t delta) {
xSemaphoreTake(_mutex, portMAX_DELAY);
int newStep = _step + delta;
// Wrap at boundaries
if (newStep > STEP_MAX) {
newStep = 0;
} else if (newStep < 0) {
newStep = STEP_MAX;
}
_step = static_cast<uint8_t>(newStep);
applyToGPIO();
uint8_t result = _step;
xSemaphoreGive(_mutex);
return result;
}
void Attenuator::persistCurrent() {
xSemaphoreTake(_mutex, portMAX_DELAY);
bool ok = saveToNVS();
uint8_t step = _step;
xSemaphoreGive(_mutex);
if (!ok) {
Serial0.printf("[Attenuator] WARNING: NVS persist failed for step=%u\n", step);
}
}
void Attenuator::applyToGPIO() { void Attenuator::applyToGPIO() {
// Optimized bitwise GPIO update — no loop needed! // Optimized bitwise GPIO update -- no loop needed!
// //
// Pin mapping: GPIO(n) = step bit (n-1), so step << 1 aligns with GPIOs 1-6 // Pin mapping: GPIO(n) = step bit (n-1), so step << 1 aligns with GPIOs 1-6
// Active-low: step bit 1 → GPIO LOW, step bit 0 → GPIO HIGH // Active-low: step bit 1 -> GPIO LOW, step bit 0 -> GPIO HIGH
// //
// Example: step=5 (0b000101 = 2.5dB) → GPIO1,3 LOW, GPIO2,4,5,6 HIGH // Example: step=5 (0b000101 = 2.5dB) -> GPIO1,3 LOW, GPIO2,4,5,6 HIGH
uint32_t step_bits = (_step & 0x3F) << 1; // Step value shifted to GPIO positions uint32_t step_bits = (_step & 0x3F) << 1; // Step value shifted to GPIO positions
// Atomic register writes for glitch-free update // Two-step register update: clear then set.
// Transient state is brief (~10 ns) and shorter than HMC472A switching
// time (40-60 ns), so the attenuator never settles to the intermediate value.
GPIO.out_w1tc = step_bits; // Set LOW where step bit = 1 GPIO.out_w1tc = step_bits; // Set LOW where step bit = 1
GPIO.out_w1ts = (~step_bits) & ATTEN_PIN_MASK; // Set HIGH where step bit = 0 GPIO.out_w1ts = (~step_bits) & ATTEN_PIN_MASK; // Set HIGH where step bit = 0
} }
void Attenuator::saveToNVS() { bool Attenuator::saveToNVS() {
_prefs.begin(NVS_NAMESPACE, false); // false = read-write if (!_prefs.begin(NVS_NAMESPACE, false)) {
_prefs.putUChar(NVS_KEY_STEP, _step); Serial0.println("[Attenuator] NVS open failed");
return false;
}
size_t written = _prefs.putUChar(NVS_KEY_STEP, _step);
_prefs.end(); _prefs.end();
if (written == 0) {
Serial0.println("[Attenuator] NVS write failed");
return false;
}
return true;
} }
uint8_t Attenuator::loadFromNVS() { uint8_t Attenuator::loadFromNVS() {

View File

@ -5,12 +5,22 @@
#include <freertos/semphr.h> #include <freertos/semphr.h>
#include "config.h" #include "config.h"
// Consistent snapshot of attenuator state (read atomically under mutex)
struct AttenuatorState {
uint8_t step;
float db;
uint8_t bits[6];
};
/** /**
* HMC472A Attenuator Controller * HMC472A Attenuator Controller
* *
* Controls 6 GPIO pins connected to V1V6 of the HMC472A. * Controls 6 GPIO pins connected to V1-V6 of the HMC472A.
* Active-low logic: bit=1 in step GPIO LOW attenuation engaged. * Active-low logic: bit=1 in step -> GPIO LOW -> attenuation engaged.
* Uses ESP32 register writes for glitch-free multi-bit transitions. * Uses ESP32 register writes for glitch-free multi-bit transitions.
*
* Thread safety: all public methods that access _step are mutex-protected.
* Safe to call from any FreeRTOS task (Arduino loop, async_tcp, etc).
*/ */
class Attenuator { class Attenuator {
public: public:
@ -19,41 +29,56 @@ public:
/** Initialize GPIOs and restore last setting from NVS */ /** Initialize GPIOs and restore last setting from NVS */
void begin(); void begin();
/** Set attenuation in dB (031.5, 0.5 steps). Returns actual value after clamping/rounding. */ /** Set attenuation in dB (0-31.5, 0.5 steps). Returns actual value after clamping/rounding. */
float setDB(float db); float setDB(float db);
/** Set attenuation as step value (063). Returns actual step after clamping. /** Set attenuation as step value (0-63). Returns actual step after clamping.
* Set persist=false to skip NVS write (use during sweep to avoid flash wear). */ * Set persist=false to skip NVS write (use during sweep to avoid flash wear). */
uint8_t setStep(uint8_t step, bool persist = true); uint8_t setStep(uint8_t step, bool persist = true);
/** Set attenuation from 6-bit array [V1, V2, V3, V4, V5, V6]. Returns step value. */ /** Set attenuation from 6-bit array [V1, V2, V3, V4, V5, V6]. Returns step value. */
uint8_t setBits(const uint8_t bits[6]); uint8_t setBits(const uint8_t bits[6]);
/** Get current attenuation in dB */ /** Get consistent snapshot of all state (step, dB, bits) under a single mutex lock.
float getDB() const; * Use this instead of separate getStep()/getDB()/getBits() calls when building
* multi-field responses to avoid inconsistency across dual cores. */
AttenuatorState getSnapshot() const;
/** Get current step value (063) */ /** Get current step value (0-63). Thread-safe single read. */
uint8_t getStep() const; uint8_t getStep() const;
/** Get single bit state (0 or 1) for pin index 05 */ /** Get current attenuation in dB. Thread-safe single read. */
float getDB() const;
/** Get single bit state (0 or 1) for pin index 0-5 */
uint8_t getBit(uint8_t index) const; uint8_t getBit(uint8_t index) const;
/** Get all 6 bits as array */ /** Get all 6 bits as array */
void getBits(uint8_t bits[6]) const; void getBits(uint8_t bits[6]) const;
/** Get the actual GPIO state (HIGH or LOW) for pin index 05 */ /** Get the actual GPIO state (HIGH or LOW) for pin index 0-5 */
bool getGPIOState(uint8_t index) const; bool getGPIOState(uint8_t index) const;
/** Atomic read-modify-write: advance step by delta, wrap at boundaries.
* Used by sweep engine to avoid TOCTOU race between getStep() and setStep().
* Returns new step value. */
uint8_t advanceStep(int8_t delta);
/** Persist current step to NVS (read + write under single mutex lock).
* Used by stopSweep() to avoid TOCTOU race. */
void persistCurrent();
private: private:
uint8_t _step; // Current step 063 uint8_t _step; // Current step 0-63
Preferences _prefs; // NVS handle Preferences _prefs; // NVS handle
SemaphoreHandle_t _mutex; // Protects _step, GPIO, and NVS mutable StaticSemaphore_t _mutexBuf; // Static storage (guaranteed allocation)
SemaphoreHandle_t _mutex; // Handle to static mutex
/** Apply current _step to GPIO pins using register writes */ /** Apply current _step to GPIO pins using register writes */
void applyToGPIO(); void applyToGPIO();
/** Save current _step to NVS */ /** Save current _step to NVS. Returns true on success. */
void saveToNVS(); bool saveToNVS();
/** Load _step from NVS (returns default if not found) */ /** Load _step from NVS (returns default if not found) */
uint8_t loadFromNVS(); uint8_t loadFromNVS();

View File

@ -6,11 +6,13 @@
#include "soc/soc.h" #include "soc/soc.h"
#include "soc/rtc_cntl_reg.h" #include "soc/rtc_cntl_reg.h"
#include <atomic>
#include "config.h" #include "config.h"
#include "attenuator.h" #include "attenuator.h"
#include "web_server.h" #include "web_server.h"
#include "display.h" #include "display.h"
#include "sweep.h" #include "app.h"
#include "usb_serial.h" #include "usb_serial.h"
// --- Global instances --- // --- Global instances ---
@ -55,7 +57,7 @@ void updateLED() {
interval = 100; interval = 100;
break; break;
default: default:
return; // Off or Solid nothing to toggle return; // Off or Solid -- nothing to toggle
} }
if (now - lastLedToggle >= interval) { if (now - lastLedToggle >= interval) {
@ -66,63 +68,70 @@ void updateLED() {
} }
// --- Sweep Mode --- // --- Sweep Mode ---
static bool sweepRunning = false; // std::atomic for cross-core visibility (async_tcp task vs Arduino loop task)
static int8_t sweepDirection = 1; // 1 = up, -1 = down static std::atomic<bool> sweepRunning{false};
static uint32_t sweepDwellMs = SWEEP_DWELL_MS_DEFAULT; static std::atomic<int8_t> sweepDirection{1}; // 1 = up, -1 = down
static uint32_t lastSweepStep = 0; static std::atomic<uint32_t> sweepDwellMs{SWEEP_DWELL_MS_DEFAULT};
static uint32_t lastSweepStep = 0; // Only accessed from loop() task
void startSweep(bool up, uint32_t dwellMs) { void startSweep(bool up, uint32_t dwellMs) {
sweepDirection = up ? 1 : -1; sweepDirection.store(up ? 1 : -1, std::memory_order_release);
sweepDwellMs = constrain(dwellMs, SWEEP_DWELL_MS_MIN, SWEEP_DWELL_MS_MAX); sweepDwellMs.store(constrain(dwellMs, SWEEP_DWELL_MS_MIN, SWEEP_DWELL_MS_MAX),
sweepRunning = true; std::memory_order_release);
lastSweepStep = millis(); lastSweepStep = millis();
sweepRunning.store(true, std::memory_order_release);
setLEDState(LEDState::FastBlink); setLEDState(LEDState::FastBlink);
Serial0.printf("[Sweep] Started, direction=%s, dwell=%u ms\n", Serial0.printf("[Sweep] Started, direction=%s, dwell=%u ms\n",
up ? "up" : "down", sweepDwellMs); up ? "up" : "down", sweepDwellMs.load());
} }
void stopSweep() { void stopSweep() {
sweepRunning = false; sweepRunning.store(false, std::memory_order_release);
// Persist final position to NVS (skipped during sweep to avoid flash wear) // Persist final position to NVS (skipped during sweep to avoid flash wear)
attenuator.setStep(attenuator.getStep(), true); attenuator.persistCurrent();
setLEDState(LEDState::Solid); setLEDState(LEDState::Solid);
Serial0.println("[Sweep] Stopped"); Serial0.println("[Sweep] Stopped");
} }
bool isSweeping() { bool isSweeping() {
return sweepRunning; return sweepRunning.load(std::memory_order_acquire);
} }
int8_t getSweepDirection() { int8_t getSweepDirection() {
return sweepDirection; return sweepDirection.load(std::memory_order_acquire);
} }
uint32_t getSweepDwellMs() { uint32_t getSweepDwellMs() {
return sweepDwellMs; return sweepDwellMs.load(std::memory_order_acquire);
} }
void updateSweep() { void updateSweep() {
if (!sweepRunning) return; if (!sweepRunning.load(std::memory_order_acquire)) return;
uint32_t now = millis(); uint32_t now = millis();
if (now - lastSweepStep < sweepDwellMs) return; uint32_t dwell = sweepDwellMs.load(std::memory_order_acquire);
if (now - lastSweepStep < dwell) return;
lastSweepStep = now; lastSweepStep = now;
int newStep = attenuator.getStep() + sweepDirection; // Atomic read-modify-write: no TOCTOU race
attenuator.advanceStep(sweepDirection.load(std::memory_order_acquire));
// Wrap around at boundaries
if (newStep > STEP_MAX) {
newStep = 0;
} else if (newStep < 0) {
newStep = STEP_MAX;
}
attenuator.setStep(newStep, false); // No NVS write during sweep (flash wear)
} }
// --- WiFi TX Power Control --- // --- WiFi TX Power Control ---
static wifi_power_t wifiTxPower = WIFI_TX_POWER_DBM; static wifi_power_t wifiTxPower = WIFI_TX_POWER_DBM;
// Valid wifi_power_t levels (quarter-dBm units)
static const int VALID_WIFI_POWERS[] = {8, 20, 28, 34, 44, 52, 60, 68, 74, 76, 78};
static const float VALID_WIFI_DBMS[] = {2.0, 5.0, 7.0, 8.5, 11.0, 13.0, 15.0, 17.0, 18.5, 19.0, 19.5};
static const int NUM_WIFI_POWER_LEVELS = 11;
bool isValidWifiPower(int raw) {
for (int i = 0; i < NUM_WIFI_POWER_LEVELS; i++) {
if (VALID_WIFI_POWERS[i] == raw) return true;
}
return false;
}
void setWiFiTxPower(wifi_power_t power) { void setWiFiTxPower(wifi_power_t power) {
wifiTxPower = power; wifiTxPower = power;
WiFi.setTxPower(power); WiFi.setTxPower(power);
@ -136,9 +145,13 @@ wifi_power_t getWiFiTxPower() {
// Convert wifi_power_t enum to approximate dBm float // Convert wifi_power_t enum to approximate dBm float
float wifiPowerToDbm(wifi_power_t power) { float wifiPowerToDbm(wifi_power_t power) {
// wifi_power_t values are in quarter-dBm units // wifi_power_t values are in quarter-dBm units
return power / 4.0f; return static_cast<int>(power) / 4.0f;
} }
const int* getValidWifiPowers() { return VALID_WIFI_POWERS; }
const float* getValidWifiDbms() { return VALID_WIFI_DBMS; }
int getNumWifiPowerLevels() { return NUM_WIFI_POWER_LEVELS; }
// --- WiFi Connection --- // --- WiFi Connection ---
bool connectWiFi() { bool connectWiFi() {
Serial0.printf("[WiFi] Connecting to %s...\n", WIFI_SSID); Serial0.printf("[WiFi] Connecting to %s...\n", WIFI_SSID);
@ -186,6 +199,12 @@ void enableOTA() {
if (otaEnabled) return; if (otaEnabled) return;
ArduinoOTA.setHostname(FW_HOSTNAME); ArduinoOTA.setHostname(FW_HOSTNAME);
#ifdef OTA_PASSWORD
ArduinoOTA.setPassword(OTA_PASSWORD);
Serial0.println("[OTA] Password authentication enabled");
#else
Serial0.println("[OTA] WARNING: No password set (define OTA_PASSWORD in platformio_local.ini)");
#endif
ArduinoOTA.onStart([]() { ArduinoOTA.onStart([]() {
stopSweep(); stopSweep();
setLEDState(LEDState::FastBlink); setLEDState(LEDState::FastBlink);
@ -219,7 +238,9 @@ bool isOTAEnabled() {
// --- Setup --- // --- Setup ---
void setup() { void setup() {
// Disable brownout detector - USB power can sag during WiFi TX // Disable brownout detector - USB power can sag during WiFi TX
// This is safe as long as we're not running from batteries // Trade-off: if supply drops below 3.0V, MCU may execute with corrupted SRAM
// rather than cleanly resetting. Acceptable for bench-powered USB device,
// NOT acceptable for battery or unstable power source.
WRITE_PERI_REG(RTC_CNTL_BROWN_OUT_REG, 0); WRITE_PERI_REG(RTC_CNTL_BROWN_OUT_REG, 0);
// UART0 on ESP32-S3-DevKitC-1 (CH343 bridge): TX=GPIO43, RX=GPIO44 // UART0 on ESP32-S3-DevKitC-1 (CH343 bridge): TX=GPIO43, RX=GPIO44
@ -293,9 +314,10 @@ void loop() {
uint32_t now = millis(); uint32_t now = millis();
if (now - lastDisplayUpdate >= DISPLAY_UPDATE_MS) { if (now - lastDisplayUpdate >= DISPLAY_UPDATE_MS) {
lastDisplayUpdate = now; lastDisplayUpdate = now;
AttenuatorState state = attenuator.getSnapshot();
updateDisplay( updateDisplay(
attenuator.getDB(), state.db,
attenuator.getStep(), state.step,
WiFi.RSSI(), WiFi.RSSI(),
isSweeping(), isSweeping(),
WiFi.status() == WL_CONNECTED WiFi.status() == WL_CONNECTED

View File

@ -4,12 +4,13 @@
#include <ArduinoJson.h> #include <ArduinoJson.h>
#include "config.h" #include "config.h"
#include "sweep.h" #include "app.h"
static Attenuator* pAtten = nullptr; static Attenuator* pAtten = nullptr;
static char rxBuf[USB_SERIAL_BUF_LEN]; static char rxBuf[USB_SERIAL_BUF_LEN];
static uint16_t rxLen = 0; static uint16_t rxLen = 0;
static bool overflow = false; static bool overflow = false;
static bool wasConnected = false; // Track USB CDC connection state
// --- Response helpers --- // --- Response helpers ---
@ -32,12 +33,14 @@ static void sendError(const char* msg) {
// --- Build common status payload --- // --- Build common status payload ---
static void buildStatus(JsonDocument& doc) { static void buildStatus(JsonDocument& doc) {
doc["attenuation_db"] = pAtten->getDB(); // Use getSnapshot() for consistent multi-field read across cores
doc["step"] = pAtten->getStep(); AttenuatorState state = pAtten->getSnapshot();
doc["attenuation_db"] = state.db;
doc["step"] = state.step;
JsonArray bits = doc["bits"].to<JsonArray>(); JsonArray bits = doc["bits"].to<JsonArray>();
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
bits.add(pAtten->getBit(i)); bits.add(state.bits[i]);
} }
doc["uptime_s"] = millis() / 1000; doc["uptime_s"] = millis() / 1000;
@ -87,6 +90,11 @@ static void cmdConfig() {
} }
static void cmdSet(JsonDocument& req) { static void cmdSet(JsonDocument& req) {
// Auto-stop sweep on manual set (prevents silent overwrite)
if (isSweeping()) {
stopSweep();
}
if (req["db"].is<float>()) { if (req["db"].is<float>()) {
float db = req["db"].as<float>(); float db = req["db"].as<float>();
if (isnan(db) || isinf(db)) { if (isnan(db) || isinf(db)) {
@ -224,6 +232,15 @@ void setupUSBSerial(Attenuator& atten) {
void handleUSBSerial() { void handleUSBSerial() {
if (!pAtten) return; if (!pAtten) return;
// Reset buffer on USB reconnect (prevents stale overflow state
// from a disconnect mid-line corrupting the first command)
bool connected = Serial;
if (connected && !wasConnected) {
rxLen = 0;
overflow = false;
}
wasConnected = connected;
while (Serial.available()) { while (Serial.available()) {
char c = Serial.read(); char c = Serial.read();

View File

@ -7,36 +7,96 @@
#include <LittleFS.h> #include <LittleFS.h>
#include "config.h" #include "config.h"
#include "sweep.h" #include "app.h"
static AsyncWebServer server(WEB_PORT); static AsyncWebServer server(WEB_PORT);
static Attenuator* pAtten = nullptr; static Attenuator* pAtten = nullptr;
// Pin name lookup (shared across handlers)
static const char* PIN_NAMES[6] = {"V1", "V2", "V3", "V4", "V5", "V6"};
// Max body size for POST requests (reject oversized/chunked bodies)
static const size_t MAX_BODY_SIZE = 512;
// --- Body accumulation buffer ---
// ESPAsyncWebServer may deliver POST bodies in chunks. We accumulate
// until index+len==total, then parse the complete body.
struct BodyBuffer {
uint8_t data[MAX_BODY_SIZE];
size_t received;
bool overflow;
};
// Per-request body buffers (one per handler type, safe because
// ESPAsyncWebServer serializes body callbacks for each request)
static BodyBuffer setBody;
static BodyBuffer sweepBody;
static BodyBuffer wifiPowerBody;
static void resetBody(BodyBuffer& buf) {
buf.received = 0;
buf.overflow = false;
}
static void accumulateBody(BodyBuffer& buf, uint8_t* data, size_t len, size_t index, size_t total) {
if (total > MAX_BODY_SIZE) {
buf.overflow = true;
return;
}
if (index + len <= MAX_BODY_SIZE) {
memcpy(buf.data + index, data, len);
buf.received = index + len;
} else {
buf.overflow = true;
}
}
static bool bodyComplete(const BodyBuffer& buf, size_t index, size_t len, size_t total) {
return (index + len >= total);
}
// --- CORS Headers --- // --- CORS Headers ---
static void addCorsHeaders(AsyncWebServerResponse* response) { static void addCorsHeaders(AsyncWebServerResponse* response) {
// Restrict to same-origin by default; override with device IP for web UI
response->addHeader("Access-Control-Allow-Origin", "*"); response->addHeader("Access-Control-Allow-Origin", "*");
response->addHeader("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); response->addHeader("Access-Control-Allow-Methods", "GET, POST, OPTIONS");
response->addHeader("Access-Control-Allow-Headers", "Content-Type"); response->addHeader("Access-Control-Allow-Headers", "Content-Type");
} }
static void sendJsonResponse(AsyncWebServerRequest* request, int code, const String& json) {
AsyncWebServerResponse* response = request->beginResponse(code, "application/json", json);
addCorsHeaders(response);
request->send(response);
}
static void sendJsonError(AsyncWebServerRequest* request, int code, const char* msg) {
JsonDocument doc;
doc["error"] = msg;
String out;
serializeJson(doc, out);
sendJsonResponse(request, code, out);
}
// --- GET /status --- // --- GET /status ---
static void handleStatus(AsyncWebServerRequest* request) { static void handleStatus(AsyncWebServerRequest* request) {
JsonDocument doc; // Use getSnapshot() for consistent multi-field read across cores
AttenuatorState state = pAtten->getSnapshot();
doc["attenuation_db"] = pAtten->getDB(); JsonDocument doc;
doc["step"] = pAtten->getStep(); doc["attenuation_db"] = state.db;
doc["step"] = state.step;
JsonArray bits = doc["bits"].to<JsonArray>(); JsonArray bits = doc["bits"].to<JsonArray>();
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
bits.add(pAtten->getBit(i)); bits.add(state.bits[i]);
} }
JsonObject pins = doc["pins"].to<JsonObject>(); JsonObject pins = doc["pins"].to<JsonObject>();
const char* pinNames[] = {"V1", "V2", "V3", "V4", "V5", "V6"};
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
JsonObject pin = pins[pinNames[i]].to<JsonObject>(); JsonObject pin = pins[PIN_NAMES[i]].to<JsonObject>();
pin["gpio"] = ATTEN_PINS[i]; pin["gpio"] = ATTEN_PINS[i];
pin["state"] = pAtten->getGPIOState(i) ? "HIGH" : "LOW"; // Active-low: bit=1 -> LOW
pin["state"] = state.bits[i] ? "LOW" : "HIGH";
pin["db"] = ATTEN_DB[i]; pin["db"] = ATTEN_DB[i];
} }
@ -52,44 +112,69 @@ static void handleStatus(AsyncWebServerRequest* request) {
String output; String output;
serializeJson(doc, output); serializeJson(doc, output);
sendJsonResponse(request, 200, output);
AsyncWebServerResponse* response = request->beginResponse(200, "application/json", output);
addCorsHeaders(response);
request->send(response);
} }
// --- POST /set --- // --- POST /set (body handler) ---
static void handleSet(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { static void handleSetBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) {
JsonDocument doc; if (index == 0) resetBody(setBody);
DeserializationError error = deserializeJson(doc, data, len); accumulateBody(setBody, data, len, index, total);
if (error) { if (!bodyComplete(setBody, index, len, total)) return;
AsyncWebServerResponse* response = request->beginResponse(400, "application/json",
"{\"error\":\"Invalid JSON\"}"); if (setBody.overflow) {
addCorsHeaders(response); sendJsonError(request, 413, "Request body too large");
request->send(response);
return; return;
} }
// Accept any of: attenuation_db, step, or bits JsonDocument doc;
DeserializationError error = deserializeJson(doc, setBody.data, setBody.received);
if (error) {
sendJsonError(request, 400, "Invalid JSON");
return;
}
// Auto-stop sweep on manual set (M-1: prevents silent overwrite)
if (isSweeping()) {
stopSweep();
}
if (doc["attenuation_db"].is<float>()) { if (doc["attenuation_db"].is<float>()) {
pAtten->setDB(doc["attenuation_db"].as<float>()); float db = doc["attenuation_db"].as<float>();
if (isnan(db) || isinf(db)) {
sendJsonError(request, 400, "attenuation_db must be a finite number");
return;
}
pAtten->setDB(db);
} else if (doc["step"].is<int>()) { } else if (doc["step"].is<int>()) {
pAtten->setStep(doc["step"].as<int>()); int stepVal = doc["step"].as<int>();
if (stepVal < STEP_MIN || stepVal > STEP_MAX) {
sendJsonError(request, 400, "step must be 0-63");
return;
}
pAtten->setStep(static_cast<uint8_t>(stepVal));
} else if (doc["bits"].is<JsonArray>()) { } else if (doc["bits"].is<JsonArray>()) {
JsonArray arr = doc["bits"].as<JsonArray>(); JsonArray arr = doc["bits"].as<JsonArray>();
if (arr.size() == 6) { if (arr.size() != 6) {
uint8_t bits[6]; sendJsonError(request, 400, "bits array must have exactly 6 elements");
for (int i = 0; i < 6; i++) { return;
bits[i] = arr[i].as<uint8_t>();
}
pAtten->setBits(bits);
} }
uint8_t bits[6];
for (int i = 0; i < 6; i++) {
if (!arr[i].is<int>()) {
sendJsonError(request, 400, "bits elements must be integers");
return;
}
int val = arr[i].as<int>();
if (val != 0 && val != 1) {
sendJsonError(request, 400, "bits elements must be 0 or 1");
return;
}
bits[i] = val;
}
pAtten->setBits(bits);
} else { } else {
AsyncWebServerResponse* response = request->beginResponse(400, "application/json", sendJsonError(request, 400, "Must provide attenuation_db, step, or bits");
"{\"error\":\"Must provide attenuation_db, step, or bits\"}");
addCorsHeaders(response);
request->send(response);
return; return;
} }
@ -110,9 +195,8 @@ static void handleConfig(AsyncWebServerRequest* request) {
doc["step_max"] = STEP_MAX; doc["step_max"] = STEP_MAX;
JsonObject gpio = doc["gpio"].to<JsonObject>(); JsonObject gpio = doc["gpio"].to<JsonObject>();
const char* pinNames[] = {"V1", "V2", "V3", "V4", "V5", "V6"};
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
gpio[pinNames[i]] = ATTEN_PINS[i]; gpio[PIN_NAMES[i]] = ATTEN_PINS[i];
} }
doc["ip"] = WiFi.localIP().toString(); doc["ip"] = WiFi.localIP().toString();
@ -126,73 +210,69 @@ static void handleConfig(AsyncWebServerRequest* request) {
String output; String output;
serializeJson(doc, output); serializeJson(doc, output);
sendJsonResponse(request, 200, output);
AsyncWebServerResponse* response = request->beginResponse(200, "application/json", output);
addCorsHeaders(response);
request->send(response);
} }
// --- POST /sweep --- // --- POST /sweep (body handler) ---
static void handleSweepStart(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { static void handleSweepStartBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) {
JsonDocument doc; if (index == 0) resetBody(sweepBody);
accumulateBody(sweepBody, data, len, index, total);
if (!bodyComplete(sweepBody, index, len, total)) return;
if (sweepBody.overflow) {
sendJsonError(request, 413, "Request body too large");
return;
}
bool up = true; bool up = true;
uint32_t dwellMs = SWEEP_DWELL_MS_DEFAULT; uint32_t dwellMs = SWEEP_DWELL_MS_DEFAULT;
if (len > 0) { if (sweepBody.received > 0) {
DeserializationError error = deserializeJson(doc, data, len); JsonDocument doc;
DeserializationError error = deserializeJson(doc, sweepBody.data, sweepBody.received);
if (!error) { if (!error) {
if (doc["direction"].is<const char*>()) { if (doc["direction"].is<const char*>()) {
up = strcmp(doc["direction"].as<const char*>(), "down") != 0; up = strcmp(doc["direction"].as<const char*>(), "down") != 0;
} }
if (doc["dwell_ms"].is<int>()) { if (doc["dwell_ms"].is<int>()) {
dwellMs = doc["dwell_ms"].as<int>(); int raw = doc["dwell_ms"].as<int>();
if (raw > 0) dwellMs = static_cast<uint32_t>(raw);
} }
} }
} }
startSweep(up, dwellMs); startSweep(up, dwellMs);
AsyncWebServerResponse* response = request->beginResponse(200, "application/json", sendJsonResponse(request, 200, "{\"status\":\"sweep started\"}");
"{\"status\":\"sweep started\"}");
addCorsHeaders(response);
request->send(response);
} }
// --- GET /sweep --- // --- GET /sweep ---
static void handleSweepStatus(AsyncWebServerRequest* request) { static void handleSweepStatus(AsyncWebServerRequest* request) {
AttenuatorState state = pAtten->getSnapshot();
JsonDocument doc; JsonDocument doc;
doc["running"] = isSweeping(); doc["running"] = isSweeping();
doc["direction"] = getSweepDirection() > 0 ? "up" : "down"; doc["direction"] = getSweepDirection() > 0 ? "up" : "down";
doc["dwell_ms"] = getSweepDwellMs(); doc["dwell_ms"] = getSweepDwellMs();
doc["current_step"] = pAtten->getStep(); doc["current_step"] = state.step;
doc["current_db"] = pAtten->getDB(); doc["current_db"] = state.db;
String output; String output;
serializeJson(doc, output); serializeJson(doc, output);
sendJsonResponse(request, 200, output);
AsyncWebServerResponse* response = request->beginResponse(200, "application/json", output);
addCorsHeaders(response);
request->send(response);
} }
// --- POST /sweep/stop --- // --- POST /sweep/stop ---
static void handleSweepStop(AsyncWebServerRequest* request) { static void handleSweepStop(AsyncWebServerRequest* request) {
stopSweep(); stopSweep();
sendJsonResponse(request, 200, "{\"status\":\"sweep stopped\"}");
AsyncWebServerResponse* response = request->beginResponse(200, "application/json",
"{\"status\":\"sweep stopped\"}");
addCorsHeaders(response);
request->send(response);
} }
// --- POST /ota --- // --- POST /ota ---
static void handleOTAEnable(AsyncWebServerRequest* request) { static void handleOTAEnable(AsyncWebServerRequest* request) {
enableOTA(); enableOTA();
sendJsonResponse(request, 200, "{\"status\":\"OTA enabled\"}");
AsyncWebServerResponse* response = request->beginResponse(200, "application/json",
"{\"status\":\"OTA enabled\"}");
addCorsHeaders(response);
request->send(response);
} }
// --- GET /wifi/power --- // --- GET /wifi/power ---
@ -203,11 +283,12 @@ static void handleWiFiPowerGet(AsyncWebServerRequest* request) {
doc["tx_power_dbm"] = wifiPowerToDbm(power); doc["tx_power_dbm"] = wifiPowerToDbm(power);
doc["rssi"] = WiFi.RSSI(); doc["rssi"] = WiFi.RSSI();
// Available power levels for reference const int* powers = getValidWifiPowers();
const float* dbms = getValidWifiDbms();
int numLevels = getNumWifiPowerLevels();
JsonArray levels = doc["available_levels"].to<JsonArray>(); JsonArray levels = doc["available_levels"].to<JsonArray>();
const int powers[] = {8, 20, 28, 34, 44, 52, 60, 68, 74, 76, 78}; // quarter-dBm values for (int i = 0; i < numLevels; i++) {
const float dbms[] = {2.0, 5.0, 7.0, 8.5, 11.0, 13.0, 15.0, 17.0, 18.5, 19.0, 19.5};
for (int i = 0; i < 11; i++) {
JsonObject level = levels.add<JsonObject>(); JsonObject level = levels.add<JsonObject>();
level["raw"] = powers[i]; level["raw"] = powers[i];
level["dbm"] = dbms[i]; level["dbm"] = dbms[i];
@ -215,38 +296,49 @@ static void handleWiFiPowerGet(AsyncWebServerRequest* request) {
String output; String output;
serializeJson(doc, output); serializeJson(doc, output);
sendJsonResponse(request, 200, output);
AsyncWebServerResponse* response = request->beginResponse(200, "application/json", output);
addCorsHeaders(response);
request->send(response);
} }
// --- POST /wifi/power --- // --- POST /wifi/power (body handler) ---
static void handleWiFiPowerSet(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { static void handleWiFiPowerSetBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) {
JsonDocument doc; if (index == 0) resetBody(wifiPowerBody);
DeserializationError error = deserializeJson(doc, data, len); accumulateBody(wifiPowerBody, data, len, index, total);
if (error) { if (!bodyComplete(wifiPowerBody, index, len, total)) return;
AsyncWebServerResponse* response = request->beginResponse(400, "application/json",
"{\"error\":\"Invalid JSON\"}"); if (wifiPowerBody.overflow) {
addCorsHeaders(response); sendJsonError(request, 413, "Request body too large");
request->send(response); return;
}
JsonDocument doc;
DeserializationError error = deserializeJson(doc, wifiPowerBody.data, wifiPowerBody.received);
if (error) {
sendJsonError(request, 400, "Invalid JSON");
return; return;
} }
// Accept either raw value or dBm (will round to nearest)
wifi_power_t newPower; wifi_power_t newPower;
if (doc["tx_power_raw"].is<int>()) { if (doc["tx_power_raw"].is<int>()) {
newPower = (wifi_power_t)doc["tx_power_raw"].as<int>(); int raw = doc["tx_power_raw"].as<int>();
if (!isValidWifiPower(raw)) {
sendJsonError(request, 400, "Invalid tx_power_raw value (use GET /wifi/power for valid levels)");
return;
}
newPower = (wifi_power_t)raw;
} else if (doc["tx_power_dbm"].is<float>()) { } else if (doc["tx_power_dbm"].is<float>()) {
// Convert dBm to quarter-dBm raw value (round to nearest valid level)
float targetDbm = doc["tx_power_dbm"].as<float>(); float targetDbm = doc["tx_power_dbm"].as<float>();
const int powers[] = {8, 20, 28, 34, 44, 52, 60, 68, 74, 76, 78}; if (isnan(targetDbm) || isinf(targetDbm)) {
const float dbms[] = {2.0, 5.0, 7.0, 8.5, 11.0, 13.0, 15.0, 17.0, 18.5, 19.0, 19.5}; sendJsonError(request, 400, "tx_power_dbm must be a finite number");
return;
}
const int* powers = getValidWifiPowers();
const float* dbms = getValidWifiDbms();
int numLevels = getNumWifiPowerLevels();
int closest = 0; int closest = 0;
float minDiff = abs(targetDbm - dbms[0]); float minDiff = fabs(targetDbm - dbms[0]);
for (int i = 1; i < 11; i++) { for (int i = 1; i < numLevels; i++) {
float diff = abs(targetDbm - dbms[i]); float diff = fabs(targetDbm - dbms[i]);
if (diff < minDiff) { if (diff < minDiff) {
minDiff = diff; minDiff = diff;
closest = i; closest = i;
@ -254,10 +346,7 @@ static void handleWiFiPowerSet(AsyncWebServerRequest* request, uint8_t* data, si
} }
newPower = (wifi_power_t)powers[closest]; newPower = (wifi_power_t)powers[closest];
} else { } else {
AsyncWebServerResponse* response = request->beginResponse(400, "application/json", sendJsonError(request, 400, "Must provide tx_power_raw or tx_power_dbm");
"{\"error\":\"Must provide tx_power_raw or tx_power_dbm\"}");
addCorsHeaders(response);
request->send(response);
return; return;
} }
@ -293,23 +382,20 @@ void setupWebServer(Attenuator& atten) {
server.on("/ota", HTTP_POST, handleOTAEnable); server.on("/ota", HTTP_POST, handleOTAEnable);
server.on("/wifi/power", HTTP_GET, handleWiFiPowerGet); server.on("/wifi/power", HTTP_GET, handleWiFiPowerGet);
// Routes with body parsing // Routes with body parsing (accumulate chunks before processing)
server.on("/set", HTTP_POST, [](AsyncWebServerRequest* request) {}, server.on("/set", HTTP_POST, [](AsyncWebServerRequest* request) {},
NULL, handleSet); NULL, handleSetBody);
server.on("/sweep", HTTP_POST, [](AsyncWebServerRequest* request) {}, server.on("/sweep", HTTP_POST, [](AsyncWebServerRequest* request) {},
NULL, handleSweepStart); NULL, handleSweepStartBody);
server.on("/wifi/power", HTTP_POST, [](AsyncWebServerRequest* request) {}, server.on("/wifi/power", HTTP_POST, [](AsyncWebServerRequest* request) {},
NULL, handleWiFiPowerSet); NULL, handleWiFiPowerSetBody);
// Static files from LittleFS (index.html, style.css, app.js, favicon.svg) // Static files from LittleFS (index.html, style.css, app.js, favicon.svg)
server.serveStatic("/", LittleFS, "/").setDefaultFile("index.html"); server.serveStatic("/", LittleFS, "/").setDefaultFile("index.html");
// 404 handler // 404 handler
server.onNotFound([](AsyncWebServerRequest* request) { server.onNotFound([](AsyncWebServerRequest* request) {
AsyncWebServerResponse* response = request->beginResponse(404, "application/json", sendJsonError(request, 404, "Not found");
"{\"error\":\"Not found\"}");
addCorsHeaders(response);
request->send(response);
}); });
server.begin(); server.begin();