# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # # http://www.apache.org/licenses/LICENSE-2.0 # # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License import abc import asyncio import dataclasses import datetime import decimal import logging import os import platform import signal import subprocess import sys import lib.litani class _MemoryProfiler: @abc.abstractmethod async def snapshot(self, root_pid): raise NotImplementedError class _UnixMemoryProfiler(_MemoryProfiler): def __init__(self): self.pids = set() self.ps_cmd = "ps" self.ps_args = ["-x", "-o", "pid,ppid,rss,vsz"] async def snapshot(self, root_pid): """Return a dict containing memory usage of process and children""" ret = { "time": datetime.datetime.now(datetime.timezone.utc).strftime( lib.litani.TIME_FORMAT_R), "rss": 0, } ps_output = await self._get_ps_output() self.pids.add(root_pid) current_pids = set(self.pids) seen_pids = set() while current_pids: new_pids = set() for pid in current_pids: if pid in seen_pids: continue seen_pids.add(pid) self._add_usage(pid, ps_output, ret) for child in self._children_of(pid, ps_output): if child not in seen_pids: new_pids.add(child) current_pids = new_pids return ret @staticmethod def human_readable(memory): units = ["B", "KiB", "MiB", "GiB", "TiB"] idx = 0 memory = decimal.Decimal(memory) while memory > 1023: idx += 1 memory /= 1024 memory_str = memory.quantize( decimal.Decimal("0.1"), rounding=decimal.ROUND_HALF_UP) return f"{memory_str} {units[idx]}" def compute_peak(self, trace): peak = {} for item in trace: for k, v in item.items(): if k in ["time"]: continue try: peak[k] = max(peak[k], v) except KeyError: peak[k] = v human_readable = {} for k, v in peak.items(): human_readable[f"human_readable_{k}"] = self.human_readable(v) return {**peak, **human_readable} async def _get_ps_output(self): """ Format: { "fields": ["pid", "ppid", "rss", "vsz"], "processes": { 123: { "pid": 123, "ppid": 457, "rss": 435116, "vsz": 5761524 } } }""" proc = await asyncio.create_subprocess_exec( self.ps_cmd, *self.ps_args, stdout=subprocess.PIPE) stdout, _ = await proc.communicate() if proc.returncode: logging.error( "%s exited (return code %d)", " ".join(self.ps_cmd), proc.returncode) sys.exit(1) ret = { "fields": [], "processes": {} } for line in stdout.decode("utf-8").splitlines(): if not ret["fields"]: ret["fields"] = [f.lower() for f in line.split()] continue for idx, value in enumerate(line.split()): if not idx: current_pid = int(value) ret["processes"][current_pid] = {"pid": current_pid} continue field = ret["fields"][idx] value = int(value) * 1024 if field in ["rss", "vsz"] else int(value) ret["processes"][current_pid][field] = value return ret @staticmethod def _children_of(pid, ps_output): for child_pid, child_process in ps_output["processes"].items(): if child_process["ppid"] == pid: yield child_pid @staticmethod def _add_usage(pid, ps_output, datum): try: process_record = ps_output["processes"][pid] except KeyError: return for field, value in process_record.items(): if field in ["pid", "ppid"]: continue try: datum[field] += value except KeyError: datum[field] = value @dataclasses.dataclass class _MemoryProfileAccumulator: profiler: _MemoryProfiler profile_interval: int pid: int = None trace: dict = dataclasses.field(default_factory=dict) def set_pid(self, pid): self.pid = pid async def __call__(self): try: while True: if not self.pid: await asyncio.sleep(1) continue result = await self.profiler.snapshot(self.pid) if result: try: self.trace["trace"].append(result) except KeyError: self.trace["trace"] = [result] await asyncio.sleep(self.profile_interval) except asyncio.CancelledError: if "trace" in self.trace: self.trace["peak"] = self.profiler.compute_peak( self.trace["trace"]) @dataclasses.dataclass class _Process: command: str interleave_stdout_stderr: bool timeout: int cwd: str job_id: str proc: subprocess.CompletedProcess = None stdout: str = None stderr: str = None timeout_reached: bool = None async def __call__(self): if self.interleave_stdout_stderr: pipe = asyncio.subprocess.STDOUT else: pipe = asyncio.subprocess.PIPE env = dict(os.environ) env[lib.litani.ENV_VAR_JOB_ID] = self.job_id proc = await asyncio.create_subprocess_shell( self.command, stdout=asyncio.subprocess.PIPE, stderr=pipe, cwd=self.cwd, env=env, start_new_session=True) self.proc = proc timeout_reached = False try: out, err = await asyncio.wait_for( proc.communicate(), timeout=self.timeout) except asyncio.TimeoutError: pgid = os.getpgid(proc.pid) os.killpg(pgid, signal.SIGTERM) await asyncio.sleep(1) try: os.killpg(pgid, signal.SIGKILL) except ProcessLookupError: pass out, err = await proc.communicate() timeout_reached = True self.stdout = out self.stderr = err self.timeout_reached = timeout_reached class Runner: @staticmethod def get_memory_profiler(profile_memory, system): if not profile_memory: return None return { "Linux": _UnixMemoryProfiler(), "Darwin": _UnixMemoryProfiler(), }.get(system, None) def __init__( self, command, interleave_stdout_stderr, cwd, timeout, profile_memory, profile_interval, job_id): self.tasks = [] self.runner = _Process( command=command, interleave_stdout_stderr=interleave_stdout_stderr, cwd=cwd, timeout=timeout, job_id=job_id) self.tasks.append(self.runner) self.profiler = None profiler = self.get_memory_profiler(profile_memory, platform.system()) if profiler: self.profiler = _MemoryProfileAccumulator(profiler, profile_interval) self.tasks.append(self.profiler) async def __call__(self): tasks = [] for task in self.tasks: tasks.append(asyncio.create_task(task())) if self.profiler: for _ in range(10): if self.runner.proc: self.profiler.set_pid(self.runner.proc.pid) break await asyncio.sleep(1) _, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED) for task in pending: task.cancel() await task def get_return_code(self): return self.runner.proc.returncode def get_stdout(self): if self.runner.stdout: return self.runner.stdout.decode("utf-8") return None def get_stderr(self): if self.runner.stderr: return self.runner.stderr.decode("utf-8") return None def reached_timeout(self): return self.runner.timeout_reached def get_memory_trace(self): return self.profiler.trace if self.profiler else {}