|
| 1 | +import os |
| 2 | +import warnings |
| 3 | + |
| 4 | +import graphviz |
| 5 | +import pandas as pd |
| 6 | + |
| 7 | +warnings.filterwarnings("ignore", category=RuntimeWarning) |
| 8 | + |
| 9 | + |
| 10 | +# -------------------------- |
| 11 | +# Load or extract I/O (caching logic preserved) |
| 12 | +# -------------------------- |
| 13 | +def load_or_extract_io(scen, model, scenario, node, year, commodities): |
| 14 | + in_file = f"{model}_{scenario}_inputs.csv" |
| 15 | + out_file = f"{model}_{scenario}_outputs.csv" |
| 16 | + dem_file = f"{model}_{scenario}_demand.csv" |
| 17 | + |
| 18 | + if ( |
| 19 | + os.path.exists(in_file) |
| 20 | + and os.path.exists(out_file) |
| 21 | + and os.path.exists(dem_file) |
| 22 | + ): |
| 23 | + print(f"Reading cached data: {in_file}, {out_file}, {dem_file}") |
| 24 | + df_in = pd.read_csv(in_file) |
| 25 | + df_out = pd.read_csv(out_file) |
| 26 | + df_dem = pd.read_csv(dem_file) |
| 27 | + return df_in, df_out, df_dem |
| 28 | + |
| 29 | + print("Extracting from scenario…") |
| 30 | + df_in = scen.par("input", {"node_loc": node, "year_act": year}).copy() |
| 31 | + df_out = scen.par("output", {"node_loc": node, "year_act": year}).copy() |
| 32 | + df_dem = scen.par("demand", {"node": node, "year": year}).copy() |
| 33 | + |
| 34 | + # Normalize column names |
| 35 | + df_in = df_in.rename(columns={"commodity": "comm_in", "level": "level_in"}) |
| 36 | + df_out = df_out.rename(columns={"commodity": "comm_out", "level": "level_out"}) |
| 37 | + df_dem = df_dem.rename(columns={"commodity": "comm", "level": "level"}) |
| 38 | + |
| 39 | + # Filter to commodities of interest |
| 40 | + if commodities: |
| 41 | + df_in = df_in[df_in["comm_in"].isin(commodities)].copy() |
| 42 | + df_out = df_out[df_out["comm_out"].isin(commodities)].copy() |
| 43 | + df_dem = df_dem[df_dem["comm"].isin(commodities)].copy() |
| 44 | + |
| 45 | + # Save to CSV |
| 46 | + df_in.to_csv(in_file, index=False) |
| 47 | + df_out.to_csv(out_file, index=False) |
| 48 | + df_dem.to_csv(dem_file, index=False) |
| 49 | + print(f"Saved to {in_file}, {out_file}, {dem_file}") |
| 50 | + |
| 51 | + return df_in, df_out, df_dem |
| 52 | + |
| 53 | + |
| 54 | +# -------------------------- |
| 55 | +# Small helper: enforce comm_* and level_* names (safe-guard) |
| 56 | +# -------------------------- |
| 57 | +def ensure_comm_level_cols(df_in, df_out, df_dem=None): |
| 58 | + if "comm_in" not in df_in.columns: |
| 59 | + if "commodity" in df_in.columns: |
| 60 | + df_in = df_in.rename(columns={"commodity": "comm_in"}) |
| 61 | + else: |
| 62 | + df_in["comm_in"] = "" |
| 63 | + if "level_in" not in df_in.columns: |
| 64 | + if "level" in df_in.columns: |
| 65 | + df_in = df_in.rename(columns={"level": "level_in"}) |
| 66 | + else: |
| 67 | + df_in["level_in"] = "" |
| 68 | + |
| 69 | + if "comm_out" not in df_out.columns: |
| 70 | + if "commodity" in df_out.columns: |
| 71 | + df_out = df_out.rename(columns={"commodity": "comm_out"}) |
| 72 | + else: |
| 73 | + df_out["comm_out"] = "" |
| 74 | + if "level_out" not in df_out.columns: |
| 75 | + if "level" in df_out.columns: |
| 76 | + df_out = df_out.rename(columns={"level": "level_out"}) |
| 77 | + else: |
| 78 | + df_out["level_out"] = "" |
| 79 | + |
| 80 | + if df_dem is not None: |
| 81 | + if "comm" not in df_dem.columns and "commodity" in df_dem.columns: |
| 82 | + df_dem = df_dem.rename(columns={"commodity": "comm"}) |
| 83 | + if "level" not in df_dem.columns: |
| 84 | + df_dem["level"] = "" |
| 85 | + |
| 86 | + return df_in, df_out, df_dem |
| 87 | + |
| 88 | + |
| 89 | +# -------------------------- |
| 90 | +# Update Graphviz plotting |
| 91 | +# -------------------------- |
| 92 | +def plot_flows_graphviz(df_in, df_out, df_dem, model, scenario, commodities): |
| 93 | + dot = graphviz.Digraph(comment=f"{model} {scenario} flows", format="png") |
| 94 | + dot.attr(rankdir="LR", splines="ortho") |
| 95 | + |
| 96 | + # Set of demand commodity.level |
| 97 | + demand_comms = ( |
| 98 | + {f"{r.comm}.{r.level}" for r in df_dem.itertuples()} |
| 99 | + if not df_dem.empty |
| 100 | + else set() |
| 101 | + ) |
| 102 | + |
| 103 | + # --- Build sets of techs connected to main commodities --- |
| 104 | + techs_in = df_in[df_in["comm_in"].isin(commodities)]["technology"].unique() |
| 105 | + techs_out = df_out[df_out["comm_out"].isin(commodities)]["technology"].unique() |
| 106 | + techs = set(techs_in).union(techs_out) |
| 107 | + |
| 108 | + # --- Main commodities actually connected --- |
| 109 | + main_in = df_in[df_in["technology"].isin(techs)][["comm_in", "level_in"]] |
| 110 | + main_out = df_out[df_out["technology"].isin(techs)][["comm_out", "level_out"]] |
| 111 | + |
| 112 | + main_comms = { |
| 113 | + f"{r.comm_in}.{r.level_in}" |
| 114 | + for r in main_in.itertuples() |
| 115 | + if r.comm_in in commodities |
| 116 | + } | { |
| 117 | + f"{r.comm_out}.{r.level_out}" |
| 118 | + for r in main_out.itertuples() |
| 119 | + if r.comm_out in commodities |
| 120 | + } |
| 121 | + |
| 122 | + # --- Extra commodities connected to those techs --- |
| 123 | + extra_in = { |
| 124 | + f"{r.comm_in}.{r.level_in}" |
| 125 | + for r in main_in.itertuples() |
| 126 | + if r.comm_in not in commodities |
| 127 | + } |
| 128 | + extra_out = { |
| 129 | + f"{r.comm_out}.{r.level_out}" |
| 130 | + for r in main_out.itertuples() |
| 131 | + if r.comm_out not in commodities |
| 132 | + } |
| 133 | + extra_comms = extra_in | extra_out |
| 134 | + |
| 135 | + # --- Add commodity nodes --- |
| 136 | + for c in main_comms: |
| 137 | + if c in demand_comms: |
| 138 | + dot.node( |
| 139 | + c, |
| 140 | + label=c, |
| 141 | + shape="ellipse", |
| 142 | + style="filled", |
| 143 | + fillcolor="violet", |
| 144 | + color="black", |
| 145 | + fontcolor="black", |
| 146 | + ) |
| 147 | + else: |
| 148 | + dot.node( |
| 149 | + c, |
| 150 | + label=c, |
| 151 | + shape="ellipse", |
| 152 | + style="filled", |
| 153 | + fillcolor="lightgray", |
| 154 | + color="black", |
| 155 | + fontcolor="black", |
| 156 | + ) |
| 157 | + |
| 158 | + for c in extra_comms: |
| 159 | + dot.node(c, label=c, shape="ellipse", color="red", fontcolor="red") |
| 160 | + |
| 161 | + # --- Technology nodes --- |
| 162 | + for t in techs: |
| 163 | + dot.node(t, shape="box", style="rounded,filled", fillcolor="lightblue") |
| 164 | + |
| 165 | + # --- Deduplicate edges by (src, dst, type) --- |
| 166 | + edges_seen = {} |
| 167 | + for _, row in df_in[df_in["technology"].isin(techs)].iterrows(): |
| 168 | + src = f"{row.comm_in}.{row.level_in}" |
| 169 | + key = (src, row.technology, "in") |
| 170 | + year_vtg = row.get("year_vtg", row.year_act) |
| 171 | + edges_seen.setdefault(key, []).append(year_vtg) |
| 172 | + |
| 173 | + for _, row in df_out[df_out["technology"].isin(techs)].iterrows(): |
| 174 | + dst = f"{row.comm_out}.{row.level_out}" |
| 175 | + key = (row.technology, dst, "out") |
| 176 | + year_vtg = row.get("year_vtg", row.year_act) |
| 177 | + edges_seen.setdefault(key, []).append(year_vtg) |
| 178 | + |
| 179 | + # --- Draw edges --- |
| 180 | + for (src, dst, typ), years in edges_seen.items(): |
| 181 | + latest = max(years) |
| 182 | + multiple = len(set(years)) > 1 |
| 183 | + style = "dashed" if multiple else "solid" |
| 184 | + color = "red" if (src in extra_comms or dst in extra_comms) else "black" |
| 185 | + dot.edge(src, dst, style=style, color=color) |
| 186 | + |
| 187 | + # --- Legend --- |
| 188 | + with dot.subgraph(name="cluster_legend") as c: |
| 189 | + c.attr(label="Legend", fontsize="10", rankdir="LR") |
| 190 | + c.node("solid_arrow", label="single vintage", shape="plaintext") |
| 191 | + c.edge("solid_arrow", "dashed_arrow", style="solid") |
| 192 | + c.node("dashed_arrow", label="multiple vintages", shape="plaintext") |
| 193 | + c.edge("solid_arrow", "dashed_arrow", style="dashed") |
| 194 | + |
| 195 | + # Region/year info |
| 196 | + c.node("region_year", label=f"Region: {NODE}, Year: {YEAR}", shape="plaintext") |
| 197 | + |
| 198 | + # Boxes meaning |
| 199 | + c.node( |
| 200 | + "tech_box", |
| 201 | + label="Technology", |
| 202 | + shape="box", |
| 203 | + style="rounded,filled", |
| 204 | + fillcolor="lightblue", |
| 205 | + ) |
| 206 | + c.node( |
| 207 | + "comm_grey", |
| 208 | + label="Commodity", |
| 209 | + shape="ellipse", |
| 210 | + style="filled", |
| 211 | + fillcolor="lightgray", |
| 212 | + ) |
| 213 | + c.node( |
| 214 | + "comm_violet", |
| 215 | + label="Demand commodity", |
| 216 | + shape="ellipse", |
| 217 | + style="filled", |
| 218 | + fillcolor="violet", |
| 219 | + ) |
| 220 | + # c.node( |
| 221 | + # "comm_red", |
| 222 | + # label="Extra linked commodity", |
| 223 | + # shape="ellipse", |
| 224 | + # color="red", |
| 225 | + # fontcolor="red", |
| 226 | + # ) |
| 227 | + |
| 228 | + # --- Save --- |
| 229 | + png_name = f"{model}_{scenario}_ascii_flows" |
| 230 | + svg_name = f"{model}_{scenario}_ascii_flows" |
| 231 | + |
| 232 | + dot.format = "png" |
| 233 | + dot.render(filename=png_name, cleanup=True) |
| 234 | + dot.format = "svg" |
| 235 | + dot.render(filename=svg_name, cleanup=True) |
| 236 | + |
| 237 | + print(f"Saved Graphviz PNG: {png_name}.png, SVG: {svg_name}.svg") |
| 238 | + |
| 239 | + |
| 240 | +# -------------------------- |
| 241 | +# Main |
| 242 | +# -------------------------- |
| 243 | +if __name__ == "__main__": |
| 244 | + """ |
| 245 | + Note: need to delete the csv files if you are re-running with different |
| 246 | + commodities |
| 247 | + """ |
| 248 | + # ==== User-editable configuration ==== |
| 249 | + PLATFORM_NAME = "<database_name>" # only used if CSV cache missing |
| 250 | + MODEL = "<model_name>" # e.g. "MESSAGEix-Nexus" |
| 251 | + SCENARIO = "<scenario_name>" |
| 252 | + NODE = "R12_CHN" |
| 253 | + YEAR = 2050 |
| 254 | + # Define the commodities you want (only these will be kept) |
| 255 | + COMMODITIES = [ |
| 256 | + "electr", |
| 257 | + "coal", |
| 258 | + ] |
| 259 | + # ===================================== |
| 260 | + |
| 261 | + in_file = f"{MODEL}_{SCENARIO}_inputs.csv" |
| 262 | + out_file = f"{MODEL}_{SCENARIO}_outputs.csv" |
| 263 | + dem_file = f"{MODEL}_{SCENARIO}_demand.csv" |
| 264 | + |
| 265 | + # If cached CSVs exist, load without connecting to ixmp |
| 266 | + if ( |
| 267 | + os.path.exists(in_file) |
| 268 | + and os.path.exists(out_file) |
| 269 | + and os.path.exists(dem_file) |
| 270 | + ): |
| 271 | + print("Loading cached CSVs…") |
| 272 | + df_in = pd.read_csv(in_file) |
| 273 | + df_out = pd.read_csv(out_file) |
| 274 | + df_dem = pd.read_csv(dem_file) |
| 275 | + # Normalize column names (old CSVs may be different) |
| 276 | + df_in, df_out, df_dem = ensure_comm_level_cols(df_in, df_out, df_dem) |
| 277 | + # Apply commodity filter again (defensive) |
| 278 | + if COMMODITIES: |
| 279 | + df_in = df_in[df_in["comm_in"].isin(COMMODITIES)].copy() |
| 280 | + df_out = df_out[df_out["comm_out"].isin(COMMODITIES)].copy() |
| 281 | + else: |
| 282 | + # Need to connect and extract |
| 283 | + import ixmp |
| 284 | + |
| 285 | + import message_ix |
| 286 | + |
| 287 | + mp = ixmp.Platform(PLATFORM_NAME) |
| 288 | + scen = message_ix.Scenario(mp, model=MODEL, scenario=SCENARIO) |
| 289 | + df_in, df_out, df_dem = load_or_extract_io( |
| 290 | + scen, MODEL, SCENARIO, NODE, YEAR, COMMODITIES |
| 291 | + ) |
| 292 | + # Explicit cleanup |
| 293 | + mp.close_db() |
| 294 | + |
| 295 | + # Final safety normalization (ensure col names exist) |
| 296 | + df_in, df_out, df_dem = ensure_comm_level_cols(df_in, df_out, df_dem) |
| 297 | + |
| 298 | + # Plots |
| 299 | + plot_flows_graphviz(df_in, df_out, df_dem, MODEL, SCENARIO, COMMODITIES) |
0 commit comments