Skip to content

Modifications for Classifier Pipeline #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 110 additions & 1 deletion adsmp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from past.builtins import basestring
from . import exceptions
from adsmp.models import ChangeLog, IdentifierMapping, MetricsBase, MetricsModel, Records
from adsmsg import OrcidClaims, DenormalizedRecord, FulltextUpdate, MetricsRecord, NonBibRecord, NonBibRecordList, MetricsRecordList, AugmentAffiliationResponseRecord, AugmentAffiliationRequestRecord
from adsmsg import OrcidClaims, DenormalizedRecord, FulltextUpdate, MetricsRecord, NonBibRecord, NonBibRecordList, MetricsRecordList, AugmentAffiliationResponseRecord, AugmentAffiliationRequestRecord, ClassifyRequestRecord, ClassifyRequestRecordList, ClassifyResponseRecord, ClassifyResponseRecordList
from adsmsg.msg import Msg
from adsputils import ADSCelery, create_engine, sessionmaker, scoped_session, contextmanager
from sqlalchemy.orm import load_only as _load_only
Expand All @@ -19,6 +19,7 @@
from copy import deepcopy
import sys
from sqlalchemy.dialects.postgresql import insert
import csv


class ADSMasterPipelineCelery(ADSCelery):
Expand Down Expand Up @@ -114,6 +115,12 @@ def update_storage(self, bibcode, type, payload):
oldval = 'not-stored'
r.augments = payload
r.augments_updated = now
elif type == 'classify':
# payload contains new value for collections field
# r.augments holds a list, save it in database
oldval = 'not-stored'
r.classifications = payload
r.classifications_updated = now
else:
raise Exception('Unknown type: %s' % type)
session.add(ChangeLog(key=bibcode, type=type, oldvalue=oldval))
Expand Down Expand Up @@ -216,6 +223,8 @@ def get_msg_type(self, msg):
return 'metrics_records'
elif isinstance(msg, AugmentAffiliationResponseRecord):
return 'augment'
elif isinstance(msg, ClassifyResponseRecord):
return 'classify'

else:
raise exceptions.IgnorableException('Unkwnown type {0} submitted for update'.format(repr(msg)))
Expand Down Expand Up @@ -510,6 +519,106 @@ def request_aff_augment(self, bibcode, data=None):
else:
self.logger.debug('request_aff_augment called but bibcode {} has no aff data'.format(bibcode))

def prepare_bibcode(self, bibcode):
"""prepare data for classifier pipeline

Parameters
----------
bibcode = reference ID for record (Needs to include SciXID)

"""
if rec is None:
self.logger.warning('request_classifier called but no data at all for bibcode {}'.format(bibcode))
return
bib_data = rec.get('bib_data', None)
if bib_data is None:
self.logger.warning('request_classifier called but no bib data for bibcode {}'.format(bibcode))
return
title = bib_data.get('title', '')
abstract = bib_data.get('abstract', '')
data = {
'bibcode': bibcode,
'title': title,
'abstract': abstract,
}
return data

def request_classify(self, bibcode=None,scix_id = None, filename=None,mode='auto', batch_size=500, data=None, check_boolean=False, operation_step=None):
""" send classifier request for bibcode to classifier pipeline

set data parameter to provide test data

Parameters
----------
bibcode = reference ID for record (Needs to include SciXID)
scix_id = reference ID for record
filename : filename of input file with list of records to classify
mode : 'auto' (default) assumes single record input from master, 'manual' assumes multiple records input at command line
batch_size : size of batch for large input files
check_boolean : Used for testing - writes the message to file
operation_step: string - defines mode of operation: classify, classify_verify, or verify

"""
self.logger.info('request_classify called with bibcode={}, filename={}, mode={}, batch_size={}, data={}, validate={}'.format(bibcode, filename, mode, batch_size, data, check_boolean))

if not self._config.get('OUTPUT_TASKNAME_CLASSIFIER'):
self.logger.warning('request_classifier called but no classifier taskname in config')
return
if not self._config.get('OUTPUT_CELERY_BROKER_CLASSIFIER'):
self.logger.warning('request_classifier called but no classifier broker in config')
return

if bibcode is not None and mode == 'auto':
if data is None:
data = self.prepare_bibcode(bibcode)
if data and data.get('title'):
data['operation_step'] = operation_step
message = ClassifyRequestRecord(**data) # Maybe make as one element list check protobuf
self.forward_message(message, pipeline='classifier')
self.logger.debug('sent classifier request for bibcode {}'.format(bibcode))
else:
self.logger.debug('request_classifier called but bibcode {} has no title data'.format(bibcode))
if filename is not None and mode == 'manual':
batch_idx = 0
batch_list = []
self.logger.info('request_classifier called with filename {}'.format(filename))
with open(filename, 'r') as f:
reader = csv.DictReader(f)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might make more sense to move this to the run.py so that way this code isn't being pulled into the celery workers.

bibcodes = [row for row in reader]
while batch_idx < len(bibcodes):
bibcodes_batch = bibcodes[batch_idx:batch_idx+batch_size]
for record in bibcodes_batch:
if record.get('title') or record.get('abstract'):
data = record
else:
data = self.prepare_bibcode(record)
if data and data.get('title'):
batch_list.append(data)
if len(batch_list) > 0:
message = ClassifyRequestRecordList()
for item in batch_list:
entry = message.classify_requests.add()
entry.bibcode = item.get('bibcode')
entry.title = item.get('title')
entry.abstract = item.get('abstract')
output_taskname=self._config.get('OUTPUT_TASKNAME_CLASSIFIER')
output_broker=self._config.get('OUTPUT_CELERY_BROKER_CLASSIFIER')
if check_boolean is True:
# Save message to file
# with open('classifier_request.json', 'w') as f:
# f.write(str(message))
json_message = MessageToJson(message)
with open('classifier_request.json', 'w') as f:
f.write(json_message)
else:
self.logger.info('Sending message for batch')
self.logger.info('sending message {}'.format(message))
self.forward_message(message, pipeline='classifier')
self.logger.debug('sent classifier request for batch {}'.format(batch_idx))

batch_idx += batch_size
batch_list = []

def generate_links_for_resolver(self, record):
"""use nonbib or bib elements of database record and return links for resolver and checksum"""
# nonbib data has something like
Expand Down
5 changes: 4 additions & 1 deletion adsmp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Records(Base):
# currently only supported key is 'affiliations'
# with the value an array holding affiliation strings and '-' placeholders
augments = Column(Text)
classifications = Column(Text)

# when data is received we set the updated timestamp
bib_data_updated = Column(UTCDateTime, default=None)
Expand All @@ -65,6 +66,7 @@ class Records(Base):
fulltext_updated = Column(UTCDateTime, default=None)
metrics_updated = Column(UTCDateTime, default=None)
augments_updated = Column(UTCDateTime, default=None)
classifications_updated = Column(UTCDateTime, default=None)

created = Column(UTCDateTime, default=get_date)
updated = Column(UTCDateTime, default=get_date)
Expand All @@ -83,9 +85,10 @@ class Records(Base):
_date_fields = ['created', 'updated', 'processed', # dates
'bib_data_updated', 'orcid_claims_updated', 'nonbib_data_updated',
'fulltext_updated', 'metrics_updated', 'augments_updated',
'classifications_updated',
'datalinks_processed', 'solr_processed', 'metrics_processed']
_text_fields = ['id', 'bibcode', 'status', 'solr_checksum', 'metrics_checksum', 'datalinks_checksum']
_json_fields = ['bib_data', 'orcid_claims', 'nonbib_data', 'metrics', 'fulltext', 'augments']
_json_fields = ['bib_data', 'orcid_claims', 'nonbib_data', 'metrics', 'fulltext', 'augments', 'classifications']

def toJSON(self, for_solr=False, load_only=None):
if for_solr:
Expand Down
14 changes: 14 additions & 0 deletions adsmp/solr_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ def extract_augments_pipeline(db_augments, solrdoc):
"institution": db_augments.get("institution", None),
}

def extract_classifications_pipeline(db_classifications, solrdoc):
"""retrieve expected classifier collections

classifications is a solr virtual field so it should never be set"""
if db_classifications is None or len(db_classifications) == 0:
return {"database" : solrdoc.get("database", None)}

# Append classifier results to classic collections
return {
"database" : list(set(db_classifications + solrdoc.get("database", [])))
}


def extract_fulltext(data, solrdoc):
out = {}
Expand Down Expand Up @@ -311,6 +323,7 @@ def get_timestamps(db_record, out):
("fulltext", extract_fulltext),
("#timestamps", get_timestamps), # use 'id' to be always called
("augments", extract_augments_pipeline), # over aff field, adds aff_*
("classifications", extract_classifications_pipeline), # overwrites databse field in bib_data
]


Expand Down Expand Up @@ -467,6 +480,7 @@ def transform_json_record(db_record):
)
)


# Compute doctype scores on the fly
out["doctype_boost"] = None

Expand Down
14 changes: 12 additions & 2 deletions adsmp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ def task_update_record(msg):
- bibcode
- and specific payload
"""
logger.debug('Updating record: %s', msg)
# logger.debug('Updating record: %s', msg)
logger.info('Updating record: %s', msg)
status = app.get_msg_status(msg)
logger.info(f'Message status: {status}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could probably become debug statements long term so we aren't flooding the logs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes made and committed.

type = app.get_msg_type(msg)
logger.info(f'Message type: {type}')
bibcodes = []

if status == 'deleted':
Expand Down Expand Up @@ -84,7 +87,14 @@ def task_update_record(msg):
msg.toJSON(including_default_value_fields=True))
if record:
logger.debug('Saved augment message: %s', msg)

elif type == 'classify':
bibcodes.append(msg.bibcode)
logger.info(f'message to JSON: {msg.toJSON(including_default_value_fields=True)}')
payload = msg.toJSON(including_default_value_fields=True)
payload = payload['collections']
record = app.update_storage(msg.bibcode, 'classify',payload)
if record:
logger.debug('Saved classify message: %s', msg)
else:
# here when record has a single bibcode
bibcodes.append(msg.bibcode)
Expand Down
74 changes: 39 additions & 35 deletions alembic/versions/2d2af8a9c996_upgrade_change_log.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this file changed? I am just a bit concerned because this alembic upgrade not matching the one that was used to upgrade the DB previously could pose an issue.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was having an issue at one point so I added the if statement to check the database. I can revert it so it matches.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted file committed.

Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,46 @@


def upgrade():
op.execute('ALTER TABLE change_log ADD COLUMN big_id BIGINT;')
op.execute('CREATE OR REPLACE FUNCTION set_new_id() RETURNS TRIGGER AS\n'\
'$BODY$\n'\
'BEGIN\n'\
'\t NEW.big_id := NEW.id;\n'\
'\t RETURN NEW;\n'\
'END\n'\
'$BODY$ LANGUAGE PLPGSQL;\n'\
'CREATE TRIGGER set_new_id_trigger BEFORE INSERT OR UPDATE ON {}\n'\
'FOR EACH ROW EXECUTE PROCEDURE set_new_id();\n'.format('change_log'))
op.execute('UPDATE change_log SET big_id=id')
op.execute('CREATE UNIQUE INDEX IF NOT EXISTS big_id_unique ON change_log(big_id);')
op.execute('ALTER TABLE change_log ADD CONSTRAINT big_id_not_null CHECK (big_id IS NOT NULL) NOT VALID;')
op.execute('ALTER TABLE change_log VALIDATE CONSTRAINT big_id_not_null;')
op.execute('ALTER TABLE change_log DROP CONSTRAINT change_log_pkey, ADD CONSTRAINT change_log_pkey PRIMARY KEY USING INDEX big_id_unique;')
op.execute('ALTER SEQUENCE change_log_id_seq OWNED BY change_log.big_id;')
op.execute("ALTER TABLE change_log ALTER COLUMN big_id SET DEFAULT nextval('change_log_id_seq');")
op.execute("ALTER TABLE change_log RENAME COLUMN id TO old_id;")
op.execute("ALTER TABLE change_log RENAME COLUMN big_id TO id;")
op.drop_column('change_log', 'old_id')
op.execute('ALTER SEQUENCE change_log_id_seq as bigint MAXVALUE 9223372036854775807')
op.execute('DROP TRIGGER IF EXISTS set_new_id_trigger ON change_log')
# ### end Alembic commands ###
conn = op.get_bind()
if conn.dialect.name == 'postgresql':
op.execute('ALTER TABLE change_log ADD COLUMN big_id BIGINT;')
op.execute('CREATE OR REPLACE FUNCTION set_new_id() RETURNS TRIGGER AS\n'\
'$BODY$\n'\
'BEGIN\n'\
'\t NEW.big_id := NEW.id;\n'\
'\t RETURN NEW;\n'\
'END\n'\
'$BODY$ LANGUAGE PLPGSQL;\n'\
'CREATE TRIGGER set_new_id_trigger BEFORE INSERT OR UPDATE ON {}\n'\
'FOR EACH ROW EXECUTE PROCEDURE set_new_id();\n'.format('change_log'))
op.execute('UPDATE change_log SET big_id=id')
op.execute('CREATE UNIQUE INDEX IF NOT EXISTS big_id_unique ON change_log(big_id);')
op.execute('ALTER TABLE change_log ADD CONSTRAINT big_id_not_null CHECK (big_id IS NOT NULL) NOT VALID;')
op.execute('ALTER TABLE change_log VALIDATE CONSTRAINT big_id_not_null;')
op.execute('ALTER TABLE change_log DROP CONSTRAINT change_log_pkey, ADD CONSTRAINT change_log_pkey PRIMARY KEY USING INDEX big_id_unique;')
op.execute('ALTER SEQUENCE change_log_id_seq OWNED BY change_log.big_id;')
op.execute("ALTER TABLE change_log ALTER COLUMN big_id SET DEFAULT nextval('change_log_id_seq');")
op.execute("ALTER TABLE change_log RENAME COLUMN id TO old_id;")
op.execute("ALTER TABLE change_log RENAME COLUMN big_id TO id;")
op.drop_column('change_log', 'old_id')
op.execute('ALTER SEQUENCE change_log_id_seq as bigint MAXVALUE 9223372036854775807')
op.execute('DROP TRIGGER IF EXISTS set_new_id_trigger ON change_log')
# ### end Alembic commands ###


def downgrade():
op.add_column('change_log', sa.Column('small_id', sa.Integer(), unique=True))
op.execute('DELETE FROM change_log WHERE id > 2147483647')
op.execute('UPDATE change_log SET small_id=id')
op.alter_column('change_log', 'small_id', nullable=False)
op.drop_constraint('change_log_pkey', 'change_log', type_='primary')
op.create_primary_key("change_log_pkey", "change_log", ["small_id", ])
op.execute('ALTER SEQUENCE change_log_id_seq OWNED BY change_log.small_id;')
op.execute("ALTER TABLE change_log ALTER COLUMN small_id SET DEFAULT nextval('change_log_id_seq');")
op.alter_column('change_log', 'id', nullable=False, new_column_name='old_id')
op.alter_column('change_log', 'small_id', nullable=False, new_column_name='id')
op.drop_column('change_log', 'old_id')
op.execute('ALTER SEQUENCE change_log_id_seq as int MAXVALUE 2147483647')
conn = op.get_bind()
if conn.dialect.name == 'postgresql':
op.add_column('change_log', sa.Column('small_id', sa.Integer(), unique=True))
op.execute('DELETE FROM change_log WHERE id > 2147483647')
op.execute('UPDATE change_log SET small_id=id')
op.alter_column('change_log', 'small_id', nullable=False)
op.drop_constraint('change_log_pkey', 'change_log', type_='primary')
op.create_primary_key("change_log_pkey", "change_log", ["small_id", ])
op.execute('ALTER SEQUENCE change_log_id_seq OWNED BY change_log.small_id;')
op.execute("ALTER TABLE change_log ALTER COLUMN small_id SET DEFAULT nextval('change_log_id_seq');")
op.alter_column('change_log', 'id', nullable=False, new_column_name='old_id')
op.alter_column('change_log', 'small_id', nullable=False, new_column_name='id')
op.drop_column('change_log', 'old_id')
op.execute('ALTER SEQUENCE change_log_id_seq as int MAXVALUE 2147483647')

39 changes: 39 additions & 0 deletions alembic/versions/6e98dcc397e6_add_classifications_column.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between the classifications column and the collections column?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided to use the name classifications so it would not be confused with the existing SOLR collections field. The later commit with classifications was ment to fix the earlier one with collections.

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""add_classifications_column

Revision ID: 6e98dcc397e6
Revises: 2d2af8a9c996
Create Date: 2025-02-28 08:52:00.341542

"""

# revision identifiers, used by Alembic.
revision = '6e98dcc397e6'
down_revision = '2d2af8a9c996'

from alembic import op
import sqlalchemy as sa



def upgrade():
# sqlite doesn't have ALTER command
cx = op.get_context()
if 'sqlite' in cx.connection.engine.name:
with op.batch_alter_table("records") as batch_op:
batch_op.add_column(sa.Column('classifications', sa.Text))
batch_op.add_column(sa.Column('classifications_updated', sa.TIMESTAMP))
else:
op.add_column('records', sa.Column('classifications', sa.Text))
op.add_column('records', sa.Column('classifications_updated', sa.TIMESTAMP))


def downgrade():
cx = op.get_context()
if 'sqlite' in cx.connection.engine.name:
with op.batch_alter_table("records") as batch_op:
batch_op.drop_column('classifications')
batch_op.drop_column('classifications_updated')
else:
op.drop_column('records', 'classifications')
op.drop_column('records', 'classifications_updated')

Loading