diff --git a/backend/src/spicebook/engine/schematic.py b/backend/src/spicebook/engine/schematic.py index bcf0fb8..2f41ff6 100644 --- a/backend/src/spicebook/engine/schematic.py +++ b/backend/src/spicebook/engine/schematic.py @@ -56,6 +56,16 @@ class ActiveLayout: unplaced: list[SpiceComponent] +@dataclass +class PlacedComponent: + """A component placed in the grid with recorded terminal positions.""" + + comp: SpiceComponent + element: object # SchemDraw placed element + column: int + terminal_positions: dict[int, tuple[float, float]] = field(default_factory=dict) + + # ── Ground / Supply Detection ─────────────────────────────────── GROUND_NAMES = {"0", "gnd"} @@ -678,50 +688,396 @@ def _render_loop(parsed: ParsedNetlist, loop: list[SpiceComponent]) -> str: return d.get_imagedata("svg").decode() -# ── Grid Layout Renderer ────────────────────────────────────── +# ── Grid Layout: Constants & Helpers ───────────────────────── + +# Vertical tier positions (schemdraw y-axis: up is positive) +_GRID_SUPPLY_Y = 0.0 +_GRID_UPPER_Y = -3.0 +_GRID_MID_Y = -5.5 +_GRID_LOWER_Y = -8.0 +_GRID_GROUND_Y = -11.0 +_GRID_X_SPACING = 6.0 +_GRID_MIN_COMP_LEN = 2.0 # minimum component length to avoid zero-height placement + + +def _classify_node_tiers( + node_map: dict[str, list[tuple[SpiceComponent, int]]], +) -> dict[str, float]: + """Assign vertical Y positions to nodes via BFS distance from ground/supply. + + Ground and supply nodes are pinned to fixed tiers. Signal nodes get + Y-positions blended between supply (top) and ground (bottom) based on + their BFS hop distance from each rail, then snapped to the nearest + discrete tier for clean vertical alignment. + """ + from collections import deque + + tiers: dict[str, float] = {} + snap_levels = [_GRID_SUPPLY_Y, _GRID_UPPER_Y, _GRID_MID_Y, _GRID_LOWER_Y, _GRID_GROUND_Y] + + # Pin ground and supply nodes + for node in node_map: + if _is_ground(node): + tiers[node] = _GRID_GROUND_Y + elif _is_supply(node): + tiers[node] = _GRID_SUPPLY_Y + + # BFS from ground nodes + dist_from_gnd: dict[str, int] = {} + queue: deque[tuple[str, int]] = deque() + for node in node_map: + if _is_ground(node): + dist_from_gnd[node] = 0 + queue.append((node, 0)) + while queue: + current, dist = queue.popleft() + for comp, idx in node_map.get(current, []): + for i, other_node in enumerate(comp.nodes): + if i != idx and other_node not in dist_from_gnd: + dist_from_gnd[other_node] = dist + 1 + queue.append((other_node, dist + 1)) + + # BFS from supply nodes + dist_from_sup: dict[str, int] = {} + queue = deque() + for node in node_map: + if _is_supply(node): + dist_from_sup[node] = 0 + queue.append((node, 0)) + while queue: + current, dist = queue.popleft() + for comp, idx in node_map.get(current, []): + for i, other_node in enumerate(comp.nodes): + if i != idx and other_node not in dist_from_sup: + dist_from_sup[other_node] = dist + 1 + queue.append((other_node, dist + 1)) + + # Assign signal nodes: blend by relative BFS distance, snap to tier + for node in node_map: + if node in tiers: + continue + dg = dist_from_gnd.get(node, 999) + ds = dist_from_sup.get(node, 999) + if ds == 999 and dg == 999: + continuous_y = _GRID_MID_Y + elif ds == 999: + continuous_y = _GRID_LOWER_Y + elif dg == 999: + continuous_y = _GRID_UPPER_Y + else: + ratio = ds / (ds + dg) + continuous_y = _GRID_SUPPLY_Y + ratio * (_GRID_GROUND_Y - _GRID_SUPPLY_Y) + tiers[node] = min(snap_levels, key=lambda t: abs(t - continuous_y)) + + return tiers + + +def _assign_columns( + components: list[SpiceComponent], + node_map: dict[str, list[tuple[SpiceComponent, int]]], +) -> list[SpiceComponent]: + """Order placeable components left-to-right via BFS through signal nodes. + + Supply sources (DC V-source between supply rail and ground) are filtered + out — they become rail symbols instead of drawn components. + """ + from collections import deque + + placeable = [c for c in components if not _is_supply_source(c)] + if not placeable: + return [] + + # Build component adjacency through shared signal nodes + comp_names = {c.name for c in placeable} + comp_lookup = {c.name: c for c in placeable} + adj: dict[str, set[str]] = {c.name: set() for c in placeable} + + for node, connections in node_map.items(): + if _is_ground(node) or _is_supply(node): + continue + node_comps = [c.name for c, _ in connections if c.name in comp_names] + for i, a in enumerate(node_comps): + for b in node_comps[i + 1:]: + adj[a].add(b) + adj[b].add(a) + + # BFS from first voltage source (or first component if none) + start = next((c for c in placeable if c.prefix == "V"), None) + if start is None: + start = placeable[0] + + ordered: list[str] = [] + visited: set[str] = set() + queue: deque[str] = deque([start.name]) + visited.add(start.name) + + while queue: + name = queue.popleft() + ordered.append(name) + for neighbor in sorted(adj.get(name, [])): + if neighbor not in visited: + visited.add(neighbor) + queue.append(neighbor) + + # Append any disconnected components not reached by BFS + for c in placeable: + if c.name not in visited: + ordered.append(c.name) + + return [comp_lookup[name] for name in ordered] + + +def _get_terminal_positions( + placed_elem: object, + comp: SpiceComponent, +) -> dict[int, tuple[float, float]]: + """Extract (x, y) terminal positions from a placed SchemDraw element.""" + positions: dict[int, tuple[float, float]] = {} + + if comp.prefix in ("Q", "M") and len(comp.nodes) >= 3: + anchors = ( + ["collector", "base", "emitter"] + if comp.prefix == "Q" + else ["drain", "gate", "source"] + ) + for i, anchor_name in enumerate(anchors): + if hasattr(placed_elem, anchor_name): + pos = getattr(placed_elem, anchor_name) + positions[i] = (float(pos[0]), float(pos[1])) + else: + if hasattr(placed_elem, "start"): + positions[0] = (float(placed_elem.start[0]), float(placed_elem.start[1])) + if hasattr(placed_elem, "end"): + idx = min(1, len(comp.nodes) - 1) + positions[idx] = (float(placed_elem.end[0]), float(placed_elem.end[1])) + + return positions + + +def _find_supply_value( + components: list[SpiceComponent], + supply_node: str, +) -> tuple[str, str]: + """Find the DC voltage value for a supply rail. Returns (source_name, formatted_value).""" + for comp in components: + if comp.prefix == "V" and len(comp.nodes) == 2: + n0, n1 = comp.nodes + if (n0.lower() == supply_node.lower() and _is_ground(n1)) or \ + (n1.lower() == supply_node.lower() and _is_ground(n0)): + return (comp.name, _format_value(comp.value)) + return ("", "") + + +def _draw_ground_rail(d, positions: list[tuple[float, float]]) -> None: + """Draw horizontal ground bus with vertical stubs and Ground symbol.""" + import schemdraw.elements as elm + + if not positions: + return + + sorted_pos = sorted(positions, key=lambda p: p[0]) + + # Horizontal bus line across all ground-connected terminals + if len(sorted_pos) > 1: + d.add(elm.Line().at(sorted_pos[0]).to(sorted_pos[-1])) + + # Vertical stubs from each terminal down to the bus + for px, py in sorted_pos: + if abs(py - _GRID_GROUND_Y) > 0.1: + d.add(elm.Line().at((px, py)).to((px, _GRID_GROUND_Y))) + + # Ground symbol at the bus midpoint + mid_x = sum(p[0] for p in sorted_pos) / len(sorted_pos) + d.add(elm.Ground().at((mid_x, _GRID_GROUND_Y))) + + +def _draw_supply_rail( + d, + positions: list[tuple[float, float]], + label: str, + value: str, +) -> None: + """Draw horizontal supply bus with vertical stubs and Vdd symbol.""" + import schemdraw.elements as elm + + if not positions: + return + + sorted_pos = sorted(positions, key=lambda p: p[0]) + + if len(sorted_pos) > 1: + d.add(elm.Line().at(sorted_pos[0]).to(sorted_pos[-1])) + + for px, py in sorted_pos: + if abs(py - _GRID_SUPPLY_Y) > 0.1: + d.add(elm.Line().at((px, py)).to((px, _GRID_SUPPLY_Y))) + + mid_x = sum(p[0] for p in sorted_pos) / len(sorted_pos) + rail_label = label.upper() + if value: + rail_label += f" {value}" + d.add(elm.Vdd().at((mid_x, _GRID_SUPPLY_Y)).label(rail_label)) + + +def _draw_signal_wire( + d, + pos_a: tuple[float, float], + pos_b: tuple[float, float], +) -> None: + """Draw an L-shaped or straight wire between two terminal positions.""" + import schemdraw.elements as elm + + x1, y1 = pos_a + x2, y2 = pos_b + + # Skip near-zero-length wires + if abs(x1 - x2) < 0.1 and abs(y1 - y2) < 0.1: + return + + if abs(x1 - x2) < 0.1 or abs(y1 - y2) < 0.1: + # Straight wire (aligned on one axis) + d.add(elm.Line().at((x1, y1)).to((x2, y2))) + else: + # L-shaped: horizontal first, then vertical + d.add(elm.Line().at((x1, y1)).to((x2, y1))) + d.add(elm.Line().at((x2, y1)).to((x2, y2))) + + +def _draw_star_wires( + d, + positions: list[tuple[float, float]], +) -> None: + """Draw star-topology wiring with junction dot for 3+ connections.""" + import schemdraw.elements as elm + + if len(positions) < 2: + return + + # Junction at the median position (reduces total wire length) + xs = sorted(p[0] for p in positions) + ys = sorted(p[1] for p in positions) + jx = xs[len(xs) // 2] + jy = ys[len(ys) // 2] + + for px, py in positions: + if abs(px - jx) < 0.1 and abs(py - jy) < 0.1: + continue + _draw_signal_wire(d, (px, py), (jx, jy)) + + d.add(elm.Dot().at((jx, jy))) + + +# ── Grid Layout Renderer ──────────────────────────────────── def _render_grid(parsed: ParsedNetlist) -> str: - """Render components in a labeled grid layout. + """Render components in a wire-connected grid layout. - Used for complex circuits where topological layout isn't feasible. - Components are arranged in columns by type with terminal node labels. + Places components vertically between BFS-classified node tiers, then + routes wires between terminals sharing the same net: power/ground get + horizontal bus rails with symbols, 2-connection signals get L-shaped + wires, and 3+ connections get star wiring with junction dots. """ import schemdraw + import schemdraw.elements as elm + + node_map = _build_node_map(parsed.components) + tiers = _classify_node_tiers(node_map) + ordered = _assign_columns(parsed.components, node_map) + + if not ordered: + d = schemdraw.Drawing(fontsize=11) + d.add(elm.Label().label("(empty circuit)")) + return d.get_imagedata("svg").decode() d = schemdraw.Drawing(fontsize=11) - # Group by role for logical ordering - sources = [c for c in parsed.components if c.prefix in ("V", "I")] - passives = [c for c in parsed.components if c.prefix in ("R", "C", "L")] - active = [c for c in parsed.components if c.prefix in ("D", "Q", "M")] - other = [ - c for c in parsed.components - if c.prefix not in ("V", "I", "R", "C", "L", "D", "Q", "M") - ] - ordered = sources + passives + active + other - - cols = min(4, max(2, len(ordered))) - x_spacing = 8 - y_spacing = 5 - - for i, comp in enumerate(ordered): - row = i // cols - col = i % cols - x = col * x_spacing - y = -row * y_spacing + # ── Phase 1: Place components vertically between node tiers ── + placed_list: list[PlacedComponent] = [] + for col, comp in enumerate(ordered): + x = col * _GRID_X_SPACING elem = _get_element(comp, parsed.models) - # 3+ terminal devices (BJT, MOSFET) need special handling if comp.prefix in ("Q", "M") and len(comp.nodes) >= 3: - placed = d.add(elem.at((x, y)).label(_component_label(comp))) - _label_multiterminal(d, placed, comp) - else: - placed = d.add( - elem.at((x, y)).right().label(_component_label(comp)) + # Multi-terminal device: center between top/bottom node tiers + node_ys = [tiers.get(n, _GRID_MID_Y) for n in comp.nodes[:3]] + center_y = (max(node_ys) + min(node_ys)) / 2 + placed_elem = d.add( + elem.at((x, center_y)).label(_component_label(comp)) + ) + elif len(comp.nodes) >= 2: + # 2-terminal: orient vertically between the two node tiers + y0 = tiers.get(comp.nodes[0], _GRID_UPPER_Y) + y1 = tiers.get(comp.nodes[1], _GRID_LOWER_Y) + length = max(abs(y0 - y1), _GRID_MIN_COMP_LEN) + + if y0 >= y1: + # node[0] higher → go down so start=node[0], end=node[1] + placed_elem = d.add( + elem.at((x, y0)).down().length(length) + .label(_component_label(comp), loc="right") + ) + else: + # node[0] lower → go up so start=node[0], end=node[1] + placed_elem = d.add( + elem.at((x, y0)).up().length(length) + .label(_component_label(comp), loc="right") + ) + else: + # Fallback: single-node or unusual component + placed_elem = d.add( + elem.at((x, _GRID_MID_Y)).down().length(_GRID_MIN_COMP_LEN) + .label(_component_label(comp), loc="right") + ) + + positions = _get_terminal_positions(placed_elem, comp) + placed_list.append(PlacedComponent( + comp=comp, element=placed_elem, column=col, + terminal_positions=positions, + )) + + # ── Phase 2: Collect terminal positions grouped by node name ── + node_positions: dict[str, list[tuple[float, float]]] = {} + for pc in placed_list: + for i, node in enumerate(pc.comp.nodes): + if i in pc.terminal_positions: + node_positions.setdefault(node, []).append( + pc.terminal_positions[i] + ) + + # ── Phase 3: Route wires by node type ──────────────────────── + for node, positions in node_positions.items(): + if len(positions) < 2: + # Single connection: open dot with node label + if positions and not _is_ground(node) and not _is_supply(node): + px, py = positions[0] + d.add( + elm.Dot(open=True).at((px, py)) + .label(node, loc="right", fontsize=9) + ) + continue + + if _is_ground(node): + _draw_ground_rail(d, positions) + elif _is_supply(node): + _, src_val = _find_supply_value(parsed.components, node) + _draw_supply_rail(d, positions, node, src_val) + elif len(positions) == 2: + _draw_signal_wire(d, positions[0], positions[1]) + # Label at the midpoint of the wire + mx = (positions[0][0] + positions[1][0]) / 2 + my = (positions[0][1] + positions[1][1]) / 2 + d.add(elm.Label().at((mx, my)).label(node, loc="right", fontsize=9)) + else: + # 3+ connections: star wiring with junction dot + _draw_star_wires(d, positions) + xs = sorted(p[0] for p in positions) + ys = sorted(p[1] for p in positions) + d.add( + elm.Label().at((xs[len(xs) // 2], ys[len(ys) // 2])) + .label(node, loc="right", fontsize=9) ) - _label_two_terminal(d, placed, comp) return d.get_imagedata("svg").decode() diff --git a/frontend/src/components/notebook/cells/SpiceCell.tsx b/frontend/src/components/notebook/cells/SpiceCell.tsx index c62ce1e..7538ee0 100644 --- a/frontend/src/components/notebook/cells/SpiceCell.tsx +++ b/frontend/src/components/notebook/cells/SpiceCell.tsx @@ -72,6 +72,15 @@ export function SpiceCell({ cell, isFirst, isLast }: SpiceCellProps) { [cell.id, cell.source, updateCellSource], ); + // Auto-generate schematic on first render if none exists + const hasAutoGenerated = useRef(false); + useEffect(() => { + if (schematicSvg || hasAutoGenerated.current) return; + if (!cell.source.trim() || cell.source.trim() === '* SPICE Netlist\n\nR1 in out 1k\nV1 in 0 DC 5\n\n.op\n.end') return; + hasAutoGenerated.current = true; + generateSchematic(cell.id); + }, [cell.id, cell.source, schematicSvg, generateSchematic]); + // Debounced auto-redraw: regenerate schematic 800ms after source changes. // Track the source at last generation to avoid retriggering from our own SVG updates. const lastGeneratedSource = useRef(null);