#! /usr/bin/env python

"""
Script for inspecting Athena++ restart files.

Input: name of .rst file to inspect.

The parameters are configured for some common compilers on little-endian
machines. If the output does not make sense, try adjusting them.
"""

# Python standard modules
import argparse
import os
import struct

# Parameters
type_int = '<i'
type_int64 = '<q'
type_uint64 = '<Q'
type_real = '<d'
type_regionsize_pad = '4c'
type_logicallocation_pad = '4c'
name_len = 15
val_len = 24
len_len = 1
fmt_len = 2

# Main function
def main(**kwargs):

  # Open file
  with open(kwargs['filename'], 'r') as f:

    # Read input parameters
    line = f.readline()
    while '<par_end>' not in line:
      line = f.readline()
    parameter_end = f.tell()

    # Read and report on header
    print('')
    print('Header data:')
    num_blocks = read_val(f, 'nbtotal', type_int)
    read_val(f, 'root_level', type_int)
    read_val(f, 'mesh_size.x1min', type_real)
    read_val(f, 'mesh_size.x2min', type_real)
    read_val(f, 'mesh_size.x3min', type_real)
    read_val(f, 'mesh_size.x1max', type_real)
    read_val(f, 'mesh_size.x2max', type_real)
    read_val(f, 'mesh_size.x3max', type_real)
    read_val(f, 'mesh_size.x1rat', type_real)
    read_val(f, 'mesh_size.x2rat', type_real)
    read_val(f, 'mesh_size.x3rat', type_real)
    read_val(f, 'mesh_size.nx1', type_int)
    read_val(f, 'mesh_size.nx2', type_int)
    read_val(f, 'mesh_size.nx3', type_int)
    skip_val(f, '[padding]', type_regionsize_pad)
    read_val(f, 'time', type_real)
    read_val(f, 'dt', type_real)
    read_val(f, 'ncycle', type_int)
    block_size = read_val(f, 'datasize', type_uint64)
    header_end = f.tell()

    # Locate end of file
    f.seek(0, os.SEEK_END)
    file_end = f.tell()

    # Calculate sizes
    file_size = file_end
    parameter_size = parameter_end
    header_size = header_end - parameter_end
    type_logicallocation = type_int64 * 3 + type_logicallocation_pad + type_int
    type_logicallocation = type_logicallocation.translate(None, '<')
    list_size = num_blocks * (struct.calcsize(type_logicallocation) + struct.calcsize(type_real))
    block_data_size = num_blocks * block_size
    user_mesh_size = file_end - header_end - list_size - block_data_size

    # Report on sizes
    max_len = len(repr(file_size))
    size_str = '{0:' + repr(max_len) + '}'
    print('')
    print('File size:             ' + size_str.format(file_end))
    print('  Input parameters:    ' + size_str.format(parameter_size))
    print('  Header:              ' + size_str.format(header_size))
    print('  User mesh data:      ' + size_str.format(user_mesh_size))
    print('  Locations and costs: ' + size_str.format(list_size))
    print('  Meshblock data:      ' + size_str.format(block_data_size))
    print('')

# Function for reading and reporting single number
def read_val(f, name, fmt):
  length = struct.calcsize(fmt)
  val = struct.unpack(fmt, f.read(length))[0]
  full_str = '  {0} = {1} (length {2}, {3})'
  name_str = '{0:<' + repr(name_len) + '}'
  val_str = '{1:<' + repr(val_len) + '}'
  len_str = '{2:<' + repr(len_len) + '}'
  fmt_str = '{3:<' + repr(fmt_len) + '}'
  full_str = full_str.format(name_str, val_str, len_str, fmt_str)
  print(full_str.format(name, repr(val), length, fmt))
  return val

# Function for skipping and reporting padding bytes
def skip_val(f, name, fmt):
  length = struct.calcsize(fmt)
  val = struct.unpack(fmt, f.read(length))[0]
  full_str = '  {0} (length {1}, {2})'
  name_str = '{0:<' + repr(name_len+3+val_len) + '}'
  len_str = '{1:<' + repr(len_len) + '}'
  fmt_str = '{2:<' + repr(fmt_len) + '}'
  full_str = full_str.format(name_str, len_str, fmt_str)
  print(full_str.format('[padding]', length, fmt))

# Execute main function
if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('filename',
      help='name of restart file to inspect')
  args = parser.parse_args()
  main(**vars(args))