#!/usr/bin/python
# @lint-avoid-python-3-compatibility-imports
#
# alijbd2stat    Summarize filesystem jbd2 stat.
#       For Linux, uses BCC, eBPF.
#
# USAGE: alijbd2stat [-h] [-t thresh] [-d device]
#
# Copyright (c) 2019-2021 Alibaba Group.
# Licensed under the Apache License, Version 2.0 (the "License")
#
# 2019/08/05 Xiaoguang Wang Created this.

from __future__ import print_function
from bcc import BPF
import ctypes as ct
import argparse
import time

# arguments
examples = """examples:
    ./alijbd2stat          # summarize filesystem jbd2 stat
    ./alijbd2stat -d sda3  # inspect specified device /dev/sda3
    ./alijbd2stat -t 10    # show jbd2 handle's context when it runs more than 10ms
"""
parser = argparse.ArgumentParser(
    description="Summarize filesystem jbd2 stat",
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog=examples)
parser.add_argument("-d", "--device", help="inspect specified device")
parser.add_argument("-t", "--thresh_time", nargs="?", default=0,
    help="show jbd2 handle's context when its run time is greater than this value")
parser.add_argument("--ebpf", action="store_true",
    help=argparse.SUPPRESS)
args = parser.parse_args()
debug = 0

# define BPF program
bpf_text = """
#include <uapi/linux/ptrace.h>
#include <linux/sched.h>
#include <linux/blkdev.h>
#include <linux/jbd2.h>

struct handle_stat {
	dev_t dev;
	u32 type;
	u32 line_no;
	unsigned long tid;
	unsigned long pre_start_time;
	unsigned long start_time;
	unsigned long end_time;
	u64 sched_delay;
	u64 kernel_ip;
	int kernel_stack_id;
	u32 pid;
	char comm[TASK_COMM_LEN];
};

struct transaction_run_stat {
	dev_t dev;
	u32 handle_count;
	u32 blocks;
	u32 blocks_logged;
	unsigned long tid;
	unsigned long wait;
	unsigned long request_delay;
	unsigned long running;
	unsigned long locked;
	unsigned long flushing;
	unsigned long logging;
};

/*
struct handle_info_per_type {
	u32 count
};
*/

BPF_HASH(handle_stat_map, struct task_struct *, struct handle_stat);
BPF_STACK_TRACE(stack_traces, 1024);
BPF_PERF_OUTPUT(jbd2_handle_stat);
BPF_PERF_OUTPUT(transaction_stat);


int trace_start_this_handle(struct pt_regs *ctx, journal_t *journal,
		handle_t *handle, gfp_t gfp_mask)
{
	struct handle_stat s;
	struct task_struct *t = (struct task_struct *)bpf_get_current_task();
	dev_t dev = journal->j_fs_dev->bd_dev;

	if (FILTER_DEV)
		return 0;

	memset(&s, 0, sizeof(struct handle_stat));
	s.dev = dev;
	s.pre_start_time = bpf_ktime_get_ns();
	s.sched_delay = t->sched_info.run_delay;

	handle_stat_map.update(&t, &s);
	return 0;
}

TRACEPOINT_PROBE(jbd2, jbd2_handle_start)
{
	struct handle_stat *s;
	struct task_struct *t = (struct task_struct *)bpf_get_current_task();
	dev_t dev = args->dev;

	if (FILTER_DEV)
		return 0;

	s = handle_stat_map.lookup(&t);
	if (s == NULL)
		return 0;

	s->start_time = bpf_ktime_get_ns();
	s->type = args->type;
	s->line_no = args->line_no;
	s->tid = args->tid;
	return 0;
}

int trace_jbd2_journal_stop(struct pt_regs *ctx, handle_t *handle)
{
	transaction_t *transaction = handle->h_transaction;
	journal_t *journal;
	struct handle_stat *s;
	struct handle_stat s2 = {};
	struct task_struct *t = (struct task_struct *)bpf_get_current_task();
	unsigned long end_time = bpf_ktime_get_ns();
	unsigned long run_time;

	if (transaction == NULL) {
		handle_stat_map.delete(&t);
		return 0;
	}

	journal = transaction->t_journal;
	dev_t dev = journal->j_fs_dev->bd_dev;
	if (FILTER_DEV)
		return 0;

	if (handle->h_ref >= 2)
		return 0;

	s = handle_stat_map.lookup(&t);
	if (s == NULL)
		return 0;

	/*
	 * This could happen, jbd2_journal_start_reserved() miss a trace
	 * jbd2_handle_start().
	 */
	if (!s->start_time)
		s->start_time = s->pre_start_time;


	run_time = end_time - s->pre_start_time;
	if (run_time < FILTER_THRESH_TIME * 1000 * 1000LLU) {
		handle_stat_map.delete(&t);
		return 0;
	}

	s2.dev = s->dev;
	s2.type = s->type;
	s2.line_no = s->line_no;
	s2.tid = s->tid;
	s2.pre_start_time = s->pre_start_time;
	s2.start_time = s->start_time;
	s2.end_time = end_time;
	s2.pid = bpf_get_current_pid_tgid();
	s2.sched_delay = t->sched_info.run_delay - s->sched_delay;
	bpf_get_current_comm(&s2.comm, sizeof(s2.comm));

	s2.kernel_stack_id = stack_traces.get_stackid(ctx, 0);
	if (s2.kernel_stack_id >= 0) {
		// populate extras to fix the kernel stack
		u64 ip = PT_REGS_IP(ctx);
		u64 page_offset;

		// if ip isn't sane, leave key ips as zero for later checking
		#if defined(CONFIG_X86_64) && defined(__PAGE_OFFSET_BASE)
			// x64, 4.16, ..., 4.11, etc., but some earlier kernel didn't have it
			page_offset = __PAGE_OFFSET_BASE;
		#elif defined(CONFIG_X86_64) && defined(__PAGE_OFFSET_BASE_L4)
			// x64, 4.17, and later
		#if defined(CONFIG_DYNAMIC_MEMORY_LAYOUT) && defined(CONFIG_X86_5LEVEL)
			page_offset = __PAGE_OFFSET_BASE_L5;
		#else
			page_offset = __PAGE_OFFSET_BASE_L4;
		#endif
		#else
			// earlier x86_64 kernels, e.g., 4.6, comes here
			// arm64, s390, powerpc, x86_32
			page_offset = PAGE_OFFSET;
		#endif

		if (ip > page_offset)
			s2.kernel_ip = ip;
	}

	jbd2_handle_stat.perf_submit(ctx, &s2, sizeof(struct handle_stat));
	handle_stat_map.delete(&t);
	return 0;
}

TRACEPOINT_PROBE(jbd2, jbd2_run_stats)
{
	struct transaction_run_stat s;
	dev_t dev = args->dev;;

	if (FILTER_DEV)
		return 0;

	memset(&s, 0, sizeof(struct transaction_run_stat));
	s.dev = args->dev;
	s.tid = args->tid;
	s.wait = args->wait;
	s.request_delay = args->request_delay;
	s.running = args->running;
	s.locked = args->locked;
	s.flushing = args->flushing;
	s.logging = args->logging;
	s.handle_count = args->handle_count;
	s.blocks = args->blocks;
	s.blocks_logged = args->blocks_logged;

	transaction_stat.perf_submit(args, &s, sizeof(struct transaction_run_stat));
	return 0;
}

"""
devid2name={}
num_trans=0

def init_dev_name():
        global devid2name

        f = open("/proc/partitions")
        for line in f.readlines():
                line = line.strip()
                if not len(line):
                        continue;

                str = line.split()
                if str[0].isalpha():
                        continue
                dev_num = (int(str[0]) << 20) + int(str[1])
                devid2name[dev_num] = str[3]
        f.close()

init_dev_name()

name2devid={v:k for k,v in devid2name.items()}

if args.device:
	bpf_text = bpf_text.replace('FILTER_DEV', 'dev != %u' % name2devid[args.device])
else:
	bpf_text = bpf_text.replace('FILTER_DEV', '0')

if args.thresh_time:
	bpf_text = bpf_text.replace('FILTER_THRESH_TIME', '%s' % args.thresh_time)
else:
	bpf_text = bpf_text.replace('FILTER_THRESH_TIME', '9999')

# code substitutions
if debug or args.ebpf:
    print(bpf_text)
    if args.ebpf:
        exit()

# load BPF program
b = BPF(text=bpf_text)

b.attach_kprobe(event="start_this_handle", fn_name="trace_start_this_handle")
b.attach_kprobe(event="jbd2_journal_stop", fn_name="trace_jbd2_journal_stop")

print("Tracing jbd2 stats... Hit Ctrl-C to end.")

TASK_COMM_LEN = 16

class HandleStat(ct.Structure):
	_fields_ = [
		("dev", ct.c_uint),
		("type", ct.c_uint),
		("line_no", ct.c_uint),
		("tid", ct.c_ulonglong),
		("pre_start_time", ct.c_ulonglong),
		("start_time", ct.c_ulonglong),
		("end_time", ct.c_ulonglong),
		("sched_delay", ct.c_ulonglong),
		("kernel_ip", ct.c_ulonglong),
		("kernel_stack_id", ct.c_int),
		("pid", ct.c_uint),
		("comm", ct.c_char * TASK_COMM_LEN),
    ]

class Data(ct.Structure):
	_fields_ = [
		("dev", ct.c_uint),
		("handle_count", ct.c_uint),
		("blocks", ct.c_uint),
		("blocks_logged", ct.c_uint),
		("tid", ct.c_ulonglong),
		("wait", ct.c_ulonglong),
		("request_delay", ct.c_ulonglong),
		("running", ct.c_ulonglong),
		("locked", ct.c_ulonglong),
		("flushing", ct.c_ulonglong),
		("logging", ct.c_ulonglong),
    ]

class trans_stat:
	def __init__(self, dev):
		self.dev=dev
		self.handle_count=0
		self.blocks=0
		self.blocks_logged=0
		self.tid=0
		slef.wait=0
		self.request_delay=0
		self.running=0
		self.locked=0
		self.flushing=0
		self.logging=0


def stack_id_err(stack_id):
	# -EFAULT in get_stackid normally means the stack-trace is not available,
	# Such as getting kernel stack trace in userspace code
	return (stack_id < 0) and (stack_id != -errno.EFAULT)

def aksym(addr):
	return b.ksym(addr) + "_[k]".encode()

def print_stack(cpu, data, size):
	k = ct.cast(data, ct.POINTER(HandleStat)).contents

	kernel_tmp = [] if k.kernel_stack_id < 0 else stack_traces.walk(k.kernel_stack_id)

	# fix kernel stack
	kernel_stack = []
	if k.kernel_stack_id >= 0:
		for addr in kernel_tmp:
			kernel_stack.append(addr)
		# the later IP checking
		if k.kernel_ip:
			kernel_stack.insert(0, k.kernel_ip)

	# print default multi-line stack output
	if stack_id_err(k.kernel_stack_id):
		print("    [Missed Kernel Stack]")
	else:
		print("comm: %s pid: %u tid: %lu pre_start_time: %lu run_time: %lu sched_delay: %lu" %
			  (k.comm, k.pid, k.tid, (k.start_time - k.pre_start_time) / 1000,
			  (k.end_time - k.pre_start_time) / 1000, k.sched_delay / 1000 / 1000))
		for addr in kernel_stack:
			print("    %s" % aksym(addr))
	#print("    %-16s %s (%d)" % ("-", k.comm.decode('utf-8', 'replace'), k.pid))
	#print()

def print_header():
	print("\n%-8s %-12s %-4s %-13s %-7s %-6s %-8s %-7s %-12s %-6s %-13s" % ("dev",
		  "tid", "wait", "request_delay", "running", "locked", "flushing", "logging",
		  "handle_count", "blocks", "blocks_logged"))

# process event
def print_event(cpu, data, size):
	global num_trans

	event = ct.cast(data, ct.POINTER(Data)).contents
	if (num_trans % 10) == 0:
		print_header()

	print("%-8s %-12lu %-4lu %-13lu %-7lu %-6lu %-8lu %-7lu %-12lu %-6lu %-13lu" %
		  (devid2name[event.dev], event.tid, event.wait, event.request_delay, event.running,
		  event.locked, event.flushing, event.logging, event.handle_count, event.blocks,
		  event.blocks_logged))
	num_trans = num_trans + 1

label="usecs"
stack_traces = b["stack_traces"]
b["jbd2_handle_stat"].open_perf_buffer(print_stack, page_cnt=64)
b["transaction_stat"].open_perf_buffer(print_event, page_cnt=64)
while 1:
    try:
        b.perf_buffer_poll()

    except KeyboardInterrupt:
        exit()
