From aba6de692ea9ff5b0990598d11cfa42685669b0f Mon Sep 17 00:00:00 2001 From: Andrew Butt Date: Fri, 7 Nov 2025 15:34:36 -0800 Subject: [PATCH] Automate placement in ArrayBuilder Signed-off-by: Andrew Butt --- .../design/tools/ArrayBuilder.java | 300 ++++++++++++-- .../design/tools/ArrayNetlistGraph.java | 378 ++++++++++++++++++ 2 files changed, 643 insertions(+), 35 deletions(-) create mode 100644 src/com/xilinx/rapidwright/design/tools/ArrayNetlistGraph.java diff --git a/src/com/xilinx/rapidwright/design/tools/ArrayBuilder.java b/src/com/xilinx/rapidwright/design/tools/ArrayBuilder.java index 37758db8c..418a1faa9 100644 --- a/src/com/xilinx/rapidwright/design/tools/ArrayBuilder.java +++ b/src/com/xilinx/rapidwright/design/tools/ArrayBuilder.java @@ -37,18 +37,22 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Queue; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; import com.xilinx.rapidwright.design.Cell; import com.xilinx.rapidwright.design.ClockTools; import com.xilinx.rapidwright.design.Design; +import com.xilinx.rapidwright.design.DesignTools; import com.xilinx.rapidwright.design.Module; import com.xilinx.rapidwright.design.ModuleInst; import com.xilinx.rapidwright.design.Net; import com.xilinx.rapidwright.design.NetTools; import com.xilinx.rapidwright.design.NetType; +import com.xilinx.rapidwright.design.RelocatableTileRectangle; import com.xilinx.rapidwright.design.SiteInst; import com.xilinx.rapidwright.design.SitePinInst; import com.xilinx.rapidwright.design.Unisim; @@ -62,7 +66,6 @@ import com.xilinx.rapidwright.device.PartNameTools; import com.xilinx.rapidwright.device.Series; import com.xilinx.rapidwright.device.Site; -import com.xilinx.rapidwright.device.SiteTypeEnum; import com.xilinx.rapidwright.device.Tile; import com.xilinx.rapidwright.edif.EDIFCell; import com.xilinx.rapidwright.edif.EDIFCellInst; @@ -75,6 +78,7 @@ import com.xilinx.rapidwright.edif.EDIFPortInst; import com.xilinx.rapidwright.edif.EDIFTools; import com.xilinx.rapidwright.edif.EDIFValueType; +import com.xilinx.rapidwright.rwroute.PartialRouter; import com.xilinx.rapidwright.tests.CodePerfTracker; import com.xilinx.rapidwright.util.FileTools; import com.xilinx.rapidwright.util.MessageGenerator; @@ -110,10 +114,13 @@ public class ArrayBuilder { private static final List SKIP_IMPL_OPTS = Arrays.asList("k", "skip-impl"); private static final List LIMIT_INSTS_OPTS = Arrays.asList("l", "limit-inst-count"); private static final List TOP_LEVEL_DESIGN_OPTS = Arrays.asList("t", "top-design"); + private static final List EXACT_PLACEMENT_OPTS = Collections.singletonList("exact-placement"); private static final List WRITE_PLACEMENT_OPTS = Collections.singletonList("write-placement"); private static final List PLACEMENT_FILE_OPTS = Collections.singletonList("read-placement"); private static final List PLACEMENT_GRID_OPTS = Collections.singletonList("write-placement-grid"); private static final List OUT_OF_CONTEXT_OPTS = Collections.singletonList("out-of-context"); + private static final List UNROUTE_STATIC_NETS_OPTS = Collections.singletonList("unroute-static-nets"); + private static final List ROUTE_CLOCK_OPTS = Collections.singletonList("route-clk"); private Design design; @@ -143,6 +150,14 @@ public class ArrayBuilder { private boolean outOfContext; + private ArrayNetlistGraph condensedGraph; + + private boolean exactPlacement; + + private boolean unrouteStaticNets; + + private boolean routeClock; + public static final double DEFAULT_CLK_PERIOD_TARGET = 2.0; private OptionParser createOptionParser() { @@ -166,7 +181,10 @@ private OptionParser createOptionParser() { acceptsAll(PLACEMENT_FILE_OPTS, "Use placement specified in file").withRequiredArg(); acceptsAll(PLACEMENT_GRID_OPTS, "Write grid of possible placement locations to specified file").withRequiredArg(); acceptsAll(TOP_LEVEL_DESIGN_OPTS, "Top level design with blackboxes/kernel insts").withRequiredArg(); + acceptsAll(EXACT_PLACEMENT_OPTS, "Use exact module overlap calculation instead of the faster bounding-box method"); acceptsAll(OUT_OF_CONTEXT_OPTS, "Specifies that the array will be compiled out of context"); + acceptsAll(UNROUTE_STATIC_NETS_OPTS, "Unroute static (GND/VCC) nets to potentially help with routability"); + acceptsAll(ROUTE_CLOCK_OPTS, "Route clock using RWRoute"); acceptsAll(HELP_OPTS, "Print this help message").forHelp(); } }; @@ -292,7 +310,7 @@ public void setOutputPlacementLocsFileName(String outputPlacementLocsFileName) { this.outputPlacementLocsFileName = outputPlacementLocsFileName; } - public boolean getOutOfContext() { + public boolean isOutOfContext() { return outOfContext; } @@ -300,11 +318,46 @@ public void setOutOfContext(boolean outOfContext) { this.outOfContext = outOfContext; } + public ArrayNetlistGraph getCondensedGraph() { + return condensedGraph; + } + + public void setCondensedGraph(ArrayNetlistGraph condensedGraph) { + this.condensedGraph = condensedGraph; + } + + public boolean isExactPlacement() { + return exactPlacement; + } + + public void setExactPlacement(boolean exactPlacement) { + this.exactPlacement = exactPlacement; + } + + public boolean unrouteStaticNets() { + return unrouteStaticNets; + } + + public void setUnrouteStaticNets(boolean unrouteStaticNets) { + this.unrouteStaticNets = unrouteStaticNets; + } + + public boolean isRouteClock() { + return routeClock; + } + + public void setRouteClock(boolean routeClock) { + this.routeClock = routeClock; + } + private void initializeArrayBuilder(OptionSet options) { Path inputFile = null; setSkipImpl(options.has(SKIP_IMPL_OPTS.get(0))); setOutOfContext(options.has(OUT_OF_CONTEXT_OPTS.get(0))); + setExactPlacement(options.has(EXACT_PLACEMENT_OPTS.get(0))); + setUnrouteStaticNets(options.has(UNROUTE_STATIC_NETS_OPTS.get(0))); + setRouteClock(options.has(ROUTE_CLOCK_OPTS.get(0))); if (options.has(KERNEL_DESIGN_OPTS.get(0))) { inputFile = Paths.get((String) options.valueOf(KERNEL_DESIGN_OPTS.get(0))); @@ -314,6 +367,7 @@ private void initializeArrayBuilder(OptionSet options) { setKernelDesign(Design.readCheckpoint(inputFile, companionEDIF, CodePerfTracker.SILENT)); } else { setKernelDesign(Design.readCheckpoint(inputFile)); + EDIFTools.removeVivadoBusPreventionAnnotations(getKernelDesign().getNetlist()); if (!design.getNetlist().getEncryptedCells().isEmpty()) { System.out.println("Design has encrypted cells"); } else { @@ -403,6 +457,7 @@ private void initializeArrayBuilder(OptionSet options) { if (options.has(TOP_LEVEL_DESIGN_OPTS.get(0))) { Design d = Design.readCheckpoint((String) options.valueOf(TOP_LEVEL_DESIGN_OPTS.get(0))); setTopDesign(d); + EDIFTools.removeVivadoBusPreventionAnnotations(getTopDesign().getNetlist()); } if (options.has(TOP_CLK_NAME_OPTS.get(0))) { @@ -529,7 +584,8 @@ public static void writePlacementLocsToFile(List modules, String fileNam }; for (Module module : modules) { lines.add(module.getName() + ":"); - List validPlacements = module.getAllValidPlacements().stream().sorted(comparator).collect(Collectors.toList()); + List validPlacements = module.getAllValidPlacements().stream().sorted(comparator) + .collect(Collectors.toList()); for (Site anchor : validPlacements) { lines.add(anchor.getName()); } @@ -537,6 +593,29 @@ public static void writePlacementLocsToFile(List modules, String fileNam FileTools.writeLinesToTextFile(lines, fileName); } + public static List> getValidPlacementGrid(Module module) { + List> placementGrid = new ArrayList<>(); + // Sort by descending Y coordinate, then ascending X coordinate + List sortedValidPlacements = module.getAllValidPlacements().stream().sorted((s1, s2) -> { + if (s1.getInstanceY() == s2.getInstanceY()) { + return s1.getInstanceX() - s2.getInstanceX(); + } + return s2.getInstanceY() - s1.getInstanceY(); + }).collect(Collectors.toList()); + int currentYCoordinate = sortedValidPlacements.get(0).getInstanceY(); + int i = 0; + placementGrid.add(new ArrayList<>()); + for (Site anchor : sortedValidPlacements) { + if (anchor.getInstanceY() < currentYCoordinate) { + i++; + placementGrid.add(new ArrayList<>()); + } + placementGrid.get(i).add(anchor); + currentYCoordinate = anchor.getInstanceY(); + } + return placementGrid; + } + public static void main(String[] args) { CodePerfTracker t = new CodePerfTracker(ArrayBuilder.class.getName()); t.start("Init"); @@ -561,7 +640,6 @@ public static void main(String[] args) { } List modules = new ArrayList<>(); - boolean unrouteStaticNets = false; if (!ab.isSkipImpl()) { t.stop().start("Implement Kernel"); FileTools.makeDirs(workDir.toString()); @@ -589,7 +667,7 @@ public static void main(String[] args) { System.out.println("Reading... " + dcpPath); Design d = Design.readCheckpoint(dcpPath); d.setName(d.getName() + "_" + i); - Module m = new Module(d, unrouteStaticNets); + Module m = new Module(d, ab.unrouteStaticNets()); modules.add(m); m.setPBlock(pe.getPBlock(i)); m.calculateAllValidPlacements(d.getDevice()); @@ -602,8 +680,35 @@ public static void main(String[] args) { // Just use the design we loaded and replicate it t.stop().start("Calculate Valid Placements"); removeBUFGs(ab.getKernelDesign()); - Module m = new Module(ab.getKernelDesign(), unrouteStaticNets); + if (ab.unrouteStaticNets()) { + Net gndNet = ab.getKernelDesign().getNet(Net.GND_NET); + if (gndNet != null) { + gndNet.unroute(); + List staticSourcePins = new ArrayList<>(); + Set staticSourceSites = new HashSet<>(); + for (SitePinInst pin : gndNet.getPins()) { + if (pin.isOutPin() && pin.getSiteInst().getName().startsWith(SiteInst.STATIC_SOURCE)) { + staticSourcePins.add(pin); + staticSourceSites.add(pin.getSiteInst()); + } + } + for (SitePinInst pin : staticSourcePins) { + gndNet.removePin(pin); + pin.getSiteInst().removePin(pin); + } + for (SiteInst siteInst : staticSourceSites) { + siteInst.setDesign(null); + siteInst.unPlace(); + } + } + Net vccNet = ab.getKernelDesign().getNet(Net.VCC_NET); + if (vccNet != null) { + vccNet.unroute(); + } + } + Module m = new Module(ab.getKernelDesign(), ab.unrouteStaticNets()); m.getNet(ab.getKernelClockName()).unroute(); + if (ab.getInputPlacementFileName() == null) { m.calculateAllValidPlacements(ab.getDevice()); } @@ -612,19 +717,46 @@ public static void main(String[] args) { } modules.add(m); } - t.stop().start("Place Instances"); Design array = null; List modInstNames = null; + List, String>> idealPlacementList = null; if (ab.getTopDesign() == null) { array = new Design("array", ab.getKernelDesign().getPartName()); } else { array = ab.getTopDesign(); + t.stop().start("Calculate ideal array placement"); // Find instances in existing design modInstNames = getMatchingModuleInstanceNames(modules.get(0), array); + if (modInstNames.isEmpty()) { + throw new RuntimeException("Failed to find module instances in top design that match kernel interface"); + } ab.setInstCountLimit(modInstNames.size()); + ab.setCondensedGraph(new ArrayNetlistGraph(array, modInstNames)); + Map, String> idealPlacement = + ab.getCondensedGraph().getGreedyPlacementGrid(); +// Map foldingMap = new HashMap<>(); +// foldingMap.put(17, 21); +// foldingMap.put(18, 20); +// foldingMap.put(20, 18); +// foldingMap.put(21, 17); +// idealPlacement = foldIdealPlacement(idealPlacement, foldingMap); +// ab.getCondensedGraph().getOptimalPlacementGrid(ab.getInstCountLimit(), ab.getInstCountLimit()); + idealPlacementList = idealPlacement.entrySet().stream() + .map((e) -> new Pair<>(e.getKey(), e.getValue())) + .sorted((p1, p2) -> { + Pair pa = p1.getFirst(); + Pair pb = p2.getFirst(); + if (!Objects.equals(pa.getSecond(), pb.getSecond())) { + return pa.getSecond().compareTo(pb.getSecond()); + } + + return pa.getFirst().compareTo(pb.getFirst()); + }) + .collect(Collectors.toList()); } + t.stop().start("Place Instances"); if (ab.getOutputPlacementLocsFileName() != null) { writePlacementLocsToFile(modules, ab.getOutputPlacementLocsFileName()); } @@ -676,34 +808,60 @@ public static void main(String[] args) { } else { ModuleInst curr = null; int i = 0; - outer: for (Module module : modules) { - for (Site anchor : module.getAllValidPlacements()) { - if (curr == null) { - String instName = modInstNames == null ? ("inst_" + i) : modInstNames.get(i); - curr = array.createModuleInst(instName, module); - i++; + + // TODO: Figure out how to handle placement for multiple modules + Module module = modules.get(0); + RelocatableTileRectangle boundingBox = module.getBoundingBox(); + List boundingBoxes = new ArrayList<>(); + List> validPlacementGrid = getValidPlacementGrid(module); + int gridX = 0; + int gridY = 5; + int lastYCoordinate = 0; + boolean searchDown = true; + while (placed < ab.getInstCountLimit()) { + if (curr == null) { + String instName = modInstNames == null ? ("inst_" + i) : idealPlacementList.get(i).getSecond(); + int yCoordinate = idealPlacementList.get(i).getFirst().getSecond(); + if (yCoordinate > lastYCoordinate) { + gridX = 0; + searchDown = true; } + lastYCoordinate = yCoordinate; + curr = array.createModuleInst(instName, module); + i++; + } + if (gridY >= validPlacementGrid.size()) { + throw new RuntimeException("Optimal placement is too tall for device"); + } + if (gridX >= validPlacementGrid.get(gridY).size()) { + throw new RuntimeException("Optimal placement is too wide for device"); + } + Site anchor = validPlacementGrid.get(gridY).get(gridX); + RelocatableTileRectangle newBoundingBox = + boundingBox.getCorresponding(anchor.getTile(), module.getAnchor().getTile()); + boolean noOverlap = boundingBoxes.stream().noneMatch((b) -> b.overlaps(newBoundingBox)); + if (ab.isExactPlacement() || (noOverlap && !boundingBoxStraddlesClockRegion(newBoundingBox))) { if (curr.place(anchor, true, false)) { - if (straddlesClockRegion(curr)) { - curr.unplace(); - continue; - } - - List overlapping = NetTools.getNetsWithOverlappingNodes(array); - if (!overlapping.isEmpty()) { + if (ab.isExactPlacement() && (straddlesClockRegion(curr) + || !NetTools.getNetsWithOverlappingNodes(array).isEmpty()) + ) { curr.unplace(); - continue; - } - - placed++; - newPlacementMap.put(curr, anchor); - System.out.println(" ** PLACED: " + placed + " " + anchor + " " + curr.getName()); - curr = null; - if (placed >= ab.getInstCountLimit()) { - break outer; + } else { + boundingBoxes.add(newBoundingBox); + placed++; + newPlacementMap.put(curr, anchor); + System.out.println(" ** PLACED: " + placed + " " + anchor + " " + curr.getName() + + " " + curr.getAnchor().getTile().getSLR()); + curr = null; + searchDown = false; } } } + if (!searchDown) { + gridX++; + } else { + gridY++; + } } } @@ -766,16 +924,26 @@ public static void main(String[] args) { array.setAutoIOBuffers(false); } - Net gndNet = array.getNet(Net.GND_NET); - gndNet.unroute(); - Net vccNet = array.getNet(Net.VCC_NET); - vccNet.unroute(); + if (ab.unrouteStaticNets()) { + Net gndNet = array.getNet(Net.GND_NET); + if (gndNet != null) { + gndNet.unroute(); + } + Net vccNet = array.getNet(Net.VCC_NET); + if (vccNet != null) { + vccNet.unroute(); + } + } array.getNetlist().consolidateAllToWorkLibrary(); + array.flattenDesign(); - if (ab.getOutOfContext()) { + if (ab.isOutOfContext()) { // Automatically find bounding PBlock based on used Slices, DSPs, and BRAMs Set usedSites = new HashSet<>(); for (SiteInst siteInst : array.getSiteInsts()) { + if (siteInst.getName().contains("STATIC_SOURCE_SLICE")) { + continue; + } if (isSLICE(siteInst) || isBRAM(siteInst) || isDSP(siteInst)) { usedSites.add(siteInst.getSite()); } @@ -784,11 +952,64 @@ public static void main(String[] args) { InlineFlopTools.createAndPlaceFlopsInlineOnTopPortsNearPins(array, ab.getTopClockName(), pBlock); } + if (ab.isRouteClock()) { + t.stop().start("Route clock"); + Net clockNet = array.getNet(ab.getTopClockName()); + DesignTools.makePhysNetNamesConsistent(array); + DesignTools.createPossiblePinsToStaticNets(array); + DesignTools.createMissingSitePinInsts(array, clockNet); + List pinsToRoute = clockNet.getPins(); + + PartialRouter.routeDesignPartialNonTimingDriven(array, pinsToRoute); + } + t.stop().start("Write DCP"); array.writeCheckpoint(ab.getOutputName()); t.stop().printSummary(); } + private static Map, String> foldIdealPlacement(Map, String> placement, + Map newRowMap) { + if (newRowMap.isEmpty()) { + return placement; + } + Map, String> newPlacement = new HashMap<>(placement); + + // Check if row updates are unique + Set fromSet = new HashSet<>(); + Set toSet = new HashSet<>(); + for (Map.Entry rowUpdate : newRowMap.entrySet()) { + if (fromSet.contains(rowUpdate.getKey())) { + throw new RuntimeException("Non-unique source row when folding placement"); + } + if (toSet.contains(rowUpdate.getValue())) { + throw new RuntimeException("Non-unique destination row when folding placement"); + } + fromSet.add(rowUpdate.getKey()); + toSet.add(rowUpdate.getValue()); + } + if (!fromSet.containsAll(toSet)) { + throw new RuntimeException("Ideal placement folding provided with a non one-to-one mapping"); + } + + for (Map.Entry rowUpdate : newRowMap.entrySet()) { + int fromRow = rowUpdate.getKey(); + int toRow = rowUpdate.getValue(); + int currColumn = 0; + while (placement.containsKey(new Pair<>(currColumn, fromRow))) { + String cell = placement.get(new Pair<>(currColumn, fromRow)); + newPlacement.put(new Pair<>(currColumn, toRow), cell); + currColumn++; + } + // Remove rest of destination row if rows are not equal length + while (placement.containsKey(new Pair<>(currColumn, toRow))) { + newPlacement.remove(new Pair<>(currColumn, toRow)); + currColumn++; + } + } + return newPlacement; + } + public static Cell createBUFGCE(Design design, EDIFCell parent, String name, Site location) { Cell bufgce = design.createAndPlaceCell(parent, name, Unisim.BUFGCE, location, location.getBEL("BUFCE")); @@ -814,6 +1035,14 @@ public static Cell createBUFGCE(Design design, EDIFCell parent, String name, Sit return bufgce; } + private static boolean boundingBoxStraddlesClockRegion(RelocatableTileRectangle boundingBox) { + ClockRegion cr0 = boundingBox.getMaxColumnTile().getClockRegion(); + ClockRegion cr1 = boundingBox.getMinColumnTile().getClockRegion(); + ClockRegion cr2 = boundingBox.getMaxRowTile().getClockRegion(); + ClockRegion cr3 = boundingBox.getMinRowTile().getClockRegion(); + return !Stream.of(cr0, cr1, cr2, cr3).allMatch(cr0::equals); + } + private static boolean straddlesClockRegion(ModuleInst mi) { ClockRegion cr = mi.getAnchor().getSite().getClockRegion(); for (SiteInst si : mi.getSiteInsts()) { @@ -842,7 +1071,8 @@ private static boolean straddlesClockRegionOrRCLK(ModuleInst mi) { private static int getRCLKRowIndex(ClockRegion cr) { Tile center = cr.getApproximateCenter(); int searchGridDim = 0; - outer: while (!center.getName().startsWith("RCLK_")) { + outer: + while (!center.getName().startsWith("RCLK_")) { searchGridDim++; for (int row = -searchGridDim; row < searchGridDim; row++) { for (int col = -searchGridDim; col < searchGridDim; col++) { diff --git a/src/com/xilinx/rapidwright/design/tools/ArrayNetlistGraph.java b/src/com/xilinx/rapidwright/design/tools/ArrayNetlistGraph.java new file mode 100644 index 000000000..47555e40f --- /dev/null +++ b/src/com/xilinx/rapidwright/design/tools/ArrayNetlistGraph.java @@ -0,0 +1,378 @@ +/* + * + * Copyright (c) 2025, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Author: Andrew Butt, AMD Advanced Research and Development. + * + * This file is part of RapidWright. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.xilinx.rapidwright.design.tools; + +import com.google.ortools.Loader; +import com.google.ortools.sat.CpModel; +import com.google.ortools.sat.CpSolver; +import com.google.ortools.sat.CpSolverStatus; +import com.google.ortools.sat.IntVar; +import com.google.ortools.sat.LinearExpr; +import com.google.ortools.sat.LinearExprBuilder; +import com.google.ortools.sat.Literal; +import com.xilinx.rapidwright.design.Design; +import com.xilinx.rapidwright.edif.EDIFHierCellInst; +import com.xilinx.rapidwright.edif.EDIFHierPortInst; +import com.xilinx.rapidwright.util.Pair; +import org.jgrapht.Graph; +import org.jgrapht.GraphPath; +import org.jgrapht.alg.cycle.CycleDetector; +import org.jgrapht.alg.shortestpath.DijkstraShortestPath; +import org.jgrapht.graph.DefaultDirectedGraph; +import org.jgrapht.graph.DefaultEdge; +import org.jgrapht.traverse.TopologicalOrderIterator; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class ArrayNetlistGraph { + Graph graph; + + public ArrayNetlistGraph() { + graph = new DefaultDirectedGraph<>(DefaultEdge.class); + } + + public ArrayNetlistGraph(Design array, List modules) { + this(); + EDIFHierCellInst top = array.getNetlist().getTopHierCellInst(); + for (String module : modules) { + addVertex(module); + } + + for (String module : modules) { + EDIFHierCellInst cellInst = array.getNetlist().getHierCellInstFromName(module); + for (EDIFHierPortInst portInst : cellInst.getHierPortInsts()) { + if (portInst.isOutput()) { + for (EDIFHierPortInst netPortInst : portInst.getHierarchicalNet().getPortInsts()) { + if (!netPortInst.equals(portInst) && netPortInst.getCellType() != null) { + EDIFHierCellInst destCellInst = netPortInst.getFullHierarchicalInst(); + if (destCellInst != null && containsNode(destCellInst.getFullHierarchicalInstName())) { + addEdge(cellInst.getFullHierarchicalInstName(), + destCellInst.getFullHierarchicalInstName()); + } + } + } + } + } + } + } + + public void addVertex(String name) { + graph.addVertex(name); + } + + public boolean containsNode(String name) { + return graph.containsVertex(name); + } + + public void addEdge(String from, String to) { + graph.addEdge(from, to); + } + + public boolean isAcyclic() { + CycleDetector cycleDetector = new CycleDetector<>(graph); + return !cycleDetector.detectCycles(); + } + + public Iterator getTopologicalOrderIterator() { + return new TopologicalOrderIterator<>(graph); + } + + public Map, String> getGreedyPlacementGrid() { + Map, String> placementMap = new HashMap<>(); + Map> reversePlacementMap = new HashMap<>(); + Map candidateMap = new HashMap<>(); + Iterator iterator = getTopologicalOrderIterator(); + DijkstraShortestPath dsp = new DijkstraShortestPath<>(graph); + String topLeftNode = iterator.next(); + placementMap.put(new Pair<>(0, 0), topLeftNode); + reversePlacementMap.put(topLeftNode, new Pair<>(0, 0)); + for (DefaultEdge edge : graph.outgoingEdgesOf(topLeftNode)) { + String node = graph.getEdgeTarget(edge); + candidateMap.put(node, 1); + } + // TODO: Generalize + String extraConstraintNode = "u_systolic_array/x[1].y[0].u_tile"; + candidateMap.remove(extraConstraintNode); + placementMap.put(new Pair<>(1, 0), extraConstraintNode); + reversePlacementMap.put(extraConstraintNode, new Pair<>(1, 0)); + for (DefaultEdge edge : graph.outgoingEdgesOf(extraConstraintNode)) { + String targetNode = graph.getEdgeTarget(edge); + int count = candidateMap.computeIfAbsent(targetNode, (n) -> 0); + candidateMap.put(targetNode, count + 1); + } + // END TODO + while (!candidateMap.isEmpty()) { + List sortedCandidates = candidateMap.entrySet().stream() + .sorted((e1, e2) -> { + if (e1.getValue() == e2.getValue()) { + // Tie-break of shorted path distance + GraphPath shortestPathE1 = dsp.getPath(topLeftNode, e1.getKey()); + GraphPath shortestPathE2 = dsp.getPath(topLeftNode, e2.getKey()); + return shortestPathE1.getLength() - shortestPathE2.getLength(); + } + return e2.getValue().compareTo(e1.getValue()); + }) + .map(Map.Entry::getKey).collect(Collectors.toList()); + String node = sortedCandidates.get(0); + candidateMap.remove(node); + for (DefaultEdge edge : graph.outgoingEdgesOf(node)) { + String targetNode = graph.getEdgeTarget(edge); + int count = candidateMap.computeIfAbsent(targetNode, (n) -> 0); + candidateMap.put(targetNode, count + 1); + } + Set inEdges = graph.incomingEdgesOf(node); + List inNeighbors = new ArrayList<>(); + for (DefaultEdge e : inEdges) { + inNeighbors.add(graph.getEdgeSource(e)); + } + if (inNeighbors.size() > 3) { + throw new RuntimeException("Greedy placement does not work for given netlist"); + } + List> inNeighborPlacements = new ArrayList<>(); + for (String inNeighbor : inNeighbors) { + inNeighborPlacements.add(reversePlacementMap.get(inNeighbor)); + } + inNeighborPlacements = inNeighborPlacements.stream().sorted( + (p1, p2) -> { + if (p1.getSecond().equals(p2.getSecond())) { + return p1.getFirst() - p2.getFirst(); + } + return p1.getSecond() - p2.getSecond(); + }).collect(Collectors.toList()); + List> validPlacements = new ArrayList<>(); + if (inNeighbors.size() == 1) { + Pair neighborPlacement = inNeighborPlacements.get(0); + validPlacements.add(new Pair<>(neighborPlacement.getFirst() + 1, neighborPlacement.getSecond())); + validPlacements.add(new Pair<>(neighborPlacement.getFirst(), neighborPlacement.getSecond() + 1)); + } else if (inNeighbors.size() == 2) { + int x = inNeighborPlacements.get(0).getFirst(); + int y = inNeighborPlacements.get(1).getSecond(); + validPlacements.add(new Pair<>(x, y)); + } else { + throw new RuntimeException("Not yet implemented, try using OR-tools based placement"); + } + Pair placement = null; + for (Pair location : validPlacements) { + if (!placementMap.containsKey(location)) { + placement = location; + } + } + if (placement == null) { + throw new RuntimeException("Could not find valid greedy placement for cell: " + node); + } + placementMap.put(placement, node); + reversePlacementMap.put(node, placement); + } + + for (int y = 0; y < graph.vertexSet().size(); y++) { + for (int x = 0; x < graph.vertexSet().size(); x++) { + if (placementMap.containsKey(new Pair<>(x, y))) { + System.out.println("Placed " + placementMap.get(new Pair<>(x, y)) + " at (" + x + ", " + y + ")"); + } + } + } + + return placementMap; + } + + public Map, String> getOptimalPlacementGrid(int width, int height) { + Map, String> placementMap = new HashMap<>(); + int numNodes = graph.vertexSet().size(); + Map numToNameMap = new HashMap<>(); + Map nameToNumMap = new HashMap<>(); + + int i = 0; + for (String v : graph.vertexSet()) { + numToNameMap.put(i, v); + nameToNumMap.put(v, i); + i++; + } + + Loader.loadNativeLibraries(); + CpModel model = new CpModel(); + Literal[][][] placements = new Literal[numNodes][width][height]; + for (int n = 0; n < numNodes; n++) { + for (int x = 0; x < width; x++) { + for (int y = 0; y < height; y++) { + placements[n][x][y] = model.newBoolVar("placement_n" + n + "x" + x + "y" + y); + } + } + } + + // At most one node can be placed at each grid location + for (int x = 0; x < width; x++) { + for (int y = 0; y < height; y++) { + List nodes = new ArrayList<>(); + for (int n = 0; n < numNodes; n++) { + nodes.add(placements[n][x][y]); + } + model.addAtMostOne(nodes); + } + } + + // Every node must be placed at exactly one location + for (int n = 0; n < numNodes; n++) { + List locations = new ArrayList<>(); + for (int x = 0; x < width; x++) { + locations.addAll(Arrays.asList(placements[n][x]).subList(0, height)); + } + model.addExactlyOne(locations); + } + + // Add auxiliary variables for x and y placement + IntVar[] xPlacement = new IntVar[numNodes]; + IntVar[] yPlacement = new IntVar[numNodes]; + for (int n = 0; n < numNodes; n++) { + xPlacement[n] = model.newIntVar(0, width, "x_loc_n" + n); + yPlacement[n] = model.newIntVar(0, height, "y_loc_n" + n); + LinearExprBuilder xExpr = LinearExpr.newBuilder(); + LinearExprBuilder yExpr = LinearExpr.newBuilder(); + for (int x = 0; x < width; x++) { + for (int y = 0; y < height; y++) { + xExpr.addTerm(placements[n][x][y], x); + yExpr.addTerm(placements[n][x][y], y); + } + } + model.addEquality(xPlacement[n], xExpr); + model.addEquality(yPlacement[n], yExpr); + } + + // Add auxiliary variables for x and y distance between connected nodes + List xDistVars = new ArrayList<>(); + List yDistVars = new ArrayList<>(); + for (String v : graph.vertexSet()) { + Set outEdges = graph.outgoingEdgesOf(v); + int sourceNum = nameToNumMap.get(v); + for (DefaultEdge e : outEdges) { + String edgeTarget = graph.getEdgeTarget(e); + int targetNum = nameToNumMap.get(edgeTarget); + + // x distance variable + IntVar xDistVar = model.newIntVar(0, width, "x_dist_" + v + "_n" + sourceNum + "_to_" + edgeTarget + "_n" + targetNum); + xDistVars.add(xDistVar); + IntVar sourceXVar = xPlacement[sourceNum]; + IntVar targetXVar = xPlacement[targetNum]; + + // Adding both of these constraints is equivalent to xDistVar = abs(sourceX - targetX) + LinearExprBuilder sourceMinusTargetX = LinearExpr.newBuilder(); + sourceMinusTargetX.addTerm(sourceXVar, 1); + sourceMinusTargetX.addTerm(targetXVar, -1); + model.addGreaterOrEqual(xDistVar, sourceMinusTargetX); + + LinearExprBuilder targetMinusSourceX = LinearExpr.newBuilder(); + targetMinusSourceX.addTerm(targetXVar, 1); + targetMinusSourceX.addTerm(sourceXVar, -1); + model.addGreaterOrEqual(xDistVar, targetMinusSourceX); + + // y distance variable + IntVar yDistVar = model.newIntVar(0, width, "y_dist_" + v + "_n" + sourceNum + "_to_" + edgeTarget + "_n" + targetNum); + yDistVars.add(yDistVar); + IntVar sourceYVar = yPlacement[sourceNum]; + IntVar targetYVar = yPlacement[targetNum]; + + // Adding both of these constraints is equivalent to xDistVar = abs(sourceX - targetX) + LinearExprBuilder sourceMinusTargetY = LinearExpr.newBuilder(); + sourceMinusTargetY.addTerm(sourceYVar, 1); + sourceMinusTargetY.addTerm(targetYVar, -1); + model.addGreaterOrEqual(yDistVar, sourceMinusTargetY); + + LinearExprBuilder targetMinusSourceY = LinearExpr.newBuilder(); + targetMinusSourceY.addTerm(targetYVar, 1); + targetMinusSourceY.addTerm(sourceYVar, -1); + model.addGreaterOrEqual(yDistVar, targetMinusSourceY); + + // Neighbors must be adjacent + LinearExprBuilder xDistPlusYDist = LinearExpr.newBuilder(); + xDistPlusYDist.addSum(new IntVar[]{xDistVar, yDistVar}); + model.addLessOrEqual(xDistPlusYDist, 1); +// model.addLessOrEqual(xDistVar, 3); +// model.addLessOrEqual(yDistVar, 3); + } + } + + // Place the anchor in the top left corner + String anchor = getTopologicalOrderIterator().next(); + int anchorNum = nameToNumMap.get(anchor); + model.addAssumption(placements[anchorNum][0][0]); + + IntVar maxXDistVar = model.newIntVar(0, width, "max_x_dist"); + for (IntVar xDistVar : xDistVars) { + model.addGreaterOrEqual(maxXDistVar, xDistVar); + } + IntVar maxYDistVar = model.newIntVar(0, width, "max_y_dist"); + for (IntVar yDistVar : yDistVars) { + model.addGreaterOrEqual(maxYDistVar, yDistVar); + } + LinearExprBuilder obj = LinearExpr.newBuilder(); + obj.add(maxXDistVar); + obj.add(maxYDistVar); + model.minimize(obj); + + // Add objective to minimize manhattan distance +// LinearExprBuilder obj = LinearExpr.newBuilder(); +// for (IntVar xDistVar : xDistVars) { +// obj.addTerm(xDistVar, 1); +// } +// for (IntVar yDistVar : yDistVars) { +// obj.add(yDistVar); +// } +// model.minimize(obj); + + CpSolver solver = new CpSolver(); +// solver.getParameters().setMaxTimeInSeconds(10.0); + CpSolverStatus status = solver.solve(model); + + if (status == CpSolverStatus.FEASIBLE || status == CpSolverStatus.OPTIMAL) { + System.out.println("Solution: " + status); + for (int x = 0; x < width; x++) { + for (int y = 0; y < height; y++) { + for (int n = 0; n < numNodes; n++) { + if (solver.booleanValue(placements[n][x][y])) { + System.out.println("Placed " + numToNameMap.get(n) + " at (" + x + ", " + y + ")"); + placementMap.put(new Pair<>(x, y), numToNameMap.get(n)); + break; + } + } + } + } + } else { + throw new RuntimeException("Failed to find optimal placement grid, solver returned status: " + status); + } + + return placementMap; + } + + @Override + public String toString() { + return graph.toString(); + } +}