-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmy_pfn_load.py
72 lines (48 loc) · 2.32 KB
/
my_pfn_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch.utils import data
import uproot4 as uproot
import numpy as np
import my_tools
__all__ = ['load']
def load(filename=['Hbb.root','Hcc.root'], num_data=300_000, cache_dir=r"/hpcfs/cepc/higgsgpu/wuzuofei/HiggsHadron-PFNs-gpu/root/sample_train"):
dataset_feature=[]
dataset_label=[]
ic=0
stepsize=1_000
for fn in filename:
num=0
events = uproot.open(cache_dir+"/"+fn+":t1")
#print('{:*^80}'.format(f'loading file {fn}'))
for array in events.iterate(step_size=stepsize, entry_stop=num_data):
event_number=len(array['pfc_truth_PID'])
num+=event_number
pfc_PID = np.array(array['pfc_truth_PID']).reshape(stepsize,1,-1)
pfc_E = np.array(array['pfc_E']).reshape(stepsize,1,-1)
#pfc_P = np.array(array['pfc_P']).reshape(stepsize,1,-1)
pfc_D0 = np.array(array['pfc_D0']).reshape(stepsize,1,-1)
pfc_DZ = np.array(array['pfc_DZ']).reshape(stepsize,1,-1)
pfc_CosTheta= np.array(array['pfc_CosTheta']).reshape(stepsize,1,-1)
pfc_Phi = np.array(array['pfc_Phi']).reshape(stepsize,1,-1)
#feature[num_of_event, feature, particle]
dataset_feature.append(np.concatenate((pfc_PID, pfc_E, pfc_D0, pfc_DZ, pfc_CosTheta, pfc_Phi), axis=1))
dataset_label.append(np.ones(event_number)*ic)
if num%10000==0:
print('{:*^80}'.format(f'load tuples successfully ----->{ic}, {fn}, {num}'))
ic+=1
dataset_feature=np.concatenate(dataset_feature,axis=0)
dataset_label=np.concatenate(dataset_label,axis=0)
#just shuffle
idx=np.random.permutation(len(dataset_label))
dataset_feature=dataset_feature[idx]
dataset_label=dataset_label[idx]
#optional, remap pid to a float number
my_tools.remap_pids(dataset_feature,pid_i=0,error_on_unknown=False)
dataset_feature=torch.from_numpy(dataset_feature)
dataset_label=torch.from_numpy(dataset_label)
dataset_feature=dataset_feature.type(torch.float32)# Essential !!!
dataset_label=dataset_label.type(torch.long) # Essential !!!
dataset=data.TensorDataset(dataset_feature,dataset_label)
return dataset
if __name__=='__main__':
dataset=load()
print(dataset)