Skip to content

Commit 6b25420

Browse files
committed
update SNP2GEX dataloader
1 parent 83bdd04 commit 6b25420

File tree

2 files changed

+120
-2
lines changed

2 files changed

+120
-2
lines changed

task1_SNP2GEX/Dataset.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from enformer_pytorch import Enformer,seq_indices_to_one_hot,str_to_one_hot
2+
import os
3+
import torch
4+
from tqdm import tqdm
5+
import pyfaidx
6+
import pandas as pd
7+
8+
9+
class SampleGeneExpressionDataset(torch.utils.data.Dataset):
10+
def __init__(self,split,csv_file,ref_path=ref_PATH,\
11+
consensus_root=fasta_ROOT,seq_len=32000,device='cuda',target_name='log_TPM',\
12+
selected_gene=[],selected_samples=[],return_ref=False):
13+
cur_split_file = pd.read_csv(csv_file)
14+
f1 = cur_split_file['split']==split
15+
f2 = cur_split_file['sample']=='ref'
16+
self.consensus_root = consensus_root
17+
self.target_name = target_name
18+
self.device = device
19+
self.dictionary = {'A': 0, 'T': 1, 'C': 2,'G':3}
20+
self.seq_len = 32000
21+
self.return_ref = return_ref
22+
23+
if(len(selected_samples)>0):
24+
f3 = cur_split_file['sample'].isin(selected_samples)
25+
self.sample_list = selected_samples
26+
else:
27+
print("use all samples for the current dataset")
28+
f3 = ~f2
29+
self.sample_list = list(cur_split_file['sample'].unique())
30+
31+
self.sample_df = cur_split_file[f1&f3].reset_index()
32+
if(len(selected_gene)!=0):
33+
self.gene_list = sorted(selected_gene)
34+
self.sample_df = self.sample_df[self.sample_df['gene'].isin(self.gene_list)].reset_index()
35+
else:
36+
self.gene_list = sorted(list(self.sample_df['gene'].unique()))
37+
38+
self.sample_fasta = self.loadSampleFasta()
39+
40+
41+
def loadSampleFasta(self):
42+
43+
'''
44+
45+
return sample_chr fasta dictionary
46+
47+
'''
48+
non_ref_sample_data =self.sample_df
49+
sample_fasta = dict()
50+
not_existed_samples = []
51+
52+
for index, cur_sample in tqdm(enumerate(self.sample_list), total=len(self.sample_list),desc="Loading sample fasta files"):
53+
cur_index= cur_sample
54+
55+
if(cur_index in sample_fasta):
56+
continue
57+
58+
cur_fn = [cur_index+'_allele_1.fasta',cur_index+'_allele_2.fasta']
59+
cur_fasta_data_list = []
60+
61+
for temp_fn in cur_fn:
62+
cur_fasta_path = os.path.join(self.consensus_root,temp_fn)
63+
if(not os.path.exists(cur_fasta_path)):
64+
not_existed_samples.append(cur_index)
65+
continue
66+
curFastaData = pyfaidx.Fasta(cur_fasta_path)
67+
cur_fasta_data_list.append(curFastaData)
68+
sample_fasta[cur_index] = cur_fasta_data_list
69+
70+
self.sample_df = self.sample_df[~self.sample_df['sample'].isin(not_existed_samples)]
71+
print("sample file not found",not_existed_samples)
72+
return sample_fasta
73+
74+
def __len__(self):
75+
return self.sample_df.shape[0]
76+
77+
def __getitem__(self,idx):
78+
79+
row = self.sample_df.iloc[idx]
80+
#print("idx",idx,'row',row)
81+
cur_sample = row['sample']
82+
sample_target = row[self.target_name]
83+
cur_sample_fasta_file = self.sample_fasta[cur_sample]
84+
cur_gene_name = row["gene"]
85+
86+
# allele 1 information
87+
cur_sample_fasta_file_1 = cur_sample_fasta_file[0]
88+
original_cur_sample_seq_1 = str(cur_sample_fasta_file_1[cur_gene_name])
89+
90+
# allele 2 information
91+
cur_sample_fasta_file_2 = cur_sample_fasta_file[1]
92+
original_cur_sample_seq_2 = str(cur_sample_fasta_file_2[cur_gene_name])
93+
if(not self.return_ref):
94+
return original_cur_sample_seq_1, original_cur_sample_seq_2, sample_target,cur_sample,cur_gene_name
95+
96+
97+

task1_SNP2GEX/README.md

+23-2
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,30 @@ The dataset contains one subfoler and one CSV file:
4646

4747

4848
### Demo data load
49-
We provide a dataloader for loading the raw sequences or the one-hot encoded matrix from the fasta file, and the target from the partition CSV file as the batch.
49+
We provide a dataloader for loading the raw sequences and the truth labels from the fasta file and the partition CSV file. The example code shown below:
5050

51-
TO BE UPDATED SOON..
51+
```python
52+
import Dataset
53+
from torch.utils.data import DataLoader
54+
import os
55+
DATA_ROOT = "/home/smli/ssd/miniHackathon/"
56+
57+
# set the `split` as "test" when you are evaluating.
58+
# sample_list = [], or gene_list = [] means all individuals and all genes of the specific split category will be used. if you want to select subset of them, pass the list to the corresponding argument.
59+
# you can also set target_name="raw count"/"log_count"/"rpkm" as you want
60+
61+
trainDataset = Dataset.SampleGeneExpressionDataset(split="train",csv_file=os.path.join(DATA_ROOT,"partitions.csv"),consensus_root=os.path.join(DATA_ROOT,"fasta"),sample_list=[],gene_list=[],target_name="log_TPM")
62+
# set the batch size or the sampler as you want.
63+
trainDataLoader = DataLoader(trainDataset, batch_size=1,shuffle=True)
64+
# iterate the dataset, modify it as needed to utilize gLM embeddings!
65+
for cur_seq_1, cur_seq_2, sample_target,sample_name,gene_name in trainDataLoader:
66+
print(f"cur_seq_1: {cur_seq_1}")
67+
print(f"cur_seq_2: {cur_seq_2}")
68+
print(f"sample_target:{sample_target}")
69+
print(f"sample_name:{sample_name}")
70+
print(f"gene_name: {gene_name}")
71+
break
72+
```
5273
5374
Alternatively, you can also implement your own dataloader.
5475

0 commit comments

Comments
 (0)