Skip to content
15 changes: 6 additions & 9 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ def initialize(
:return: List of data returned from initialization of modifiers
:rtype: List[Any]
"""
self.state.update(**kwargs)
if self.initialized_: # TODO: do not initialize twice
return
if self.initialized_:
raise ValueError(
"Initialize was called twice. To update state values prior to "
"initialization, please use `active_session().state.update()`"
)

self.state.update(**kwargs)
logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
Expand Down Expand Up @@ -215,12 +218,6 @@ def _check_setup_event_lifecycle(self, event_type: EventType):
"Cannot invoke event before recipe, model, and start are set"
)

if not self.state.compression_ready:
logger.error("Cannot invoke event before recipe, model, and start are set")
raise ValueError(
"Cannot invoke event before recipe, model, and start are set"
)

logger.debug("Setting up event lifecycle for event type: {}", event_type)

for mod in self.modifiers:
Expand Down
12 changes: 0 additions & 12 deletions src/llmcompressor/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,6 @@ class State:
model_log_cadence: Optional[float] = None
_last_log_step: Union[float, int, None] = None

@property
def compression_ready(self) -> bool:
"""
Check if the model and optimizer are set for compression.

:return: True if model and optimizer are set, False otherwise
:rtype: bool
"""
ready = self.model is not None and self.optimizer is not None
logger.debug("Compression ready: {}", ready)
return ready

def update(
self,
model: Any = None,
Expand Down
43 changes: 20 additions & 23 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@ def initialize_session(

train_data = self.get_train_dataloader()

# calculate total_steps_per_epoch
# n_gpu handled internally by dataloader
total_batch_size = (
self.args.per_device_train_batch_size
* self.args.gradient_accumulation_steps
)
if isinstance(self.train_dataset, IterableDataset):
logger.warning(
"Training is being run with a streamed dataset, "
"steps_per_epoch cannot be determined and will default to "
"1. LLM Compressor modifiers utilizing this statistic may not "
"behave as expected. "
)
self.total_steps_per_epoch = 1
else:
self.total_steps_per_epoch = math.ceil(
len(self.train_dataset) / total_batch_size
)

self.accelerator.wait_for_everyone()
with summon_full_params_context(self.model, offload_to_cpu=True):
active_session().initialize(
Expand All @@ -156,6 +175,7 @@ def initialize_session(
copy_data=False,
attach_optim_callbacks=True,
fsdp_active=self.is_fsdp_enabled,
steps_per_epoch=self.total_steps_per_epoch,
metadata=self.metadata,
)

Expand Down Expand Up @@ -199,29 +219,6 @@ def create_optimizer(self):
self._check_super_defined("create_optimizer")
super().create_optimizer()

# n_gpu handled internally by dataloader
total_batch_size = (
self.args.per_device_train_batch_size
* self.args.gradient_accumulation_steps
)

if isinstance(self.train_dataset, IterableDataset):
logger.warning(
"Training is being run with a streamed dataset, "
"steps_per_epoch cannot be determined and will default to "
"1. LLM Compressor modifiers utilizing this statistic may not "
"behave as expected. "
)
self.total_steps_per_epoch = 1
else:
self.total_steps_per_epoch = math.ceil(
len(self.train_dataset) / total_batch_size
)

active_session().initialize(
optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch
)

return self.optimizer

def create_scheduler(
Expand Down
10 changes: 0 additions & 10 deletions tests/unit/core/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,6 @@ def test_state_update():
assert state.model_log_cadence == 2


@pytest.mark.regression
def test_state_sparsification_ready():
state = State()
assert not state.compression_ready

state.model = "model"
state.optimizer = "optimizer"
assert state.compression_ready


@pytest.mark.regression
def test_state_update_loggers():
state = State()
Expand Down