Skip to content

Commit

Permalink
Merge pull request #180 from asmacdo/refactor-sample-updates
Browse files Browse the repository at this point in the history
BF: Fix sample aggregation
  • Loading branch information
asmacdo authored Sep 20, 2024
2 parents 2b12679 + 08dcaeb commit 43fa87c
Show file tree
Hide file tree
Showing 3 changed files with 390 additions and 31 deletions.
65 changes: 39 additions & 26 deletions src/con_duct/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class ProcessStats:
etime: str
cmd: str

def max(self, other: ProcessStats) -> ProcessStats:
def aggregate(self, other: ProcessStats) -> ProcessStats:
cmd = self.cmd
if self.cmd != other.cmd:
lgr.debug(
Expand Down Expand Up @@ -267,19 +267,24 @@ class Sample:
timestamp: str = "" # TS of last sample collected

def add_pid(self, pid: int, stats: ProcessStats) -> None:
# We do not calculate averages when we add a pid because we require all pids first
assert (
self.stats.get(pid) is None
) # add_pid should only be called when pid not in Sample
self.total_rss = (self.total_rss or 0) + stats.rss
self.total_vsz = (self.total_vsz or 0) + stats.vsz
self.total_pmem = (self.total_pmem or 0.0) + stats.pmem
self.total_pcpu = (self.total_pcpu or 0.0) + stats.pcpu
self.stats[pid] = stats
self.timestamp = max(self.timestamp, stats.timestamp)
self.stats[pid] = stats

def max(self: Sample, other: Sample) -> Sample:
def aggregate(self: Sample, other: Sample) -> Sample:
output = Sample()
for pid in self.stats.keys() | other.stats.keys():
if (mine := self.stats.get(pid)) is not None:
if (theirs := other.stats.get(pid)) is not None:
output.add_pid(pid, mine.max(theirs))
output.add_pid(pid, mine.aggregate(theirs))
else:
output.add_pid(pid, mine)
else:
Expand All @@ -292,6 +297,8 @@ def max(self: Sample, other: Sample) -> Sample:
output.total_pcpu = max(self.total_pcpu or 0.0, other.total_pcpu)
output.total_rss = max(self.total_rss or 0, other.total_rss)
output.total_vsz = max(self.total_vsz or 0, other.total_vsz)
output.averages = self.averages
output.averages.update(other)
return output

def for_json(self) -> dict[str, Any]:
Expand Down Expand Up @@ -333,11 +340,10 @@ def __init__(
self.session_id: int | None = None
self.gpus: list[dict[str, str]] | None = None
self.env: dict[str, str] | None = None
self.number = 0
self.number = 1
self.system_info: SystemInfo | None = None
self.max_values = Sample()
self.averages: Averages = Averages()
self.current_sample: Sample | None = None
self.full_run_stats = Sample()
self.current_sample: Optional[Sample] = None
self.end_time: float | None = None
self.run_time_seconds: str | None = None

Expand Down Expand Up @@ -443,8 +449,19 @@ def collect_sample(self) -> Optional[Sample]:
except subprocess.CalledProcessError as exc: # when session_id has no processes
lgr.debug("Error collecting sample: %s", str(exc))
return None

sample.averages = Averages.from_sample(sample)
return sample

def update_from_sample(self, sample: Sample) -> None:
self.full_run_stats = self.full_run_stats.aggregate(sample)
if self.current_sample is None:
self.current_sample = Sample().aggregate(sample)
else:
assert self.current_sample.averages is not None
self.current_sample = self.current_sample.aggregate(sample)
assert self.current_sample is not None

def write_subreport(self) -> None:
assert self.current_sample is not None
with open(self.log_paths.usage, "a") as resource_statistics_log:
Expand All @@ -464,15 +481,15 @@ def execution_summary(self) -> dict[str, Any]:
"command": self.command,
"logs_prefix": self.log_paths.prefix if self.log_paths else "",
"wall_clock_time": self.wall_clock_time,
"peak_rss": self.max_values.total_rss,
"average_rss": self.averages.rss,
"peak_vsz": self.max_values.total_vsz,
"average_vsz": self.averages.vsz,
"peak_pmem": self.max_values.total_pmem,
"average_pmem": self.averages.pmem,
"peak_pcpu": self.max_values.total_pcpu,
"average_pcpu": self.averages.pcpu,
"num_samples": self.averages.num_samples,
"peak_rss": self.full_run_stats.total_rss,
"average_rss": self.full_run_stats.averages.rss,
"peak_vsz": self.full_run_stats.total_vsz,
"average_vsz": self.full_run_stats.averages.vsz,
"peak_pmem": self.full_run_stats.total_pmem,
"average_pmem": self.full_run_stats.averages.pmem,
"peak_pcpu": self.full_run_stats.total_pcpu,
"average_pcpu": self.full_run_stats.averages.pcpu,
"num_samples": self.full_run_stats.averages.num_samples,
"num_reports": self.number,
}

Expand Down Expand Up @@ -672,15 +689,11 @@ def monitor_process(
break
# process is still running, but we could not collect sample
continue
report.averages.update(sample)
if report.current_sample is None:
sample.averages = Averages.from_sample(sample)
report.current_sample = sample
else:
assert report.current_sample.averages is not None
report.current_sample.averages.update(sample)
report.max_values = report.max_values.max(sample)
if report.start_time and report.elapsed_time >= report.number * report_interval:
report.update_from_sample(sample)
if (
report.start_time
and report.elapsed_time >= (report.number - 1) * report_interval
):
report.write_subreport()
report.current_sample = None
report.number += 1
Expand Down Expand Up @@ -887,7 +900,7 @@ def execute(args: Arguments) -> int:

# If we have any extra samples that haven't been written yet, do it now
if report.current_sample is not None:
report.max_values = report.max_values.max(report.current_sample)
report.full_run_stats = report.full_run_stats.aggregate(report.current_sample)
report.write_subreport()

report.process = process
Expand Down
Loading

0 comments on commit 43fa87c

Please sign in to comment.