#!/usr/bin/python
# @lint-avoid-python-3-compatibility-imports
#
# alisysdelay Catch tasks hung in kernel mode and get its stack
#             For Linux, uses BCC, BPF, perf_events. Embedded C.
#
# In kernel without CONFIG_PREEMPT enabled, scheduling latency sometimes
# may be unacceptable, because some tasks may run too long in kernel mode
# without calling cond_sched(). These tasks may be kernel thread, or user
# processes that enter kernel mode by syscall.
#
# alisysdelay can catch tasks that have been running continuously for a long time
# whithout scheduling, and kernel/user stackes of these tasks will be printed.
# The threshold can be user defined.
#
# By default CPU idle stacks are excluded by simply excluding PID 0.
#
# REQUIRES: Linux 4.9+ (BPF_PROG_TYPE_PERF_EVENT support).
#
# USAGE: alisysdelay [-h] [-d] [-a] [-f frequency] [-t threshold] [--stack-storage-size size]
#
# Copyright (c) 2016 Brendan Gregg, Netflix, Inc.
# Copyright (c) 2019 Jeffle Xu, Alibaba, Inc.
# Licensed under the Apache License, Version 2.0 (the "License")

from __future__ import print_function
from bcc import BPF, PerfType, PerfSWConfig
import argparse
import errno
import ctypes as ct


# arg validation
def positive_int(val):
    try:
        ival = int(val)
    except ValueError:
        raise argparse.ArgumentTypeError("must be an integer")

    if ival < 0:
        raise argparse.ArgumentTypeError("must be positive")
    return ival

def positive_nonzero_int(val):
    ival = positive_int(val)
    if ival == 0:
        raise argparse.ArgumentTypeError("must be nonzero")
    return ival

def stack_id_err(stack_id):
    # -EFAULT in get_stackid normally means the stack-trace is not availible,
    # Such as getting kernel stack trace in userspace code
    return (stack_id < 0) and (stack_id != -errno.EFAULT)

# arguments
examples = """examples:
    ./alisysdelay             # catch tasks hung in kernel mode
    ./alisysdelay -f 99       # catch tasks hung in kernel mode at 99 Hertz
    ./alisysdelay -t 100      # catch tasks that hung in kernel mode for more than 100us
"""
parser = argparse.ArgumentParser(
    description="Catch tasks hung in kernel mode and get its stack",
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog=examples)
parser.add_argument("-f", "--frequency", type=positive_nonzero_int, default=100,
    help="sample frequency in Hertz "
        "(default %(default)s)")
parser.add_argument("-d", "--delimited", action="store_true",
    help="insert delimiter between kernel/user stacks")
parser.add_argument("-a", "--annotations", action="store_true",
    help="add _[k] annotations to kernel frames")
parser.add_argument("-t", "--threshold", type=positive_nonzero_int, default=50,
    help="set the threshold of delay in kernel mode, in us "
        "(default %(default)s)")
parser.add_argument("--stack-storage-size", type=positive_nonzero_int, default=16384,
    help="the number of unique stack traces that can be stored and displayed "
        "(default %(default)s)")
parser.add_argument("--ebpf", action="store_true",
    help=argparse.SUPPRESS)

# option logic
args = parser.parse_args()
debug = 0


# define BPF program
bpf_text = """
#include <uapi/linux/ptrace.h>
#include <uapi/linux/bpf_perf_event.h>
#include <asm/segment.h>
#include <asm/processor-flags.h>
#include <linux/sched.h>

typedef struct event {
    u32 pid;
    u64 kernel_ip;
    int user_stack_id;
    int kernel_stack_id;
    char comm[TASK_COMM_LEN];
} event_t;

BPF_PERCPU_ARRAY(start, u64, 1);
BPF_STACK_TRACE(stack_traces, STACK_STORAGE_SIZE);
BPF_PERF_OUTPUT(events);

static inline void update_start_time()
{
    u32 idx = 0;
    u64 ts = bpf_ktime_get_ns();
    start.update(&idx, &ts);
}

TRACEPOINT_PROBE(raw_syscalls, sys_enter)
{
    update_start_time();
    return 0;
}

TRACEPOINT_PROBE(sched, sched_switch)
{
    update_start_time();
    return 0;
}

int kprobe___cond_resched(struct pt_regs *ctx)
{
    update_start_time();
    return 0;
}

// This is copy from arch/x86/asm/ptrace.h
static inline int user_mode_dump(struct pt_regs *regs)
{
#ifdef CONFIG_X86_32
    u64 cs, flags;
    bpf_probe_read(&cs, sizeof(u64), &regs->cs);
    bpf_probe_read(&flags, sizeof(u64), &regs->flags);
    return ((cs & SEGMENT_RPL_MASK) | (flags & X86_VM_MASK)) >= USER_RPL;
#else
    u64 cs;
    bpf_probe_read(&cs, sizeof(u64), &regs->cs);
    return !!( cs & 3);
#endif
}

static bool need_check_delay(struct pt_regs *regs)
{
    // ignore process 0
    if (bpf_get_current_pid_tgid() == 0) {
        return 0;
    }

    // skip when currently cpu is in user mode
    if (user_mode_dump(regs)) {
        return 0;
    }

    return 1;
}

static inline void get_stack_id(struct bpf_perf_event_data *ctx)
{
    event_t key = {.pid = bpf_get_current_pid_tgid()};
    bpf_get_current_comm(&key.comm, sizeof(key.comm));

    // get stacks
    key.user_stack_id = stack_traces.get_stackid(&ctx->regs, BPF_F_USER_STACK);
    key.kernel_stack_id = stack_traces.get_stackid(&ctx->regs, 0);

    if (key.kernel_stack_id >= 0) {
        // populate extras to fix the kernel stack
        u64 ip = PT_REGS_IP(&ctx->regs);
        u64 page_offset;

        // if ip isn't sane, leave key ips as zero for later checking
#if defined(CONFIG_X86_64) && defined(__PAGE_OFFSET_BASE)
        // x64, 4.16, ..., 4.11, etc., but some earlier kernel didn't have it
        page_offset = __PAGE_OFFSET_BASE;
#elif defined(CONFIG_X86_64) && defined(__PAGE_OFFSET_BASE_L4)
        // x64, 4.17, and later
#if defined(CONFIG_DYNAMIC_MEMORY_LAYOUT) && defined(CONFIG_X86_5LEVEL)
        page_offset = __PAGE_OFFSET_BASE_L5;
#else
        page_offset = __PAGE_OFFSET_BASE_L4;
#endif
#else
        // earlier x86_64 kernels, e.g., 4.6, comes here
        // arm64, s390, powerpc, x86_32
        page_offset = PAGE_OFFSET;
#endif

        if (ip > page_offset) {
            key.kernel_ip = ip;
        }
    }

    events.perf_submit(ctx, &key, sizeof(event_t));
}

int sys_timer(struct bpf_perf_event_data *ctx) {

    if (need_check_delay(&ctx->regs)) {
        u32 idx = 0;
        u64 *tsp = start.lookup(&idx);
        if (!tsp || *tsp == 0) {
            return 0;   // missed start
        }

        if ((bpf_ktime_get_ns() - *tsp) >= (DELAY_THRESHOLD)) {
            get_stack_id(ctx);
        }
    }
    else {
        update_start_time();
    }

    return 0;
}
"""

# set stack storage size
bpf_text = bpf_text.replace('STACK_STORAGE_SIZE', str(args.stack_storage_size))
bpf_text = bpf_text.replace('DELAY_THRESHOLD', str(args.threshold * 1000))


if debug or args.ebpf:
    print(bpf_text)
    if args.ebpf:
        exit()

# initialize BPF & perf_events
b = BPF(text=bpf_text)
b.attach_perf_event(ev_type=PerfType.SOFTWARE,
    ev_config=PerfSWConfig.CPU_CLOCK, fn_name="sys_timer",
    sample_freq=args.frequency, pid=-1, cpu=-1, group_fd=-1)


def aksym(addr):
    if args.annotations:
        return b.ksym(addr) + "_[k]".encode()
    else:
        return b.ksym(addr)

TASK_COMM_LEN = 16
class Record(ct.Structure):
    _fields_ = [("pid", ct.c_ulong),
                ("kernel_ip", ct.c_ulonglong),
                ("user_stack_id", ct.c_int),
                ("kernel_stack_id", ct.c_int),
                ("comm", ct.c_char * TASK_COMM_LEN)]


def print_stack(cpu, data, size):
    k = ct.cast(data, ct.POINTER(Record)).contents

    user_stack = [] if k.user_stack_id < 0 else stack_traces.walk(k.user_stack_id)
    kernel_tmp = [] if k.kernel_stack_id < 0 else stack_traces.walk(k.kernel_stack_id)

    # fix kernel stack
    kernel_stack = []
    if k.kernel_stack_id >= 0:
        for addr in kernel_tmp:
            kernel_stack.append(addr)
        # the later IP checking
        if k.kernel_ip:
            kernel_stack.insert(0, k.kernel_ip)

    # print default multi-line stack output
    if stack_id_err(k.kernel_stack_id):
        print("    [Missed Kernel Stack]")
    else:
        for addr in kernel_stack:
            print("    %s" % aksym(addr))
    if args.delimited and k.user_stack_id >= 0 and k.kernel_stack_id >= 0:
        print("    --")
    if stack_id_err(k.user_stack_id):
        print("    [Missed User Stack]")
    else:
        for addr in user_stack:
            print("    %s" % b.sym(addr, k.pid).decode('utf-8', 'replace'))
    print("    %-16s %s (%d)" % ("-", k.comm.decode('utf-8', 'replace'), k.pid))
    print()


# header
print("Sampling at %s Hz frequency" % args.frequency, end="")
print("... Hit Ctrl-C to end.")

# collect samples
b["events"].open_perf_buffer(print_stack)
stack_traces = b["stack_traces"]

while (1):
    try:
        b.perf_buffer_poll()
    except KeyboardInterrupt:
        exit()
