Skip to content

Commit

Permalink
fix: output all areas in count even with zeros for compare
Browse files Browse the repository at this point in the history
  • Loading branch information
asoronow committed Nov 4, 2024
1 parent db5130a commit f74971d
Showing 1 changed file with 110 additions and 73 deletions.
183 changes: 110 additions & 73 deletions py/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def percent_colocalized(boxes1, boxes2, threshold=0.5):
default=False,
)
args = parser.parse_args()

prediction_path = Path(args.predictions.strip())
annotation_path = Path(args.annotations.strip())
output_path = Path(args.output.strip())
Expand All @@ -145,7 +144,6 @@ def percent_colocalized(boxes1, boxes2, threshold=0.5):
colocalized = {}
region_areas = {}
for i, pName in enumerate(predictionFiles):
# divide up the results file by section as well
sums[pName] = {}
region_areas[pName] = {}
with open(prediction_path / pName, "rb") as predictionPkl, open(
Expand All @@ -162,8 +160,6 @@ def percent_colocalized(boxes1, boxes2, threshold=0.5):
annotation, predicted_size
)



unique_ids, counts = np.unique(annotation_rescaled, return_counts=True)
for unique_id, count in zip(unique_ids, counts):
name = regions[unique_id]["acronym"]
Expand All @@ -181,106 +177,146 @@ def percent_colocalized(boxes1, boxes2, threshold=0.5):
else:
region_areas[pName][parent_name] += count

all_boxes = {c: [] for c in range(len(predictions))}
for c, detection in enumerate(predictions):
sums[pName][c] = {}
counted_boxes = 0
for box in detection.boxes:
counted_boxes += 1
all_boxes[c] += [box]
x, y, mX, mY = box[0], box[1], box[2], box[3]
xPos = int((mX - (mX - x) // 2))
yPos = int((mY - (mY - y) // 2))
try:
atlas_id = annotation_rescaled[yPos, xPos]
except IndexError:
# resize was in the wrong order
annotation_rescaled = resize_image_nearest_neighbor(
annotation, predicted_size[::-1]
)
atlas_id = annotation_rescaled[yPos, xPos]

acronym = regions[atlas_id]["acronym"]
if args.layers:
if sums[pName][c].get(acronym, False):
sums[pName][c][acronym] += 1
else:
sums[pName][c][acronym] = 1
# Initialize counts based on args.layers
sums[pName] = {}
for c in range(len(predictions)):
sums[pName][c] = {}
if args.layers:
# Include all regions (layers included)
region_acronyms = set()
for region_id in regions.keys():
region_acronyms.add(regions[region_id]["acronym"])
else:
# Exclude layers; use parent regions
region_acronyms = set()
for region_id, region_info in regions.items():
area_name = region_info["name"]
if "layer" not in area_name.lower():
region_acronyms.add(region_info["acronym"])
else:
id_path = regions[atlas_id]["id_path"].split("/")
# check if layer in name
area_name = regions[atlas_id]["name"]
if "layer" in area_name.lower():
if len(id_path) >= 2:
parent_id = np.uint32(id_path[-2])
else:
parent_id = atlas_id
# Get parent acronym
id_path = region_info["id_path"].split("/")
if len(id_path) >= 2:
parent_id = np.uint32(id_path[-2])
parent_acronym = regions[parent_id]["acronym"]
if sums[pName][c].get(parent_acronym, False):
sums[pName][c][parent_acronym] += 1
else:
sums[pName][c][parent_acronym] = 1
region_acronyms.add(parent_acronym)
else:
if sums[pName][c].get(acronym, False):
sums[pName][c][acronym] += 1
else:
sums[pName][c][acronym] = 1

# Compute colocalization
colocalized[pName] = {}
local_colocalized = colocalized[pName]
for c, boxes in all_boxes.items():
local_colocalized[c] = {}
for c2, boxes2 in all_boxes.items():
local_colocalized[c][c2] = percent_colocalized(boxes, boxes2)
region_acronyms.add(region_info["acronym"])
# Initialize counts to zero
for acronym in region_acronyms:
sums[pName][c][acronym] = 0

all_boxes = {c: [] for c in range(len(predictions))}
for c, detection in enumerate(predictions):
counted_boxes = 0
for box in detection.boxes:
counted_boxes += 1
all_boxes[c] += [box]
x, y, mX, mY = box[0], box[1], box[2], box[3]
xPos = int((mX - (mX - x) // 2))
yPos = int((mY - (mY - y) // 2))
try:
atlas_id = annotation_rescaled[yPos, xPos]
except IndexError:
# Resize was in the wrong order
annotation_rescaled = resize_image_nearest_neighbor(
annotation, predicted_size[::-1]
)
atlas_id = annotation_rescaled[yPos, xPos]

region_info = regions[atlas_id]
acronym = region_info["acronym"]
if args.layers:
# Count the region as is
sums[pName][c][acronym] += 1
else:
# Exclude layers
area_name = region_info["name"]
if "layer" in area_name.lower():
id_path = region_info["id_path"].split("/")
if len(id_path) >= 2:
parent_id = np.uint32(id_path[-2])
parent_acronym = regions[parent_id]["acronym"]
sums[pName][c][parent_acronym] += 1
else:
sums[pName][c][acronym] += 1
else:
sums[pName][c][acronym] += 1

# Compute colocalization
colocalized[pName] = {}
local_colocalized = colocalized[pName]
for c, boxes in all_boxes.items():
local_colocalized[c] = {}
for c2, boxes2 in all_boxes.items():
local_colocalized[c][c2] = percent_colocalized(boxes, boxes2)

with open(output_path / "count_results.csv", "w", newline="") as resultFile:
print("Writing output...", flush=True)
lines = []
running_counts = {}
running_areas = {}
# Process the sums dictionary to create a unified structure per file
for file, channels in sums.items():
lines.append([file])
all_channel_regions = [channels[channel].keys() for channel in channels]
all_channel_regions = [
item for sublist in all_channel_regions for item in sublist
]
# Collect region acronyms based on args.layers
if args.layers:
all_region_acronyms = set()
for channel_counts in channels.values():
all_region_acronyms.update(channel_counts.keys())
else:
all_region_acronyms = set()
for region_id, region_info in regions.items():
area_name = region_info["name"]
if "layer" not in area_name.lower():
all_region_acronyms.add(region_info["acronym"])
else:
id_path = region_info["id_path"].split("/")
if len(id_path) >= 2:
parent_id = np.uint32(id_path[-2])
parent_acronym = regions[parent_id]["acronym"]
all_region_acronyms.add(parent_acronym)
else:
all_region_acronyms.add(region_info["acronym"])

lines.append(
["Region Acronym", "Region Name", "Area (px)"]
+ [f"Channel #{c}" for c in range(len(channels))]
)
for region in sorted(all_channel_regions):
for region in sorted(all_region_acronyms):
per_channel_counts = []
for channel in channels:
if channels[channel].get(region, False):
per_channel_counts.append(channels[channel][region])
else:
per_channel_counts.append(0)

per_channel_counts.append(channels[channel].get(region, 0))
if running_counts.get(region, False):
running_counts[region] += per_channel_counts[-1]
else:
running_counts[region] = per_channel_counts[-1]

# find name from acronym
region_id = acronym_to_region[region]
region_name = regions[region_id]["name"]
# Find name from acronym
region_id = acronym_to_region.get(region)
if region_id is None:
region_name = "Unknown"
else:
region_name = regions[region_id]["name"]
region_area = region_areas[file].get(region, 0)
lines.append(
[
region,
region_name,
region_areas[file][region],
region_area,
]
+ per_channel_counts
)
lines.append([])

lines.append(["Totals"])
lines.append(["Region Acronym", "Region Name", "Count"])
for region, count in sorted(running_counts.items()):
region_id = acronym_to_region[region]
region_name = regions[region_id]["name"]
for region in sorted(running_counts.keys()):
count = running_counts.get(region, 0)
region_id = acronym_to_region.get(region)
if region_id is None:
region_name = "Unknown"
else:
region_name = regions[region_id]["name"]
lines.append([region, region_name, count])

lines.append([])
Expand All @@ -290,10 +326,11 @@ def percent_colocalized(boxes1, boxes2, threshold=0.5):
lines.append([s] + [f"Channel #{c}" for c in range(len(colocal))])
for c, colocal2 in colocal.items():
line = [f"Channel #{c}"]
for c2, percent in colocal2.items():
for c2 in range(len(colocal2)):
percent = colocal2.get(c2, 0)
line.append(percent)
lines.append(line)

writer = csv.writer(resultFile)
writer.writerows(lines)
print("Done!", flush=True)
print("Done!", flush=True)

0 comments on commit f74971d

Please sign in to comment.