2021-11-30 14:51:24 +01:00

314 lines
8.9 KiB
Python

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License
import abc
import asyncio
import dataclasses
import datetime
import decimal
import logging
import os
import platform
import signal
import subprocess
import sys
import lib.litani
class _MemoryProfiler:
@abc.abstractmethod
async def snapshot(self, root_pid):
raise NotImplementedError
class _UnixMemoryProfiler(_MemoryProfiler):
def __init__(self):
self.pids = set()
self.ps_cmd = "ps"
self.ps_args = ["-x", "-o", "pid,ppid,rss,vsz"]
async def snapshot(self, root_pid):
"""Return a dict containing memory usage of process and children"""
ret = {
"time":
datetime.datetime.now(datetime.timezone.utc).strftime(
lib.litani.TIME_FORMAT_R),
"rss": 0,
}
ps_output = await self._get_ps_output()
self.pids.add(root_pid)
current_pids = set(self.pids)
seen_pids = set()
while current_pids:
new_pids = set()
for pid in current_pids:
if pid in seen_pids:
continue
seen_pids.add(pid)
self._add_usage(pid, ps_output, ret)
for child in self._children_of(pid, ps_output):
if child not in seen_pids:
new_pids.add(child)
current_pids = new_pids
return ret
@staticmethod
def human_readable(memory):
units = ["B", "KiB", "MiB", "GiB", "TiB"]
idx = 0
memory = decimal.Decimal(memory)
while memory > 1023:
idx += 1
memory /= 1024
memory_str = memory.quantize(
decimal.Decimal("0.1"), rounding=decimal.ROUND_HALF_UP)
return f"{memory_str} {units[idx]}"
def compute_peak(self, trace):
peak = {}
for item in trace:
for k, v in item.items():
if k in ["time"]:
continue
try:
peak[k] = max(peak[k], v)
except KeyError:
peak[k] = v
human_readable = {}
for k, v in peak.items():
human_readable[f"human_readable_{k}"] = self.human_readable(v)
return {**peak, **human_readable}
async def _get_ps_output(self):
""" Format: {
"fields": ["pid", "ppid", "rss", "vsz"],
"processes": {
123: {
"pid": 123,
"ppid": 457,
"rss": 435116,
"vsz": 5761524
}
}
}"""
proc = await asyncio.create_subprocess_exec(
self.ps_cmd, *self.ps_args, stdout=subprocess.PIPE)
stdout, _ = await proc.communicate()
if proc.returncode:
logging.error(
"%s exited (return code %d)", " ".join(self.ps_cmd),
proc.returncode)
sys.exit(1)
ret = {
"fields": [],
"processes": {}
}
for line in stdout.decode("utf-8").splitlines():
if not ret["fields"]:
ret["fields"] = [f.lower() for f in line.split()]
continue
for idx, value in enumerate(line.split()):
if not idx:
current_pid = int(value)
ret["processes"][current_pid] = {"pid": current_pid}
continue
field = ret["fields"][idx]
value = int(value) * 1024 if field in ["rss", "vsz"] else int(value)
ret["processes"][current_pid][field] = value
return ret
@staticmethod
def _children_of(pid, ps_output):
for child_pid, child_process in ps_output["processes"].items():
if child_process["ppid"] == pid:
yield child_pid
@staticmethod
def _add_usage(pid, ps_output, datum):
try:
process_record = ps_output["processes"][pid]
except KeyError:
return
for field, value in process_record.items():
if field in ["pid", "ppid"]:
continue
try:
datum[field] += value
except KeyError:
datum[field] = value
@dataclasses.dataclass
class _MemoryProfileAccumulator:
profiler: _MemoryProfiler
profile_interval: int
pid: int = None
trace: dict = dataclasses.field(default_factory=dict)
def set_pid(self, pid):
self.pid = pid
async def __call__(self):
try:
while True:
if not self.pid:
await asyncio.sleep(1)
continue
result = await self.profiler.snapshot(self.pid)
if result:
try:
self.trace["trace"].append(result)
except KeyError:
self.trace["trace"] = [result]
await asyncio.sleep(self.profile_interval)
except asyncio.CancelledError:
if "trace" in self.trace:
self.trace["peak"] = self.profiler.compute_peak(
self.trace["trace"])
@dataclasses.dataclass
class _Process:
command: str
interleave_stdout_stderr: bool
timeout: int
cwd: str
job_id: str
proc: subprocess.CompletedProcess = None
stdout: str = None
stderr: str = None
timeout_reached: bool = None
async def __call__(self):
if self.interleave_stdout_stderr:
pipe = asyncio.subprocess.STDOUT
else:
pipe = asyncio.subprocess.PIPE
env = dict(os.environ)
env[lib.litani.ENV_VAR_JOB_ID] = self.job_id
proc = await asyncio.create_subprocess_shell(
self.command, stdout=asyncio.subprocess.PIPE, stderr=pipe,
cwd=self.cwd, env=env, start_new_session=True)
self.proc = proc
timeout_reached = False
try:
out, err = await asyncio.wait_for(
proc.communicate(), timeout=self.timeout)
except asyncio.TimeoutError:
pgid = os.getpgid(proc.pid)
os.killpg(pgid, signal.SIGTERM)
await asyncio.sleep(1)
try:
os.killpg(pgid, signal.SIGKILL)
except ProcessLookupError:
pass
out, err = await proc.communicate()
timeout_reached = True
self.stdout = out
self.stderr = err
self.timeout_reached = timeout_reached
class Runner:
@staticmethod
def get_memory_profiler(profile_memory, system):
if not profile_memory:
return None
return {
"Linux": _UnixMemoryProfiler(),
"Darwin": _UnixMemoryProfiler(),
}.get(system, None)
def __init__(
self, command, interleave_stdout_stderr, cwd, timeout,
profile_memory, profile_interval, job_id):
self.tasks = []
self.runner = _Process(
command=command, interleave_stdout_stderr=interleave_stdout_stderr,
cwd=cwd, timeout=timeout, job_id=job_id)
self.tasks.append(self.runner)
self.profiler = None
profiler = self.get_memory_profiler(profile_memory, platform.system())
if profiler:
self.profiler = _MemoryProfileAccumulator(profiler, profile_interval)
self.tasks.append(self.profiler)
async def __call__(self):
tasks = []
for task in self.tasks:
tasks.append(asyncio.create_task(task()))
if self.profiler:
for _ in range(10):
if self.runner.proc:
self.profiler.set_pid(self.runner.proc.pid)
break
await asyncio.sleep(1)
_, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
await task
def get_return_code(self):
return self.runner.proc.returncode
def get_stdout(self):
if self.runner.stdout:
return self.runner.stdout.decode("utf-8")
return None
def get_stderr(self):
if self.runner.stderr:
return self.runner.stderr.decode("utf-8")
return None
def reached_timeout(self):
return self.runner.timeout_reached
def get_memory_trace(self):
return self.profiler.trace if self.profiler else {}