Skip to content

NeuralNetworkClassifier callback function not being called during training #994

@MohammedAlmajhali

Description

@MohammedAlmajhali

Environment

  • Qiskit Machine Learning version: 0.8.4
  • Python version: 3.12.11
  • Operating system: Google Colab (Ubuntu) Linux-6.6.97+-x86_64-with-glibc2.35

What is happening?

The callback function passed to NeuralNetworkClassifier is not being executed during training. No output is produced from the callback function despite trying multiple approaches.

How can we reproduce the issue?

import numpy as np
from qiskit.circuit.library import ZZFeatureMap, RealAmplitudes
from qiskit.primitives import Estimator
from qiskit.quantum_info import SparsePauliOp
from qiskit_algorithms.optimizers import ADAM
from qiskit_machine_learning.neural_networks import EstimatorQNN
from qiskit_machine_learning.algorithms.classifiers import NeuralNetworkClassifier

# Generate sample data
np.random.seed(42)
X = np.random.randn(50, 4)  # 50 samples, 4 features
y = np.random.randint(0, 3, 50)  # 3 classes

# Create quantum circuit
n_qubits = 4
feature_map = ZZFeatureMap(feature_dimension=4, reps=2)
ansatz = RealAmplitudes(num_qubits=n_qubits, reps=3)
qc = feature_map.compose(ansatz)

# Create observables for 3-class classification
observables = []
for idx in range(3):
    pauli = ['I'] * n_qubits
    pauli[-1 - idx] = 'Z'
    observables.append(SparsePauliOp(''.join(pauli)))

# Create callback function
def callback_function(weights, obj_func_eval):
    print(f"Callback called! Loss: {obj_func_eval}", flush=True)

# Create QNN
qnn = EstimatorQNN(
    estimator=Estimator(),
    circuit=qc,
    observables=observables,
    input_params=list(feature_map.parameters),
    weight_params=list(ansatz.parameters),
)

# Create classifier with callback
optimizer = ADAM(maxiter=10)
vqc = NeuralNetworkClassifier(
    neural_network=qnn,
    loss='cross_entropy',
    one_hot=True,
    optimizer=optimizer,
    callback=callback_function
)

# Train - callback should print but doesn't
print("Starting training...")
vqc.fit(X, y)
print("Training complete")

What should happen?

The callback function should be called at each iteration and print:

Starting training...
Callback called! Loss: 0.xxx
Callback called! Loss: 0.xxx
...
Training complete

Any suggestions?

I've tried:

  1. Different callback signatures: (weights, obj_func_eval), (*args), etc.
  2. Using sys.stdout.flush() and print(..., flush=True)
  3. With ADAM from qiskit_algorithms: from qiskit_algorithms.optimizers import ADAM
  4. With a SciPy-based optimizer: from qiskit_algorithms.optimizers import L_BFGS_B
  5. Using logging module instead of print

But it doesn't work.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions