Skip to content

Commit 30c873c

Browse files
committed
Add script to generate network flow with multiple commodities and levels
1 parent 26303c0 commit 30c873c

File tree

1 file changed

+299
-0
lines changed

1 file changed

+299
-0
lines changed
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import os
2+
import warnings
3+
4+
import graphviz
5+
import pandas as pd
6+
7+
warnings.filterwarnings("ignore", category=RuntimeWarning)
8+
9+
10+
# --------------------------
11+
# Load or extract I/O (caching logic preserved)
12+
# --------------------------
13+
def load_or_extract_io(scen, model, scenario, node, year, commodities):
14+
in_file = f"{model}_{scenario}_inputs.csv"
15+
out_file = f"{model}_{scenario}_outputs.csv"
16+
dem_file = f"{model}_{scenario}_demand.csv"
17+
18+
if (
19+
os.path.exists(in_file)
20+
and os.path.exists(out_file)
21+
and os.path.exists(dem_file)
22+
):
23+
print(f"Reading cached data: {in_file}, {out_file}, {dem_file}")
24+
df_in = pd.read_csv(in_file)
25+
df_out = pd.read_csv(out_file)
26+
df_dem = pd.read_csv(dem_file)
27+
return df_in, df_out, df_dem
28+
29+
print("Extracting from scenario…")
30+
df_in = scen.par("input", {"node_loc": node, "year_act": year}).copy()
31+
df_out = scen.par("output", {"node_loc": node, "year_act": year}).copy()
32+
df_dem = scen.par("demand", {"node": node, "year": year}).copy()
33+
34+
# Normalize column names
35+
df_in = df_in.rename(columns={"commodity": "comm_in", "level": "level_in"})
36+
df_out = df_out.rename(columns={"commodity": "comm_out", "level": "level_out"})
37+
df_dem = df_dem.rename(columns={"commodity": "comm", "level": "level"})
38+
39+
# Filter to commodities of interest
40+
if commodities:
41+
df_in = df_in[df_in["comm_in"].isin(commodities)].copy()
42+
df_out = df_out[df_out["comm_out"].isin(commodities)].copy()
43+
df_dem = df_dem[df_dem["comm"].isin(commodities)].copy()
44+
45+
# Save to CSV
46+
df_in.to_csv(in_file, index=False)
47+
df_out.to_csv(out_file, index=False)
48+
df_dem.to_csv(dem_file, index=False)
49+
print(f"Saved to {in_file}, {out_file}, {dem_file}")
50+
51+
return df_in, df_out, df_dem
52+
53+
54+
# --------------------------
55+
# Small helper: enforce comm_* and level_* names (safe-guard)
56+
# --------------------------
57+
def ensure_comm_level_cols(df_in, df_out, df_dem=None):
58+
if "comm_in" not in df_in.columns:
59+
if "commodity" in df_in.columns:
60+
df_in = df_in.rename(columns={"commodity": "comm_in"})
61+
else:
62+
df_in["comm_in"] = ""
63+
if "level_in" not in df_in.columns:
64+
if "level" in df_in.columns:
65+
df_in = df_in.rename(columns={"level": "level_in"})
66+
else:
67+
df_in["level_in"] = ""
68+
69+
if "comm_out" not in df_out.columns:
70+
if "commodity" in df_out.columns:
71+
df_out = df_out.rename(columns={"commodity": "comm_out"})
72+
else:
73+
df_out["comm_out"] = ""
74+
if "level_out" not in df_out.columns:
75+
if "level" in df_out.columns:
76+
df_out = df_out.rename(columns={"level": "level_out"})
77+
else:
78+
df_out["level_out"] = ""
79+
80+
if df_dem is not None:
81+
if "comm" not in df_dem.columns and "commodity" in df_dem.columns:
82+
df_dem = df_dem.rename(columns={"commodity": "comm"})
83+
if "level" not in df_dem.columns:
84+
df_dem["level"] = ""
85+
86+
return df_in, df_out, df_dem
87+
88+
89+
# --------------------------
90+
# Update Graphviz plotting
91+
# --------------------------
92+
def plot_flows_graphviz(df_in, df_out, df_dem, model, scenario, commodities):
93+
dot = graphviz.Digraph(comment=f"{model} {scenario} flows", format="png")
94+
dot.attr(rankdir="LR", splines="ortho")
95+
96+
# Set of demand commodity.level
97+
demand_comms = (
98+
{f"{r.comm}.{r.level}" for r in df_dem.itertuples()}
99+
if not df_dem.empty
100+
else set()
101+
)
102+
103+
# --- Build sets of techs connected to main commodities ---
104+
techs_in = df_in[df_in["comm_in"].isin(commodities)]["technology"].unique()
105+
techs_out = df_out[df_out["comm_out"].isin(commodities)]["technology"].unique()
106+
techs = set(techs_in).union(techs_out)
107+
108+
# --- Main commodities actually connected ---
109+
main_in = df_in[df_in["technology"].isin(techs)][["comm_in", "level_in"]]
110+
main_out = df_out[df_out["technology"].isin(techs)][["comm_out", "level_out"]]
111+
112+
main_comms = {
113+
f"{r.comm_in}.{r.level_in}"
114+
for r in main_in.itertuples()
115+
if r.comm_in in commodities
116+
} | {
117+
f"{r.comm_out}.{r.level_out}"
118+
for r in main_out.itertuples()
119+
if r.comm_out in commodities
120+
}
121+
122+
# --- Extra commodities connected to those techs ---
123+
extra_in = {
124+
f"{r.comm_in}.{r.level_in}"
125+
for r in main_in.itertuples()
126+
if r.comm_in not in commodities
127+
}
128+
extra_out = {
129+
f"{r.comm_out}.{r.level_out}"
130+
for r in main_out.itertuples()
131+
if r.comm_out not in commodities
132+
}
133+
extra_comms = extra_in | extra_out
134+
135+
# --- Add commodity nodes ---
136+
for c in main_comms:
137+
if c in demand_comms:
138+
dot.node(
139+
c,
140+
label=c,
141+
shape="ellipse",
142+
style="filled",
143+
fillcolor="violet",
144+
color="black",
145+
fontcolor="black",
146+
)
147+
else:
148+
dot.node(
149+
c,
150+
label=c,
151+
shape="ellipse",
152+
style="filled",
153+
fillcolor="lightgray",
154+
color="black",
155+
fontcolor="black",
156+
)
157+
158+
for c in extra_comms:
159+
dot.node(c, label=c, shape="ellipse", color="red", fontcolor="red")
160+
161+
# --- Technology nodes ---
162+
for t in techs:
163+
dot.node(t, shape="box", style="rounded,filled", fillcolor="lightblue")
164+
165+
# --- Deduplicate edges by (src, dst, type) ---
166+
edges_seen = {}
167+
for _, row in df_in[df_in["technology"].isin(techs)].iterrows():
168+
src = f"{row.comm_in}.{row.level_in}"
169+
key = (src, row.technology, "in")
170+
year_vtg = row.get("year_vtg", row.year_act)
171+
edges_seen.setdefault(key, []).append(year_vtg)
172+
173+
for _, row in df_out[df_out["technology"].isin(techs)].iterrows():
174+
dst = f"{row.comm_out}.{row.level_out}"
175+
key = (row.technology, dst, "out")
176+
year_vtg = row.get("year_vtg", row.year_act)
177+
edges_seen.setdefault(key, []).append(year_vtg)
178+
179+
# --- Draw edges ---
180+
for (src, dst, typ), years in edges_seen.items():
181+
latest = max(years)
182+
multiple = len(set(years)) > 1
183+
style = "dashed" if multiple else "solid"
184+
color = "red" if (src in extra_comms or dst in extra_comms) else "black"
185+
dot.edge(src, dst, style=style, color=color)
186+
187+
# --- Legend ---
188+
with dot.subgraph(name="cluster_legend") as c:
189+
c.attr(label="Legend", fontsize="10", rankdir="LR")
190+
c.node("solid_arrow", label="single vintage", shape="plaintext")
191+
c.edge("solid_arrow", "dashed_arrow", style="solid")
192+
c.node("dashed_arrow", label="multiple vintages", shape="plaintext")
193+
c.edge("solid_arrow", "dashed_arrow", style="dashed")
194+
195+
# Region/year info
196+
c.node("region_year", label=f"Region: {NODE}, Year: {YEAR}", shape="plaintext")
197+
198+
# Boxes meaning
199+
c.node(
200+
"tech_box",
201+
label="Technology",
202+
shape="box",
203+
style="rounded,filled",
204+
fillcolor="lightblue",
205+
)
206+
c.node(
207+
"comm_grey",
208+
label="Commodity",
209+
shape="ellipse",
210+
style="filled",
211+
fillcolor="lightgray",
212+
)
213+
c.node(
214+
"comm_violet",
215+
label="Demand commodity",
216+
shape="ellipse",
217+
style="filled",
218+
fillcolor="violet",
219+
)
220+
# c.node(
221+
# "comm_red",
222+
# label="Extra linked commodity",
223+
# shape="ellipse",
224+
# color="red",
225+
# fontcolor="red",
226+
# )
227+
228+
# --- Save ---
229+
png_name = f"{model}_{scenario}_ascii_flows"
230+
svg_name = f"{model}_{scenario}_ascii_flows"
231+
232+
dot.format = "png"
233+
dot.render(filename=png_name, cleanup=True)
234+
dot.format = "svg"
235+
dot.render(filename=svg_name, cleanup=True)
236+
237+
print(f"Saved Graphviz PNG: {png_name}.png, SVG: {svg_name}.svg")
238+
239+
240+
# --------------------------
241+
# Main
242+
# --------------------------
243+
if __name__ == "__main__":
244+
"""
245+
Note: need to delete the csv files if you are re-running with different
246+
commodities
247+
"""
248+
# ==== User-editable configuration ====
249+
PLATFORM_NAME = "<database_name>" # only used if CSV cache missing
250+
MODEL = "<model_name>" # e.g. "MESSAGEix-Nexus"
251+
SCENARIO = "<scenario_name>"
252+
NODE = "R12_CHN"
253+
YEAR = 2050
254+
# Define the commodities you want (only these will be kept)
255+
COMMODITIES = [
256+
"electr",
257+
"coal",
258+
]
259+
# =====================================
260+
261+
in_file = f"{MODEL}_{SCENARIO}_inputs.csv"
262+
out_file = f"{MODEL}_{SCENARIO}_outputs.csv"
263+
dem_file = f"{MODEL}_{SCENARIO}_demand.csv"
264+
265+
# If cached CSVs exist, load without connecting to ixmp
266+
if (
267+
os.path.exists(in_file)
268+
and os.path.exists(out_file)
269+
and os.path.exists(dem_file)
270+
):
271+
print("Loading cached CSVs…")
272+
df_in = pd.read_csv(in_file)
273+
df_out = pd.read_csv(out_file)
274+
df_dem = pd.read_csv(dem_file)
275+
# Normalize column names (old CSVs may be different)
276+
df_in, df_out, df_dem = ensure_comm_level_cols(df_in, df_out, df_dem)
277+
# Apply commodity filter again (defensive)
278+
if COMMODITIES:
279+
df_in = df_in[df_in["comm_in"].isin(COMMODITIES)].copy()
280+
df_out = df_out[df_out["comm_out"].isin(COMMODITIES)].copy()
281+
else:
282+
# Need to connect and extract
283+
import ixmp
284+
285+
import message_ix
286+
287+
mp = ixmp.Platform(PLATFORM_NAME)
288+
scen = message_ix.Scenario(mp, model=MODEL, scenario=SCENARIO)
289+
df_in, df_out, df_dem = load_or_extract_io(
290+
scen, MODEL, SCENARIO, NODE, YEAR, COMMODITIES
291+
)
292+
# Explicit cleanup
293+
mp.close_db()
294+
295+
# Final safety normalization (ensure col names exist)
296+
df_in, df_out, df_dem = ensure_comm_level_cols(df_in, df_out, df_dem)
297+
298+
# Plots
299+
plot_flows_graphviz(df_in, df_out, df_dem, MODEL, SCENARIO, COMMODITIES)

0 commit comments

Comments
 (0)