Skip to content

Commit 34cbba2

Browse files
committed
Fix bug in gene filtering logic
1 parent 0ccb39a commit 34cbba2

File tree

1 file changed

+46
-35
lines changed

1 file changed

+46
-35
lines changed

src/allium_prepro/gex_preprocessor.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,13 @@ def _preprocess_genes(self):
9595
translated_genes = self._gt.translate_genes(data.index.values,
9696
source='ensembl_id',
9797
target='symbol')
98-
data['gene_name_std'] = data.index.map(translated_genes).fillna("")
98+
99+
def _standardize_ensembl_name(ensembl_name):
100+
if ensembl_name in translated_genes:
101+
return translated_genes[ensembl_name]
102+
return ensembl_name
103+
104+
data['gene_name_std'] = data.index.map(_standardize_ensembl_name)
99105
else:
100106
# Code to use if index is in symbol format
101107
# Get all gene name values to update to the latest standard
@@ -115,48 +121,54 @@ def _preprocess_genes(self):
115121
ref['gene_name_std'] = ref['name'].map(
116122
updated_genes).fillna(ref['name'])
117123

118-
# Drop all rows in data where gene_name_std
119-
# does not appear in ref['gene_name_std']
120-
data = data[data['gene_name_std'].isin(ref['gene_name_std'])]
124+
filtered_rows = {}
125+
case_columns = [col for col in data.columns if
126+
re.match(self._sample_col_regex, col)]
121127

122-
# Join the two dataframes on the gene_name_std column
123-
data = data.join(ref.set_index('gene_name_std'), on='gene_name_std')
128+
def _find_ref_key(row):
129+
if row.name in ref['id'].values:
130+
return ref[ref['id'] == row.name]['id'].values[0]
131+
if row['gene_name_std'] in ref['gene_name_std'].values:
132+
ref_rows = ref[ref['gene_name_std'] == row['gene_name_std']]
124133

125-
# Get duplicates
126-
duplicates = data[data.duplicated(subset='gene_name_std', keep=False)]
134+
if len(ref_rows) == 1:
135+
return ref_rows['id'].values[0]
127136

128-
# Get all unique indices from duplicates
129-
duplicate_gene_names = duplicates['gene_name_std'].unique()
137+
# See if one of the rows also matches a ref id
138+
for i, r in ref_rows.iterrows():
139+
if r['id'] in ref['id'].values:
140+
return r['id']
130141

131-
# For each row in duplicates,
132-
# sum the values of the rows with the same name
133-
# and update the row with the sum. Then drop the duplicates.
134-
case_columns = [col for col in duplicates.columns if
135-
re.match(self._sample_col_regex, col)]
142+
# Return the first row
143+
return ref_rows['id'].values[0]
136144

137-
for duplicate_gene in duplicate_gene_names:
138-
# Get all rows with the same id
139-
dupes = data[data['gene_name_std'] == duplicate_gene]
145+
for i, row in data.iterrows():
146+
# Strip whitespace
147+
key = _find_ref_key(row)
140148

141-
# Sum the rows
142-
new_counts = dupes[case_columns].sum()
149+
# Print ref row length
150+
if key is not None:
151+
try:
152+
key = key.strip()
153+
except Exception:
154+
print(f'Error: {key}')
155+
exit()
143156

144-
# Update the corresponding data rows with the new counts
145-
data.loc[
146-
data['gene_name_std'] == duplicate_gene, case_columns
147-
] = new_counts.values
157+
# Just keep case cols in row
158+
row = row[case_columns]
148159

149-
# Keep only the first record of all duplicates
150-
data = data.drop_duplicates(subset='gene_name_std')
160+
# Is there already a row in filtered_rows?
161+
if key in filtered_rows:
162+
filtered_rows[key] += row
163+
else:
164+
filtered_rows[key] = row
165+
# Keep the previous index name for the row
166+
filtered_rows[key].name = key
151167

152-
# Print records in ref.id that are not in data.id
153-
missing = ref[~ref['id'].isin(data['id'])]
154-
155-
# Change index to id column
156-
data = data.set_index('id')
168+
data = pd.DataFrame(filtered_rows.values())
157169

158-
# Drop all columns except for the case columns
159-
data = data[case_columns]
170+
# Print records in ref.id that are not in data.id
171+
missing = ref[~ref['id'].isin(data.index)]
160172

161173
# Create records for all missing genes in data, filled with 0s
162174
# The missing$id is the index value, and all the case columns are 0
@@ -165,8 +177,7 @@ def _preprocess_genes(self):
165177
data=0)
166178

167179
# Dump missing data to file
168-
if not missing_data.empty:
169-
missing_data.to_csv(self._missing_genes_path)
180+
missing_data.to_csv(self._missing_genes_path)
170181

171182
# Remove index name
172183
missing_data.index.name = None

0 commit comments

Comments
 (0)