Skip to content

Commit

Permalink
Merge pull request #15 from cms-ml/feature/forward_more_aot_flags
Browse files Browse the repository at this point in the history
Allow passing arbitrary flags to the aot compiler.
  • Loading branch information
valsdav authored Mar 27, 2024
2 parents 63ae87e + 66ed012 commit 0ab4f98
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions cmsml/scripts/compile_tf_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import os

from cmsml.util import colored, interruptable_popen
from cmsml.util import colored, interruptable_popen, make_list
from cmsml.tensorflow.tools import import_tf, load_model


Expand All @@ -20,8 +20,9 @@ def compile_tf_graph(
output_serving_key: str | None = None,
compile_prefix: str | None = None,
compile_class: str | None = None,
xla_flags: list[str] | None = None,
tf_xla_flags: list[str] | None = None,
xla_flags: list[str] | str | None = None,
tf_xla_flags: list[str] | str | None = None,
additional_flags: list[str] | str | None = None,
) -> None:
"""
For AOT compilation a static memory layout at runtime is required. This function prepares the given input SavedModel
Expand All @@ -38,7 +39,7 @@ def compile_tf_graph(
An optional AOT compilation is initiated if *compile_class* and *compile_prefix* are given. In this case
*compile_prefix* is the file prefix, while *compile_class* is the name of the AOT class within the generated files.
*xla_flags* and *tf_xla_flags* are forwarded to :py:func:`aot_compile`.
*xla_flags*, *tf_xla_flags* and *additional_flags* are forwarded to :py:func:`aot_compile`.
"""
tf = import_tf()[0]

Expand Down Expand Up @@ -102,6 +103,7 @@ def compile_tf_graph(
serving_key=output_serving_key,
xla_flags=xla_flags,
tf_xla_flags=tf_xla_flags,
additional_flags=additional_flags,
)


Expand All @@ -112,8 +114,9 @@ def aot_compile(
class_name: str,
batch_sizes: tuple[int] = (1,),
serving_key: str = r"serving_default_bs{}",
xla_flags: list[str] | None = None,
tf_xla_flags: list[str] | None = None,
xla_flags: list[str] | str | None = None,
tf_xla_flags: list[str] | str | None = None,
additional_flags: list[str] | str | None = None,
) -> None:
"""
Loads the graph from the SavedModel located at *model_path*, extracts the static graph specified by *serving_key*
Expand All @@ -123,7 +126,8 @@ def aot_compile(
header access the AOT-compiled network.
When *xla_flags* and *tf_xla_flags* are given, they are forwarded as comma-separated values to the *XLA_FLAGS*
and *TF_XLA_FLAGS* environment variables, respectively.
and *TF_XLA_FLAGS* environment variables, respectively. *additional_flags* are forwarded as is to the underlying
aot compiler invocation.
"""
# prepare model path
model_path = os.path.abspath(os.path.expandvars(os.path.expanduser(str(model_path))))
Expand All @@ -145,16 +149,21 @@ def aot_compile(
# ammend the env when xla flags were passed
env = os.environ.copy()
if xla_flags:
xla_flags = make_list(xla_flags)
xla_flags_orig = env.get("XLA_FLAGS", "")
if xla_flags_orig:
xla_flags = [xla_flags_orig.rstrip(",")] + xla_flags
env["XLA_FLAGS"] = ",".join(map(str, xla_flags))
if tf_xla_flags:
tf_xla_flags = make_list(tf_xla_flags)
tf_xla_flags_orig = env.get("TF_XLA_FLAGS", "")
if tf_xla_flags_orig:
tf_xla_flags = [tf_xla_flags_orig.rstrip(",")] + tf_xla_flags
env["TF_XLA_FLAGS"] = ",".join(map(str, tf_xla_flags))

# prepare additional flags
additional_flags_str = " ".join(make_list(additional_flags)) if additional_flags else ""

# compile for each batch size
for bs in sorted(set(map(int, batch_sizes))):
cmd = (
Expand All @@ -164,7 +173,8 @@ def aot_compile(
f" --output_prefix {prefix.format(bs)}"
f" --cpp_class {class_name.format(bs)}"
" --tag_set serve"
)
f" {additional_flags_str}"
).strip()

print(f"compiling for batch size {colored(bs, 'magenta')}")
code = interruptable_popen(cmd, executable="/bin/bash", shell=True, cwd=output_path, env=env)[0]
Expand Down Expand Up @@ -222,17 +232,23 @@ def main() -> None:
)
parser.add_argument(
"--output-serving-key",
help=r"serving key pattern for concrete models in --output-path, with {} being replaced by "
help=r"serving key pattern for concrete models in --output-path, with {} being replaced by "
r"the batch size; default: <input_serving_key>__bs{}",
)
parser.add_argument(
"--compile",
"-c",
nargs=2,
help=r"file name prefix and class name of the AOT compiled objects; in both values, {} is "
help=r"file name prefix and class name of the AOT compiled objects; in both values, {} is "
"replaced by the batch size; no AOT compilation is triggered when empty; files will be "
"saved at <output_path>/aot/<prefix>{.h,.o,_metadata.o,_makefile.inc}",
)
parser.add_argument(
"--additional-flags",
"-f",
help="additional, space-separated flags to be passed to the underlying aot compiler invocation; "
"for more info, see 'saved_model_cli --helpfull'",
)

args = parser.parse_args()

Expand All @@ -244,6 +260,7 @@ def main() -> None:
output_serving_key=args.output_serving_key,
compile_prefix=args.compile and args.compile[0],
compile_class=args.compile and args.compile[1],
additional_flags=args.additional_flags,
)


Expand Down

0 comments on commit 0ab4f98

Please sign in to comment.