Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/api_docs/modules/debug_config.html
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ <h2>DebugConfig<a class="headerlink" href="#debugconfig" title="Link to this hea
<span class="gp">&gt;&gt;&gt; </span><span class="n">core_config</span> <span class="o">=</span> <span class="n">mct</span><span class="o">.</span><span class="n">core</span><span class="o">.</span><span class="n">CoreConfig</span><span class="p">(</span><span class="n">debug_config</span><span class="o">=</span><span class="n">debug_config</span><span class="p">)</span>
</pre></div>
</div>
<div class="admonition important">
<p class="admonition-title">Important</p>
<p>If a callback function is configured, the GPTQ data iteration progress bar is disabled and not displayed.</p>
</div>
</dd></dl>

</section>
Expand Down
2 changes: 1 addition & 1 deletion docs/searchindex.js

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@
PROGRESS_INFO_CALLBACK = 'progress_info_callback'
TOTAL_STEP = 'total_step'

PROGRESS_BAR_POSITION = 2
PROGRESS_BAR_POSITION = 1
DEFAULT_TOTAL_STEP = 4
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def set_description(self, description: str):
self.close()
raise

self.pbar.set_description(formatted_description, refresh=False)
self.pbar.update()
self.pbar.n += 1
self.pbar.set_description(formatted_description, refresh=True)

progress_info = {
COMPLETED_COMPONENTS: description,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class DebugConfig:
>>> import model_compression_toolkit as mct
>>> debug_config = mct.core.DebugConfig(progress_info_callback=progress_info_callback)
>>> core_config = mct.core.CoreConfig(debug_config=debug_config)

.. important::
If a callback function is configured, the GPTQ data iteration progress bar is disabled and not displayed.

"""

analyze_similarity: bool = False
Expand Down
1 change: 1 addition & 0 deletions model_compression_toolkit/gptq/common/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self,
self.fw_info = fw_info
self.representative_data_gen_fn = representative_data_gen_fn
self.progress_info_controller = progress_info_controller
self.disable_data_pbar = progress_info_controller is not None

def _get_total_grad_steps():
return get_total_grad_steps(representative_data_gen_fn) * gptq_config.n_epochs
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/gptq/keras/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def micro_training_loop(self,
"""
with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
for _ in epochs_pbar:
with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
with tqdm(self.train_dataloader, position=1, leave=False, disable=self.disable_data_pbar) as data_pbar:
for data in data_pbar:

input_data, distill_loss_weights, reg_weight = data
Expand Down
155 changes: 79 additions & 76 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Callable, Tuple, Union, Optional
from packaging import version
from tqdm.contrib.logging import logging_redirect_tqdm

from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
Expand Down Expand Up @@ -232,82 +233,84 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da

"""

if core_config.debug_config.bypass:
return in_model, None

KerasModelValidation(model=in_model,
fw_info=DEFAULT_KERAS_INFO).validate()

if core_config.is_mixed_precision_enabled:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
"Ensure usage of the correct API for keras_post_training_quantization "
"or provide a valid mixed-precision configuration.") # pragma: no cover

tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)

fw_impl = GPTQKerasImplemantation()

target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
# Attach tpc model to framework
attach2keras = AttachTpcToKeras()
framework_platform_capabilities = attach2keras.attach(
target_platform_capabilities,
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)

progress_info_controller = ProgressInfoController(
total_step=research_progress_total(core_config, target_resource_utilization, gptq_config),
description="MCT Keras GPTQ Progress",
progress_info_callback=core_config.debug_config.progress_info_callback
)

tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
representative_data_gen=representative_data_gen,
core_config=core_config,
fw_info=DEFAULT_KERAS_INFO,
fw_impl=fw_impl,
fqc=framework_platform_capabilities,
target_resource_utilization=target_resource_utilization,
tb_w=tb_w,
running_gptq=True,
progress_info_controller=progress_info_controller)

float_graph = copy.deepcopy(tg)

tg_gptq = gptq_runner(tg,
core_config,
gptq_config,
representative_data_gen,
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
DEFAULT_KERAS_INFO,
fw_impl,
tb_w,
hessian_info_service=hessian_info_service,
progress_info_controller=progress_info_controller)

del hessian_info_service

if progress_info_controller is not None:
progress_info_controller.set_description("MCT Graph Finalization")

if core_config.debug_config.analyze_similarity:
analyzer_model_quantization(representative_data_gen,
tb_w,
float_graph,
tg_gptq,
fw_impl,
DEFAULT_KERAS_INFO)

exportable_model, user_info = get_exportable_keras_model(tg_gptq)
if framework_platform_capabilities.tpc.add_metadata:
exportable_model = add_metadata(exportable_model,
create_model_metadata(fqc=framework_platform_capabilities,
scheduling_info=scheduling_info))

if progress_info_controller is not None:
progress_info_controller.close()

return exportable_model, user_info
with logging_redirect_tqdm():

if core_config.debug_config.bypass:
return in_model, None

KerasModelValidation(model=in_model,
fw_info=DEFAULT_KERAS_INFO).validate()

if core_config.is_mixed_precision_enabled:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
"Ensure usage of the correct API for keras_post_training_quantization "
"or provide a valid mixed-precision configuration.") # pragma: no cover

tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)

fw_impl = GPTQKerasImplemantation()

target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
# Attach tpc model to framework
attach2keras = AttachTpcToKeras()
framework_platform_capabilities = attach2keras.attach(
target_platform_capabilities,
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)

progress_info_controller = ProgressInfoController(
total_step=research_progress_total(core_config, target_resource_utilization, gptq_config),
description="MCT Keras GPTQ Progress",
progress_info_callback=core_config.debug_config.progress_info_callback
)

tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
representative_data_gen=representative_data_gen,
core_config=core_config,
fw_info=DEFAULT_KERAS_INFO,
fw_impl=fw_impl,
fqc=framework_platform_capabilities,
target_resource_utilization=target_resource_utilization,
tb_w=tb_w,
running_gptq=True,
progress_info_controller=progress_info_controller)

float_graph = copy.deepcopy(tg)

tg_gptq = gptq_runner(tg,
core_config,
gptq_config,
representative_data_gen,
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
DEFAULT_KERAS_INFO,
fw_impl,
tb_w,
hessian_info_service=hessian_info_service,
progress_info_controller=progress_info_controller)

del hessian_info_service

if progress_info_controller is not None:
progress_info_controller.set_description("MCT Graph Finalization")

if core_config.debug_config.analyze_similarity:
analyzer_model_quantization(representative_data_gen,
tb_w,
float_graph,
tg_gptq,
fw_impl,
DEFAULT_KERAS_INFO)

exportable_model, user_info = get_exportable_keras_model(tg_gptq)
if framework_platform_capabilities.tpc.add_metadata:
exportable_model = add_metadata(exportable_model,
create_model_metadata(fqc=framework_platform_capabilities,
scheduling_info=scheduling_info))

if progress_info_controller is not None:
progress_info_controller.close()

return exportable_model, user_info

else:
# If tensorflow is not installed,
Expand Down
2 changes: 1 addition & 1 deletion model_compression_toolkit/gptq/pytorch/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def micro_training_loop(self,
"""
with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
for _ in epochs_pbar:
with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
with tqdm(self.train_dataloader, position=1, leave=False, disable=self.disable_data_pbar) as data_pbar:
for sample in data_pbar:
data, loss_weight, reg_weight = to_torch_tensor(sample)
input_data = [d * self.input_scale for d in data]
Expand Down
Loading
Loading