#!/usr/bin/python
# @lint-avoid-python-3-compatibility-imports
from __future__ import print_function

import argparse
import ctypes as ct
import os
import platform
import re
import sys
import time

from bcc import BPF
from bcc.utils import printb
from datetime import datetime
from time import strftime

#
# alimutexsnoop.py Trace all process that holds a mutex longer than a
# specified time interval (500us, etc).
#           For Linux, uses BCC, eBPF. Embedded C.
#
# USAGE: alimutexsnoop [-h] [-t] [-p PID]
#
# Author: Alvin Zheng
_examples = """examples:
    alimutexsnoop           # trace all mutexes that are held for a long time
    alimutexsnoop -t 50     # set the time threshold (us), mutexes are held longer
                              than this time threshold will be printed
    alimutexsnoop -p 123    # only trace the specified process
"""
"""
    Copyright (c) 2019 Alvin Zheng, Alibaba, Inc.
    Licensed under the Apache License, Version 2.0 (the "License")
"""

def positive_int(val):
    try:
        ival = int(val)
    except ValueError:
        raise argparse.ArgumentTypeError("must be an integer")
    if ival <= 0:
        raise argparse.ArgumentTypeError("must be positive value")
    return ival

def _getParser():
    parser = argparse.ArgumentParser(
        description="Trace all mutexes that have been held for a long time",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=_examples)
    a = parser.add_argument
    a("-t", "--threshold", type=positive_int, default=1000,
      help="The mutexes held longer than this time threshold will be traced(us)")
    a("-p", "--pid", type=positive_int,
      help="trace this PID only")
    return parser.parse_args


class Global():
    parse_args = _getParser()
    args = None
    argv = None
    bpf = None

class Data(ct.Structure):
    """Event data matching struct data_t in _embedded_c()."""
    _TASK_COMM_LEN = 16      # linux/sched.h
    _pack_ = 1
    _fields_ = [
        ("start_time", ct.c_ulonglong),  # task->start_time
        ("pid_tgid", ct.c_ulonglong),
        ("pid", ct.c_uint),  # task->tgid, thread group id == sys_getpid()
        ("mutex_addr", ct.c_ulonglong),
        ("k_stack_id", ct.c_int),
        ("u_stack_id", ct.c_int),
        ("task", ct.c_char * _TASK_COMM_LEN)
    ]

def _embedded_c(args):
    """Generate C programs for mutex_lock and mutex_unlock."""
    c = """
    EBPF_COMMENT
    #include <linux/sched.h>
    BPF_STATIC_ASSERT_DEF
    BPF_STACK_TRACE(stack_traces,655360);
    struct key_t{
        u64 pid_tgid;
        u64 mutex_addr;
    } __attribute__((packed));
    BPF_HASH(held_mutexes, struct key_t,u64);
    struct data_t{
        u64 start_time;
        u64 pid_tgid;
        u32 pid;
        u64 mutex_addr;
        u32 k_stack_id;
        u32 u_stack_id;
        char task[TASK_COMM_LEN];
    } __attribute__((packed));
    BPF_STATIC_ASSERT(sizeof(struct data_t) == CTYPES_SIZEOF_DATA);
    BPF_PERF_OUTPUT(results);

    int kprobe_ret_mutex_lock(struct pt_regs * ctx,struct mutex * lock)
    {
        struct task_struct *task = (typeof(task))bpf_get_current_task();
        FILTER_SELF_LOCK
        if (FILTER_PID) { return 0; }
        struct key_t key = {
            .pid_tgid=bpf_get_current_pid_tgid(),
            .mutex_addr=(u64)((void*)lock)
        };
        u64 locktime = bpf_ktime_get_ns();
        u64 * ptr=held_mutexes.lookup_or_init(&key, &locktime);
        if(!ptr){
            bpf_trace_printk("could not add held_mutexes, thread: %d, mutex: %p \\n" ,task->tgid,lock);
            return 1;
        }
        return 0;
    }
    int kprobe_mutex_unlock(struct pt_regs * ctx, struct mutex * lock)
    {
        struct task_struct *task = (typeof(task))bpf_get_current_task();
        FILTER_SELF_UNLOCK
        if (FILTER_PID) { return 0; }
        struct key_t key = {
            .pid_tgid=bpf_get_current_pid_tgid(),
            .mutex_addr=(u64)((void*)lock)
        };
        u64 cur = bpf_ktime_get_ns();
        u64 * locktime=held_mutexes.lookup(&key);
        if(!locktime){
            bpf_trace_printk("cannot find the lock in thread: %d mutex: %p \\n",task->tgid,lock);
            return 1;
        }
        if(cur - (*locktime) > TIME_THRESHOLD){
            struct data_t data = {
                .start_time=task->start_time,
                .pid_tgid=bpf_get_current_pid_tgid(),
                .pid=task->pid,
                .mutex_addr=(u64)((void*)lock)
            };
            bpf_get_current_comm(&data.task,sizeof(data.task));
            data.k_stack_id=stack_traces.get_stackid(ctx,BPF_F_REUSE_STACKID);
            data.u_stack_id=stack_traces.get_stackid(ctx,BPF_F_REUSE_STACKID|BPF_F_USER_STACK);
            results.perf_submit(ctx,&data,sizeof(data));
        }
        held_mutexes.delete(&key);
        return 0;
    }
    """
    # TODO: this macro belongs in bcc/src/cc/export/helpers.h
    bpf_static_assert_def = r"""
    #ifndef BPF_STATIC_ASSERT
    #define BPF_STATIC_ASSERT(condition) __attribute__((unused)) \
    extern int bpf_static_assert[(condition) ? 1 : -1]
    #endif
    """
    selfpid = os.getpid()
    filter_self = r"""
    if(task->tgid==%d){
        return 0;
    }
    """ % selfpid
    code_substitutions = [
        ('EBPF_COMMENT', _ebpf_comment()),
        ("BPF_STATIC_ASSERT_DEF", bpf_static_assert_def),
        ("CTYPES_SIZEOF_DATA", str(ct.sizeof(Data))),
        ("TIME_THRESHOLD", str(Global.args.threshold * 1000)),  # ns->us
        ("FILTER_SELF_LOCK", filter_self),
        ("FILTER_SELF_UNLOCK", filter_self),
        ('FILTER_PID', '0' if not Global.args.pid else "task->pid != %d" % Global.args.pid)
    ]
    for old, new in code_substitutions:
        c = c.replace(old, new)
    # print(c)
    return c

def _ebpf_comment():
    """Return a C-style comment with information about the generated code."""
    comment = ('Created by %s at %s:\n\t%s' %
                    (sys.argv[0], strftime("%Y-%m-%d %H:%M:%S %Z"), _embedded_c.__doc__))
    args = str(vars(Global.args)).replace('{', '{\n\t').replace(', ', ',\n\t').replace('}', ',\n }\n\n')
    return ("\n   /*" + ("\n %s\n\n ARGV = %s\n\n ARGS = %s/" %
                             (comment, ' '.join(Global.argv), args))
                   .replace('\n', '\n\t*').replace('\t', '    '))

def _print_header():
    print("%-16s %-6s  %-16s \n" % ("PCOMM", "PID", "MUTEX_ADDR"))

def _print_event(cpu, data, size):  # callback
    """Print the long-held mutexes."""
    stack_traces = Global.bpf['stack_traces']
    e = ct.cast(data, ct.POINTER(Data)).contents
    print("%-16s %-6d  %-16x\n" %
              (e.task.decode(), e.pid, e.mutex_addr), end="")
    # print the kernel stack
    if e.k_stack_id >= 0:
        cur_stack = reversed(list(stack_traces.walk(e.k_stack_id)))
        for addr in cur_stack:
            printb(b"Kernel Stack:    %-16x  %s" % (addr, Global.bpf.ksym(addr)))
    else:
        print("No KERNEL STACK FOUND")
    # print the user stack
    if e.u_stack_id >= 0:
        cur_stack = reversed(list(stack_traces.walk(e.u_stack_id)))
        for addr in cur_stack:
            printb(b"USER Stack:      %-16x  %s" % (addr, Global.bpf.sym(addr, e.pid_tgid)))
    else:
        print("No USER STACK FOUND")
    print("=================================")
    print()
# =============================
# Module: These functions are available for import
# =============================
def initialize(arg_list=sys.argv[1:]):
    """Trace all long-held mutexes

    arg_list - list of args, if omitted then uses command line args
               arg_list is passed to argparse.ArgumentParser.parse_args()

    Returns a tuple (return_code, result)
       0 = Ok, result is the return value from BPF()
       1 = args.ebpf is requested, result is the generated C code
       os.EX_NOPERM: need CAP_SYS_ADMIN, result is error message
       os.EX_SOFTWARE: internal software error, result is error message
    """
    Global.argv = arg_list
    Global.args = Global.parse_args(arg_list)
    if os.geteuid() != 0:
        return (os.EX_NOPERM, "Need sudo (CAP_SYS_ADMIN) for BPF() system call")
    c = _embedded_c(Global.args)
    try:
        return (os.EX_OK, BPF(text=c))
    except Exception as e:
        error = format(e)
        return (os.EX_SOFTWARE, "BPF error: " + error)

def snoop(bpf, event_handler):
    """Call event_handler when release a lock that is held for a long time
    bpf - result returned by successful initialize()
    event_handler - callback function to handle termination event
    args.pid - Return after event_handler is called, only monitoring this pid
    """
    bpf["results"].open_perf_buffer(event_handler)
    while True:
        bpf.perf_buffer_poll()
        if Global.args.pid:
            if not os.path.exists('/proc/%d' % Global.args.pid):
                # the target process has exited
                break

# =============================
# Script: invoked as a script
# =============================
def main():
    try:
        rc, buffer = initialize()
        if rc != os.EX_OK:
            print(buffer)
            sys.exit(-1)
        Global.bpf = buffer
        Global.bpf.attach_kretprobe(event="mutex_lock", fn_name="kprobe_ret_mutex_lock")
        Global.bpf.attach_kprobe(event="mutex_unlock", fn_name="kprobe_mutex_unlock")
        _print_header()
        snoop(buffer, _print_event)
    except KeyboardInterrupt:
        print()
        sys.exit()

    return 0

if __name__ == '__main__':
    main()
