diff --git a/backend/src/spicebook/engine/schematic.py b/backend/src/spicebook/engine/schematic.py index 429da16..418edb6 100644 --- a/backend/src/spicebook/engine/schematic.py +++ b/backend/src/spicebook/engine/schematic.py @@ -4,6 +4,8 @@ Pipeline: netlist text → parse → component list → node graph → layout """ import logging +import re +import xml.etree.ElementTree as ET from dataclasses import dataclass, field logger = logging.getLogger(__name__) @@ -32,6 +34,28 @@ class ParsedNetlist: models: dict[str, str] = field(default_factory=dict) +@dataclass +class TerminalPath: + """A chain of components traced from an active device terminal.""" + + terminal: str # "collector"/"base"/"emitter" or "drain"/"gate"/"source" + components: list[SpiceComponent] + end_node: str # final node name ("0", "vcc", "out", etc.) + end_type: str # "ground" | "supply" | "open" + + +@dataclass +class ActiveLayout: + """Layout plan for a circuit with one active device.""" + + device: SpiceComponent + device_type: str # "bjt_npn" | "bjt_pnp" | "nfet" | "pfet" + paths: dict[str, list[TerminalPath]] + supply_sources: list[SpiceComponent] + signal_sources: list[SpiceComponent] + unplaced: list[SpiceComponent] + + # ── Ground / Supply Detection ─────────────────────────────────── GROUND_NAMES = {"0", "gnd"} @@ -46,6 +70,28 @@ def _is_supply(node: str) -> bool: return node.lower() in SUPPLY_NAMES +_TRANSIENT_PATTERN = re.compile( + r"\b(?:AC|SIN|PULSE|PWL|EXP|SFFM)\b", re.IGNORECASE +) + + +def _is_supply_source(comp: SpiceComponent) -> bool: + """True if DC-only voltage source between a supply rail and ground.""" + if comp.prefix != "V" or len(comp.nodes) != 2: + return False + if _TRANSIENT_PATTERN.search(comp.value): + return False + n0, n1 = comp.nodes + return (_is_supply(n0) and _is_ground(n1)) or (_is_ground(n0) and _is_supply(n1)) + + +def _is_signal_source(comp: SpiceComponent) -> bool: + """True if V/I source with AC or transient specification.""" + if comp.prefix not in ("V", "I"): + return False + return bool(_TRANSIENT_PATTERN.search(comp.value)) + + # ── SPICE Netlist Parser ─────────────────────────────────────── @@ -345,6 +391,141 @@ def _find_main_loop( return path if len(path) > 1 else None +# ── Connected Layout: Active-Device Centered ───────────────────── + + +def _classify_end(node: str) -> str: + """Classify a terminal endpoint node.""" + if _is_ground(node): + return "ground" + if _is_supply(node): + return "supply" + return "open" + + +def _trace_paths_from_terminal( + terminal_node: str, + terminal_name: str, + node_map: dict[str, list[tuple[SpiceComponent, int]]], + exclude: set[str], +) -> list[TerminalPath]: + """Trace 2-terminal component chains from an active device terminal. + + Uses a shared seen set across all paths from this terminal to prevent + duplicate placement when paths share components (diamond topologies). + """ + paths: list[TerminalPath] = [] + seen: set[str] = set() + for comp, idx in node_map.get(terminal_node, []): + if comp.name in exclude or comp.name in seen or len(comp.nodes) != 2: + continue + chain = [comp] + visited = exclude | seen | {comp.name} + current = comp.nodes[1 - idx] + while True: + if _is_ground(current) or _is_supply(current): + break + candidates = [ + (c, i) + for c, i in node_map.get(current, []) + if c.name not in visited and len(c.nodes) == 2 + ] + if len(candidates) != 1: + break + next_comp, next_idx = candidates[0] + visited.add(next_comp.name) + chain.append(next_comp) + current = next_comp.nodes[1 - next_idx] + seen.update(c.name for c in chain) + paths.append( + TerminalPath(terminal_name, chain, current, _classify_end(current)) + ) + return paths + + +def _plan_active_layout( + parsed: ParsedNetlist, + node_map: dict[str, list[tuple[SpiceComponent, int]]], +) -> ActiveLayout | None: + """Build layout plan for circuits with exactly one active device.""" + devices = [c for c in parsed.components if c.prefix in ("Q", "M")] + if len(devices) != 1: + return None + device = devices[0] + if len(device.nodes) < 3: + return None + + model_type = parsed.models.get(device.model, "").upper() + if device.prefix == "Q": + device_type = "bjt_pnp" if "PNP" in model_type else "bjt_npn" + else: + device_type = "pfet" if ("PMOS" in model_type or model_type == "P") else "nfet" + + supply_sources = [c for c in parsed.components if _is_supply_source(c)] + signal_sources = [c for c in parsed.components if _is_signal_source(c)] + + if device.prefix == "Q": + terminals = { + "collector": device.nodes[0], + "base": device.nodes[1], + "emitter": device.nodes[2], + } + else: + terminals = { + "drain": device.nodes[0], + "gate": device.nodes[1], + "source": device.nodes[2], + } + + exclude = {device.name} | {s.name for s in supply_sources} + paths: dict[str, list[TerminalPath]] = {} + placed = set(exclude) + + for tname, tnode in terminals.items(): + tpaths = _trace_paths_from_terminal(tnode, tname, node_map, exclude) + paths[tname] = tpaths + for p in tpaths: + for c in p.components: + placed.add(c.name) + + unplaced = [c for c in parsed.components if c.name not in placed] + if unplaced: + logger.debug( + "Connected layout: %d unplaced components: %s", + len(unplaced), + [c.name for c in unplaced], + ) + + return ActiveLayout( + device=device, + device_type=device_type, + paths=paths, + supply_sources=supply_sources, + signal_sources=signal_sources, + unplaced=unplaced, + ) + + +def _path_style( + term_name: str, + path: TerminalPath, + has_signal: bool, + is_inverted: bool, + supply_term: str, + input_term: str, +) -> str: + """Classify a path's drawing style: up, down, input, or output.""" + if path.end_type == "supply": + return "up" if not is_inverted else "down" + if term_name == input_term and has_signal and len(path.components) > 1: + return "input" + if term_name == supply_term and path.end_type == "ground" and len(path.components) > 1: + return "output" + if path.end_type == "ground": + return "down" if not is_inverted else "up" + return "down" if not is_inverted else "up" + + # ── Value Formatting ─────────────────────────────────────────── @@ -591,18 +772,284 @@ def _label_multiterminal(d, placed, comp: SpiceComponent) -> None: ) +# ── Connected Layout Renderer ──────────────────────────────── + + +def _draw_vert_chain(d, parsed, start, components, going_up, end_type, end_node): + """Draw a chain of components vertically, terminated by Vdd or Ground.""" + import schemdraw.elements as elm + + direction = "up" if going_up else "down" + for i, comp in enumerate(components): + elem = _get_element(comp, parsed.models) + if i == 0 and start is not None: + elem = elem.at(start) + elem = getattr(elem, direction)() + elem = elem.label(_component_label(comp), loc="left") + d.add(elem) + + if end_type == "ground": + d.add(elm.Ground()) + elif end_type == "supply": + d.add(elm.Vdd().label(end_node.upper())) + else: + d.add(elm.Dot(open=True).label(end_node, fontsize=9)) + + +def _draw_horiz_then_down(d, parsed, start, path, going_right): + """Draw components horizontally, with the last turning down to ground.""" + import schemdraw.elements as elm + + h_dir = "right" if going_right else "left" + comps = path.components + + for i, comp in enumerate(comps): + elem = _get_element(comp, parsed.models) + if i == 0 and start is not None: + elem = elem.at(start) + + is_last = i == len(comps) - 1 + if is_last and len(comps) > 1: + # Turn downward at the bend + d.push() + d.add(elem.down().label(_component_label(comp), loc="right")) + if path.end_type == "ground": + d.add(elm.Ground()) + elif path.end_type == "supply": + d.add(elm.Vdd().label(path.end_node.upper())) + else: + d.add(elm.Dot(open=True).label(path.end_node, fontsize=9)) + d.pop() + # Label the junction node at the bend + prev = comps[-2] + shared = set(prev.nodes) & set(comp.nodes) + for node in shared: + if not _is_ground(node) and not _is_supply(node): + loc = "right" if going_right else "left" + d.add( + elm.Dot(open=True).label(node, loc=loc, fontsize=9) + ) + break + else: + d.add(getattr(elem, h_dir)().label(_component_label(comp))) + + # Single-component horizontal path ending at ground/supply/open + if len(comps) == 1: + if path.end_type == "ground": + d.add(elm.Ground()) + elif path.end_type == "supply": + d.add(elm.Vdd().label(path.end_node.upper())) + else: + d.add(elm.Dot(open=True).label(path.end_node, fontsize=9)) + + +def _render_connected(parsed: ParsedNetlist, layout: ActiveLayout) -> str: + """Render a single-active-device circuit with connected wires. + + Places the transistor/FET at the center and draws terminal paths + using push/pop for branching at junction points. + """ + import schemdraw + import schemdraw.elements as elm + + d = schemdraw.Drawing(fontsize=12) + signal_names = {s.name for s in layout.signal_sources} + is_inverted = layout.device_type in ("bjt_pnp", "pfet") + + # Place active device at the center + if layout.device_type.startswith("bjt"): + dev_elem = elm.BjtPnp() if is_inverted else elm.BjtNpn() + q = d.add(dev_elem.label(_component_label(layout.device))) + anchors = { + "collector": q.collector, + "base": q.base, + "emitter": q.emitter, + } + supply_term = "emitter" if is_inverted else "collector" + input_term = "base" + else: + dev_elem = elm.PFet() if is_inverted else elm.NFet() + q = d.add(dev_elem.label(_component_label(layout.device))) + anchors = { + "drain": q.drain, + "gate": q.gate, + "source": q.source, + } + supply_term = "source" if is_inverted else "drain" + input_term = "gate" + + for term_name, term_paths in layout.paths.items(): + if not term_paths: + continue + anchor = anchors[term_name] + + # Classify each path's drawing style + up_paths: list[TerminalPath] = [] + down_paths: list[TerminalPath] = [] + input_paths: list[TerminalPath] = [] + output_paths: list[TerminalPath] = [] + + for p in term_paths: + has_sig = any(c.name in signal_names for c in p.components) + style = _path_style( + term_name, p, has_sig, is_inverted, supply_term, input_term + ) + if style == "up": + up_paths.append(p) + elif style == "down": + down_paths.append(p) + elif style == "input": + input_paths.append(p) + else: + output_paths.append(p) + + total = len(up_paths) + len(down_paths) + len(input_paths) + len(output_paths) + + # Junction wire for input terminal when paths branch + if term_name == input_term and total > 1: + junc = d.add(elm.Line().at(anchor).left(1)) + draw_from = junc.end + else: + draw_from = anchor + + # Vertical-up paths (toward supply rail) + for i, p in enumerate(up_paths): + d.push() + if i > 0: + d.add(elm.Line().at(draw_from).right(1.5 * i)) + _draw_vert_chain( + d, parsed, None, p.components, True, p.end_type, p.end_node + ) + else: + _draw_vert_chain( + d, parsed, draw_from, p.components, True, p.end_type, p.end_node + ) + d.pop() + + # Vertical-down paths (toward ground) + for i, p in enumerate(down_paths): + d.push() + if i > 0: + d.add(elm.Line().at(draw_from).right(1.5 * i)) + _draw_vert_chain( + d, parsed, None, p.components, False, p.end_type, p.end_node + ) + else: + _draw_vert_chain( + d, parsed, draw_from, p.components, False, p.end_type, p.end_node + ) + d.pop() + + # Output paths (right then down) + for p in output_paths: + d.push() + _draw_horiz_then_down(d, parsed, draw_from, p, going_right=True) + d.pop() + + # Input paths (left then down) + for p in input_paths: + d.push() + _draw_horiz_then_down(d, parsed, draw_from, p, going_right=False) + d.pop() + + # Junction dot at branch point + if total > 1: + d.add(elm.Dot().at(draw_from)) + + return d.get_imagedata("svg").decode() + + +# ── SVG Annotation ──────────────────────────────────────────── + +# Prefixes whose values are numeric/SI-unit and make sense to edit inline +_EDITABLE_PREFIXES = {"R", "C", "L", "V", "I", "E", "G", "F", "H", "B", "S"} + +# SVG namespace +_SVG_NS = "http://www.w3.org/2000/svg" + + +def annotate_svg(svg_str: str, parsed: ParsedNetlist) -> str: + """Add data attributes to SVG text elements for interactive editing. + + SchemDraw emits elements with two children: + tspan[0] = component name (e.g., "R1") + tspan[1] = display value (e.g., "1k") + + This function matches those names against parsed components and adds: + - data-component="R1" on the parent + - data-editable="true" + data-raw-value="1k" on the value tspan + (only for components with numeric/editable values) + """ + # Register SVG namespace to avoid ns0: prefix pollution in output + ET.register_namespace("", _SVG_NS) + + try: + root = ET.fromstring(svg_str) + except ET.ParseError: + logger.warning("Failed to parse SVG for annotation") + return svg_str + + # Build lookup: uppercase component name → SpiceComponent + comp_lookup: dict[str, SpiceComponent] = { + c.name.upper(): c for c in parsed.components + } + + ns = {"svg": _SVG_NS} + for text_el in root.findall(".//svg:text", ns): + tspans = text_el.findall("svg:tspan", ns) + if len(tspans) != 2: + continue + + name_text = (tspans[0].text or "").strip() + + if not name_text: + continue + + comp = comp_lookup.get(name_text.upper()) + if comp is None: + continue + + # Annotate the parent with the component name + text_el.set("data-component", comp.name) + + # Only mark as editable if this component type has a tuneable value + # and the value is not just a model name + if comp.prefix in _EDITABLE_PREFIXES and comp.value: + tspans[1].set("data-editable", "true") + tspans[1].set("data-raw-value", comp.value) + + return ET.tostring(root, encoding="unicode") + + +def build_component_map(parsed: ParsedNetlist) -> dict[str, str]: + """Build a name→raw_value map for components with editable values.""" + result: dict[str, str] = {} + for comp in parsed.components: + if comp.prefix in _EDITABLE_PREFIXES and comp.value: + result[comp.name] = comp.value + return result + + # ── Public API ───────────────────────────────────────────────── -def netlist_to_svg(netlist_text: str) -> str: - """Convert a SPICE netlist to an SVG schematic diagram. +@dataclass +class SchematicResult: + """Result of schematic generation with SVG and component metadata.""" + + svg: str + component_map: dict[str, str] + + +def netlist_to_svg(netlist_text: str) -> SchematicResult: + """Convert a SPICE netlist to an annotated SVG schematic diagram. Tries a clean loop layout for simple circuits (<=12 two-terminal components with a clear main path), falls back to a labeled grid for complex circuits with active devices. Returns: - SVG string. + SchematicResult with annotated SVG and component value map. Raises: ValueError: If no components could be parsed from the netlist. @@ -624,13 +1071,30 @@ def netlist_to_svg(netlist_text: str) -> str: node_map = _build_node_map(parsed.components) # Try loop layout for simple circuits with a clear main path + svg: str | None = None two_terminal_only = all(len(c.nodes) == 2 for c in parsed.components) if two_terminal_only and len(parsed.components) <= 12: loop = _find_main_loop(parsed.components, node_map) if loop and len(loop) >= 2: try: - return _render_loop(parsed, loop) + svg = _render_loop(parsed, loop) except Exception as exc: logger.warning("Loop layout failed, using grid: %s", exc) - return _render_grid(parsed) + # Try connected layout for single-active-device circuits + if svg is None: + layout = _plan_active_layout(parsed, node_map) + if layout is not None: + try: + svg = _render_connected(parsed, layout) + except Exception as exc: + logger.warning("Connected layout failed, using grid: %s", exc) + + if svg is None: + svg = _render_grid(parsed) + + # Post-process: annotate SVG with data attributes for interactivity + svg = annotate_svg(svg, parsed) + component_map = build_component_map(parsed) + + return SchematicResult(svg=svg, component_map=component_map)