Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Net-to-CoreML Conversion Script #222

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tf/net_to_coreml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
import os
from net_to_model import convert
import coremltools as ct

if __name__ == "__main__":
##############
# NET TO MODEL
args, root_dir, tfp = convert(include_attn_wts_output=False, rescale_rule50=False)

#################
# MODEL TO COREML
input_shape = ct.Shape(shape=(1, 112, 8, 8))

# Set the compute precision
compute_precision = ct.precision.FLOAT16
# compute_precision = ct.precision.FLOAT32

# Convert the model to CoreML
coreml_model = ct.convert(
tfp.model,
convert_to="mlprogram",
inputs=[ct.TensorType(shape=input_shape, name="input_1")],
compute_precision=compute_precision,
)

# Get the protobuf spec
spec = coreml_model._spec

# Rename the input
ct.utils.rename_feature(spec, "input_1", "input_planes")

# Get input names
input_names = [input.name for input in spec.description.input]

# Print the input names
print(f"Renamed input: {input_names}")

# Set output names
output_names = ["output_policy", "output_value"]

if tfp.moves_left:
output_names.append("output_moves_left")

# Rename output names
for i, name in enumerate(output_names):
# Rename the output
ct.utils.rename_feature(spec, spec.description.output[i].name, name)

# Print the output names
print(f"Renamed output: {[output_i.name for output_i in spec.description.output]}")

# Set model description
coreml_model.short_description = f"Lc0 converted from {args.net}"

# Rebuild the model with the updated spec
print(f"Rebuilding model with updated spec ...")
rebuilt_mlmodel = ct.models.MLModel(
coreml_model._spec, weights_dir=coreml_model._weights_dir
)

# Save the CoreML model
print(f"Saving model ...")
coreml_model_path = os.path.join(root_dir, f"{args.net}.mlpackage")
coreml_model.save(coreml_model_path)

print(f"CoreML model saved at {coreml_model_path}")
61 changes: 33 additions & 28 deletions tf/net_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,38 @@
import yaml
import tfprocess

argparser = argparse.ArgumentParser(description='Convert net to model.')
argparser.add_argument('net',
type=str,
help='Net file to be converted to a model checkpoint.')
argparser.add_argument('--start',
type=int,
default=0,
help='Offset to set global_step to.')
argparser.add_argument('--cfg',
type=argparse.FileType('r'),
help='yaml configuration with training parameters')
argparser.add_argument('-e',
'--ignore-errors',
action='store_true',
help='Ignore missing and wrong sized values.')
args = argparser.parse_args()
cfg = yaml.safe_load(args.cfg.read())
print(yaml.dump(cfg, default_flow_style=False))
START_FROM = args.start
def convert(include_attn_wts_output=True, rescale_rule50=True):
argparser = argparse.ArgumentParser(description='Convert net to model.')
argparser.add_argument('net',
type=str,
help='Net file to be converted to a model checkpoint.')
argparser.add_argument('--start',
type=int,
default=0,
help='Offset to set global_step to.')
argparser.add_argument('--cfg',
type=argparse.FileType('r'),
help='yaml configuration with training parameters')
argparser.add_argument('-e',
'--ignore-errors',
action='store_true',
help='Ignore missing and wrong sized values.')
args = argparser.parse_args()
cfg = yaml.safe_load(args.cfg.read())
print(yaml.dump(cfg, default_flow_style=False))
START_FROM = args.start

tfp = tfprocess.TFProcess(cfg)
tfp.init_net()
tfp.replace_weights(args.net, args.ignore_errors)
tfp.global_step.assign(START_FROM)
tfp = tfprocess.TFProcess(cfg)
tfp.init_net(include_attn_wts_output)
tfp.replace_weights(args.net, args.ignore_errors, rescale_rule50)
tfp.global_step.assign(START_FROM)

root_dir = os.path.join(cfg['training']['path'], cfg['name'])
if not os.path.exists(root_dir):
os.makedirs(root_dir)
tfp.manager.save(checkpoint_number=START_FROM)
print("Wrote model to {}".format(tfp.manager.latest_checkpoint))
root_dir = os.path.join(cfg['training']['path'], cfg['name'])
if not os.path.exists(root_dir):
os.makedirs(root_dir)
tfp.manager.save(checkpoint_number=START_FROM)
print("Wrote model to {}".format(tfp.manager.latest_checkpoint))
return args, root_dir, tfp

if __name__ == "__main__":
convert()
16 changes: 9 additions & 7 deletions tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,10 @@ def init(self, train_dataset, test_dataset, validation_dataset=None):
else:
self.init_net()

def init_net(self):
def init_net(self, include_attn_wts_output=True):
self.l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001))
input_var = tf.keras.Input(shape=(112, 8, 8))
outputs = self.construct_net(input_var)
outputs = self.construct_net(input_var, include_attn_wts_output=include_attn_wts_output)
self.model = tf.keras.Model(inputs=input_var, outputs=outputs)

# swa_count initialized regardless to make checkpoint code simpler.
Expand Down Expand Up @@ -628,7 +628,7 @@ def accuracy(target, output):
keep_checkpoint_every_n_hours=24,
checkpoint_name=self.cfg['name'])

def replace_weights(self, proto_filename, ignore_errors=False):
def replace_weights(self, proto_filename, ignore_errors=False, rescale_rule50=True):
self.net.parse_proto(proto_filename)

filters, blocks = self.net.filters(), self.net.blocks()
Expand Down Expand Up @@ -676,7 +676,7 @@ def replace_weights(self, proto_filename, ignore_errors=False):

if weight.shape.ndims == 4:
# Rescale rule50 related weights as clients do not normalize the input.
if weight.name == 'input/conv2d/kernel:0' and self.net.pb.format.network_format.input < pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES:
if rescale_rule50 and weight.name == 'input/conv2d/kernel:0' and self.net.pb.format.network_format.input < pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES:
num_inputs = 112
# 50 move rule is the 110th input, or 109 starting from 0.
rule50_input = 109
Expand Down Expand Up @@ -1520,7 +1520,7 @@ def apply_promotion_logits(self, queries, keys, attn_wts):
h_fc1 = ApplyAttentionPolicyMap()(policy_attn_logits, promotion_logits)
return h_fc1

def construct_net(self, inputs, name=''):
def construct_net(self, inputs, name='', include_attn_wts_output=True):

if self.encoder_layers > 0:
flow, attn_wts = self.create_encoder_body(inputs,
Expand Down Expand Up @@ -1665,9 +1665,11 @@ def construct_net(self, inputs, name=''):
# attention weights added as optional output for analysis -- ignored by backend
if self.POLICY_HEAD == pb.NetworkFormat.POLICY_ATTENTION:
if self.moves_left:
outputs = [h_fc1, h_fc3, h_fc5, attn_wts]
outputs = [h_fc1, h_fc3, h_fc5]
else:
outputs = [h_fc1, h_fc3, attn_wts]
outputs = [h_fc1, h_fc3]
if include_attn_wts_output:
outputs.append(attn_wts)
elif self.moves_left:
outputs = [h_fc1, h_fc3, h_fc5]
else:
Expand Down