hmc472/firmware/src/usb_serial.cpp
Ryan Malloy 4e19882d32 Harden firmware for dual-core concurrency and input validation
Address safety review findings for the dual-interface (WiFi + USB serial)
architecture running on the ESP32-S3's two Xtensa LX7 cores:

- Protect sweep state with std::atomic (acquire/release ordering)
- Add Attenuator::getSnapshot() for consistent multi-field reads
- Add advanceStep()/persistCurrent() to eliminate TOCTOU races
- Switch to StaticSemaphore_t (compile-time mutex, can't fail)
- Accumulate web server POST bodies before parsing (chunked TCP fix)
- Backport USB serial input validation to web server handlers
- Auto-stop sweep on manual set (prevents silent overwrite)
- Validate WiFi TX power against known-good levels
- Add OTA password authentication support
- Check NVS write return values, log failures
- Reset USB serial buffer on reconnect (stale overflow fix)
- Rename sweep.h to app.h (declares more than sweep functions)
2026-02-18 18:43:08 -07:00

272 lines
7.1 KiB
C++

#include "usb_serial.h"
#include <Arduino.h>
#include <ArduinoJson.h>
#include "config.h"
#include "app.h"
static Attenuator* pAtten = nullptr;
static char rxBuf[USB_SERIAL_BUF_LEN];
static uint16_t rxLen = 0;
static bool overflow = false;
static bool wasConnected = false; // Track USB CDC connection state
// --- Response helpers ---
static void sendOk(JsonDocument& doc) {
doc["ok"] = true;
String out;
serializeJson(doc, out);
Serial.println(out);
}
static void sendError(const char* msg) {
JsonDocument doc;
doc["ok"] = false;
doc["error"] = msg;
String out;
serializeJson(doc, out);
Serial.println(out);
}
// --- Build common status payload ---
static void buildStatus(JsonDocument& doc) {
// Use getSnapshot() for consistent multi-field read across cores
AttenuatorState state = pAtten->getSnapshot();
doc["attenuation_db"] = state.db;
doc["step"] = state.step;
JsonArray bits = doc["bits"].to<JsonArray>();
for (int i = 0; i < 6; i++) {
bits.add(state.bits[i]);
}
doc["uptime_s"] = millis() / 1000;
doc["version"] = FW_VERSION;
JsonObject sweep = doc["sweep"].to<JsonObject>();
sweep["running"] = isSweeping();
sweep["direction"] = getSweepDirection() > 0 ? "up" : "down";
sweep["dwell_ms"] = getSweepDwellMs();
}
// --- Command handlers ---
static void cmdIdentify() {
JsonDocument doc;
doc["device"] = "hmc472a-attenuator";
doc["protocol"] = "usb-serial-json-v1";
doc["version"] = FW_VERSION;
JsonArray cmds = doc["commands"].to<JsonArray>();
cmds.add("identify");
cmds.add("status");
cmds.add("config");
cmds.add("set");
cmds.add("sweep");
cmds.add("sweep_stop");
sendOk(doc);
}
static void cmdStatus() {
JsonDocument doc;
buildStatus(doc);
sendOk(doc);
}
static void cmdConfig() {
JsonDocument doc;
doc["version"] = FW_VERSION;
doc["hostname"] = FW_HOSTNAME;
doc["db_min"] = DB_MIN;
doc["db_max"] = DB_MAX;
doc["db_step"] = DB_STEP;
doc["step_min"] = STEP_MIN;
doc["step_max"] = STEP_MAX;
sendOk(doc);
}
static void cmdSet(JsonDocument& req) {
// Auto-stop sweep on manual set (prevents silent overwrite)
if (isSweeping()) {
stopSweep();
}
if (req["db"].is<float>()) {
float db = req["db"].as<float>();
if (isnan(db) || isinf(db)) {
sendError("db must be a finite number");
return;
}
pAtten->setDB(db);
} else if (req["step"].is<int>()) {
int stepVal = req["step"].as<int>();
if (stepVal < STEP_MIN || stepVal > STEP_MAX) {
sendError("step must be 0-63");
return;
}
pAtten->setStep(static_cast<uint8_t>(stepVal));
} else if (req["bits"].is<JsonArray>()) {
JsonArray arr = req["bits"].as<JsonArray>();
if (arr.size() != 6) {
sendError("bits array must have exactly 6 elements");
return;
}
uint8_t bits[6];
for (int i = 0; i < 6; i++) {
if (!arr[i].is<int>()) {
sendError("bits elements must be integers");
return;
}
int val = arr[i].as<int>();
if (val != 0 && val != 1) {
sendError("bits elements must be 0 or 1");
return;
}
bits[i] = val;
}
pAtten->setBits(bits);
} else {
sendError("set requires db, step, or bits field");
return;
}
// Return status after set
JsonDocument doc;
buildStatus(doc);
sendOk(doc);
}
static void cmdSweep(JsonDocument& req) {
bool up = true;
uint32_t dwellMs = SWEEP_DWELL_MS_DEFAULT;
if (req["direction"].is<const char*>()) {
up = strcmp(req["direction"].as<const char*>(), "down") != 0;
}
if (req["dwell_ms"].is<int>()) {
int raw = req["dwell_ms"].as<int>();
if (raw > 0) dwellMs = static_cast<uint32_t>(raw);
}
// Accept start/stop params for protocol compatibility, but current
// sweep engine only supports full-range direction+dwell
bool hasRangeParams = !req["start"].isNull() || !req["stop"].isNull();
if (hasRangeParams && req["direction"].isNull()) {
// Infer direction from start/stop
float start = req["start"] | 0.0f;
float stop = req["stop"] | 31.5f;
up = (stop > start);
}
startSweep(up, dwellMs);
JsonDocument doc;
JsonObject sweep = doc["sweep"].to<JsonObject>();
sweep["running"] = true;
sweep["direction"] = up ? "up" : "down";
sweep["dwell_ms"] = dwellMs;
if (hasRangeParams) {
doc["note"] = "start/stop range not yet supported; using full-range sweep with inferred direction";
}
sendOk(doc);
}
static void cmdSweepStop() {
stopSweep();
JsonDocument doc;
JsonObject sweep = doc["sweep"].to<JsonObject>();
sweep["running"] = false;
sendOk(doc);
}
// --- Line dispatch ---
static void processLine(const char* line, uint16_t len) {
JsonDocument req;
DeserializationError err = deserializeJson(req, line, len);
if (err) {
sendError("invalid JSON");
return;
}
if (!req["cmd"].is<const char*>()) {
sendError("missing cmd field");
return;
}
const char* cmd = req["cmd"].as<const char*>();
if (strcmp(cmd, "identify") == 0) {
cmdIdentify();
} else if (strcmp(cmd, "status") == 0) {
cmdStatus();
} else if (strcmp(cmd, "config") == 0) {
cmdConfig();
} else if (strcmp(cmd, "set") == 0) {
cmdSet(req);
} else if (strcmp(cmd, "sweep") == 0) {
cmdSweep(req);
} else if (strcmp(cmd, "sweep_stop") == 0) {
cmdSweepStop();
} else {
sendError("unknown command");
}
}
// --- Public API ---
void setupUSBSerial(Attenuator& atten) {
pAtten = &atten;
Serial.begin(USB_SERIAL_BAUD);
rxLen = 0;
overflow = false;
Serial0.println("[USBSerial] Initialized on native USB CDC");
}
void handleUSBSerial() {
if (!pAtten) return;
// Reset buffer on USB reconnect (prevents stale overflow state
// from a disconnect mid-line corrupting the first command)
bool connected = Serial;
if (connected && !wasConnected) {
rxLen = 0;
overflow = false;
}
wasConnected = connected;
while (Serial.available()) {
char c = Serial.read();
if (c == '\n' || c == '\r') {
if (overflow) {
// Line was too long — discard and report
sendError("line too long (max 255 bytes)");
overflow = false;
rxLen = 0;
return; // One dispatch per loop iteration
}
if (rxLen > 0) {
rxBuf[rxLen] = '\0';
processLine(rxBuf, rxLen);
rxLen = 0;
return; // One dispatch per loop iteration
}
// Empty line (bare \n or \r\n second byte) — ignore
continue;
}
if (rxLen < USB_SERIAL_BUF_LEN - 1) {
rxBuf[rxLen++] = c;
} else {
overflow = true;
}
}
}