Skip to content
Snippets Groups Projects
Forked from Tessaris Sergio / wumpus
5 commits behind, 7 commits ahead of the upstream repository.
cli.py 3.52 KiB
#!/usr/bin/env python

"""
Command line interface
"""

import argparse
import io
import json
import os
import sys

from . import __version__
from .gridworld import GridWorld
from .runner import get_subclasses, check_entrypoint, get_player_class, get_world_class, run_episode, worlds


def gridrunner(*args):
    """
    Run episodes on worlds using the specified player.
    """

    world_classes = sorted(get_subclasses(GridWorld), key=lambda c: c.__name__)

    parser = argparse.ArgumentParser(description=gridrunner.__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('infiles', type=argparse.FileType('r'), nargs='*', help='world description JSON files, they must be compatible with the world type (see --world option).')
    parser.add_argument('--name', '-n', type=str, help='name of the player, default to the name of the player class')
    parser.add_argument('--path', '-p', type=str, default='.', help="path of the player library, it's prepended to the sys.path variable")
    parser.add_argument('--entry', '-e', type=check_entrypoint, required=True, help="object reference for a Player subclass in the form 'importable.module:object.attr'. See <https://packaging.python.org/specifications/entry-points/#data-model> for details.")
    parser.add_argument('--world', '-w', type=str, default=world_classes[0].__name__, choices=[c.__name__ for c in world_classes], help='class name of the world')
    parser.add_argument('--horizon', '-z', type=int, default=20, help='maximum number of steps')
    parser.add_argument('--noshow', action='store_false', help="prevent the printing the world at each step")
    parser.add_argument('--out', '-o', type=argparse.FileType('w'), default=sys.stdout, help="write output to file")
    parser.add_argument('--version', action='version', version='%(prog)s ' + __version__)
    parser.add_argument('--log', '-l', type=argparse.FileType('w'), help="write the log of the games to file (JSON)")
    args_dict = vars(parser.parse_args(args))

    name = args_dict['name']
    path = os.path.abspath(args_dict['path']) if args_dict['path'] != '.' else os.getcwd()
    obj_ref = args_dict['entry']
    world_type = args_dict['world']
    horizon = args_dict['horizon']
    show = args_dict['noshow']
    outf: io.TextIOBase = args_dict['out']
    game_log: io.TextIOBase = args_dict['log']

    player_class = get_player_class(obj_ref, path=path)
    world_class = get_world_class(world_type)

    if name is None:
        name = player_class.__name__

    player = player_class(name=name)

    if game_log is not None:
        print('[', file=game_log)

    if len(args_dict['infiles']) > 0:
        morelogs = False
        for world in worlds(args_dict['infiles'], world_class):
            glog = run_episode(world, player, horizon=horizon, show=show, outf=outf)
            if game_log is not None:
                if morelogs:
                    print(',', file=game_log)
                else:
                    morelogs = True
                json.dump(glog, game_log)
    else:
        world = world_class.random()
        # show world definition
        print('-' * 10 + ' Playing on world: ' + '-' * 10, file=outf)
        world.to_JSON(outf)
        print('\n' + '-' * 40, file=outf)
        glog = run_episode(world, player, horizon=horizon, show=show, outf=outf)
        if game_log is not None:
            json.dump(glog, game_log)

    if game_log is not None:
        print(']', file=game_log)

    return 0


def main():
    sys.exit(gridrunner(*sys.argv[1:]))


if __name__ == "__main__":
    main()