#!/usr/bin/env python """ Command line interface """ import argparse import io import json import os import sys from . import __version__ from .gridworld import GridWorld from .runner import get_subclasses, check_entrypoint, get_player_class, get_world_class, run_episode, worlds def gridrunner(*args): """ Run episodes on worlds using the specified player. """ world_classes = sorted(get_subclasses(GridWorld), key=lambda c: c.__name__) parser = argparse.ArgumentParser(description=gridrunner.__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('infiles', type=argparse.FileType('r'), nargs='*', help='world description JSON files, they must be compatible with the world type (see --world option).') parser.add_argument('--name', '-n', type=str, help='name of the player, default to the name of the player class') parser.add_argument('--path', '-p', type=str, default='.', help="path of the player library, it's prepended to the sys.path variable") parser.add_argument('--entry', '-e', type=check_entrypoint, required=True, help="object reference for a Player subclass in the form 'importable.module:object.attr'. See <https://packaging.python.org/specifications/entry-points/#data-model> for details.") parser.add_argument('--world', '-w', type=str, default=world_classes[0].__name__, choices=[c.__name__ for c in world_classes], help='class name of the world') parser.add_argument('--horizon', '-z', type=int, default=20, help='maximum number of steps') parser.add_argument('--noshow', action='store_false', help="prevent the printing the world at each step") parser.add_argument('--out', '-o', type=argparse.FileType('w'), default=sys.stdout, help="write output to file") parser.add_argument('--version', action='version', version='%(prog)s ' + __version__) parser.add_argument('--log', '-l', type=argparse.FileType('w'), help="write the log of the games to file (JSON)") args_dict = vars(parser.parse_args(args)) name = args_dict['name'] path = os.path.abspath(args_dict['path']) if args_dict['path'] != '.' else os.getcwd() obj_ref = args_dict['entry'] world_type = args_dict['world'] horizon = args_dict['horizon'] show = args_dict['noshow'] outf: io.TextIOBase = args_dict['out'] game_log: io.TextIOBase = args_dict['log'] player_class = get_player_class(obj_ref, path=path) world_class = get_world_class(world_type) if name is None: name = player_class.__name__ player = player_class(name=name) if game_log is not None: print('[', file=game_log) if len(args_dict['infiles']) > 0: morelogs = False for world in worlds(args_dict['infiles'], world_class): glog = run_episode(world, player, horizon=horizon, show=show, outf=outf) if game_log is not None: if morelogs: print(',', file=game_log) else: morelogs = True json.dump(glog, game_log) else: world = world_class.random() # show world definition print('-' * 10 + ' Playing on world: ' + '-' * 10, file=outf) world.to_JSON(outf) print('\n' + '-' * 40, file=outf) glog = run_episode(world, player, horizon=horizon, show=show, outf=outf) if game_log is not None: json.dump(glog, game_log) if game_log is not None: print(']', file=game_log) return 0 def main(): sys.exit(gridrunner(*sys.argv[1:])) if __name__ == "__main__": main()