diff --git a/.coveragerc b/.coveragerc index 3b2dc41..c7e7a38 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,8 +1,7 @@ [report] +show_missing = True + exclude_lines = - def print_lattice - def print_report - def print_table def valid_date def __repr__ def __str__ @@ -12,14 +11,9 @@ exclude_lines = raise NotImplementedError [run] -source = ./pyt +source = + ./pyt + ./tests omit = - pyt/__main__.py - pyt/definition_chains.py - pyt/draw.py pyt/formatters/json.py pyt/formatters/text.py - pyt/github_search.py - pyt/liveness.py - pyt/repo_runner.py - pyt/save.py diff --git a/README.rst b/README.rst index 3883e33..db4ba9e 100644 --- a/README.rst +++ b/README.rst @@ -25,19 +25,7 @@ Static analysis of Python web applications based on theoretical foundations (Con Features -------- -* Detect Command injection - -* Detect SQL injection - -* Detect XSS - -* Detect directory traversal - -* Get a control flow graph - -* Get a def-use and/or a use-def chain - -* Search GitHub and analyse hits with PyT +* Detect command injection, SSRF, SQL injection, XSS, directory traveral etc. * A lot of customisation possible @@ -62,67 +50,48 @@ PyT can also be installed from source. To do so, clone the repo, and then run: python3 setup.py install -Usage from Source -================= - -Using it like a user ``python3 -m pyt -f example/vulnerable_code/XSS_call.py save -du`` - -Running the tests ``python3 -m tests`` - -Running an individual test file ``python3 -m unittest tests.import_test`` - -Running an individual test ``python3 -m unittest tests.import_test.ImportTest.test_import`` - - -Contributions -============= - -Join our slack group: https://pyt-dev.slack.com/ - ask for invite: mr.thalmann@gmail.com - -`Guidelines`_ - -.. _Guidelines: https://github.com/python-security/pyt/blob/master/CONTRIBUTIONS.md - - -Virtual env setup guide -======================= - -Create a directory to hold the virtual env and project - -``mkdir ~/a_folder`` - -``cd ~/a_folder`` - -Clone the project into the directory - -``git clone https://github.com/python-security/pyt.git`` - -Create the virtual environment - -``python3 -m venv ~/a_folder/`` - -Check that you have the right versions - -``python3 --version`` sample output ``Python 3.6.0`` - -``pip --version`` sample output ``pip 9.0.1 from /Users/kevinhock/a_folder/lib/python3.6/site-packages (python 3.6)`` - -Change to project directory - -``cd pyt`` - -Install dependencies - -``pip install -r requirements.txt`` - -``pip list`` sample output :: +Usage +===== + +.. code-block:: + + usage: python -m pyt [-h] [-f FILEPATH] [-a ADAPTOR] [-pr PROJECT_ROOT] + [-b BASELINE_JSON_FILE] [-j] [-m BLACKBOX_MAPPING_FILE] + [-t TRIGGER_WORD_FILE] [-o OUTPUT_FILE] [-trim] [-i] + + required arguments: + -f FILEPATH, --filepath FILEPATH + Path to the file that should be analysed. + + optional arguments: + -a ADAPTOR, --adaptor ADAPTOR + Choose a web framework adaptor: Flask(Default), + Django, Every or Pylons + -pr PROJECT_ROOT, --project-root PROJECT_ROOT + Add project root, only important when the entry file + is not at the root of the project. + -b BASELINE_JSON_FILE, --baseline BASELINE_JSON_FILE + Path of a baseline report to compare against (only + JSON-formatted files are accepted) + -j, --json Prints JSON instead of report. + -m BLACKBOX_MAPPING_FILE, --blackbox-mapping-file BLACKBOX_MAPPING_FILE + Input blackbox mapping file. + -t TRIGGER_WORD_FILE, --trigger-word-file TRIGGER_WORD_FILE + Input file with a list of sources and sinks + -o OUTPUT_FILE, --output OUTPUT_FILE + write report to filename + --ignore-nosec do not skip lines with # nosec comments + + print arguments: + -trim, --trim-reassigned-in + Trims the reassigned list to just the vulnerability + chain. + -i, --interactive Will ask you about each blackbox function call in + vulnerability chains. + +How It Works +============ - gitdb (0.6.4) - GitPython (2.0.8) - graphviz (0.4.10) - pip (9.0.1) - requests (2.10.0) - setuptools (28.8.0) - smmap (0.9.0) +You will find a README.rst in every directory in the pyt folder, `start here`_. -In the future, just type ``source ~/a_folder/bin/activate`` to start developing. +.. _start here: https://github.com/python-security/pyt/tree/re_organize_code/pyt diff --git a/pyt/README.rst b/pyt/README.rst new file mode 100644 index 0000000..7d69b9c --- /dev/null +++ b/pyt/README.rst @@ -0,0 +1,71 @@ +How It Works +============ + +`__main__.py`_ is where all the high-level steps happen. + +.. _\_\_main\_\_.py: https://github.com/python-security/pyt/blob/re_organize_code/pyt/__main__.py + +Step 1 + Parse command line arguments. + + `parse_args`_ in `usage.py`_ + + .. _parse_args: https://github.com/python-security/pyt/blob/re_organize_code/pyt/usage.py#L113 + .. _usage.py: https://github.com/python-security/pyt/blob/re_organize_code/pyt/usage.py + + +Step 2 + Generate the `Abstract Syntax Tree (AST)`_. + + Essentially done in these lines of code with the `ast`_ module: + + .. code-block:: python + + import ast + ast.parse(f.read()) + + `generate_ast`_ in `ast_helper.py`_ + + .. _Abstract Syntax Tree (AST): https://en.wikipedia.org/wiki/Abstract_syntax_tree + .. _ast: https://docs.python.org/3/library/ast.html + .. _generate_ast: https://github.com/python-security/pyt/blob/re_organize_code/pyt/core/ast_helper.py#L24 + .. _ast_helper.py: https://github.com/python-security/pyt/blob/re_organize_code/pyt/core/ast_helper.py + + +Step 3 + Pass the AST to create a `Control Flow Graph (CFG)`_ + + .. _Control Flow Graph (CFG): https://github.com/python-security/pyt/tree/re_organize_code/pyt/cfg + +Step 4 + Pass the CFG to a `Framework Adaptor`_, which will mark the arguments of certain functions as tainted sources. + + .. _Framework Adaptor: https://github.com/python-security/pyt/tree/re_organize_code/pyt/web_frameworks + +Step 5 + Perform `(modified-)reaching definitions analysis`_, to know where definitions reach. + + .. _\(modified\-\)reaching definitions analysis: https://github.com/python-security/pyt/tree/re_organize_code/pyt/analysis + +Step 6 + `Find vulnerabilities`_, by seeing where sources reach, and how. + + .. _Find vulnerabilities: https://github.com/python-security/pyt/tree/re_organize_code/pyt/vulnerabilities + +Step 7 + `Remove already known vulnerabilities`_ if a `baseline`_ (JSON file of a previous run of PyT) is provided. + + .. _Remove already known vulnerabilities: https://github.com/python-security/pyt/blob/re_organize_code/pyt/vulnerabilities/vulnerability_helper.py#L194 + .. _baseline: https://github.com/python-security/pyt/blob/re_organize_code/pyt/usage.py#L54 + +Step 8 + Output the results in either `text or JSON form`_, to stdout or the `output file`_. + + .. _text or JSON form: https://github.com/python-security/pyt/tree/re_organize_code/pyt/formatters + .. _output file: https://github.com/python-security/pyt/blob/re_organize_code/pyt/usage.py#L80 + +Here is an image from the `original thesis`_: + +.. image:: https://github.com/KevinHock/rtdpyt/blob/master/docs/img/overview.png + +.. _original thesis: http://projekter.aau.dk/projekter/files/239563289/final.pdf#page=62 diff --git a/pyt/__init__.py b/pyt/__init__.py index aa35dd9..e69de29 100644 --- a/pyt/__init__.py +++ b/pyt/__init__.py @@ -1,5 +0,0 @@ -from .__main__ import main - - -if __name__ == '__main__': - main() diff --git a/pyt/__main__.py b/pyt/__main__.py index b81e6f0..b230211 100644 --- a/pyt/__main__.py +++ b/pyt/__main__.py @@ -1,237 +1,38 @@ -"""This module is the comand line tool of pyt.""" +"""The comand line module of PyT.""" -import argparse import os import sys -from datetime import date -from pprint import pprint -from .argument_helpers import ( - default_blackbox_mapping_file, - default_trigger_word_file, - valid_date, - VulnerabilityFiles, - UImode -) -from .ast_helper import generate_ast -from .baseline import get_vulnerabilities_not_in_baseline -from .constraint_table import ( - initialize_constraint_table, - print_table +from .analysis.constraint_table import initialize_constraint_table +from .analysis.fixed_point import analyse +from .cfg import make_cfg +from .core.ast_helper import generate_ast +from .core.project_handler import ( + get_directory_modules, + get_modules ) -from .draw import draw_cfgs, draw_lattices -from .expr_visitor import make_cfg -from .fixed_point import analyse from .formatters import ( json, text ) -from .framework_adaptor import FrameworkAdaptor -from .framework_helper import ( +from .usage import parse_args +from .vulnerabilities import ( + find_vulnerabilities, + get_vulnerabilities_not_in_baseline, + UImode +) +from .web_frameworks import ( + FrameworkAdaptor, is_django_view_function, is_flask_route_function, is_function, is_function_without_leading_ ) -from .github_search import scan_github, set_github_api_token -from .lattice import print_lattice -from .liveness import LivenessAnalysis -from .project_handler import get_directory_modules, get_modules -from .reaching_definitions import ReachingDefinitionsAnalysis -from .reaching_definitions_taint import ReachingDefinitionsTaintAnalysis -from .repo_runner import get_repos -from .save import ( - cfg_to_file, - create_database, - def_use_chain_to_file, - lattice_to_file, - Output, - use_def_chain_to_file, - verbose_cfg_to_file, - vulnerabilities_to_file -) -from .vulnerabilities import find_vulnerabilities - - -def parse_args(args): - parser = argparse.ArgumentParser(prog='python -m pyt') - parser.set_defaults(which='') - - subparsers = parser.add_subparsers() - - entry_group = parser.add_mutually_exclusive_group(required=True) - entry_group.add_argument('-f', '--filepath', - help='Path to the file that should be analysed.', - type=str) - entry_group.add_argument('-gr', '--git-repos', - help='Takes a CSV file of git_url, path per entry.', - type=str) - - parser.add_argument('-pr', '--project-root', - help='Add project root, this is important when the entry' + - ' file is not at the root of the project.', type=str) - parser.add_argument('-d', '--draw-cfg', - help='Draw CFG and output as .pdf file.', - action='store_true') - parser.add_argument('-o', '--output-filename', - help='Output filename.', type=str) - parser.add_argument('-csv', '--csv-path', type=str, - help='Give the path of the csv file' - ' repos should be added to.') - - print_group = parser.add_mutually_exclusive_group() - print_group.add_argument('-p', '--print', - help='Prints the nodes of the CFG.', - action='store_true') - print_group.add_argument('-vp', '--verbose-print', - help='Verbose printing of -p.', action='store_true') - print_group.add_argument('-trim', '--trim-reassigned-in', - help='Trims the reassigned list to the vulnerability chain.', - action='store_true', - default=False) - print_group.add_argument('-i', '--interactive', - help='Will ask you about each vulnerability chain and blackbox nodes.', - action='store_true', - default=False) - - parser.add_argument('-t', '--trigger-word-file', - help='Input trigger word file.', - type=str, - default=default_trigger_word_file) - parser.add_argument('-m', '--blackbox-mapping-file', - help='Input blackbox mapping file.', - type=str, - default=default_blackbox_mapping_file) - parser.add_argument('-py2', '--python-2', - help='[WARNING, EXPERIMENTAL] Turns on Python 2 mode,' + - ' needed when target file(s) are written in Python 2.', action='store_true') - parser.add_argument('-l', '--log-level', - help='Choose logging level: CRITICAL, ERROR,' + - ' WARNING(Default), INFO, DEBUG, NOTSET.', type=str) - parser.add_argument('-a', '--adaptor', - help='Choose an adaptor: Flask(Default), Django, Every or Pylons', - type=str) - parser.add_argument('-db', '--create-database', - help='Creates a sql file that can be used to' + - ' create a database.', action='store_true') - parser.add_argument('-dl', '--draw-lattice', - nargs='+', help='Draws a lattice.') - parser.add_argument('-j', '--json', - help='Prints JSON instead of report.', - action='store_true', - default=False) - - analysis_group = parser.add_mutually_exclusive_group() - analysis_group.add_argument('-li', '--liveness', - help='Run liveness analysis. Default is' + - ' reaching definitions tainted version.', - action='store_true') - analysis_group.add_argument('-re', '--reaching', - help='Run reaching definitions analysis.' + - ' Default is reaching definitions' + - ' tainted version.', action='store_true') - analysis_group.add_argument('-rt', '--reaching-taint', - help='This is the default analysis:' + - ' reaching definitions tainted version.', - action='store_true') - - parser.add_argument('-ppm', '--print-project-modules', - help='Print project modules.', action='store_true') - parser.add_argument('-b', '--baseline', - help='path of a baseline report to compare against ' - '(only JSON-formatted files are accepted)', - type=str, - default=False) - parser.add_argument('--ignore-nosec', dest='ignore_nosec', action='store_true', - help='do not skip lines with # nosec comments') - - save_parser = subparsers.add_parser('save', help='Save menu.') - save_parser.set_defaults(which='save') - save_parser.add_argument('-fp', '--filename-prefix', - help='Filename prefix fx file_lattice.pyt', - type=str) - save_parser.add_argument('-du', '--def-use-chain', - help='Output the def-use chain(s) to file.', - action='store_true') - save_parser.add_argument('-ud', '--use-def-chain', - help='Output the use-def chain(s) to file', - action='store_true') - save_parser.add_argument('-cfg', '--control-flow-graph', - help='Output the CFGs to file.', - action='store_true') - save_parser.add_argument('-vcfg', '--verbose-control-flow-graph', - help='Output the verbose CFGs to file.', - action='store_true') - save_parser.add_argument('-an', '--analysis', - help='Output analysis results to file' + - ' in form of a constraint table.', - action='store_true') - save_parser.add_argument('-la', '--lattice', help='Output lattice(s) to file.', - action='store_true') - save_parser.add_argument('-vu', '--vulnerabilities', - help='Output vulnerabilities to file.', - action='store_true') - save_parser.add_argument('-all', '--save-all', - help='Output everything to file.', - action='store_true') - - - search_parser = subparsers.add_parser( - 'github_search', - help='Searches through github and runs PyT' - ' on found repositories. This can take some time.') - search_parser.set_defaults(which='search') - - search_parser.add_argument( - '-ss', '--search-string', required=True, - help='String for searching for repos on github.', type=str) - - search_parser.add_argument('-sd', '--start-date', - help='Start date for repo search. ' - 'Criteria used is Created Date.', - type=valid_date, - default=date(2010, 1, 1)) - return parser.parse_args(args) - - -def analyse_repo(args, github_repo, analysis_type, ui_mode, nosec_lines): - cfg_list = list() - directory = os.path.dirname(github_repo.path) - project_modules = get_modules(directory) - local_modules = get_directory_modules(directory) - tree = generate_ast(github_repo.path) - cfg = make_cfg( - tree, - project_modules, - local_modules, - github_repo.path - ) - cfg_list.append(cfg) - - initialize_constraint_table(cfg_list) - analyse(cfg_list, analysis_type=analysis_type) - vulnerabilities = find_vulnerabilities( - cfg_list, - analysis_type, - ui_mode, - VulnerabilityFiles( - args.blackbox_mapping_file, - args.trigger_word_file - ), - nosec_lines - ) - return vulnerabilities def main(command_line_args=sys.argv[1:]): args = parse_args(command_line_args) - analysis = ReachingDefinitionsTaintAnalysis - if args.liveness: - analysis = LivenessAnalysis - elif args.reaching: - analysis = ReachingDefinitionsAnalysis - ui_mode = UImode.NORMAL if args.interactive: ui_mode = UImode.INTERACTIVE @@ -239,45 +40,18 @@ def main(command_line_args=sys.argv[1:]): ui_mode = UImode.TRIM path = os.path.normpath(args.filepath) - cfg_list = list() + if args.ignore_nosec: nosec_lines = set() else: - file = open(path, "r") + file = open(path, 'r') lines = file.readlines() nosec_lines = set( - lineno for - (lineno, line) in enumerate(lines, start=1) - if '#nosec' in line or '# nosec' in line) - - if args.git_repos: - repos = get_repos(args.git_repos) - for repo in repos: - repo.clone() - vulnerabilities = analyse_repo(args, repo, analysis, ui_mode, nosec_lines) - if args.json: - json.report(vulnerabilities, sys.stdout) - else: - text.report(vulnerabilities, sys.stdout) - if not vulnerabilities: - repo.clean_up() - exit() - - - if args.which == 'search': - set_github_api_token() - scan_github( - args.search_string, - args.start_date, - analysis, - analyse_repo, - args.csv_path, - ui_mode, - args + lineno for + (lineno, line) in enumerate(lines, start=1) + if '#nosec' in line or '# nosec' in line ) - exit() - directory = None if args.project_root: directory = os.path.normpath(args.project_root) else: @@ -285,16 +59,15 @@ def main(command_line_args=sys.argv[1:]): project_modules = get_modules(directory) local_modules = get_directory_modules(directory) - tree = generate_ast(path, python_2=args.python_2) + tree = generate_ast(path) - cfg_list = list() cfg = make_cfg( tree, project_modules, local_modules, path ) - cfg_list.append(cfg) + cfg_list = [cfg] framework_route_criteria = is_flask_route_function if args.adaptor: if args.adaptor.lower().startswith('e'): @@ -304,81 +77,33 @@ def main(command_line_args=sys.argv[1:]): elif args.adaptor.lower().startswith('d'): framework_route_criteria = is_django_view_function # Add all the route functions to the cfg_list - FrameworkAdaptor(cfg_list, project_modules, local_modules, framework_route_criteria) + FrameworkAdaptor( + cfg_list, + project_modules, + local_modules, + framework_route_criteria + ) initialize_constraint_table(cfg_list) - - analyse(cfg_list, analysis_type=analysis) - + analyse(cfg_list) vulnerabilities = find_vulnerabilities( cfg_list, - analysis, ui_mode, - VulnerabilityFiles( - args.blackbox_mapping_file, - args.trigger_word_file - ), + args.blackbox_mapping_file, + args.trigger_word_file, nosec_lines ) - + if args.baseline: - vulnerabilities = get_vulnerabilities_not_in_baseline(vulnerabilities, args.baseline) - + vulnerabilities = get_vulnerabilities_not_in_baseline( + vulnerabilities, + args.baseline + ) + if args.json: - json.report(vulnerabilities, sys.stdout) + json.report(vulnerabilities, args.output_file) else: - text.report(vulnerabilities, sys.stdout) - - if args.draw_cfg: - if args.output_filename: - draw_cfgs(cfg_list, args.output_filename) - else: - draw_cfgs(cfg_list) - if args.print: - lattice = print_lattice(cfg_list, analysis) - - print_table(lattice) - for i, e in enumerate(cfg_list): - print('############## CFG number: ', i) - print(e) - if args.verbose_print: - for i, e in enumerate(cfg_list): - print('############## CFG number: ', i) - print(repr(e)) - - if args.print_project_modules: - print('############## PROJECT MODULES ##############') - pprint(project_modules) - - if args.create_database: - create_database(cfg_list, vulnerabilities) - if args.draw_lattice: - draw_lattices(cfg_list) - - # Output to file - if args.which == 'save': - if args.filename_prefix: - Output.filename_prefix = args.filename_prefix - if args.save_all: - def_use_chain_to_file(cfg_list) - use_def_chain_to_file(cfg_list) - cfg_to_file(cfg_list) - verbose_cfg_to_file(cfg_list) - lattice_to_file(cfg_list, analysis) - vulnerabilities_to_file(vulnerabilities) - else: - if args.def_use_chain: - def_use_chain_to_file(cfg_list) - if args.use_def_chain: - use_def_chain_to_file(cfg_list) - if args.control_flow_graph: - cfg_to_file(cfg_list) - if args.verbose_control_flow_graph: - verbose_cfg_to_file(cfg_list) - if args.lattice: - lattice_to_file(cfg_list, analysis) - if args.vulnerabilities: - vulnerabilities_to_file(vulnerabilities) + text.report(vulnerabilities, args.output_file) if __name__ == '__main__': diff --git a/pyt/analysis/README.rst b/pyt/analysis/README.rst new file mode 100644 index 0000000..e1ba9ef --- /dev/null +++ b/pyt/analysis/README.rst @@ -0,0 +1,88 @@ +This code is responsible for answering two questions: + + +Where do definitions reach? +=========================== + +Traditionally `reaching definitions`_, a classic dataflow-analysis, +has been used to answer this question. To understand reaching definitions, +watch this `wonderful YouTube video`_ and come back here. +We use reaching definitions, with one small modification, +a `reassignment check`_. + + +.. code-block:: python + + # Reassignment check + if cfg_node.left_hand_side not in cfg_node.right_hand_side_variables: + # Get previous assignments of cfg_node.left_hand_side and remove them from JOIN + arrow_result = self.arrow(JOIN, cfg_node.left_hand_side) + +As an example, + +.. code-block:: python + + image_name = request.args.get('image_name') + image_name = os.path.join(base_dir, image_name) + send_file(image_name) + +we still want to know that something from a request reached `send_file`. + + +.. _reaching definitions: https://en.wikipedia.org/wiki/Reaching_definition +.. _reassignment check: https://github.com/python-security/pyt/blob/re_organize_code/pyt/analysis/reaching_definitions_taint.py#L23-L26 +.. _wonderful YouTube video: https://www.youtube.com/watch?v=NVBQSR_HdL0 + + +How does a definition reach? +============================ + +After we know that a definition reaches a use that we are interested in, +we make what use called `definition-use chains`_ to figure out how the definition +reaches the use. This is necessary because there may be more than one path from +the definition to the use. Here is the code from `definition_chains.py`_: + +.. code-block:: python + + def build_def_use_chain( + cfg_nodes, + lattice + ): + def_use = defaultdict(list) + # For every node + for node in cfg_nodes: + # That's a definition + if isinstance(node, AssignmentNode): + # Get the uses + for variable in node.right_hand_side_variables: + # Loop through most of the nodes before it + for earlier_node in get_constraint_nodes(node, lattice): + # and add them to the 'uses list' of each earlier node, when applicable + # 'earlier node' here being a simplification + if variable in earlier_node.left_hand_side: + def_use[earlier_node].append(node) + return def_use + +.. _definition-use chains: https://en.wikipedia.org/wiki/Use-define_chain +.. _definition_chains.py: https://github.com/python-security/pyt/blob/re_organize_code/pyt/analysis/definition_chains.py#L16-L33 + + +Additional details +================== + +This folder probably will not change at all for the lifetime of the project, +unless we were to implement more advanced analyses like `solving string +constraints`_ or doing `alias analysis`_. Right now and in the foreseeable +future there are more pressing concerns, like handling web frameworks +and handling all AST node types in the `CFG construction`_. + +Stefan and Bruno like the `Schwartzbach notes`_, as you will see in some comments. +But looking these two algorithms up will yield countless results, my favorite is +this `amazing guy from YouTube`_. + + +.. _solving string constraints: https://zyh1121.github.io/z3str3Docs/inputLanguage.html +.. _alias analysis: https://www3.cs.stonybrook.edu/~liu/papers/Alias-DLS10.pdf +.. _CFG construction: https://github.com/python-security/pyt/tree/re_organize_code/pyt/cfg +.. _Schwartzbach notes: http://lara.epfl.ch/w/_media/sav08:schwartzbach.pdf +.. _amazing guy from YouTube: https://www.youtube.com/watch?v=NVBQSR_HdL0 diff --git a/pyt/constraint_table.py b/pyt/analysis/constraint_table.py similarity index 74% rename from pyt/constraint_table.py rename to pyt/analysis/constraint_table.py index de7a0ce..bc7466d 100644 --- a/pyt/constraint_table.py +++ b/pyt/analysis/constraint_table.py @@ -17,9 +17,3 @@ def constraint_join(cfg_nodes): for e in cfg_nodes: r = r | constraint_table[e] return r - - -def print_table(lattice): - print('Constraint table:') - for k, v in constraint_table.items(): - print(str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)])) diff --git a/pyt/analysis/definition_chains.py b/pyt/analysis/definition_chains.py new file mode 100644 index 0000000..fad898b --- /dev/null +++ b/pyt/analysis/definition_chains.py @@ -0,0 +1,33 @@ +from collections import defaultdict + +from .constraint_table import constraint_table +from ..core.node_types import AssignmentNode + + +def get_constraint_nodes( + node, + lattice +): + for n in lattice.get_elements(constraint_table[node]): + if n is not node: + yield n + + +def build_def_use_chain( + cfg_nodes, + lattice +): + def_use = defaultdict(list) + # For every node + for node in cfg_nodes: + # That's a definition + if isinstance(node, AssignmentNode): + # Get the uses + for variable in node.right_hand_side_variables: + # Loop through most of the nodes before it + for earlier_node in get_constraint_nodes(node, lattice): + # and add them to the 'uses list' of each earlier node, when applicable + # 'earlier node' here being a simplification + if variable in earlier_node.left_hand_side: + def_use[earlier_node].append(node) + return def_use diff --git a/pyt/fixed_point.py b/pyt/analysis/fixed_point.py similarity index 78% rename from pyt/fixed_point.py rename to pyt/analysis/fixed_point.py index b057493..e77086c 100644 --- a/pyt/fixed_point.py +++ b/pyt/analysis/fixed_point.py @@ -1,16 +1,17 @@ """This module implements the fixed point algorithm.""" from .constraint_table import constraint_table +from .reaching_definitions_taint import ReachingDefinitionsTaintAnalysis class FixedPointAnalysis(): """Run the fix point analysis.""" - def __init__(self, cfg, analysis): + def __init__(self, cfg): """Fixed point analysis. Analysis must be a dataflow analysis containing a 'fixpointmethod' - method that analyses one CFG node.""" - self.analysis = analysis(cfg) + method that analyses one CFG.""" + self.analysis = ReachingDefinitionsTaintAnalysis(cfg) self.cfg = cfg def fixpoint_runner(self): @@ -22,15 +23,15 @@ def fixpoint_runner(self): self.analysis.fixpointmethod(q[0]) # y = F_i(x_1, ..., x_n); y = constraint_table[q[0]] # y = q[0].new_constraint - if not self.analysis.equal(y, x_i): + if y != x_i: for node in self.analysis.dep(q[0]): # for (v in dep(v_i)) q.append(node) # q.append(v): constraint_table[q[0]] = y # q[0].old_constraint = q[0].new_constraint # x_i = y q = q[1:] # q = q.tail() # The list minus the head -def analyse(cfg_list, *, analysis_type): +def analyse(cfg_list): """Analyse a list of control flow graphs with a given analysis type.""" for cfg in cfg_list: - analysis = FixedPointAnalysis(cfg, analysis_type) + analysis = FixedPointAnalysis(cfg) analysis.fixpoint_runner() diff --git a/pyt/lattice.py b/pyt/analysis/lattice.py similarity index 74% rename from pyt/lattice.py rename to pyt/analysis/lattice.py index 125a8c2..5ffcec1 100644 --- a/pyt/lattice.py +++ b/pyt/analysis/lattice.py @@ -1,11 +1,21 @@ from .constraint_table import constraint_table +from ..core.node_types import AssignmentNode + + +def get_lattice_elements(cfg_nodes): + """Returns all assignment nodes as they are the only lattice elements + in the reaching definitions analysis. + """ + for node in cfg_nodes: + if isinstance(node, AssignmentNode): + yield node class Lattice: - def __init__(self, cfg_nodes, analysis_type): + def __init__(self, cfg_nodes): self.el2bv = dict() # Element to bitvector dictionary self.bv2el = list() # Bitvector to element list - for i, e in enumerate(analysis_type.get_lattice_elements(cfg_nodes)): + for i, e in enumerate(get_lattice_elements(cfg_nodes)): # Give each element a unique shift of 1 self.el2bv[e] = 0b1 << i self.bv2el.insert(0, e) @@ -37,15 +47,3 @@ def in_constraint(self, node1, node2): return False return constraint & value != 0 - - -def print_lattice(cfg_list, analysis_type): - nodes = list() - for cfg in cfg_list: - nodes.extend(cfg.nodes) - l = Lattice(nodes, analysis_type) - - print('Lattice:') - for k, v in l.el2bv.items(): - print(str(k) + ': ' + str(v)) - return l diff --git a/pyt/analysis/reaching_definitions_taint.py b/pyt/analysis/reaching_definitions_taint.py new file mode 100644 index 0000000..dad8e9e --- /dev/null +++ b/pyt/analysis/reaching_definitions_taint.py @@ -0,0 +1,51 @@ +from .constraint_table import ( + constraint_join, + constraint_table +) +from ..core.node_types import AssignmentNode +from .lattice import Lattice + + +class ReachingDefinitionsTaintAnalysis(): + def __init__(self, cfg): + self.cfg = cfg + self.lattice = Lattice(cfg.nodes) + + def fixpointmethod(self, cfg_node): + """The most important part of PyT, where we perform + the variant of reaching definitions to find where sources reach. + """ + JOIN = self.join(cfg_node) + # Assignment check + if isinstance(cfg_node, AssignmentNode): + arrow_result = JOIN + + # Reassignment check + if cfg_node.left_hand_side not in cfg_node.right_hand_side_variables: + # Get previous assignments of cfg_node.left_hand_side and remove them from JOIN + arrow_result = self.arrow(JOIN, cfg_node.left_hand_side) + + arrow_result = arrow_result | self.lattice.el2bv[cfg_node] + constraint_table[cfg_node] = arrow_result + # Default case + else: + constraint_table[cfg_node] = JOIN + + def join(self, cfg_node): + """Joins all constraints of the ingoing nodes and returns them. + This represents the JOIN auxiliary definition from Schwartzbach.""" + return constraint_join(cfg_node.ingoing) + + def arrow(self, JOIN, _id): + """Removes all previous assignments from JOIN that have the same left hand side. + This represents the arrow id definition from Schwartzbach.""" + r = JOIN + for node in self.lattice.get_elements(JOIN): + if node.left_hand_side == _id: + r = r ^ self.lattice.el2bv[node] + return r + + def dep(self, q_1): + """Represents the dep mapping from Schwartzbach.""" + for node in q_1.outgoing: + yield node diff --git a/pyt/analysis_base.py b/pyt/analysis_base.py deleted file mode 100644 index 8a4bbcf..0000000 --- a/pyt/analysis_base.py +++ /dev/null @@ -1,36 +0,0 @@ -"""This module contains a base class for the analysis component used in PyT.""" - -from abc import ( - ABCMeta, - abstractmethod -) - - -class AnalysisBase(metaclass=ABCMeta): - """Base class for fixed point analyses.""" - - annotated_cfg_nodes = dict() - - def __init__(self, cfg): - self.cfg = cfg - self.build_lattice(cfg) - - @staticmethod - @abstractmethod - def get_lattice_elements(cfg_nodes): - pass - - @abstractmethod - def equal(self, value, other): - """Define the equality for two constraint sets - that are defined by bitvectors.""" - pass - - @abstractmethod - def build_lattice(self, cfg): - pass - - @abstractmethod - def dep(self, q_1): - """Represents the dep mapping from Schwartzbach.""" - pass diff --git a/pyt/argument_helpers.py b/pyt/argument_helpers.py deleted file mode 100644 index 847636f..0000000 --- a/pyt/argument_helpers.py +++ /dev/null @@ -1,43 +0,0 @@ -import os -from argparse import ArgumentTypeError -from collections import namedtuple -from datetime import datetime -from enum import Enum - - -default_blackbox_mapping_file = os.path.join( - os.path.dirname(__file__), - 'vulnerability_definitions', - 'blackbox_mapping.json' -) - - -default_trigger_word_file = os.path.join( - os.path.dirname(__file__), - 'vulnerability_definitions', - 'all_trigger_words.pyt' -) - - -def valid_date(s): - date_format = "%Y-%m-%d" - try: - return datetime.strptime(s, date_format).date() - except ValueError: - msg = "Not a valid date: '{0}'. Format: {1}".format(s, date_format) - raise ArgumentTypeError(msg) - - -class UImode(Enum): - INTERACTIVE = 0 - NORMAL = 1 - TRIM = 2 - - -VulnerabilityFiles = namedtuple( - 'VulnerabilityFiles', - ( - 'blackbox_mapping', - 'triggers' - ) -) diff --git a/pyt/baseline.py b/pyt/baseline.py deleted file mode 100644 index 1e3258a..0000000 --- a/pyt/baseline.py +++ /dev/null @@ -1,11 +0,0 @@ -import json - - -def get_vulnerabilities_not_in_baseline(vulnerabilities, baseline): - baseline = json.load(open(baseline)) - output = list() - vulnerabilities =[vuln for vuln in vulnerabilities] - for vuln in vulnerabilities: - if vuln.as_dict() not in baseline['vulnerabilities']: - output.append(vuln) - return(output) diff --git a/pyt/cfg/README.rst b/pyt/cfg/README.rst new file mode 100644 index 0000000..4f561c7 --- /dev/null +++ b/pyt/cfg/README.rst @@ -0,0 +1,155 @@ +`make_cfg`_ is what `__main__.py`_ calls, it takes the `Abstract Syntax Tree`_, creates an `ExprVisitor`_ and returns a `Control Flow Graph`_. + +.. _make_cfg: https://github.com/python-security/pyt/blob/re_organize_code/pyt/cfg/make_cfg.py#L22-L38 +.. _\_\_main\_\_.py: https://github.com/python-security/pyt/blob/re_organize_code/pyt/__main__.py#L33-L106 +.. _Abstract Syntax Tree: https://en.wikipedia.org/wiki/Abstract_syntax_tree +.. _Control Flow Graph: https://en.wikipedia.org/wiki/Control_flow_graph + +`stmt_visitor.py`_ and `expr_visitor.py`_ mirror the `abstract grammar`_ of Python. Statements can contain expressions, but not the other way around. This is why `ExprVisitor`_ inherits from `StmtVisitor`_, which inherits from `ast.NodeVisitor`_; from the standard library. + +.. _StmtVisitor: https://github.com/python-security/pyt/blob/re_organize_code/pyt/cfg/stmt_visitor.py#L55 +.. _ExprVisitor: https://github.com/python-security/pyt/blob/re_organize_code/pyt/cfg/expr_visitor.py#L33 + +This is how `ast.NodeVisitor`_ works: + +.. code-block:: python + + def visit(self, node): + """Visit a node.""" + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + return visitor(node) + + +So as you'll see, there is a `visit\_` function for almost every AST node type. We keep track of all the nodes while we visit by adding them to self.nodes, connecting them via `ingoing and outgoing node attributes`_. + +.. _ingoing and outgoing node attributes: https://github.com/python-security/pyt/blob/re_organize_code/pyt/core/node_types.py#L27-L48 + +The two most illustrative functions are `stmt_star_handler`_ and expr_star_handler. expr_star_handler has not been merged to master so let's talk about `stmt_star_handler`_. + + +Handling an if: statement +========================= + +Example code + +.. code-block:: python + + if some_condition: + x = 5 + +This is the relevant part of the `abstract grammar`_ + +.. code-block:: python + + If(expr test, stmt* body, stmt* orelse) + # Note: stmt* means any number of statements. + + +Upon visiting an if: statement we will enter `visit_If`_ in `stmt_visitor.py`_. Since we know that the test is just one expression, we can just call self.visit() on it. The body could be an infinite number of statements, so we use the `stmt_star_handler`_ function. + +`stmt_star_handler`_ returns a namedtuple (`ConnectStatements`_) with the first statement, last_statements and break_statements of all of the statements that were in the body of the node. `stmt_star_handler`_ takes care of connecting each statement in the body to the next one. + +We then connect the test node to the first node in the body (if some_condition -> x = 5) and return a namedtuple (`ControlFlowNode`_) with the test, last_statements and break_statements. + + +.. _visit\_If: https://github.com/python-security/pyt/blob/re_organize_code/pyt/cfg/stmt_visitor.py#L208-L232 + +.. _ConnectStatements: https://github.com/python-security/pyt/blob/re_organize_code/pyt/cfg/stmt_visitor_helper.py#L15 + +.. _ControlFlowNode: https://github.com/python-security/pyt/blob/re_organize_code/pyt/core/node_types.py#L7 + +.. _stmt\_visitor.py: https://github.com/python-security/pyt/blob/re_organize_code/pyt/cfg/stmt_visitor.py + +.. _expr\_visitor.py: https://github.com/python-security/pyt/blob/re_organize_code/pyt/cfg/expr_visitor.py + +.. _stmt_star_handler: https://github.com/python-security/pyt/blob/re_organize_code/pyt/cfg/stmt_visitor.py#L60-L121 + + +.. code-block:: python + + def visit_If(self, node): + test = self.append_node(IfNode( + node.test, + node, + path=self.filenames[-1] + )) + + body_connect_stmts = self.stmt_star_handler(node.body) + # ... + test.connect(body_connect_stmts.first_statement) + + if node.orelse: + # ... + else: + # if there is no orelse, test needs an edge to the next_node + body_connect_stmts.last_statements.append(test) + + last_statements = remove_breaks(body_connect_stmts.last_statements) + + return ControlFlowNode( + test, + last_statements, + break_statements=body_connect_stmts.break_statements + ) + + +.. code-block:: python + + def stmt_star_handler( + self, + stmts + ): + """Handle stmt* expressions in an AST node. + Links all statements together in a list of statements. + Accounts for statements with multiple last nodes. + """ + break_nodes = list() + cfg_statements = list() + + first_node = None + node_not_to_step_past = self.nodes[-1] + + for stmt in stmts: + node = self.visit(stmt) + + if isinstance(node, ControlFlowNode): + break_nodes.extend(node.break_statements) + elif isinstance(node, BreakNode): + break_nodes.append(node) + + cfg_statements.append(node) + if not first_node: + if isinstance(node, ControlFlowNode): + first_node = node.test + else: + first_node = get_first_node( + node, + node_not_to_step_past + ) + + connect_nodes(cfg_statements) + + if first_node: + first_statement = first_node + else: + first_statement = get_first_statement(cfg_statements[0]) + + last_statements = get_last_statements(cfg_statements) + + return ConnectStatements( + first_statement=first_statement, + last_statements=last_statements, + break_statements=break_nodes + ) + + +.. _ast.NodeVisitor: https://docs.python.org/3/library/ast.html#ast.NodeVisitor +.. _abstract grammar: https://docs.python.org/3/library/ast.html#abstract-grammar + +References +========== + +For more information on AST nodes, see the `Green Tree Snakes`_ documentation. + +.. _Green Tree Snakes: https://greentreesnakes.readthedocs.io/en/latest/nodes.html diff --git a/pyt/cfg/__init__.py b/pyt/cfg/__init__.py new file mode 100644 index 0000000..30037a7 --- /dev/null +++ b/pyt/cfg/__init__.py @@ -0,0 +1,3 @@ +from .make_cfg import make_cfg + +__all__ = ['make_cfg'] diff --git a/pyt/alias_helper.py b/pyt/cfg/alias_helper.py similarity index 99% rename from pyt/alias_helper.py rename to pyt/cfg/alias_helper.py index 9d29444..a1c83ab 100644 --- a/pyt/alias_helper.py +++ b/pyt/cfg/alias_helper.py @@ -1,5 +1,6 @@ """This module contains alias helper functions for the expr_visitor module.""" + def as_alias_handler(alias_list): """Returns a list of all the names that will be called.""" list_ = list() diff --git a/pyt/expr_visitor.py b/pyt/cfg/expr_visitor.py similarity index 97% rename from pyt/expr_visitor.py rename to pyt/cfg/expr_visitor.py index 410d3b9..b4a9616 100644 --- a/pyt/expr_visitor.py +++ b/pyt/cfg/expr_visitor.py @@ -1,21 +1,12 @@ import ast -from .alias_helper import ( - handle_aliases_in_calls -) -from .ast_helper import ( +from .alias_helper import handle_aliases_in_calls +from ..core.ast_helper import ( Arguments, get_call_names_as_string ) -from .expr_visitor_helper import ( - BUILTINS, - CFG, - return_connection_handler, - SavedVariable -) -from .label_visitor import LabelVisitor -from .module_definitions import ModuleDefinitions -from .node_types import ( +from ..core.module_definitions import ModuleDefinitions +from ..core.node_types import ( AssignmentCallNode, AssignmentNode, BBorBInode, @@ -26,7 +17,15 @@ RestoreNode, ReturnNode ) -from .right_hand_side_visitor import RHSVisitor +from .expr_visitor_helper import ( + BUILTINS, + return_connection_handler, + SavedVariable +) +from ..helper_visitors import ( + LabelVisitor, + RHSVisitor +) from .stmt_visitor import StmtVisitor from .stmt_visitor_helper import CALL_IDENTIFIER @@ -564,22 +563,3 @@ def visit_Call(self, node): # Mark the call as a blackbox because we don't have the definition return self.add_blackbox_or_builtin_call(node, blackbox=True) return self.add_blackbox_or_builtin_call(node, blackbox=False) - - -def make_cfg( - node, - project_modules, - local_modules, - filename, - module_definitions=None -): - visitor = ExprVisitor( - node, - project_modules, - local_modules, filename, - module_definitions - ) - return CFG( - visitor.nodes, - visitor.blackbox_assignments - ) diff --git a/pyt/expr_visitor_helper.py b/pyt/cfg/expr_visitor_helper.py similarity index 56% rename from pyt/expr_visitor_helper.py rename to pyt/cfg/expr_visitor_helper.py index aebeccd..9667f7c 100644 --- a/pyt/expr_visitor_helper.py +++ b/pyt/cfg/expr_visitor_helper.py @@ -1,6 +1,6 @@ from collections import namedtuple -from .node_types import ConnectToExitNode +from ..core.node_types import ConnectToExitNode SavedVariable = namedtuple( @@ -33,24 +33,6 @@ ) -class CFG(): - def __init__(self, nodes, blackbox_assignments): - self.nodes = nodes - self.blackbox_assignments = blackbox_assignments - - def __repr__(self): - output = '' - for x, n in enumerate(self.nodes): - output = ''.join((output, 'Node: ' + str(x) + ' ' + repr(n), '\n\n')) - return output - - def __str__(self): - output = '' - for x, n in enumerate(self.nodes): - output = ''.join((output, 'Node: ' + str(x) + ' ' + str(n), '\n\n')) - return output - - def return_connection_handler(nodes, exit_node): """Connect all return statements to the Exit node.""" for function_body_node in nodes: diff --git a/pyt/cfg/make_cfg.py b/pyt/cfg/make_cfg.py new file mode 100644 index 0000000..eaa78c9 --- /dev/null +++ b/pyt/cfg/make_cfg.py @@ -0,0 +1,38 @@ +from .expr_visitor import ExprVisitor + + +class CFG(): + def __init__(self, nodes, blackbox_assignments): + self.nodes = nodes + self.blackbox_assignments = blackbox_assignments + + def __repr__(self): + output = '' + for x, n in enumerate(self.nodes): + output = ''.join((output, 'Node: ' + str(x) + ' ' + repr(n), '\n\n')) + return output + + def __str__(self): + output = '' + for x, n in enumerate(self.nodes): + output = ''.join((output, 'Node: ' + str(x) + ' ' + str(n), '\n\n')) + return output + + +def make_cfg( + tree, + project_modules, + local_modules, + filename, + module_definitions=None +): + visitor = ExprVisitor( + tree, + project_modules, + local_modules, filename, + module_definitions + ) + return CFG( + visitor.nodes, + visitor.blackbox_assignments + ) diff --git a/pyt/stmt_visitor.py b/pyt/cfg/stmt_visitor.py similarity index 99% rename from pyt/stmt_visitor.py rename to pyt/cfg/stmt_visitor.py index c855c9b..06a985e 100644 --- a/pyt/stmt_visitor.py +++ b/pyt/cfg/stmt_visitor.py @@ -9,17 +9,16 @@ not_as_alias_handler, retrieve_import_alias_mapping ) -from .ast_helper import ( +from ..core.ast_helper import ( generate_ast, get_call_names_as_string ) -from .label_visitor import LabelVisitor -from .module_definitions import ( +from ..core.module_definitions import ( LocalModuleDefinition, ModuleDefinition, ModuleDefinitions ) -from .node_types import ( +from ..core.node_types import ( AssignmentNode, AssignmentCallNode, BBorBInode, @@ -33,8 +32,14 @@ ReturnNode, TryNode ) -from .project_handler import get_directory_modules -from .right_hand_side_visitor import RHSVisitor +from ..core.project_handler import ( + get_directory_modules +) +from ..helper_visitors import ( + LabelVisitor, + RHSVisitor, + VarsVisitor +) from .stmt_visitor_helper import ( CALL_IDENTIFIER, ConnectStatements, @@ -45,7 +50,6 @@ get_last_statements, remove_breaks ) -from .vars_visitor import VarsVisitor class StmtVisitor(ast.NodeVisitor): diff --git a/pyt/stmt_visitor_helper.py b/pyt/cfg/stmt_visitor_helper.py similarity index 99% rename from pyt/stmt_visitor_helper.py rename to pyt/cfg/stmt_visitor_helper.py index 315c333..407df31 100644 --- a/pyt/stmt_visitor_helper.py +++ b/pyt/cfg/stmt_visitor_helper.py @@ -2,7 +2,7 @@ import random from collections import namedtuple -from .node_types import ( +from ..core.node_types import ( AssignmentCallNode, BBorBInode, BreakNode, diff --git a/pyt/core/README.rst b/pyt/core/README.rst new file mode 100644 index 0000000..3ba5b13 --- /dev/null +++ b/pyt/core/README.rst @@ -0,0 +1 @@ +Coming soon. diff --git a/pyt/ast_helper.py b/pyt/core/ast_helper.py similarity index 90% rename from pyt/ast_helper.py rename to pyt/core/ast_helper.py index 985eee7..e741ac5 100644 --- a/pyt/ast_helper.py +++ b/pyt/core/ast_helper.py @@ -8,7 +8,6 @@ BLACK_LISTED_CALL_NAMES = ['self'] recursive = False -python_2_mode = False def convert_to_3(path): # pragma: no cover @@ -22,17 +21,12 @@ def convert_to_3(path): # pragma: no cover exit(1) -def generate_ast(path, python_2=False): +def generate_ast(path): """Generate an Abstract Syntax Tree using the ast module. Args: path(str): The path to the file e.g. example/foo/bar.py - python_2(bool): Determines whether or not to call 2to3. """ - # If set, it stays set. - global python_2_mode - if python_2: # pragma: no cover - python_2_mode = True if os.path.isfile(path): with open(path, 'r') as f: try: @@ -40,8 +34,7 @@ def generate_ast(path, python_2=False): except SyntaxError: # pragma: no cover global recursive if not recursive: - if not python_2_mode: - convert_to_3(path) + convert_to_3(path) recursive = True return generate_ast(path) else: diff --git a/pyt/module_definitions.py b/pyt/core/module_definitions.py similarity index 97% rename from pyt/module_definitions.py rename to pyt/core/module_definitions.py index bde14ce..6ec197d 100644 --- a/pyt/module_definitions.py +++ b/pyt/core/module_definitions.py @@ -4,7 +4,8 @@ import ast -# Contains all project definitions for a program run: +# Contains all project definitions for a program run +# Only used in framework_adaptor.py, but modified here project_definitions = dict() diff --git a/pyt/node_types.py b/pyt/core/node_types.py similarity index 99% rename from pyt/node_types.py rename to pyt/core/node_types.py index 3819963..6cc2f1e 100644 --- a/pyt/node_types.py +++ b/pyt/core/node_types.py @@ -1,7 +1,7 @@ """This module contains all of the CFG nodes types.""" from collections import namedtuple -from .label_visitor import LabelVisitor +from ..helper_visitors import LabelVisitor ControlFlowNode = namedtuple( diff --git a/pyt/project_handler.py b/pyt/core/project_handler.py similarity index 100% rename from pyt/project_handler.py rename to pyt/core/project_handler.py diff --git a/pyt/definition_chains.py b/pyt/definition_chains.py deleted file mode 100644 index 8b813e8..0000000 --- a/pyt/definition_chains.py +++ /dev/null @@ -1,74 +0,0 @@ -import ast - -from .constraint_table import constraint_table -from .lattice import Lattice -from .node_types import AssignmentNode -from .reaching_definitions import ReachingDefinitionsAnalysis -from .vars_visitor import VarsVisitor - - -def get_vars(node): - vv = VarsVisitor() - if isinstance(node.ast_node, (ast.If, ast.While)): - vv.visit(node.ast_node.test) - elif isinstance(node.ast_node, (ast.ClassDef, ast.FunctionDef)): - return set() - else: - try: - vv.visit(node.ast_node) - except AttributeError: # If no ast_node - vv.result = list() - - vv.result = set(vv.result) - - # Filter out lvars: - for var in vv.result: - try: - if var in node.right_hand_side_variables: - yield var - except AttributeError: - yield var - - -def get_constraint_nodes(node, lattice): - for n in lattice.get_elements(constraint_table[node]): - if n is not node: - yield n - - -def build_use_def_chain(cfg_nodes): - use_def = dict() - lattice = Lattice(cfg_nodes, ReachingDefinitionsAnalysis) - - for node in cfg_nodes: - definitions = list() - for constraint_node in get_constraint_nodes(node, lattice): - for var in get_vars(node): - if var in constraint_node.left_hand_side: - definitions.append((var, constraint_node)) - use_def[node] = definitions - - return use_def - - -def build_def_use_chain(cfg_nodes): - def_use = dict() - lattice = Lattice(cfg_nodes, ReachingDefinitionsAnalysis) - - # For every node - for node in cfg_nodes: - # That's a definition - if isinstance(node, AssignmentNode): - # Make an empty list for it in def_use dict - def_use[node] = list() - - # Get its uses - for variable in node.right_hand_side_variables: - # Loop through most of the nodes before it - for earlier_node in get_constraint_nodes(node, lattice): - # and add to the 'uses list' of each earlier node, when applicable - # 'earlier node' here being a simplification - if variable in earlier_node.left_hand_side: - def_use[earlier_node].append(node) - - return def_use diff --git a/pyt/draw.py b/pyt/draw.py deleted file mode 100644 index 7dbf378..0000000 --- a/pyt/draw.py +++ /dev/null @@ -1,230 +0,0 @@ -"""Draws CFG.""" - -import argparse -from itertools import permutations -from subprocess import call - -from graphviz import Digraph - -from .node_types import AssignmentNode - - -IGNORED_LABEL_NAME_CHARACHTERS = ':' - -cfg_styles = { - 'graph': { - 'fontsize': '16', - 'fontcolor': 'black', - 'bgcolor': 'transparent', - 'rankdir': 'TB', - 'splines': 'ortho', - 'margin': '0.01', - }, - 'nodes': { - 'fontname': 'Gotham', - 'shape': 'box', - 'fontcolor': 'black', - 'color': 'black', - 'style': 'filled', - 'fillcolor': 'transparent', - }, - 'edges': { - 'style': 'filled', - 'color': 'black', - 'arrowhead': 'normal', - 'fontname': 'Courier', - 'fontsize': '12', - 'fontcolor': 'black', - } -} - -lattice_styles = { - 'graph': { - 'fontsize': '16', - 'fontcolor': 'black', - 'bgcolor': 'transparent', - 'rankdir': 'TB', - 'splines': 'line', - 'margin': '0.01', - 'ranksep': '1', - }, - 'nodes': { - 'fontname': 'Gotham', - 'shape': 'none', - 'fontcolor': 'black', - 'color': 'black', - 'style': 'filled', - 'fillcolor': 'transparent', - }, - 'edges': { - 'style': 'filled', - 'color': 'black', - 'arrowhead': 'none', - 'fontname': 'Courier', - 'fontsize': '12', - 'fontcolor': 'black', - } -} - - -def apply_styles(graph, styles): - """Apply styles to graph.""" - graph.graph_attr.update( - ('graph' in styles and styles['graph']) or {} - ) - graph.node_attr.update( - ('nodes' in styles and styles['nodes']) or {} - ) - graph.edge_attr.update( - ('edges' in styles and styles['edges']) or {} - ) - return graph - - -def draw_cfg(cfg, output_filename='output'): - """Draw CFG and output as pdf.""" - graph = Digraph(format='pdf') - - for node in cfg.nodes: - stripped_label = node.label.replace(IGNORED_LABEL_NAME_CHARACHTERS, '') - - if 'Exit' in stripped_label: - graph.node(stripped_label, 'Exit', shape='none') - elif 'Entry' in stripped_label: - graph.node(stripped_label, 'Entry', shape='none') - else: - graph.node(stripped_label, stripped_label) - - for ingoing_node in node.ingoing: - graph.edge(ingoing_node.label.replace( - IGNORED_LABEL_NAME_CHARACHTERS, ''), stripped_label) - - graph = apply_styles(graph, cfg_styles) - graph.render(filename=output_filename) - - -class Node(): - def __init__(self, s, parent, children=None): - self.s = s - self.parent = parent - self.children = children - - def __str__(self): - return 'Node: ' + str(self.s) + ' Parent: ' + str(self.parent) + ' Children: ' + str(self.children) - - def __hash__(self): - return hash(str(self.s)) - - -def draw_node(l, graph, node): - node_label = str(node.s) - graph.node(node_label, node_label) - for child in node.children: - child_label = str(child.s) - graph.node(child_label, child_label) - if not (node_label, child_label) in l: - graph.edge(node_label, child_label, ) - l.append((node_label, child_label)) - draw_node(l, graph, child) - - -def make_lattice(s, length): - p = Node(s, None) - p.children = get_children(p, s, length) - return p - - -def get_children(p, s, length): - children = set() - if length < 0: - return children - for subset in permutations(s, length): - setsubset = set(subset) - append = True - for node in children: - if setsubset == node.s: - append = False - break - if append: - n = Node(setsubset, p) - n.children = get_children(n, setsubset, length-1) - children.add(n) - return children - - -def add_anchor(filename): - filename += '.dot' - out = list() - delimiter = '->' - with open(filename, 'r') as fd: - for line in fd: - if delimiter in line: - s = line.split(delimiter) - ss = s[0][:-1] - s[0] = ss + ':s ' - ss = s[1][:-1] - s[1] = ss + ':n\n' - s.insert(1, delimiter) - out.append(''.join(s)) - elif 'set()' in line: - out.append('"set()" [label="{}"]') - else: - out.append(line) - with open(filename, 'w') as fd: - for line in out: - fd.write(line) - - -def run_dot(filename): - filename += '.dot' - call(['dot', '-Tpdf', filename, '-o', filename.replace('.dot', '.pdf')]) - - -def draw_lattice(cfg, output_filename='output'): - """Draw CFG and output as pdf.""" - graph = Digraph(format='pdf') - - ll = [s.label for s in cfg.nodes if isinstance(s, AssignmentNode)] - root = make_lattice(ll, len(ll)-1) - l = list() - draw_node(l, graph, root) - - graph = apply_styles(graph, lattice_styles) - graph.render(filename=output_filename+'.dot') - - add_anchor(output_filename) - run_dot(output_filename) - - -def draw_lattice_from_labels(labels, output_filename): - graph = Digraph(format='pdf') - - root = make_lattice(labels, len(labels)-1) - l = list() - draw_node(l, graph, root) - - graph = apply_styles(graph, lattice_styles) - graph.render(filename=output_filename+'.dot') - - add_anchor(output_filename) - run_dot(output_filename) - - -def draw_lattices(cfg_list, output_prefix='output'): - for i, cfg in enumerate(cfg_list): - draw_lattice(cfg, output_prefix + '_' + str(i)) - - -def draw_cfgs(cfg_list, output_prefix='output'): - for i, cfg in enumerate(cfg_list): - draw_cfg(cfg, output_prefix + '_' + str(i)) - - -parser = argparse.ArgumentParser() -parser.add_argument('-l', '--labels', nargs='+', - help='Set of labels in lattice.') -parser.add_argument('-n', '--name', help='Specify filename.', type=str) -if __name__ == '__main__': - args = parser.parse_args() - - draw_lattice_from_labels(args.labels, args.name) diff --git a/pyt/github_search.py b/pyt/github_search.py deleted file mode 100644 index df0cb40..0000000 --- a/pyt/github_search.py +++ /dev/null @@ -1,278 +0,0 @@ -import re -import requests -import time -from abc import ABCMeta, abstractmethod -from datetime import date, datetime, timedelta - -from . import repo_runner -from .reaching_definitions_taint import ReachingDefinitionsTaintAnalysis -from .repo_runner import add_repo_to_csv, NoEntryPathError -from .save import save_repo_scan - - -DEFAULT_TIMEOUT_IN_SECONDS = 60 -GITHUB_API_URL = 'https://api.github.com' -GITHUB_OAUTH_TOKEN = None -NUMBER_OF_REQUESTS_ALLOWED_PER_MINUTE = 30 # Rate limit is 10 and 30 with auth -SEARCH_CODE_URL = GITHUB_API_URL + '/search/code' -SEARCH_REPO_URL = GITHUB_API_URL + '/search/repositories' - - -def set_github_api_token(): - global GITHUB_OAUTH_TOKEN - try: - GITHUB_OAUTH_TOKEN = open('github_access_token.pyt', - 'r').read().strip() - except FileNotFoundError: - print('Insert your GitHub access token' - ' in the github_access_token.pyt file in the pyt package' - ' if you want to use GitHub search.') - exit(0) - - -class Languages: - _prefix = 'language:' - python = _prefix + 'python' - javascript = _prefix + 'javascript' - # add others here - - -class Query: - def __init__(self, base_url, search_string, - language=None, repo=None, time_interval=None, per_page=100): - repo = self._repo_parameter(repo) - time_interval = self._time_interval_parameter(time_interval) - search_string = self._search_parameter(search_string) - per_page = self._per_page_parameter(per_page) - parameters = self._construct_parameters([search_string, - language, - repo, - time_interval, - per_page]) - self.query_string = self._construct_query(base_url, parameters) - - def _construct_query(self, base_url, parameters): - query = base_url - query += '+'.join(parameters) - return query - - def _construct_parameters(self, parameters): - r = list() - for p in parameters: - if p: - r.append(p) - return r - - def _search_parameter(self, search_string): - return '?q="' + search_string + '"' - - def _repo_parameter(self, repo): - if repo: - return 'repo:' + repo.name - else: - return None - - def _time_interval_parameter(self, created): - if created: - p = re.compile('\d\d\d\d-\d\d-\d\d \.\. \d\d\d\d-\d\d-\d\d') - m = p.match(created) - if m.group(): - return 'created:"' + m.group() + '"' - else: - print('The time interval parameter should be ' - 'of the form: "YYYY-MM-DD .. YYYY-MM-DD"') - exit(1) - return None - - def _per_page_parameter(self, per_page): - if per_page > 100: - print('The GitHub api does not allow pages with over 100 results.') - exit(1) - return '&per_page={}'.format(per_page) - - -class IncompleteResultsError(Exception): - pass - - -class RequestCounter: - def __init__(self, timeout=DEFAULT_TIMEOUT_IN_SECONDS): - self.timeout_in_seconds = timeout # timeout in seconds - self.counter = list() - - def append(self, request_time): - if len(self.counter) < NUMBER_OF_REQUESTS_ALLOWED_PER_MINUTE: - self.counter.append(request_time) - else: - delta = request_time - self.counter[0] - if delta.seconds < self.timeout_in_seconds: - print('Maximum requests "{}" reached' - ' timing out for {} seconds.' - .format(len(self.counter), - self.timeout_in_seconds - delta.seconds)) - self.timeout(self.timeout_in_seconds - delta.seconds) - self.counter.pop(0) # pop index 0 - self.counter.append(datetime.now()) - else: - self.counter.pop(0) # pop index 0 - self.counter.append(request_time) - - def timeout(self, time_in_seconds=DEFAULT_TIMEOUT_IN_SECONDS): - time.sleep(time_in_seconds) - - -class Search(metaclass=ABCMeta): - request_counter = RequestCounter() - - def __init__(self, query): - self.total_count = None - self.incomplete_results = None - self.results = list() - self._request(query.query_string) - - def _request(self, query_string): - Search.request_counter.append(datetime.now()) - - print('Making request: {}'.format(query_string)) - - headers = {'Authorization': 'token ' + GITHUB_OAUTH_TOKEN} - r = requests.get(query_string, headers=headers) - - json = r.json() - - if r.status_code != 200: - print('Bad request:') - print(r.status_code) - print(json) - Search.request_counter.timeout() - self._request(query_string) - return - - self.total_count = json['total_count'] - print('Number of results: {}.'.format(self.total_count)) - self.incomplete_results = json['incomplete_results'] - if self.incomplete_results: - raise IncompleteResultsError() - self.parse_results(json['items']) - - @abstractmethod - def parse_results(self, json_results): - pass - - -class SearchRepo(Search): - def parse_results(self, json_results): - for item in json_results: - self.results.append(Repo(item)) - - -class SearchCode(Search): - def parse_results(self, json_results): - for item in json_results: - self.results.append(File(item)) - - -class File: - def __init__(self, json): - self.name = json['name'] - self.repo = Repo(json['repository']) - - -class Repo: - def __init__(self, json): - self.url = json['html_url'] - self.name = json['full_name'] - - -def get_dates(start_date, end_date=date.today(), interval=7): - delta = end_date - start_date - for i in range(delta.days // interval): - yield (start_date + timedelta(days=(i * interval) - interval), - start_date + timedelta(days=i * interval)) - else: - # Take care of the remainder of days - yield (start_date + timedelta(days=i * interval), - start_date + timedelta(days=i * interval + - interval + - delta.days % interval)) - - -def scan_github(search_string, start_date, analysis_type, analyse_repo_func, csv_path, ui_mode, other_args): - analyse_repo = analyse_repo_func - for d in get_dates(start_date, interval=7): - q = Query(SEARCH_REPO_URL, search_string, - language=Languages.python, - time_interval=str(d[0]) + ' .. ' + str(d[1]), - per_page=100) - s = SearchRepo(q) - for repo in s.results: - q = Query(SEARCH_CODE_URL, 'app = Flask(__name__)', - Languages.python, repo) - s = SearchCode(q) - if s.results: - r = repo_runner.Repo(repo.url) - try: - r.clone() - except NoEntryPathError as err: - save_repo_scan(repo, r.path, vulnerabilities=None, error=err) - continue - except: - save_repo_scan(repo, r.path, vulnerabilities=None, error='Other Error Unknown while cloning :-(') - continue - try: - vulnerabilities = analyse_repo(other_args, r, analysis_type, ui_mode) - if vulnerabilities: - save_repo_scan(repo, r.path, vulnerabilities) - add_repo_to_csv(csv_path, r) - else: - save_repo_scan(repo, r.path, vulnerabilities=None) - r.clean_up() - except SyntaxError as err: - save_repo_scan(repo, r.path, vulnerabilities=None, error=err) - except IOError as err: - save_repo_scan(repo, r.path, vulnerabilities=None, error=err) - except AttributeError as err: - save_repo_scan(repo, r.path, vulnerabilities=None, error=err) - except: - save_repo_scan(repo, r.path, vulnerabilities=None, error='Other Error Unknown :-(') - -if __name__ == '__main__': - for x in get_dates(date(2010, 1, 1), interval=93): - print(x) - exit() - scan_github('flask', ReachingDefinitionsTaintAnalysis) - exit() - q = Query(SEARCH_REPO_URL, 'flask') - s = SearchRepo(q) - for repo in s.results[:3]: - q = Query(SEARCH_CODE_URL, 'app = Flask(__name__)', Languages.python, repo) - s = SearchCode(q) - r = repo_runner.Repo(repo.url) - r.clone() - print(r.path) - r.clean_up() - print(repo.name) - print(len(s.results)) - print([f.name for f in s.results]) - exit() - - r = RequestCounter('test', timeout=2) - for x in range(15): - r.append(datetime.now()) - exit() - - dates = get_dates(date(2010, 1, 1)) - for date in dates: - q = Query(SEARCH_REPO_URL, 'flask', - time_interval=str(date) + ' .. ' + str(date)) - print(q.query_string) - exit() - s = SearchRepo(q) - print(s.total_count) - print(s.incomplete_results) - print([r.URL for r in s.results]) - q = Query(SEARCH_CODE_URL, 'import flask', Languages.python, s.results[0]) - s = SearchCode(q) - #print(s.total_count) - #print(s.incomplete_results) - #print([f.name for f in s.results]) diff --git a/pyt/helper_visitors/README.rst b/pyt/helper_visitors/README.rst new file mode 100644 index 0000000..cce52a0 --- /dev/null +++ b/pyt/helper_visitors/README.rst @@ -0,0 +1 @@ +Documentation coming soon. diff --git a/pyt/helper_visitors/__init__.py b/pyt/helper_visitors/__init__.py new file mode 100644 index 0000000..ffd5f87 --- /dev/null +++ b/pyt/helper_visitors/__init__.py @@ -0,0 +1,10 @@ +from .label_visitor import LabelVisitor +from .right_hand_side_visitor import RHSVisitor +from .vars_visitor import VarsVisitor + + +__all__ = [ + 'LabelVisitor', + 'RHSVisitor', + 'VarsVisitor' +] diff --git a/pyt/label_visitor.py b/pyt/helper_visitors/label_visitor.py similarity index 100% rename from pyt/label_visitor.py rename to pyt/helper_visitors/label_visitor.py diff --git a/pyt/right_hand_side_visitor.py b/pyt/helper_visitors/right_hand_side_visitor.py similarity index 100% rename from pyt/right_hand_side_visitor.py rename to pyt/helper_visitors/right_hand_side_visitor.py diff --git a/pyt/vars_visitor.py b/pyt/helper_visitors/vars_visitor.py similarity index 98% rename from pyt/vars_visitor.py rename to pyt/helper_visitors/vars_visitor.py index f64abd9..bda24c9 100644 --- a/pyt/vars_visitor.py +++ b/pyt/helper_visitors/vars_visitor.py @@ -1,7 +1,7 @@ import ast import itertools -from .ast_helper import get_call_names +from ..core.ast_helper import get_call_names class VarsVisitor(ast.NodeVisitor): diff --git a/pyt/liveness.py b/pyt/liveness.py deleted file mode 100644 index 38935b9..0000000 --- a/pyt/liveness.py +++ /dev/null @@ -1,134 +0,0 @@ -import ast - -from .analysis_base import AnalysisBase -from .ast_helper import get_call_names_as_string -from .constraint_table import ( - constraint_join, - constraint_table -) -from .lattice import Lattice -from .node_types import ( - AssignmentNode, - BBorBInode, - EntryOrExitNode -) -from .vars_visitor import VarsVisitor - - -class LivenessAnalysis(AnalysisBase): - """Reaching definitions analysis rules implemented.""" - - def __init__(self, cfg): - super().__init__(cfg) - - def join(self, cfg_node): - """Joins all constraints of the ingoing nodes and returns them. - This represents the JOIN auxiliary definition from Schwartzbach.""" - return constraint_join(cfg_node.outgoing) - - def is_output(self, cfg_node): - if isinstance(cfg_node.ast_node, ast.Call): - call_name = get_call_names_as_string(cfg_node.ast_node.func) - if 'print' in call_name: - return True - return False - - def is_condition(self, cfg_node): - if isinstance(cfg_node.ast_node, (ast.If, ast.While)): - return True - elif self.is_output(cfg_node): - return True - return False - - def remove_id_assignment(self, JOIN, cfg_node): - lvars = list() - - if isinstance(cfg_node, BBorBInode): - lvars.append(cfg_node.left_hand_side) - else: - try: - for expr in cfg_node.ast_node.targets: - vv = VarsVisitor() - vv.visit(expr) - lvars.extend(vv.result) - except AttributeError: # If it is AugAssign - vv = VarsVisitor() - vv.visit(cfg_node.ast_node.target) - lvars.extend(vv.result) - for var in lvars: - if var in self.lattice.get_elements(JOIN): - # Remove var from JOIN - JOIN = JOIN ^ self.lattice.el2bv[var] - return JOIN - - def add_vars_assignment(self, JOIN, cfg_node): - rvars = list() - if isinstance(cfg_node, BBorBInode): - # A conscience decision was made not to include e.g. ~call_N's in RHS vars - rvars.extend(cfg_node.right_hand_side_variables) - else: - vv = VarsVisitor() - vv.visit(cfg_node.ast_node.value) - rvars.extend(vv.result) - for var in rvars: - # Add var to JOIN - JOIN = JOIN | self.lattice.el2bv[var] - return JOIN - - def add_vars_conditional(self, JOIN, cfg_node): - varse = None - if isinstance(cfg_node.ast_node, ast.While): - vv = VarsVisitor() - vv.visit(cfg_node.ast_node.test) - varse = vv.result - elif self.is_output(cfg_node): - vv = VarsVisitor() - vv.visit(cfg_node.ast_node) - varse = vv.result - elif isinstance(cfg_node.ast_node, ast.If): - vv = VarsVisitor() - vv.visit(cfg_node.ast_node.test) - varse = vv.result - - for var in varse: - JOIN = JOIN | self.lattice.el2bv[var] - - return JOIN - - def fixpointmethod(self, cfg_node): - if isinstance(cfg_node, EntryOrExitNode) and 'Exit' in cfg_node.label: - constraint_table[cfg_node] = 0 - elif isinstance(cfg_node, AssignmentNode): - JOIN = self.join(cfg_node) - JOIN = self.remove_id_assignment(JOIN, cfg_node) - JOIN = self.add_vars_assignment(JOIN, cfg_node) - constraint_table[cfg_node] = JOIN - elif self.is_condition(cfg_node): - JOIN = self.join(cfg_node) - JOIN = self.add_vars_conditional(JOIN, cfg_node) - constraint_table[cfg_node] = JOIN - else: - constraint_table[cfg_node] = self.join(cfg_node) - - def dep(self, q_1): - """Represents the dep mapping from Schwartzbach.""" - for node in q_1.outgoing: - yield node - - def get_lattice_elements(cfg_nodes): - """Returns all variables as they are the only lattice elements - in the liveness analysis. - This is a static method which is overwritten from the base class.""" - lattice_elements = set() # set() to avoid duplicates - for node in (node for node in cfg_nodes if node.ast_node): - vv = VarsVisitor() - vv.visit(node.ast_node) - for var in vv.result: - lattice_elements.add(var) - return lattice_elements - - def equal(self, value, other): - return value == other - - def build_lattice(self, cfg): - self.lattice = Lattice(cfg.nodes, LivenessAnalysis) diff --git a/pyt/reaching_definitions.py b/pyt/reaching_definitions.py deleted file mode 100644 index 3bf5d07..0000000 --- a/pyt/reaching_definitions.py +++ /dev/null @@ -1,20 +0,0 @@ -from .constraint_table import constraint_table -from .node_types import AssignmentNode -from .reaching_definitions_base import ReachingDefinitionsAnalysisBase - - -class ReachingDefinitionsAnalysis(ReachingDefinitionsAnalysisBase): - """Reaching definitions analysis rules implemented.""" - - def fixpointmethod(self, cfg_node): - JOIN = self.join(cfg_node) - # Assignment check - if isinstance(cfg_node, AssignmentNode): - arrow_result = JOIN - # Get previous assignments of cfg_node.left_hand_side and remove them from JOIN - arrow_result = self.arrow(JOIN, cfg_node.left_hand_side) - arrow_result = arrow_result | self.lattice.el2bv[cfg_node] - constraint_table[cfg_node] = arrow_result - # Default case - else: - constraint_table[cfg_node] = JOIN diff --git a/pyt/reaching_definitions_base.py b/pyt/reaching_definitions_base.py deleted file mode 100644 index 5fecaa6..0000000 --- a/pyt/reaching_definitions_base.py +++ /dev/null @@ -1,47 +0,0 @@ -from .analysis_base import AnalysisBase -from .constraint_table import constraint_join -from .lattice import Lattice -from .node_types import AssignmentNode - - -class ReachingDefinitionsAnalysisBase(AnalysisBase): - """Reaching definitions analysis rules implemented.""" - - def __init__(self, cfg): - super().__init__(cfg) - - def join(self, cfg_node): - """Joins all constraints of the ingoing nodes and returns them. - This represents the JOIN auxiliary definition from Schwartzbach.""" - return constraint_join(cfg_node.ingoing) - - def arrow(self, JOIN, _id): - """Removes all previous assignments from JOIN that have the same left hand side. - This represents the arrow id definition from Schwartzbach.""" - r = JOIN - for node in self.lattice.get_elements(JOIN): - if node.left_hand_side == _id: - r = r ^ self.lattice.el2bv[node] - return r - - def fixpointmethod(self, cfg_node): - raise NotImplementedError() - - def dep(self, q_1): - """Represents the dep mapping from Schwartzbach.""" - for node in q_1.outgoing: - yield node - - def get_lattice_elements(cfg_nodes): - """Returns all assignment nodes as they are the only lattice elements - in the reaching definitions analysis. - This is a static method which is overwritten from the base class.""" - for node in cfg_nodes: - if isinstance(node, AssignmentNode): - yield node - - def equal(self, value, other): - return value == other - - def build_lattice(self, cfg): - self.lattice = Lattice(cfg.nodes, ReachingDefinitionsAnalysisBase) diff --git a/pyt/reaching_definitions_taint.py b/pyt/reaching_definitions_taint.py deleted file mode 100644 index 3ca72ac..0000000 --- a/pyt/reaching_definitions_taint.py +++ /dev/null @@ -1,24 +0,0 @@ -from .constraint_table import constraint_table -from .node_types import AssignmentNode -from .reaching_definitions_base import ReachingDefinitionsAnalysisBase - - -class ReachingDefinitionsTaintAnalysis(ReachingDefinitionsAnalysisBase): - """Reaching definitions analysis rules implemented.""" - - def fixpointmethod(self, cfg_node): - JOIN = self.join(cfg_node) - # Assignment check - if isinstance(cfg_node, AssignmentNode): - arrow_result = JOIN - - # Reassignment check - if cfg_node.left_hand_side not in cfg_node.right_hand_side_variables: - # Get previous assignments of cfg_node.left_hand_side and remove them from JOIN - arrow_result = self.arrow(JOIN, cfg_node.left_hand_side) - - arrow_result = arrow_result | self.lattice.el2bv[cfg_node] - constraint_table[cfg_node] = arrow_result - # Default case - else: - constraint_table[cfg_node] = JOIN diff --git a/pyt/repo_runner.py b/pyt/repo_runner.py deleted file mode 100644 index cfb6329..0000000 --- a/pyt/repo_runner.py +++ /dev/null @@ -1,90 +0,0 @@ -"""This modules runs PyT on a CSV file of git repos.""" -import os -import shutil - -import git - - -DEFAULT_CSV_PATH = 'flask_open_source_apps.csv' - - -class NoEntryPathError(Exception): - pass - - -class Repo: - """Holder for a repo with git URL and - a path to where the analysis should start""" - - def __init__(self, URL, path=None): - self.URL = URL.strip() - if path: - self.path = path.strip() - else: - self.path = None - self.directory = None - - def clone(self): - """Clone repo and update path to match the current one""" - - r = self.URL.split('/')[-1].split('.') - if len(r) > 1: - self.directory = '.'.join(r[:-1]) - else: - self.directory = r[0] - - if self.directory not in os.listdir(): - git.Git().clone(self.URL) - - if self.path is None: - self._find_entry_path() - elif self.path[0] == '/': - self.path = self.path[1:] - self.path = os.path.join(self.directory, self.path) - else: - self.path = os.path.join(self.directory, self.path) - - def _find_entry_path(self): - for root, dirs, files in os.walk(self.directory): - for f in files: - if f.endswith('.py'): - with open(os.path.join(root, f), 'r') as fd: - if 'app = Flask(__name__)' in fd.read(): - self.path = os.path.join(root, f) - return - raise NoEntryPathError('No entry path found in repo {}.' - .format(self.URL)) - - def clean_up(self): - """Deletes the repo""" - shutil.rmtree(self.directory) - - -def get_repos(csv_path): - """Parses a CSV file containing repos.""" - repos = list() - with open(csv_path, 'r') as fd: - for line in fd: - url, path = line.split(',') - repos.append(Repo(url, path)) - return repos - - -def add_repo_to_file(path, repo): - try: - with open(path, 'a') as fd: - fd.write('{}{}, {}' - .format(os.linesep, repo.URL, repo.path)) - except FileNotFoundError: - print('-csv handle not used and fallback path not found: {}' - .format(DEFAULT_CSV_PATH)) - print('You need to specify the csv_path' - ' by using the "-csv" handle.') - exit(1) - - -def add_repo_to_csv(csv_path, repo): - if csv_path is None: - add_repo_to_file(DEFAULT_CSV_PATH, repo) - else: - add_repo_to_file(csv_path, repo) diff --git a/pyt/save.py b/pyt/save.py deleted file mode 100644 index 5f0e375..0000000 --- a/pyt/save.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -from datetime import datetime - -from .definition_chains import ( - build_def_use_chain, - build_use_def_chain -) -from .formatters import text -from .lattice import Lattice -from .node_types import Node - - -database_file_name = 'db.sql' -nodes_table_name = 'nodes' -vulnerabilities_table_name = 'vulnerabilities' - -def create_nodes_table(): - with open(database_file_name, 'a') as fd: - fd.write('DROP TABLE IF EXISTS ' + nodes_table_name + '\n') - fd.write('CREATE TABLE ' + nodes_table_name + '(id int,label varchar(255),line_number int, path varchar(255));') - -def create_vulnerabilities_table(): - with open(database_file_name, 'a') as fd: - fd.write('DROP TABLE IF EXISTS ' + vulnerabilities_table_name + '\n') - fd.write('CREATE TABLE ' + vulnerabilities_table_name + '(id int, source varchar(255), source_word varchar(255), sink varchar(255), sink_word varchar(255));') - -def quote(item): - if isinstance(item, Node): - item = item.label - return "'" + item.replace("'", "''") + "'" - -def insert_vulnerability(vulnerability): - with open(database_file_name, 'a') as fd: - fd.write('\nINSERT INTO ' + vulnerabilities_table_name + '\n') - fd.write('VALUES (') - fd.write(quote(vulnerability.__dict__['source']) + ',') - fd.write(quote(vulnerability.__dict__['source_trigger_word']) + ',') - fd.write(quote(vulnerability.__dict__['sink']) + ',') - fd.write(quote(vulnerability.__dict__['sink_trigger_word'])) - fd.write(');') - -def insert_node(node): - with open(database_file_name, 'a') as fd: - fd.write('\nINSERT INTO ' + nodes_table_name + '\n') - fd.write('VALUES (') - fd.write("'" + node.__dict__['label'].replace("'", "''") + "'" + ',') - line_number = node.__dict__['line_number'] - if line_number: - fd.write(str(line_number) + ',') - else: - fd.write('NULL,') - path = node.__dict__['path'] - if path: - fd.write("'" + path.replace("'", "''") + "'") - else: - fd.write('NULL') - fd.write(');') - -def create_database(cfg_list, vulnerabilities): - create_nodes_table() - for cfg in cfg_list: - for node in cfg.nodes: - insert_node(node) - create_vulnerabilities_table() - for vulnerability in vulnerabilities: - insert_vulnerability(vulnerability) - - -class Output(): - filename_prefix = None - - def __init__(self, title): - if Output.filename_prefix: - self.title = Output.filename_prefix + '_' + title - else: - self.title = title - - def __enter__(self): - self.fd = open(self.title, 'w') - return self.fd - - def __exit__(self, type, value, traceback): - self.fd.close() - - -def def_use_chain_to_file(cfg_list): - with Output('def-use_chain.pyt') as fd: - for i, cfg in enumerate(cfg_list): - fd.write('##### Def-use chain for CFG {} #####{}' - .format(i, os.linesep)) - def_use = build_def_use_chain(cfg.nodes) - for k, v in def_use.items(): - fd.write('Def: {} -> Use: [{}]{}' - .format(k.label, - ', '.join([n.label for n in v]), - os.linesep)) - - -def use_def_chain_to_file(cfg_list): - with Output('use-def_chain.pyt') as fd: - for i, cfg in enumerate(cfg_list): - fd.write('##### Use-def chain for CFG {} #####{}' - .format(i, os.linesep)) - def_use = build_use_def_chain(cfg.nodes) - for k, v in def_use.items(): - fd.write('Use: {} -> Def: [{}]{}' - .format(k.label, - ', '.join([n[1].label for n in v]), - os.linesep)) - - -def cfg_to_file(cfg_list): - with Output('control_flow_graph.pyt') as fd: - for i, cfg in enumerate(cfg_list): - fd.write('##### CFG {} #####{}'.format(i, os.linesep)) - for i, node in enumerate(cfg.nodes): - fd.write('Node {}: {}{}'.format(i, node.label, os.linesep)) - - -def verbose_cfg_to_file(cfg_list): - with Output('verbose_control_flow_graph.pyt') as fd: - for i, cfg in enumerate(cfg_list): - fd.write('##### CFG {} #####{}'.format(i, os.linesep)) - for i, node in enumerate(cfg.nodes): - fd.write('Node {}: {}{}'.format(i, repr(node), os.linesep)) - - -def lattice_to_file(cfg_list, analysis_type): - with Output('lattice.pyt') as fd: - for i, cfg in enumerate(cfg_list): - fd.write('##### Lattice for CFG {} #####{}'.format(i, os.linesep)) - l = Lattice(cfg.nodes, analysis_type) - - fd.write('# Elements to bitvector #{}'.format(os.linesep)) - for k, v in l.el2bv.items(): - fd.write('{} -> {}{}'.format(str(k), bin(v), os.linesep)) - - fd.write('# Bitvector to elements #{}'.format(os.linesep)) - for k, v in l.el2bv.items(): - fd.write('{} -> {}{}'.format(bin(v), str(k), os.linesep)) - - -def vulnerabilities_to_file(vulnerabilities): - with Output('vulnerabilities.pyt') as fd: - text.report(vulnerabilities, fd) - - -def save_repo_scan(repo, entry_path, vulnerabilities, error=None): - with open('scan.pyt', 'a') as fd: - fd.write('{}{}'.format(repo.name, os.linesep)) - fd.write('{}{}'.format(repo.url, os.linesep)) - fd.write('Entry file: {}{}'.format(entry_path, os.linesep)) - fd.write('Scanned: {}{}'.format(datetime.now(), os.linesep)) - if vulnerabilities: - text.report(vulnerabilities, fd) - else: - fd.write('No vulnerabilities found.{}'.format(os.linesep)) - if error: - fd.write('An Error occurred while scanning the repo: {}' - .format(str(error))) - fd.write(os.linesep) - fd.write(os.linesep) diff --git a/pyt/usage.py b/pyt/usage.py new file mode 100644 index 0000000..4930eb0 --- /dev/null +++ b/pyt/usage.py @@ -0,0 +1,131 @@ +import argparse +import os +import sys +from datetime import datetime + + +default_blackbox_mapping_file = os.path.join( + os.path.dirname(__file__), + 'vulnerability_definitions', + 'blackbox_mapping.json' +) + + +default_trigger_word_file = os.path.join( + os.path.dirname(__file__), + 'vulnerability_definitions', + 'all_trigger_words.pyt' +) + + +def valid_date(s): + date_format = "%Y-%m-%d" + try: + return datetime.strptime(s, date_format).date() + except ValueError: + msg = "Not a valid date: '{0}'. Format: {1}".format(s, date_format) + raise argparse.ArgumentTypeError(msg) + + +def _add_required_group(parser): + required_group = parser.add_argument_group('required arguments') + required_group.add_argument( + '-f', '--filepath', + help='Path to the file that should be analysed.', + type=str + ) + + +def _add_optional_group(parser): + optional_group = parser.add_argument_group('optional arguments') + + optional_group.add_argument( + '-a', '--adaptor', + help='Choose a web framework adaptor: ' + 'Flask(Default), Django, Every or Pylons', + type=str + ) + optional_group.add_argument( + '-pr', '--project-root', + help='Add project root, only important when the entry ' + 'file is not at the root of the project.', + type=str + ) + optional_group.add_argument( + '-b', '--baseline', + help='Path of a baseline report to compare against ' + '(only JSON-formatted files are accepted)', + type=str, + default=False, + metavar='BASELINE_JSON_FILE', + ) + optional_group.add_argument( + '-j', '--json', + help='Prints JSON instead of report.', + action='store_true', + default=False + ) + optional_group.add_argument( + '-m', '--blackbox-mapping-file', + help='Input blackbox mapping file.', + type=str, + default=default_blackbox_mapping_file + ) + optional_group.add_argument( + '-t', '--trigger-word-file', + help='Input file with a list of sources and sinks', + type=str, + default=default_trigger_word_file + ) + optional_group.add_argument( + '-o', '--output', + help='write report to filename', + dest='output_file', + action='store', + type=argparse.FileType('w'), + default=sys.stdout, + ) + optional_group.add_argument( + '--ignore-nosec', + dest='ignore_nosec', + action='store_true', + help='do not skip lines with # nosec comments' + ) + + +def _add_print_group(parser): + print_group = parser.add_argument_group('print arguments') + print_group.add_argument( + '-trim', '--trim-reassigned-in', + help='Trims the reassigned list to just the vulnerability chain.', + action='store_true', + default=True + ) + print_group.add_argument( + '-i', '--interactive', + help='Will ask you about each blackbox function call in vulnerability chains.', + action='store_true', + default=False + ) + + +def _check_required_and_mutually_exclusive_args(parser, args): + if args.filepath is None: + parser.error('The -f/--filepath argument is required') + + +def parse_args(args): + if len(args) == 0: + args.append('-h') + parser = argparse.ArgumentParser(prog='python -m pyt') + parser._action_groups.pop() + _add_required_group(parser) + _add_optional_group(parser) + _add_print_group(parser) + + args = parser.parse_args(args) + _check_required_and_mutually_exclusive_args( + parser, + args + ) + return args diff --git a/pyt/utils/log.py b/pyt/utils/log.py deleted file mode 100644 index 74a49f5..0000000 --- a/pyt/utils/log.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging - - -LOGGING_FMT = '%(levelname)3s] %(filename)s::%(funcName)s(%(lineno)d) - %(message)s' - - -def remove_other_handlers(to_keep=None): - for hdl in logger.handlers: - if hdl != to_keep: - logger.removeHandler(hdl) - - -def enable_logger(to_file=None): - logger.setLevel(logging.DEBUG) - ch = logging.StreamHandler() if not to_file else logging.FileHandler(to_file, mode='w') - ch.setLevel(logging.DEBUG) - fmt = logging.Formatter(LOGGING_FMT) - ch.setFormatter(fmt) - logger.addHandler(ch) - remove_other_handlers(ch) - -logger = logging.getLogger('pyt') -remove_other_handlers() diff --git a/pyt/vulnerabilities/README.rst b/pyt/vulnerabilities/README.rst new file mode 100644 index 0000000..3ba5b13 --- /dev/null +++ b/pyt/vulnerabilities/README.rst @@ -0,0 +1 @@ +Coming soon. diff --git a/pyt/vulnerabilities/__init__.py b/pyt/vulnerabilities/__init__.py new file mode 100644 index 0000000..992af18 --- /dev/null +++ b/pyt/vulnerabilities/__init__.py @@ -0,0 +1,12 @@ +from .vulnerabilities import find_vulnerabilities +from .vulnerability_helper import ( + get_vulnerabilities_not_in_baseline, + UImode +) + + +__all__ = [ + 'find_vulnerabilities', + 'get_vulnerabilities_not_in_baseline', + 'UImode' +] diff --git a/pyt/trigger_definitions_parser.py b/pyt/vulnerabilities/trigger_definitions_parser.py similarity index 99% rename from pyt/trigger_definitions_parser.py rename to pyt/vulnerabilities/trigger_definitions_parser.py index 7515da4..62cdbce 100644 --- a/pyt/trigger_definitions_parser.py +++ b/pyt/vulnerabilities/trigger_definitions_parser.py @@ -1,4 +1,3 @@ -import os from collections import namedtuple diff --git a/pyt/vulnerabilities.py b/pyt/vulnerabilities/vulnerabilities.py similarity index 88% rename from pyt/vulnerabilities.py rename to pyt/vulnerabilities/vulnerabilities.py index 1bc5775..b41ae37 100644 --- a/pyt/vulnerabilities.py +++ b/pyt/vulnerabilities/vulnerabilities.py @@ -2,71 +2,28 @@ import ast import json -from collections import namedtuple -from .argument_helpers import UImode -from .definition_chains import build_def_use_chain -from .lattice import Lattice -from .node_types import ( +from ..analysis.definition_chains import build_def_use_chain +from ..analysis.lattice import Lattice +from ..core.node_types import ( AssignmentNode, BBorBInode, IfNode, TaintedNode ) -from .right_hand_side_visitor import RHSVisitor +from ..helper_visitors import ( + RHSVisitor, + VarsVisitor +) from .trigger_definitions_parser import parse -from .vars_visitor import VarsVisitor from .vulnerability_helper import ( + Sanitiser, + TriggerNode, + Triggers, vuln_factory, - VulnerabilityType -) - - -Sanitiser = namedtuple( - 'Sanitiser', - ( - 'trigger_word', - 'cfg_node' - ) + VulnerabilityType, + UImode ) -Triggers = namedtuple( - 'Triggers', - ( - 'sources', - 'sinks', - 'sanitiser_dict' - ) -) - - -class TriggerNode(): - def __init__(self, trigger_word, sanitisers, cfg_node, secondary_nodes=[]): - self.trigger_word = trigger_word - self.sanitisers = sanitisers - self.cfg_node = cfg_node - self.secondary_nodes = secondary_nodes - - def append(self, cfg_node): - if not cfg_node == self.cfg_node: - if self.secondary_nodes and cfg_node not in self.secondary_nodes: - self.secondary_nodes.append(cfg_node) - elif not self.secondary_nodes: - self.secondary_nodes = [cfg_node] - - def __repr__(self): - output = 'TriggerNode(' - - if self.trigger_word: - output = '{} trigger_word is {}, '.format( - output, - self.trigger_word - ) - - return ( - output + - 'sanitisers are {}, '.format(self.sanitisers) + - 'cfg_node is {})\n'.format(self.cfg_node) - ) def identify_triggers( @@ -172,7 +129,7 @@ def append_node_if_reassigned( def find_triggers( nodes, trigger_words, - nosec_lines=set() + nosec_lines ): """Find triggers from the trigger_word_list in the nodes. @@ -437,7 +394,10 @@ def get_vulnerability( elif isinstance(cfg_node, IfNode): potential_sanitiser = cfg_node - def_use = build_def_use_chain(cfg.nodes) + def_use = build_def_use_chain( + cfg.nodes, + lattice + ) for chain in get_vulnerability_chains( source.cfg_node, sink.cfg_node, @@ -506,36 +466,37 @@ def find_vulnerabilities_in_cfg( def find_vulnerabilities( cfg_list, - analysis_type, ui_mode, - vulnerability_files, + blackbox_mapping_file, + source_sink_file, nosec_lines=set() ): """Find vulnerabilities in a list of CFGs from a trigger_word_file. Args: cfg_list(list[CFG]): the list of CFGs to scan. - analysis_type(AnalysisBase): analysis object used to create lattice. ui_mode(UImode): determines if we interact with the user or trim the nodes in the output, if at all. - vulnerability_files(VulnerabilityFiles): contains trigger words and blackbox_mapping files + blackbox_mapping_file(str) + source_sink_file(str) Returns: A list of vulnerabilities. """ vulnerabilities = list() - definitions = parse(vulnerability_files.triggers) - with open(vulnerability_files.blackbox_mapping) as infile: + definitions = parse(source_sink_file) + + with open(blackbox_mapping_file) as infile: blackbox_mapping = json.load(infile) for cfg in cfg_list: find_vulnerabilities_in_cfg( cfg, definitions, - Lattice(cfg.nodes, analysis_type), + Lattice(cfg.nodes), ui_mode, blackbox_mapping, vulnerabilities, nosec_lines ) - with open(vulnerability_files.blackbox_mapping, 'w') as outfile: + with open(blackbox_mapping_file, 'w') as outfile: json.dump(blackbox_mapping, outfile, indent=4) return vulnerabilities diff --git a/pyt/vulnerability_helper.py b/pyt/vulnerabilities/vulnerability_helper.py similarity index 67% rename from pyt/vulnerability_helper.py rename to pyt/vulnerabilities/vulnerability_helper.py index 832160e..1104de1 100644 --- a/pyt/vulnerability_helper.py +++ b/pyt/vulnerabilities/vulnerability_helper.py @@ -1,8 +1,8 @@ -"""This module contains vulnerability types and helpers. +"""This module contains vulnerability types, Enums, nodes and helpers.""" -It is only used in vulnerabilities.py -""" +import json from enum import Enum +from collections import namedtuple class VulnerabilityType(Enum): @@ -12,6 +12,12 @@ class VulnerabilityType(Enum): UNKNOWN = 3 +class UImode(Enum): + INTERACTIVE = 0 + NORMAL = 1 + TRIM = 2 + + def vuln_factory(vulnerability_type): if vulnerability_type == VulnerabilityType.UNKNOWN: return UnknownVulnerability @@ -62,9 +68,9 @@ def __str__(self): reassigned_str = _get_reassignment_str(self.reassignment_nodes) return ( 'File: {}\n' - ' > User input at line {}, trigger word "{}":\n' + ' > User input at line {}, source "{}":\n' '\t {}{}\nFile: {}\n' - ' > reaches line {}, trigger word "{}":\n' + ' > reaches line {}, sink "{}":\n' '\t{}'.format( self.source.path, self.source.line_number, self.source_trigger_word, @@ -134,3 +140,64 @@ def __str__(self): '\nThis vulnerability is unknown due to: ' + str(self.unknown_assignment) ) + + +Sanitiser = namedtuple( + 'Sanitiser', + ( + 'trigger_word', + 'cfg_node' + ) +) + + +Triggers = namedtuple( + 'Triggers', + ( + 'sources', + 'sinks', + 'sanitiser_dict' + ) +) + + +class TriggerNode(): + def __init__(self, trigger_word, sanitisers, cfg_node, secondary_nodes=[]): + self.trigger_word = trigger_word + self.sanitisers = sanitisers + self.cfg_node = cfg_node + self.secondary_nodes = secondary_nodes + + def append(self, cfg_node): + if not cfg_node == self.cfg_node: + if self.secondary_nodes and cfg_node not in self.secondary_nodes: + self.secondary_nodes.append(cfg_node) + elif not self.secondary_nodes: + self.secondary_nodes = [cfg_node] + + def __repr__(self): + output = 'TriggerNode(' + + if self.trigger_word: + output = '{} trigger_word is {}, '.format( + output, + self.trigger_word + ) + + return ( + output + + 'sanitisers are {}, '.format(self.sanitisers) + + 'cfg_node is {})\n'.format(self.cfg_node) + ) + + +def get_vulnerabilities_not_in_baseline( + vulnerabilities, + baseline_file +): + baseline = json.load(open(baseline_file)) + output = list() + for vuln in vulnerabilities: + if vuln.as_dict() not in baseline['vulnerabilities']: + output.append(vuln) + return(output) diff --git a/pyt/vulnerability_definitions/README.rst b/pyt/vulnerability_definitions/README.rst new file mode 100644 index 0000000..cce52a0 --- /dev/null +++ b/pyt/vulnerability_definitions/README.rst @@ -0,0 +1 @@ +Documentation coming soon. diff --git a/pyt/web_frameworks/README.rst b/pyt/web_frameworks/README.rst new file mode 100644 index 0000000..aa1b121 --- /dev/null +++ b/pyt/web_frameworks/README.rst @@ -0,0 +1,5 @@ +Coming soon. + + +Web frameworks +Sorry state of affairs diff --git a/pyt/web_frameworks/__init__.py b/pyt/web_frameworks/__init__.py new file mode 100644 index 0000000..e764e8a --- /dev/null +++ b/pyt/web_frameworks/__init__.py @@ -0,0 +1,20 @@ +from .framework_adaptor import ( + FrameworkAdaptor, + _get_func_nodes +) +from .framework_helper import ( + is_django_view_function, + is_flask_route_function, + is_function, + is_function_without_leading_ +) + + +__all__ = [ + 'FrameworkAdaptor', + 'is_django_view_function', + 'is_flask_route_function', + 'is_function', + 'is_function_without_leading_', + '_get_func_nodes' # Only used in framework_helper_test +] diff --git a/pyt/framework_adaptor.py b/pyt/web_frameworks/framework_adaptor.py similarity index 90% rename from pyt/framework_adaptor.py rename to pyt/web_frameworks/framework_adaptor.py index c7e5119..2bc4d7e 100644 --- a/pyt/framework_adaptor.py +++ b/pyt/web_frameworks/framework_adaptor.py @@ -2,10 +2,10 @@ import ast -from .ast_helper import Arguments -from .expr_visitor import make_cfg -from .module_definitions import project_definitions -from .node_types import ( +from ..cfg import make_cfg +from ..core.ast_helper import Arguments +from ..core.module_definitions import project_definitions +from ..core.node_types import ( AssignmentNode, TaintedNode ) @@ -16,7 +16,13 @@ class FrameworkAdaptor(): entry points in a framework and then taints their arguments. """ - def __init__(self, cfg_list, project_modules, local_modules, is_route_function): + def __init__( + self, + cfg_list, + project_modules, + local_modules, + is_route_function + ): self.cfg_list = cfg_list self.project_modules = project_modules self.local_modules = local_modules diff --git a/pyt/framework_helper.py b/pyt/web_frameworks/framework_helper.py similarity index 95% rename from pyt/framework_helper.py rename to pyt/web_frameworks/framework_helper.py index fd5996f..7156968 100644 --- a/pyt/framework_helper.py +++ b/pyt/web_frameworks/framework_helper.py @@ -1,12 +1,14 @@ """Provides helper functions that help with determining if a function is a route function.""" import ast -from .ast_helper import get_call_names +from ..core.ast_helper import get_call_names -def is_function(function): - """Always returns true because arg is always a function.""" - return True +def is_django_view_function(ast_node): + if len(ast_node.args.args): + first_arg_name = ast_node.args.args[0].arg + return first_arg_name == 'request' + return False def is_flask_route_function(ast_node): @@ -18,11 +20,9 @@ def is_flask_route_function(ast_node): return False -def is_django_view_function(ast_node): - if len(ast_node.args.args): - first_arg_name = ast_node.args.args[0].arg - return first_arg_name == 'request' - return False +def is_function(function): + """Always returns true because arg is always a function.""" + return True def is_function_without_leading_(ast_node): diff --git a/tests/__main__.py b/tests/__main__.py index 8fad545..5356acd 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -1,4 +1,8 @@ -from unittest import TestLoader, TestSuite, TextTestRunner +from unittest import ( + TestLoader, + TestSuite, + TextTestRunner +) test_suite = TestSuite() @@ -11,6 +15,6 @@ if result.wasSuccessful(): print('Success') exit(0) -else: +else: # pragma: no cover print('Failure') exit(1) diff --git a/tests/analysis/__init__.py b/tests/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/analysis/analysis_base_test_case.py b/tests/analysis/analysis_base_test_case.py new file mode 100644 index 0000000..e92efb1 --- /dev/null +++ b/tests/analysis/analysis_base_test_case.py @@ -0,0 +1,76 @@ +from collections import namedtuple + +from ..base_test_case import BaseTestCase + +from pyt.analysis.constraint_table import ( + constraint_table, + initialize_constraint_table +) +from pyt.analysis.fixed_point import FixedPointAnalysis +from pyt.analysis.lattice import Lattice + + +def clear_constraint_table(): + for key in list(constraint_table): + del constraint_table[key] + + +class AnalysisBaseTestCase(BaseTestCase): + connection = namedtuple( + 'connection', + ( + 'constraintset', + 'element' + ) + ) + + def setUp(self): + self.cfg = None + + def assertInCfg(self, connections, lattice): + """Assert that all connections in the connections list exists in the cfg, + as well as all connections not in the list do not exist. + + Args: + connections(list[tuples]): the node at index 0 of the tuple has + to be in the new_constraint set of the node + at index 1 of the tuple. + lattice(Lattice): The lattice we're analysing. + """ + for connection in connections: + self.assertEqual(lattice.in_constraint( + self.cfg.nodes[connection[0]], + self.cfg.nodes[connection[1]]), + True, + str(connection) + " expected to be connected") + nodes = len(self.cfg.nodes) + + for element in range(nodes): + for sets in range(nodes): + if (element, sets) not in connections: + self.assertEqual( + lattice.in_constraint( + self.cfg.nodes[element], + self.cfg.nodes[sets] + ), + False, + "(%s,%s)" % (self.cfg.nodes[element], self.cfg.nodes[sets]) + " expected to be disconnected" + ) + + def constraints(self, list_of_constraints, node_number): + for c in list_of_constraints: + yield (c, node_number) + + def run_analysis(self, path): + self.cfg_create_from_file(path) + clear_constraint_table() + initialize_constraint_table([self.cfg]) + self.analysis = FixedPointAnalysis(self.cfg) + self.analysis.fixpoint_runner() + return Lattice(self.cfg.nodes) + + def string_compare_alnum(self, output, expected_string): + return ( + [char for char in output if char.isalnum()] == + [char for char in expected_string if char.isalnum()] + ) diff --git a/tests/analysis/reaching_definitions_taint_test.py b/tests/analysis/reaching_definitions_taint_test.py new file mode 100644 index 0000000..20f50b5 --- /dev/null +++ b/tests/analysis/reaching_definitions_taint_test.py @@ -0,0 +1,103 @@ +from .analysis_base_test_case import AnalysisBaseTestCase + +from pyt.analysis.constraint_table import constraint_table + + +class ReachingDefinitionsTaintTest(AnalysisBaseTestCase): + # Note: the numbers in the test represent the line numbers of the assignments in the program. + def test_linear_program(self): + lattice = self.run_analysis('examples/example_inputs/linear.py') + + EXPECTED = [ + "Label: Entry module:", + "Label: ~call_1 = ret_input(): Label: ~call_1 = ret_input()", + "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_input()", + "Label: y = x - 1: Label: y = x - 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", + "Label: ~call_2 = ret_print(x): Label: ~call_2 = ret_print(x), Label: y = x - 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", + "Label: Exit module: Label: ~call_2 = ret_print(x), Label: y = x - 1, Label: x = ~call_1, Label: ~call_1 = ret_input()" + ] + i = 0 + for k, v in constraint_table.items(): + row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) + self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) + i = i + 1 + + def test_if_program(self): + lattice = self.run_analysis('examples/example_inputs/if_program.py') + + EXPECTED = [ + "Label: Entry module:", + "Label: ~call_1 = ret_input(): Label: ~call_1 = ret_input()", + "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_input()", + "Label: if x > 0:: Label: x = ~call_1, Label: ~call_1 = ret_input()", + "Label: y = x + 1: Label: y = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", + "Label: ~call_2 = ret_print(x): Label: ~call_2 = ret_print(x), Label: y = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", + "Label: Exit module: Label: ~call_2 = ret_print(x), Label: y = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_input()" + ] + i = 0 + for k, v in constraint_table.items(): + row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) + self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) + i = i + 1 + + def test_example(self): + lattice = self.run_analysis('examples/example_inputs/example.py') + + EXPECTED = [ + "Label: Entry module:", + "Label: ~call_1 = ret_input(): Label: ~call_1 = ret_input()", + "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_input()", + "Label: ~call_2 = ret_int(x): Label: ~call_2 = ret_int(x), Label: x = ~call_1, Label: ~call_1 = ret_input()", + "Label: x = ~call_2: Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: while x > 1:: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: y = x / 2: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: if y > 3:: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: x = x - y: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: z = x - 4: Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: if z > 0:: Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: x = x / 2: Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: z = z - 1: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: ~call_3 = ret_print(x): Label: ~call_3 = ret_print(x), Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", + "Label: Exit module: Label: ~call_3 = ret_print(x), Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()" + ] + i = 0 + for k, v in constraint_table.items(): + row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) + self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) + i = i + 1 + + def test_func_with_params(self): + lattice = self.run_analysis('examples/example_inputs/function_with_params.py') + + self.assertInCfg([(1, 1), + (1, 2), (2, 2), + (1, 3), (2, 3), (3, 3), + (1, 4), (2, 4), (3, 4), (4, 4), + (1, 5), (2, 5), (3, 5), (4, 5), + *self.constraints([1, 2, 3, 4, 6], 6), + *self.constraints([1, 2, 3, 4, 6, 7], 7), + *self.constraints([1, 2, 3, 4, 6, 7], 8), + *self.constraints([2, 3, 4, 6, 7, 9], 9), + *self.constraints([2, 3, 4, 6, 7, 9], 10)], lattice) + + def test_while(self): + lattice = self.run_analysis('examples/example_inputs/while.py') + + EXPECTED = [ + "Label: Entry module: ", + "Label: ~call_2 = ret_input(): Label: ~call_2 = ret_input()", + "Label: ~call_1 = ret_int(~call_2): Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", + "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", + "Label: while x < 10:: Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input(", + "Label: x = x + 1: Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", + "Label: if x == 5:: Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", + "Label: BreakNode: Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", + "Label: x = 6: Label: x = 6, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", + "Label: ~call_3 = ret_print(x): Label: ~call_3 = ret_print(x), Label: x = 6, Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", + "Label: Exit module: Label: ~call_3 = ret_print(x), Label: x = 6, Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()" + ] + i = 0 + for k, v in constraint_table.items(): + row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) + self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) + i = i + 1 diff --git a/tests/analysis_base_test_case.py b/tests/analysis_base_test_case.py deleted file mode 100644 index cdd3318..0000000 --- a/tests/analysis_base_test_case.py +++ /dev/null @@ -1,49 +0,0 @@ -import unittest -from collections import namedtuple - -from .base_test_case import BaseTestCase -from pyt.constraint_table import initialize_constraint_table -from pyt.fixed_point import FixedPointAnalysis -from pyt.lattice import Lattice - - -class AnalysisBaseTestCase(BaseTestCase): - connection = namedtuple( - 'connection', - ( - 'constraintset', - 'element' - ) - ) - def setUp(self): - self.cfg = None - - def assertInCfg(self, connections, lattice): - """Assert that all connections in the connections list exists in the cfg, - as well as all connections not in the list do not exist. - - Args: - connections(list[tuples]): the node at index 0 of the tuple has - to be in the new_constraint set of the node - at index 1 of the tuple. - lattice(Lattice): The lattice we're analysing. - """ - for connection in connections: - self.assertEqual(lattice.in_constraint(self.cfg.nodes[connection[0]], self.cfg.nodes[connection[1]]), True, str(connection) + " expected to be connected") - nodes = len(self.cfg.nodes) - - for element in range(nodes): - for sets in range(nodes): - if (element, sets) not in connections: - self.assertEqual(lattice.in_constraint(self.cfg.nodes[element], self.cfg.nodes[sets]), False, "(%s,%s)" % (self.cfg.nodes[element], self.cfg.nodes[sets]) + " expected to be disconnected") - - def constraints(self, list_of_constraints, node_number): - for c in list_of_constraints: - yield (c,node_number) - - def run_analysis(self, path, analysis_type): - self.cfg_create_from_file(path) - initialize_constraint_table([self.cfg]) - self.analysis = FixedPointAnalysis(self.cfg, analysis_type) - self.analysis.fixpoint_runner() - return Lattice(self.cfg.nodes, analysis_type) diff --git a/tests/base_test_case.py b/tests/base_test_case.py index bbcb9d5..21b7c69 100644 --- a/tests/base_test_case.py +++ b/tests/base_test_case.py @@ -1,80 +1,29 @@ """A module that contains a base class that has helper methods for testing PyT.""" import unittest -from pyt.ast_helper import generate_ast -from pyt.expr_visitor import make_cfg -from pyt.module_definitions import project_definitions +from pyt.cfg import make_cfg +from pyt.core.ast_helper import generate_ast +from pyt.core.module_definitions import project_definitions class BaseTestCase(unittest.TestCase): """A base class that has helper methods for testing PyT.""" - def assertInCfg(self, connections): - """Asserts that all connections in the connections list exists in the cfg, - as well as that all connections not in the list do not exist. - - Args: - connections(list[tuple]): the node at index 0 of the tuple has - to be in the new_constraint set of the node - at index 1 of the tuple. - """ - for connection in connections: - self.assertIn(self.cfg.nodes[connection[0]], self.cfg.nodes[connection[1]].outgoing, str(connection) + " expected to be connected") - self.assertIn(self.cfg.nodes[connection[1]], self.cfg.nodes[connection[0]].ingoing, str(connection) + " expected to be connected") - - nodes = len(self.cfg.nodes) - - for element in range(nodes): - for sets in range(nodes): - if not (element, sets) in connections: - self.assertNotIn(self.cfg.nodes[element], self.cfg.nodes[sets].outgoing, "(%s <- %s)" % (element, sets) + " expected to be disconnected") - self.assertNotIn(self.cfg.nodes[sets], self.cfg.nodes[element].ingoing, "(%s <- %s)" % (sets, element) + " expected to be disconnected") - - def assertConnected(self, node, successor): - """Asserts that a node is connected to its successor. - This means that node has successor in its outgoing and - successor has node in its ingoing.""" - - self.assertIn(successor, node.outgoing, - '\n%s was NOT found in the outgoing list of %s containing: ' % (successor.label, node.label) + '[' + ', '.join([x.label for x in node.outgoing]) + ']') - - self.assertIn(node, successor.ingoing, - '\n%s was NOT found in the ingoing list of %s containing: ' % (node.label, successor.label) + '[' + ', '.join([x.label for x in successor.ingoing]) + ']') - - def assertNotConnected(self, node, successor): - """Asserts that a node is not connected to its successor. - This means that node does not the successor in its outgoing and - successor does not have the node in its ingoing.""" - - self.assertNotIn(successor, node.outgoing, - '\n%s was mistakenly found in the outgoing list of %s containing: ' % (successor.label, node.label) + '[' + ', '.join([x.label for x in node.outgoing]) + ']') - - self.assertNotIn(node, successor.ingoing, - '\n%s was mistakenly found in the ingoing list of %s containing: ' % (node.label, successor.label) + '[' + ', '.join([x.label for x in successor.ingoing]) + ']') - - def assertLineNumber(self, node, line_number): - self.assertEqual(node.line_number, line_number) - - def cfg_list_to_dict(self, list): - """This method converts the CFG list to a dict, making it easier to find nodes to test. - This method assumes that no nodes in the code have the same label""" - return {x.label: x for x in list} - def assert_length(self, _list, *, expected_length): actual_length = len(_list) self.assertEqual(expected_length, actual_length) - def cfg_create_from_file(self, filename, project_modules=list(), local_modules=list()): + def cfg_create_from_file( + self, + filename, + project_modules=list(), + local_modules=list() + ): project_definitions.clear() tree = generate_ast(filename) - self.cfg = make_cfg(tree, project_modules, local_modules, filename) - - def string_compare_alpha(self, output, expected_string): - return [char for char in output if char.isalpha()] \ - == \ - [char for char in expected_string if char.isalpha()] - - def string_compare_alnum(self, output, expected_string): - return [char for char in output if char.isalnum()] \ - == \ - [char for char in expected_string if char.isalnum()] + self.cfg = make_cfg( + tree, + project_modules, + local_modules, + filename + ) diff --git a/tests/cfg/__init__.py b/tests/cfg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cfg/cfg_base_test_case.py b/tests/cfg/cfg_base_test_case.py new file mode 100644 index 0000000..5d6d0f8 --- /dev/null +++ b/tests/cfg/cfg_base_test_case.py @@ -0,0 +1,50 @@ +from ..base_test_case import BaseTestCase + + +class CFGBaseTestCase(BaseTestCase): + + def assertInCfg(self, connections): + """Asserts that all connections in the connections list exists in the cfg, + as well as that all connections not in the list do not exist. + + Args: + connections(list[tuple]): the node at index 0 of the tuple has + to be in the new_constraint set of the node + at index 1 of the tuple. + """ + for connection in connections: + self.assertIn( + self.cfg.nodes[connection[0]], + self.cfg.nodes[connection[1]].outgoing, + str(connection) + " expected to be connected" + ) + self.assertIn( + self.cfg.nodes[connection[1]], + self.cfg.nodes[connection[0]].ingoing, + str(connection) + " expected to be connected" + ) + + nodes = len(self.cfg.nodes) + + for element in range(nodes): + for sets in range(nodes): + if not (element, sets) in connections: + self.assertNotIn( + self.cfg.nodes[element], + self.cfg.nodes[sets].outgoing, + "(%s <- %s)" % (element, sets) + " expected to be disconnected" + ) + self.assertNotIn( + self.cfg.nodes[sets], + self.cfg.nodes[element].ingoing, + "(%s <- %s)" % (sets, element) + " expected to be disconnected" + ) + + def assertLineNumber(self, node, line_number): + self.assertEqual(node.line_number, line_number) + + def cfg_list_to_dict(self, list): + """This method converts the CFG list to a dict, making it easier to find nodes to test. + This method assumes that no nodes in the code have the same label. + """ + return {x.label: x for x in list} diff --git a/tests/cfg_test.py b/tests/cfg/cfg_test.py similarity index 92% rename from tests/cfg_test.py rename to tests/cfg/cfg_test.py index f3a3a28..ceb9c5c 100644 --- a/tests/cfg_test.py +++ b/tests/cfg/cfg_test.py @@ -1,22 +1,22 @@ -from .base_test_case import BaseTestCase -from pyt.node_types import EntryOrExitNode, Node +from .cfg_base_test_case import CFGBaseTestCase +from pyt.core.node_types import ( + EntryOrExitNode, + Node +) -class CFGGeneralTest(BaseTestCase): + +class CFGGeneralTest(CFGBaseTestCase): def test_repr_cfg(self): self.cfg_create_from_file('examples/example_inputs/for_complete.py') self.nodes = self.cfg_list_to_dict(self.cfg.nodes) - #print(repr(self.cfg)) - def test_str_cfg(self): self.cfg_create_from_file('examples/example_inputs/for_complete.py') self.nodes = self.cfg_list_to_dict(self.cfg.nodes) - #print(self.cfg) - def test_no_tuples(self): self.cfg_create_from_file('examples/example_inputs/for_complete.py') @@ -35,7 +35,10 @@ def test_start_and_exit_nodes(self): node = 1 exit_node = 2 - self.assertInCfg([(1,0),(2,1)]) + self.assertInCfg([ + (node, start_node), + (exit_node, node) + ]) self.assertEqual(type(self.cfg.nodes[start_node]), EntryOrExitNode) self.assertEqual(type(self.cfg.nodes[exit_node]), EntryOrExitNode) @@ -57,7 +60,7 @@ def test_str_ignored(self): self.assertEqual(expected_label, actual_label) -class CFGForTest(BaseTestCase): +class CFGForTest(CFGBaseTestCase): def test_for_complete(self): self.cfg_create_from_file('examples/example_inputs/for_complete.py') @@ -72,7 +75,7 @@ def test_for_complete(self): next_node = 6 exit_node = 7 - self.assertEqual(self.cfg.nodes[for_node].label,'for x in range(3):') + self.assertEqual(self.cfg.nodes[for_node].label, 'for x in range(3):') self.assertEqual(self.cfg.nodes[body_1].label, '~call_1 = ret_print(x)') self.assertEqual(self.cfg.nodes[body_2].label, 'y += 1') self.assertEqual(self.cfg.nodes[else_body_1].label, "~call_2 = ret_print('Final: %s' % x)") @@ -107,15 +110,23 @@ def test_for_no_orelse(self): def test_for_tuple_target(self): self.cfg_create_from_file('examples/example_inputs/for_tuple_target.py') - self.assert_length(self.cfg.nodes, expected_length = 4) + self.assert_length(self.cfg.nodes, expected_length=4) entry_node = 0 for_node = 1 print_node = 2 exit_node = 3 - self.assertInCfg([(for_node,entry_node),(print_node,for_node),(for_node,print_node),(exit_node,for_node)]) - self.assertEqual(self.cfg.nodes[for_node].label, "for (x, y) in [(1, 2), (3, 4)]:") + self.assertInCfg([ + (for_node, entry_node), + (print_node, for_node), + (for_node, print_node), + (exit_node, for_node) + ]) + self.assertEqual( + self.cfg.nodes[for_node].label, + "for (x, y) in [(1, 2), (3, 4)]:" + ) def test_for_line_numbers(self): self.cfg_create_from_file('examples/example_inputs/for_complete.py') @@ -152,18 +163,21 @@ def test_for_func_iterator(self): _print = 7 _exit = 8 - self.assertInCfg([(_for, entry), - (_for, call_foo), - (_for, _print), - (entry_foo, _for), - (call_to_range, entry_foo), - (ret_foo, call_to_range), - (exit_foo, ret_foo), - (call_foo, exit_foo), - (_print, _for), - (_exit, _for)]) + self.assertInCfg([ + (_for, entry), + (_for, call_foo), + (_for, _print), + (entry_foo, _for), + (call_to_range, entry_foo), + (ret_foo, call_to_range), + (exit_foo, ret_foo), + (call_foo, exit_foo), + (_print, _for), + (_exit, _for) + ]) + -class CFGTryTest(BaseTestCase): +class CFGTryTest(CFGBaseTestCase): def connected(self, node, successor): return (successor, node) @@ -200,14 +214,14 @@ def test_orelse(self): print_a5 = 3 except_im = 4 except_im_body_1 = 5 - value_equal_call_2 = 6 # value = ~call_2 + value_equal_call_2 = 6 # value = ~call_2 print_wagyu = 7 save_node = 8 assign_to_temp = 9 assign_from_temp = 10 function_entry = 11 ret_of_subprocess_call = 12 - ret_does_this_kill_us_equal_call_5 = 13 # ret_does_this_kill_us = ~call_5 + ret_does_this_kill_us_equal_call_5 = 13 # ret_does_this_kill_us = ~call_5 function_exit = 14 restore_node = 15 return_handler = 16 @@ -271,7 +285,7 @@ def test_final(self): self.connected(print_final, _exit)]) -class CFGIfTest(BaseTestCase): +class CFGIfTest(CFGBaseTestCase): def test_if_complete(self): self.cfg_create_from_file('examples/example_inputs/if_complete.py') @@ -297,8 +311,18 @@ def test_if_complete(self): self.assertEqual(self.cfg.nodes[else_body].label, 'x += 4') self.assertEqual(self.cfg.nodes[next_node].label, 'x += 5') - - self.assertInCfg([(test, entry), (eliftest, test), (body_1, test), (body_2, body_1), (next_node, body_2), (else_body, eliftest), (elif_body, eliftest), (next_node, elif_body), (next_node, else_body), (exit_node, next_node)]) + self.assertInCfg([ + (test, entry), + (eliftest, test), + (body_1, test), + (body_2, body_1), + (next_node, body_2), + (else_body, eliftest), + (elif_body, eliftest), + (next_node, elif_body), + (next_node, else_body), + (exit_node, next_node) + ]) def test_single_if(self): self.cfg_create_from_file('examples/example_inputs/if.py') @@ -309,7 +333,13 @@ def test_single_if(self): test_node = 1 body_node = 2 exit_node = 3 - self.assertInCfg([(test_node,start_node), (body_node,test_node), (exit_node,test_node), (exit_node,body_node)]) + + self.assertInCfg([ + (test_node, start_node), + (body_node, test_node), + (exit_node, test_node), + (exit_node, body_node) + ]) def test_single_if_else(self): self.cfg_create_from_file('examples/example_inputs/if_else.py') @@ -321,7 +351,14 @@ def test_single_if_else(self): body_node = 2 else_body = 3 exit_node = 4 - self.assertInCfg([(test_node,start_node), (body_node,test_node), (else_body,test_node), (exit_node,else_body), (exit_node,body_node)]) + + self.assertInCfg([ + (test_node, start_node), + (body_node, test_node), + (else_body, test_node), + (exit_node, else_body), + (exit_node, body_node) + ]) def test_multiple_if_else(self): self.cfg_create_from_file('examples/example_inputs/multiple_if_else.py') @@ -409,7 +446,6 @@ def test_nested_if_else_elif(self): (_exit, elif_body) ]) - def test_if_line_numbers(self): self.cfg_create_from_file('examples/example_inputs/if_complete.py') @@ -443,10 +479,15 @@ def test_if_not(self): body = 2 _exit = 3 - self.assertInCfg([(1, 0), (2, 1), (3, 2), (3, 1)]) + self.assertInCfg([ + (_if, entry), + (body, _if), + (_exit, body), + (_exit, _if) + ]) -class CFGWhileTest(BaseTestCase): +class CFGWhileTest(CFGBaseTestCase): def test_while_complete(self): self.cfg_create_from_file('examples/example_inputs/while_complete.py') @@ -464,7 +505,16 @@ def test_while_complete(self): self.assertEqual(self.cfg.nodes[test].label, 'while x > 0:') - self.assertInCfg([(test, entry), (body_1, test), (else_body_1, test), ( body_2, body_1), (test, body_2), (else_body_2, else_body_1), (next_node, else_body_2), (exit_node, next_node)]) + self.assertInCfg([ + (test, entry), + (body_1, test), + (else_body_1, test), + (body_2, body_1), + (test, body_2), + (else_body_2, else_body_1), + (next_node, else_body_2), + (exit_node, next_node) + ]) def test_while_no_orelse(self): self.cfg_create_from_file('examples/example_inputs/while_no_orelse.py') @@ -478,7 +528,14 @@ def test_while_no_orelse(self): next_node = 4 exit_node = 5 - self.assertInCfg([(test, entry), (body_1, test), ( next_node, test), (body_2, body_1), (test, body_2), (exit_node, next_node)]) + self.assertInCfg([ + (test, entry), + (body_1, test), + (next_node, test), + (body_2, body_1), + (test, body_2), + (exit_node, next_node) + ]) def test_while_line_numbers(self): self.cfg_create_from_file('examples/example_inputs/while_complete.py') @@ -501,7 +558,7 @@ def test_while_line_numbers(self): self.assertLineNumber(next_stmt, 7) -class CFGAssignmentMultiTest(BaseTestCase): +class CFGAssignmentMultiTest(CFGBaseTestCase): def test_assignment_multi_target(self): self.cfg_create_from_file('examples/example_inputs/assignment_two_targets.py') @@ -509,7 +566,7 @@ def test_assignment_multi_target(self): start_node = 0 node = 1 node_2 = 2 - exit_node =3 + exit_node = 3 self.assertInCfg([(node, start_node), (node_2, node), (exit_node, node_2)]) @@ -520,15 +577,22 @@ def test_assignment_multi_target_call(self): self.cfg_create_from_file('examples/example_inputs/assignment_multiple_assign_call.py') self.assert_length(self.cfg.nodes, expected_length=6) - start_node = self.cfg.nodes[0] + + # start_node = self.cfg.nodes[0] assignment_to_call1 = self.cfg.nodes[1] assignment_to_x = self.cfg.nodes[2] assignment_to_call2 = self.cfg.nodes[3] assignment_to_y = self.cfg.nodes[4] - exit_node = self.cfg.nodes[5] + # exit_node = self.cfg.nodes[5] # This assert means N should be connected to N+1 - self.assertInCfg([(1,0),(2,1),(3,2),(4,3),(5,4)]) + self.assertInCfg([ + (1, 0), + (2, 1), + (3, 2), + (4, 3), + (5, 4) + ]) self.assertEqual(assignment_to_call1.label, '~call_1 = ret_int(5)') self.assertEqual(assignment_to_x.label, 'x = ~call_1') @@ -570,10 +634,10 @@ def test_multiple_assignment(self): self.assert_length(self.cfg.nodes, expected_length=4) - start_node = self.cfg.nodes[0] + # start_node = self.cfg.nodes[0] assign_y = self.cfg.nodes[1] assign_x = self.cfg.nodes[2] - exit_node = self.cfg.nodes[-1] + # exit_node = self.cfg.nodes[-1] self.assertEqual(assign_x.label, 'x = 5') self.assertEqual(assign_y.label, 'y = 5') @@ -605,7 +669,7 @@ def test_assignment_tuple_value(self): self.assertEqual(self.cfg.nodes[node].label, 'a = (x, y)') -class CFGComprehensionTest(BaseTestCase): +class CFGComprehensionTest(CFGBaseTestCase): def test_nodes(self): self.cfg_create_from_file('examples/example_inputs/comprehensions.py') @@ -652,7 +716,8 @@ def test_dict_comprehension_multi(self): self.assertEqual(listcomp.label, 'dd = {x + y : y for x in [1, 2, 3] for y in [4, 5, 6]}') -class CFGFunctionNodeTest(BaseTestCase): + +class CFGFunctionNodeTest(CFGBaseTestCase): def connected(self, node, successor): return (successor, node) @@ -1007,9 +1072,6 @@ def test_multiple_blackbox_calls_in_user_defined_call_after_if(self): (_exit, ret_send_file) ]) - - - def test_multiple_user_defined_calls_in_blackbox_call_after_if(self): path = 'examples/vulnerable_code/multiple_user_defined_calls_in_blackbox_call_after_if.py' self.cfg_create_from_file(path) @@ -1105,9 +1167,7 @@ def test_call_on_call(self): self.cfg_create_from_file(path) - - -class CFGCallWithAttributeTest(BaseTestCase): +class CFGCallWithAttributeTest(CFGBaseTestCase): def setUp(self): self.cfg_create_from_file('examples/example_inputs/call_with_attribute.py') @@ -1126,7 +1186,8 @@ def test_call_with_attribute_line_numbers(self): self.assertLineNumber(call, 5) -class CFGBreak(BaseTestCase): + +class CFGBreak(CFGBaseTestCase): """Break in while and for and other places""" def test_break(self): self.cfg_create_from_file('examples/example_inputs/while_break.py') @@ -1154,7 +1215,7 @@ def test_break(self): (_exit, print_next)]) -class CFGNameConstant(BaseTestCase): +class CFGNameConstant(CFGBaseTestCase): def setUp(self): self.cfg_create_from_file('examples/example_inputs/name_constant.py') @@ -1172,13 +1233,12 @@ def test_name_constant_if(self): self.assertEqual(expected_label, actual_label) -class CFGName(BaseTestCase): +class CFGName(CFGBaseTestCase): """Test is Name nodes are properly handled in different contexts""" def test_name_if(self): self.cfg_create_from_file('examples/example_inputs/name_if.py') - self.assert_length(self.cfg.nodes, expected_length=5) self.assertEqual(self.cfg.nodes[2].label, 'if x:') diff --git a/tests/import_test.py b/tests/cfg/import_test.py similarity index 98% rename from tests/import_test.py rename to tests/cfg/import_test.py index 94d829b..842fafa 100644 --- a/tests/import_test.py +++ b/tests/cfg/import_test.py @@ -1,9 +1,13 @@ import ast import os -from .base_test_case import BaseTestCase -from pyt.ast_helper import get_call_names_as_string -from pyt.project_handler import get_directory_modules, get_modules_and_packages +from ..base_test_case import BaseTestCase + +from pyt.core.ast_helper import get_call_names_as_string +from pyt.core.project_handler import ( + get_directory_modules, + get_modules_and_packages +) class ImportTest(BaseTestCase): @@ -184,17 +188,18 @@ def test_from_directory(self): self.cfg_create_from_file(file_path, project_modules, local_modules) - - EXPECTED = ["Entry module", - "Module Entry bar", - "Module Exit bar", - "temp_1_s = 'hey'", - "s = temp_1_s", - "Function Entry bar.H", - "ret_bar.H = s + 'end'", - "Exit bar.H", - "~call_1 = ret_bar.H", - "Exit module"] + EXPECTED = [ + "Entry module", + "Module Entry bar", + "Module Exit bar", + "temp_1_s = 'hey'", + "s = temp_1_s", + "Function Entry bar.H", + "ret_bar.H = s + 'end'", + "Exit bar.H", + "~call_1 = ret_bar.H", + "Exit module" + ] for node, expected_label in zip(self.cfg.nodes, EXPECTED): self.assertEqual(node.label, expected_label) @@ -313,7 +318,6 @@ def test_from_dot(self): 'c = ~call_1', 'Exit module'] - for node, expected_label in zip(self.cfg.nodes, EXPECTED): self.assertEqual(node.label, expected_label) @@ -338,7 +342,6 @@ def test_from_dot_dot(self): 'c = ~call_1', 'Exit module'] - for node, expected_label in zip(self.cfg.nodes, EXPECTED): self.assertEqual(node.label, expected_label) @@ -446,7 +449,6 @@ def test_multiple_functions_with_aliases(self): "c = ~call_3", "Exit module"] - for node, expected_label in zip(self.cfg.nodes, EXPECTED): self.assertEqual(node.label, expected_label) diff --git a/tests/nested_functions_test.py b/tests/cfg/nested_functions_test.py similarity index 91% rename from tests/nested_functions_test.py rename to tests/cfg/nested_functions_test.py index 7147bba..7a5f577 100644 --- a/tests/nested_functions_test.py +++ b/tests/cfg/nested_functions_test.py @@ -1,7 +1,11 @@ import os.path -from .base_test_case import BaseTestCase -from pyt.project_handler import get_directory_modules, get_modules_and_packages +from ..base_test_case import BaseTestCase + +from pyt.core.project_handler import ( + get_directory_modules, + get_modules_and_packages +) class NestedTest(BaseTestCase): diff --git a/tests/command_line_test.py b/tests/command_line_test.py deleted file mode 100644 index 25d06a7..0000000 --- a/tests/command_line_test.py +++ /dev/null @@ -1,35 +0,0 @@ -"""This just tests __main__.py""" -import sys -from contextlib import contextmanager -from io import StringIO - -from .base_test_case import BaseTestCase -from pyt.__main__ import parse_args - -@contextmanager -def capture_sys_output(): - capture_out, capture_err = StringIO(), StringIO() - current_out, current_err = sys.stdout, sys.stderr - try: - sys.stdout, sys.stderr = capture_out, capture_err - yield capture_out, capture_err - finally: - sys.stdout, sys.stderr = current_out, current_err - -class CommandLineTest(BaseTestCase): - def test_no_args(self): - with self.assertRaises(SystemExit): - with capture_sys_output() as (_, stderr): - parse_args([]) - - EXPECTED = """usage: python -m pyt [-h] (-f FILEPATH | -gr GIT_REPOS) [-pr PROJECT_ROOT] - [-d] [-o OUTPUT_FILENAME] [-csv CSV_PATH] - [-p | -vp | -trim | -i] [-t TRIGGER_WORD_FILE] - [-m BLACKBOX_MAPPING_FILE] [-py2] [-l LOG_LEVEL] - [-a ADAPTOR] [-db] [-dl DRAW_LATTICE [DRAW_LATTICE ...]] - [-j] [-li | -re | -rt] [-ppm] [-b BASELINE] - [--ignore-nosec] - {save,github_search} ...\n""" + \ - "python -m pyt: error: one of the arguments " + \ - "-f/--filepath -gr/--git-repos is required\n" - self.assertEqual(stderr.getvalue(), EXPECTED) diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/project_handler_test.py b/tests/core/project_handler_test.py similarity index 84% rename from tests/project_handler_test.py rename to tests/core/project_handler_test.py index a5105dc..b6657cd 100644 --- a/tests/project_handler_test.py +++ b/tests/core/project_handler_test.py @@ -1,13 +1,13 @@ import os import unittest -from pprint import pprint -from pyt.project_handler import ( +from pyt.core.project_handler import ( get_modules, get_modules_and_packages, is_python_file ) + class ProjectHandlerTest(unittest.TestCase): """Tests for the project handler.""" @@ -28,12 +28,12 @@ def test_get_modules(self): modules = get_modules(project_folder) app_path = os.path.join(project_folder, 'app.py') - utils_path = os.path.join(project_folder,'utils.py') + utils_path = os.path.join(project_folder, 'utils.py') exceptions_path = os.path.join(project_folder, 'exceptions.py') some_path = os.path.join(project_folder, folder, 'some.py') indhold_path = os.path.join(project_folder, folder, directory, 'indhold.py') - relative_folder_name = '.' + folder + # relative_folder_name = '.' + folder app_name = project_namespace + '.' + 'app' utils_name = project_namespace + '.' + 'utils' exceptions_name = project_namespace + '.' + 'exceptions' @@ -66,8 +66,8 @@ def test_get_modules_and_packages(self): folder_path = os.path.join(project_folder, folder) app_path = os.path.join(project_folder, 'app.py') exceptions_path = os.path.join(project_folder, 'exceptions.py') - utils_path = os.path.join(project_folder,'utils.py') - directory_path = os.path.join(project_folder, folder, directory) + utils_path = os.path.join(project_folder, 'utils.py') + # directory_path = os.path.join(project_folder, folder, directory) some_path = os.path.join(project_folder, folder, 'some.py') indhold_path = os.path.join(project_folder, folder, directory, 'indhold.py') @@ -75,15 +75,23 @@ def test_get_modules_and_packages(self): app_name = project_namespace + '.' + 'app' exceptions_name = project_namespace + '.' + 'exceptions' utils_name = project_namespace + '.' + 'utils' - relative_directory_name = '.' + folder + '.' + directory + # relative_directory_name = '.' + folder + '.' + directory some_name = project_namespace + '.' + folder + '.some' indhold_name = project_namespace + '.' + folder + '.' + directory + '.indhold' - folder_tuple = (relative_folder_name[1:], folder_path, relative_folder_name) + folder_tuple = ( + relative_folder_name[1:], + folder_path, + relative_folder_name + ) app_tuple = (app_name, app_path) exceptions_tuple = (exceptions_name, exceptions_path) utils_tuple = (utils_name, utils_path) - directory_tuple = (relative_directory_name[1:], directory_path, relative_directory_name) + # directory_tuple = ( + # relative_directory_name[1:], + # directory_path, + # relative_directory_name + # ) some_tuple = (some_name, some_path) indhold_tuple = (indhold_name, indhold_path) diff --git a/tests/github_search_test.py b/tests/github_search_test.py deleted file mode 100644 index f5797f3..0000000 --- a/tests/github_search_test.py +++ /dev/null @@ -1,11 +0,0 @@ -import unittest -from datetime import date - -from pyt.github_search import get_dates - - -class GetDatesTest(unittest.TestCase): - def test_range_shorter_than_interval(self): - date_range = get_dates(date(2016,12,12), date(2016,12,13), 7) - - diff --git a/tests/helper_visitors/__init__.py b/tests/helper_visitors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/label_visitor_test.py b/tests/helper_visitors/label_visitor_test.py similarity index 61% rename from tests/label_visitor_test.py rename to tests/helper_visitors/label_visitor_test.py index 6ee44bb..0f2d2f7 100644 --- a/tests/label_visitor_test.py +++ b/tests/helper_visitors/label_visitor_test.py @@ -1,7 +1,7 @@ import ast import unittest -from pyt.label_visitor import LabelVisitor +from pyt.helper_visitors import LabelVisitor class LabelVisitorTestCase(unittest.TestCase): @@ -14,64 +14,60 @@ def perform_labeling_on_expression(self, expr): return label + class LabelVisitorTest(LabelVisitorTestCase): def test_assign(self): label = self.perform_labeling_on_expression('a = 1') - self.assertEqual(label.result,'a = 1') + self.assertEqual(label.result, 'a = 1') def test_augassign(self): label = self.perform_labeling_on_expression('a +=2') - self.assertEqual(label.result,'a += 2') + self.assertEqual(label.result, 'a += 2') def test_compare_simple(self): label = self.perform_labeling_on_expression('a > b') - self.assertEqual(label.result,'a > b') + self.assertEqual(label.result, 'a > b') def test_compare_multi(self): label = self.perform_labeling_on_expression('a > b > c') - self.assertEqual(label.result,'a > b > c') + self.assertEqual(label.result, 'a > b > c') def test_binop(self): label = self.perform_labeling_on_expression('a / b') - self.assertEqual(label.result,'a / b') + self.assertEqual(label.result, 'a / b') def test_call_no_arg(self): label = self.perform_labeling_on_expression('range()') - self.assertEqual(label.result,'range()') - + self.assertEqual(label.result, 'range()') def test_call_single_arg(self): label = self.perform_labeling_on_expression('range(5)') - self.assertEqual(label.result,'range(5)') + self.assertEqual(label.result, 'range(5)') def test_call_multi_arg(self): - label = self.perform_labeling_on_expression('range(1,5)') - self.assertEqual(label.result,'range(1, 5)') + label = self.perform_labeling_on_expression('range(1, 5)') + self.assertEqual(label.result, 'range(1, 5)') def test_tuple_one_element(self): label = self.perform_labeling_on_expression('(1)') - self.assertEqual(label.result,'1') + self.assertEqual(label.result, '1') def test_tuple_two_elements(self): - label = self.perform_labeling_on_expression('(1,2)') - self.assertEqual(label.result,'(1, 2)') + label = self.perform_labeling_on_expression('(1, 2)') + self.assertEqual(label.result, '(1, 2)') def test_empty_tuple(self): label = self.perform_labeling_on_expression('()') - self.assertEqual(label.result,'()') + self.assertEqual(label.result, '()') def test_empty_list(self): label = self.perform_labeling_on_expression('[]') - self.assertEqual(label.result,'[]') + self.assertEqual(label.result, '[]') def test_list_one_element(self): label = self.perform_labeling_on_expression('[1]') - self.assertEqual(label.result,'[1]') + self.assertEqual(label.result, '[1]') def test_list_two_elements(self): - label = self.perform_labeling_on_expression('[1,2]') - self.assertEqual(label.result,'[1, 2]') - - - - + label = self.perform_labeling_on_expression('[1, 2]') + self.assertEqual(label.result, '[1, 2]') diff --git a/tests/vars_visitor_test.py b/tests/helper_visitors/vars_visitor_test.py similarity index 97% rename from tests/vars_visitor_test.py rename to tests/helper_visitors/vars_visitor_test.py index 4f1c7c1..849206b 100644 --- a/tests/vars_visitor_test.py +++ b/tests/helper_visitors/vars_visitor_test.py @@ -1,7 +1,7 @@ import ast import unittest -from pyt.vars_visitor import VarsVisitor +from pyt.helper_visitors import VarsVisitor class VarsVisitorTestCase(unittest.TestCase): diff --git a/tests/lattice_test.py b/tests/lattice_test.py deleted file mode 100644 index 5164c8a..0000000 --- a/tests/lattice_test.py +++ /dev/null @@ -1,143 +0,0 @@ -from .base_test_case import BaseTestCase -from pyt.constraint_table import constraint_table -from pyt.lattice import Lattice -from pyt.reaching_definitions_taint import ReachingDefinitionsTaintAnalysis - - -class LatticeTest(BaseTestCase): - - class AnalysisType: - @staticmethod - def get_lattice_elements(cfg_nodes): - for node in cfg_nodes: - if node.lattice_element == True: - yield node - def equality(self, value): - return self.value == value - - class Node: - def __init__(self, value, lattice_element): - self.value = value - self.lattice_element = lattice_element - def __str__(self): - return str(self.value) - - def test_generate_integer_elements(self): - one = self.Node(1, True) - two = self.Node(2, True) - three = self.Node(3, True) - a = self.Node('a', False) - b = self.Node('b', False) - c = self.Node('c', False) - cfg_nodes = [one, two, three, a, b, c] - lattice = Lattice(cfg_nodes, self.AnalysisType) - - self.assertEqual(lattice.el2bv[one], 0b1) - self.assertEqual(lattice.el2bv[two], 0b10) - self.assertEqual(lattice.el2bv[three], 0b100) - - self.assertEqual(lattice.bv2el[0], three) - self.assertEqual(lattice.bv2el[1], two) - self.assertEqual(lattice.bv2el[2], one) - - def test_join(self): - # join not used at the moment - return - - a = self.Node('x = 1', True) - b = self.Node('print(x)', False) - c = self.Node('x = 3', True) - d = self.Node('y = x', True) - - lattice = Lattice([a, c, d], self.AnalysisType) - - # Constraint results after fixpoint: - lattice.table[a] = 0b0001 - lattice.table[b] = 0b0001 - lattice.table[c] = 0b0010 - lattice.table[d] = 0b1010 - - r = lattice.join([a,c], [c]) - self.assertEqual(r, 0b11) - r = lattice.join([a, c], [d, c]) - self.assertEqual(r, 0b1011) - r = lattice.join([a], [c]) - self.assertEqual(r, 0b11) - r = lattice.join([c], [d]) - self.assertEqual(r, 0b1010) - r = lattice.join([], [a]) - self.assertEqual(r, 0b1) - r = lattice.join([a,c,d], [a,c,d]) - self.assertEqual(r, 0b1011) - r = lattice.join([d,c], []) - self.assertEqual(r, 0b1010) - - def test_meet(self): - # meet not used on lattice atm - return - - a = self.Node('x = 1', True) - b = self.Node('print(x)', False) - c = self.Node('x = 3', True) - d = self.Node('y = x', True) - - lattice = Lattice([a, c, d], self.AnalysisType) - - # Constraint results after fixpoint: - lattice.table[a] = 0b0001 - lattice.table[b] = 0b0001 - lattice.table[c] = 0b0010 - lattice.table[d] = 0b1010 - - r = lattice.meet([a,c], [c,d]) - self.assertEqual(r, 0b10) - r = lattice.meet([a], [d]) - self.assertEqual(r, 0b0) - r = lattice.meet([a,c,d], [a,c]) - self.assertEqual(r, 0b011) - r = lattice.meet([c,d], [a,d]) - self.assertEqual(r, 0b1010) - r = lattice.meet([], []) - self.assertEqual(r, 0b0) - r = lattice.meet([a], []) - self.assertEqual(r, 0b0) - - def test_in_constraint(self): - a = self.Node('x = 1', True) - b = self.Node('print(x)', False) - c = self.Node('x = 3', True) - d = self.Node('y = x', True) - - lattice = Lattice([a, c, d], self.AnalysisType) - - constraint_table[a] = 0b001 - constraint_table[b] = 0b001 - constraint_table[c] = 0b010 - constraint_table[d] = 0b110 - - self.assertEqual(lattice.in_constraint(a, b), True) - self.assertEqual(lattice.in_constraint(a, a), True) - self.assertEqual(lattice.in_constraint(a, d), False) - self.assertEqual(lattice.in_constraint(a, c), False) - self.assertEqual(lattice.in_constraint(c, d), True) - self.assertEqual(lattice.in_constraint(d, d), True) - self.assertEqual(lattice.in_constraint(c, c), True) - self.assertEqual(lattice.in_constraint(c, a), False) - self.assertEqual(lattice.in_constraint(c, b), False) - - def test_get_elements(self): - a = self.Node('x = 1', True) - b = self.Node('print(x)', False) - c = self.Node('x = 3', True) - d = self.Node('y = x', True) - - lattice = Lattice([a, c, d], self.AnalysisType) - - self.assertEqual(set(lattice.get_elements(0b111)), {a,c,d}) - self.assertEqual(set(lattice.get_elements(0b0)), set()) - self.assertEqual(set(lattice.get_elements(0b1)), {a}) - self.assertEqual(set(lattice.get_elements(0b10)), {c}) - self.assertEqual(set(lattice.get_elements(0b100)), {d}) - self.assertEqual(set(lattice.get_elements(0b11)), {a,c}) - self.assertEqual(set(lattice.get_elements(0b101)), {a,d}) - self.assertEqual(set(lattice.get_elements(0b110)), {c,d}) diff --git a/tests/liveness_test.py b/tests/liveness_test.py deleted file mode 100644 index da6f08a..0000000 --- a/tests/liveness_test.py +++ /dev/null @@ -1,33 +0,0 @@ -from .analysis_base_test_case import AnalysisBaseTestCase -from pyt.constraint_table import constraint_table -from pyt.liveness import LivenessAnalysis - - -class LivenessTest(AnalysisBaseTestCase): - def test_example(self): - lattice = self.run_analysis('examples/example_inputs/example.py', LivenessAnalysis) - - x = 0b1 # 1 - y = 0b10 # 2 - z = 0b100 # 4 - - lattice.el2bv['x'] = x - lattice.el2bv['y'] = y - lattice.el2bv['z'] = z - - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[0]]), []) - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[1]]), []) - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[2]]), []) - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[3]]), ['x']) - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[4]]), ['x']) - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[5]]), ['x']) - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[6]]), ['x']) - self.assertEqual(set(lattice.get_elements(constraint_table[self.cfg.nodes[7]])), set(['x','y'])) - self.assertEqual(set(lattice.get_elements(constraint_table[self.cfg.nodes[8]])), set(['x','y'])) - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[9]]), ['x']) - self.assertEqual(set(lattice.get_elements(constraint_table[self.cfg.nodes[10]])), set(['x','z'])) - self.assertEqual(set(lattice.get_elements(constraint_table[self.cfg.nodes[11]])), set(['x','z'])) - self.assertEqual(set(lattice.get_elements(constraint_table[self.cfg.nodes[12]])), set(['x','z'])) - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[13]]), ['x']) - self.assertEqual(lattice.get_elements(constraint_table[self.cfg.nodes[14]]), []) - self.assertEqual(len(lattice.el2bv), 3) diff --git a/tests/main_test.py b/tests/main_test.py new file mode 100644 index 0000000..eea6ff4 --- /dev/null +++ b/tests/main_test.py @@ -0,0 +1,60 @@ +import mock + +from .base_test_case import BaseTestCase +from pyt.__main__ import main + + +class MainTest(BaseTestCase): + @mock.patch('pyt.__main__.parse_args') + @mock.patch('pyt.__main__.find_vulnerabilities') + @mock.patch('pyt.__main__.text') + def test_text_output(self, mock_text, mock_find_vulnerabilities, mock_parse_args): + mock_find_vulnerabilities.return_value = 'stuff' + example_file = 'examples/vulnerable_code/inter_command_injection.py' + output_file = 'mocked_outfile' + + mock_parse_args.return_value = mock.Mock( + autospec=True, + filepath=example_file, + project_root=None, + baseline=None, + json=None, + output_file=output_file + ) + main([ + 'parse_args is mocked' + ]) + assert mock_text.report.call_count == 1 + # This with: makes no sense + with self.assertRaises(AssertionError): + assert mock_text.report.assert_called_with( + mock_find_vulnerabilities.return_value, + mock_parse_args.return_value.output_file + ) + + @mock.patch('pyt.__main__.parse_args') + @mock.patch('pyt.__main__.find_vulnerabilities') + @mock.patch('pyt.__main__.json') + def test_json_output(self, mock_json, mock_find_vulnerabilities, mock_parse_args): + mock_find_vulnerabilities.return_value = 'stuff' + example_file = 'examples/vulnerable_code/inter_command_injection.py' + output_file = 'mocked_outfile' + + mock_parse_args.return_value = mock.Mock( + autospec=True, + filepath=example_file, + project_root=None, + baseline=None, + json=True, + output_file=output_file + ) + main([ + 'parse_args is mocked' + ]) + assert mock_json.report.call_count == 1 + # This with: makes no sense + with self.assertRaises(AssertionError): + assert mock_json.report.assert_called_with( + mock_find_vulnerabilities.return_value, + mock_parse_args.return_value.output_file + ) diff --git a/tests/reaching_definitions_taint_test.py b/tests/reaching_definitions_taint_test.py deleted file mode 100644 index dc18dcd..0000000 --- a/tests/reaching_definitions_taint_test.py +++ /dev/null @@ -1,111 +0,0 @@ -from .analysis_base_test_case import AnalysisBaseTestCase -from pyt.constraint_table import constraint_table -from pyt.reaching_definitions_taint import ReachingDefinitionsTaintAnalysis - - -class ReachingDefinitionsTaintTest(AnalysisBaseTestCase): - # Note: the numbers in the test represent the line numbers of the assignments in the program. - def test_linear_program(self): - constraint_table = {} - lattice = self.run_analysis('examples/example_inputs/linear.py', ReachingDefinitionsTaintAnalysis) - - EXPECTED = [ - "Label: Entry module:", - "Label: ~call_1 = ret_input(): Label: ~call_1 = ret_input()", - "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: y = x - 1: Label: y = x - 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: ~call_2 = ret_print(x): Label: ~call_2 = ret_print(x), Label: y = x - 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: Exit module: Label: ~call_2 = ret_print(x), Label: y = x - 1, Label: x = ~call_1, Label: ~call_1 = ret_input()" - ] - i = 0 - for k, v in constraint_table.items(): - row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) - self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) - i = i + 1 - - - def test_if_program(self): - constraint_table = {} - lattice = self.run_analysis('examples/example_inputs/if_program.py', ReachingDefinitionsTaintAnalysis) - - EXPECTED = [ - "Label: Entry module:", - "Label: ~call_1 = ret_input(): Label: ~call_1 = ret_input()", - "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: if x > 0:: Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: y = x + 1: Label: y = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: ~call_2 = ret_print(x): Label: ~call_2 = ret_print(x), Label: y = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: Exit module: Label: ~call_2 = ret_print(x), Label: y = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_input()" - ] - i = 0 - for k, v in constraint_table.items(): - row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) - self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) - i = i + 1 - - def test_example(self): - constraint_table = {} - lattice = self.run_analysis('examples/example_inputs/example.py', ReachingDefinitionsTaintAnalysis) - - EXPECTED = [ - "Label: Entry module:", - "Label: ~call_1 = ret_input(): Label: ~call_1 = ret_input()", - "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: ~call_2 = ret_int(x): Label: ~call_2 = ret_int(x), Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: x = ~call_2: Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: while x > 1:: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: y = x / 2: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: if y > 3:: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: x = x - y: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: z = x - 4: Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: if z > 0:: Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: x = x / 2: Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: z = z - 1: Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: ~call_3 = ret_print(x): Label: ~call_3 = ret_print(x), Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: Exit module: Label: ~call_3 = ret_print(x), Label: z = z - 1, Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()" - ] - i = 0 - for k, v in constraint_table.items(): - row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) - self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) - i = i + 1 - - def test_func_with_params(self): - lattice = self.run_analysis('examples/example_inputs/function_with_params.py', ReachingDefinitionsTaintAnalysis) - - self.assertInCfg([(1,1), - (1,2), (2,2), - (1,3), (2,3), (3,3), - (1,4), (2,4), (3,4), (4,4), - (1,5), (2,5), (3,5), (4,5), - *self.constraints([1,2,3,4,6], 6), - *self.constraints([1,2,3,4,6,7], 7), - *self.constraints([1,2,3,4,6,7], 8), - *self.constraints([2,3,4,6,7,9], 9), - *self.constraints([2,3,4,6,7,9], 10)], lattice) - - def test_while(self): - constraint_table = {} - lattice = self.run_analysis('examples/example_inputs/while.py', ReachingDefinitionsTaintAnalysis) - - EXPECTED = [ - "Label: Entry module: ", - "Label: ~call_2 = ret_input(): Label: ~call_2 = ret_input()", - "Label: ~call_1 = ret_int(~call_2): Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", - "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", - "Label: while x < 10:: Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input(", - "Label: x = x + 1: Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", - "Label: if x == 5:: Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", - "Label: BreakNode: Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", - "Label: x = 6: Label: x = 6, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", - "Label: ~call_3 = ret_print(x): Label: ~call_3 = ret_print(x), Label: x = 6, Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()", - "Label: Exit module: Label: ~call_3 = ret_print(x), Label: x = 6, Label: x = x + 1, Label: x = ~call_1, Label: ~call_1 = ret_int(~call_2), Label: ~call_2 = ret_input()" - ] - i = 0 - for k, v in constraint_table.items(): - row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) - self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) - i = i + 1 - - def test_join(self): - pass diff --git a/tests/reaching_definitions_test.py b/tests/reaching_definitions_test.py deleted file mode 100644 index 7552489..0000000 --- a/tests/reaching_definitions_test.py +++ /dev/null @@ -1,50 +0,0 @@ -from .analysis_base_test_case import AnalysisBaseTestCase -from pyt.constraint_table import constraint_table -from pyt.reaching_definitions import ReachingDefinitionsAnalysis - - -class ReachingDefinitionsTest(AnalysisBaseTestCase): - def test_linear_program(self): - constraint_table = {} - lattice = self.run_analysis('examples/example_inputs/linear.py', ReachingDefinitionsAnalysis) - - EXPECTED = [ - "Label: Entry module: ", - "Label: ~call_1 = ret_input(): Label: ~call_1 = ret_input()", - "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: y = x - 1: Label: y = x - 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: ~call_2 = ret_print(x): Label: ~call_2 = ret_print(x), Label: y = x - 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: Exit module: Label: ~call_2 = ret_print(x), Label: y = x - 1, Label: x = ~call_1, Label: ~call_1 = ret_input()", - ] - i = 0 - for k, v in constraint_table.items(): - row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) - self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) - i = i + 1 - - def test_example(self): - constraint_table = {} - lattice = self.run_analysis('examples/example_inputs/example.py', ReachingDefinitionsAnalysis) - - EXPECTED = [ - "Label: Entry module: ", - "Label: ~call_1 = ret_input(): Label: ~call_1 = ret_input()", - "Label: x = ~call_1: Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: ~call_2 = ret_int(x): Label: ~call_2 = ret_int(x), Label: x = ~call_1, Label: ~call_1 = ret_input()", - "Label: x = ~call_2: Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: while x > 1:: Label: z = z - 1, Label: x = x / 2, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: y = x / 2: Label: z = z - 1, Label: x = x / 2, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: if y > 3:: Label: z = z - 1, Label: x = x / 2, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: x = x - y: Label: z = z - 1, Label: x = x - y, Label: y = x / 2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: z = x - 4: Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: if z > 0:: Label: x = x / 2, Label: z = x - 4, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: x = x / 2: Label: x = x / 2, Label: z = x - 4, Label: y = x / 2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: z = z - 1: Label: z = z - 1, Label: x = x / 2, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: ~call_3 = ret_print(x): Label: ~call_3 = ret_print(x), Label: z = z - 1, Label: x = x / 2, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - "Label: Exit module: Label: ~call_3 = ret_print(x), Label: z = z - 1, Label: x = x / 2, Label: x = x - y, Label: y = x / 2, Label: x = ~call_2, Label: ~call_2 = ret_int(x), Label: ~call_1 = ret_input()", - ] - i = 0 - for k, v in constraint_table.items(): - row = str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]) - self.assertTrue(self.string_compare_alnum(row, EXPECTED[i])) - i = i + 1 diff --git a/tests/usage_test.py b/tests/usage_test.py new file mode 100644 index 0000000..cae390e --- /dev/null +++ b/tests/usage_test.py @@ -0,0 +1,95 @@ +import sys +from contextlib import contextmanager +from io import StringIO + +from .base_test_case import BaseTestCase +from pyt.usage import parse_args + + +@contextmanager +def capture_sys_output(): + capture_out, capture_err = StringIO(), StringIO() + current_out, current_err = sys.stdout, sys.stderr + try: + sys.stdout, sys.stderr = capture_out, capture_err + yield capture_out, capture_err + finally: + sys.stdout, sys.stderr = current_out, current_err + + +class UsageTest(BaseTestCase): + def test_no_args(self): + with self.assertRaises(SystemExit): + with capture_sys_output() as (stdout, _): + parse_args([]) + + self.maxDiff = None + + EXPECTED = """usage: python -m pyt [-h] [-f FILEPATH] [-a ADAPTOR] [-pr PROJECT_ROOT] + [-b BASELINE_JSON_FILE] [-j] [-m BLACKBOX_MAPPING_FILE] + [-t TRIGGER_WORD_FILE] [-o OUTPUT_FILE] [--ignore-nosec] + [-trim] [-i] + +required arguments: + -f FILEPATH, --filepath FILEPATH + Path to the file that should be analysed. + +optional arguments: + -a ADAPTOR, --adaptor ADAPTOR + Choose a web framework adaptor: Flask(Default), + Django, Every or Pylons + -pr PROJECT_ROOT, --project-root PROJECT_ROOT + Add project root, only important when the entry file + is not at the root of the project. + -b BASELINE_JSON_FILE, --baseline BASELINE_JSON_FILE + Path of a baseline report to compare against (only + JSON-formatted files are accepted) + -j, --json Prints JSON instead of report. + -m BLACKBOX_MAPPING_FILE, --blackbox-mapping-file BLACKBOX_MAPPING_FILE + Input blackbox mapping file. + -t TRIGGER_WORD_FILE, --trigger-word-file TRIGGER_WORD_FILE + Input file with a list of sources and sinks + -o OUTPUT_FILE, --output OUTPUT_FILE + write report to filename + --ignore-nosec do not skip lines with # nosec comments + +print arguments: + -trim, --trim-reassigned-in + Trims the reassigned list to just the vulnerability + chain. + -i, --interactive Will ask you about each blackbox function call in + vulnerability chains.\n""" + + self.assertEqual(stdout.getvalue(), EXPECTED) + + def test_valid_args_but_no_filepath(self): + with self.assertRaises(SystemExit): + with capture_sys_output() as (_, stderr): + parse_args(['-j']) + + EXPECTED = """usage: python -m pyt [-h] [-f FILEPATH] [-a ADAPTOR] [-pr PROJECT_ROOT] + [-b BASELINE_JSON_FILE] [-j] [-m BLACKBOX_MAPPING_FILE] + [-t TRIGGER_WORD_FILE] [-o OUTPUT_FILE] [--ignore-nosec] + [-trim] [-i] +python -m pyt: error: The -f/--filepath argument is required\n""" + + self.assertEqual(stderr.getvalue(), EXPECTED) + +# def test_using_both_mutually_exclusive_args(self): +# with self.assertRaises(SystemExit): +# with capture_sys_output() as (_, stderr): +# parse_args(['-f', 'foo.py', '-trim', '--interactive']) + +# EXPECTED = """usage: python -m pyt [-h] [-f FILEPATH] [-a ADAPTOR] [-pr PROJECT_ROOT] +# [-b BASELINE_JSON_FILE] [-j] [-m BLACKBOX_MAPPING_FILE] +# [-t TRIGGER_WORD_FILE] [-o OUTPUT_FILE] [-trim] [-i] +# python -m pyt: error: argument -i/--interactive: not allowed with argument -trim/--trim-reassigned-in\n""" + +# self.assertEqual(stderr.getvalue(), EXPECTED) + + def test_normal_usage(self): + with capture_sys_output() as (stdout, stderr): + parse_args(['-f', 'foo.py']) + + self.assertEqual(stdout.getvalue(), '') + self.assertEqual(stderr.getvalue(), '') diff --git a/tests/vulnerabilities/__init__.py b/tests/vulnerabilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/vulnerabilities_across_files_test.py b/tests/vulnerabilities/vulnerabilities_across_files_test.py similarity index 86% rename from tests/vulnerabilities_across_files_test.py rename to tests/vulnerabilities/vulnerabilities_across_files_test.py index 7492aee..d8bd384 100644 --- a/tests/vulnerabilities_across_files_test.py +++ b/tests/vulnerabilities/vulnerabilities_across_files_test.py @@ -1,22 +1,28 @@ import os -from .base_test_case import BaseTestCase -from pyt.argument_helpers import ( +from .vulnerabilities_base_test_case import VulnerabilitiesBaseTestCase + +from pyt.analysis.constraint_table import initialize_constraint_table +from pyt.analysis.fixed_point import analyse +from pyt.core.project_handler import ( + get_directory_modules, + get_modules +) +from pyt.usage import ( default_blackbox_mapping_file, - default_trigger_word_file, - UImode, - VulnerabilityFiles + default_trigger_word_file +) +from pyt.vulnerabilities import ( + find_vulnerabilities, + UImode +) +from pyt.web_frameworks import ( + FrameworkAdaptor, + is_flask_route_function ) -from pyt.constraint_table import initialize_constraint_table -from pyt.fixed_point import analyse -from pyt.framework_adaptor import FrameworkAdaptor -from pyt.framework_helper import is_flask_route_function -from pyt.project_handler import get_directory_modules, get_modules -from pyt.reaching_definitions_taint import ReachingDefinitionsTaintAnalysis -from pyt.vulnerabilities import find_vulnerabilities -class EngineTest(BaseTestCase): +class EngineTest(VulnerabilitiesBaseTestCase): def run_analysis(self, path): path = os.path.normpath(path) @@ -31,16 +37,13 @@ def run_analysis(self, path): initialize_constraint_table(cfg_list) - analyse(cfg_list, analysis_type=ReachingDefinitionsTaintAnalysis) + analyse(cfg_list) return find_vulnerabilities( cfg_list, - ReachingDefinitionsTaintAnalysis, UImode.NORMAL, - VulnerabilityFiles( - default_blackbox_mapping_file, - default_trigger_word_file - ) + default_blackbox_mapping_file, + default_trigger_word_file ) def test_find_vulnerabilities_absolute_from_file_command_injection(self): @@ -62,7 +65,7 @@ def test_blackbox_library_call(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code_across_files/blackbox_library_call.py - > User input at line 12, trigger word "request.args.get(": + > User input at line 12, source "request.args.get(": ~call_1 = ret_request.args.get('suggestion') Reassigned in: File: examples/vulnerable_code_across_files/blackbox_library_call.py @@ -74,7 +77,7 @@ def test_blackbox_library_call(self): File: examples/vulnerable_code_across_files/blackbox_library_call.py > Line 16: hey = command File: examples/vulnerable_code_across_files/blackbox_library_call.py - > reaches line 17, trigger word "subprocess.call(": + > reaches line 17, sink "subprocess.call(": ~call_3 = ret_subprocess.call(hey, shell=True) This vulnerability is unknown due to: Label: ~call_2 = ret_scrypt.encrypt('echo ' + param + ' >> ' + 'menu.txt', 'password') """ @@ -87,7 +90,7 @@ def test_builtin_with_user_defined_inner(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/nested_functions_code/builtin_with_user_defined_inner.py - > User input at line 16, trigger word "form[": + > User input at line 16, source "form[": req_param = request.form['suggestion'] Reassigned in: File: examples/nested_functions_code/builtin_with_user_defined_inner.py @@ -109,7 +112,7 @@ def test_builtin_with_user_defined_inner(self): File: examples/nested_functions_code/builtin_with_user_defined_inner.py > Line 19: foo = ~call_1 File: examples/nested_functions_code/builtin_with_user_defined_inner.py - > reaches line 20, trigger word "subprocess.call(": + > reaches line 20, sink "subprocess.call(": ~call_3 = ret_subprocess.call(foo, shell=True) This vulnerability is unknown due to: Label: ~call_1 = ret_scrypt.encrypt(~call_2) """ @@ -121,7 +124,7 @@ def test_sink_with_result_of_blackbox_nested(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/nested_functions_code/sink_with_result_of_blackbox_nested.py - > User input at line 12, trigger word "form[": + > User input at line 12, source "form[": req_param = request.form['suggestion'] Reassigned in: File: examples/nested_functions_code/sink_with_result_of_blackbox_nested.py @@ -131,13 +134,13 @@ def test_sink_with_result_of_blackbox_nested(self): File: examples/nested_functions_code/sink_with_result_of_blackbox_nested.py > Line 13: result = ~call_1 File: examples/nested_functions_code/sink_with_result_of_blackbox_nested.py - > reaches line 14, trigger word "subprocess.call(": + > reaches line 14, sink "subprocess.call(": ~call_3 = ret_subprocess.call(result, shell=True) This vulnerability is unknown due to: Label: ~call_2 = ret_scrypt.encrypt(req_param) """ OTHER_EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/nested_functions_code/sink_with_result_of_blackbox_nested.py - > User input at line 12, trigger word "form[": + > User input at line 12, source "form[": req_param = request.form['suggestion'] Reassigned in: File: examples/nested_functions_code/sink_with_result_of_blackbox_nested.py @@ -147,13 +150,14 @@ def test_sink_with_result_of_blackbox_nested(self): File: examples/nested_functions_code/sink_with_result_of_blackbox_nested.py > Line 13: result = ~call_1 File: examples/nested_functions_code/sink_with_result_of_blackbox_nested.py - > reaches line 14, trigger word "subprocess.call(": + > reaches line 14, sink "subprocess.call(": ~call_3 = ret_subprocess.call(result, shell=True) This vulnerability is unknown due to: Label: ~call_1 = ret_scrypt.encrypt(~call_2) """ - self.assertTrue(self.string_compare_alpha(vulnerability_description, EXPECTED_VULNERABILITY_DESCRIPTION) - or - self.string_compare_alpha(vulnerability_description, OTHER_EXPECTED_VULNERABILITY_DESCRIPTION)) + self.assertTrue( + self.string_compare_alpha(vulnerability_description, EXPECTED_VULNERABILITY_DESCRIPTION) or + self.string_compare_alpha(vulnerability_description, OTHER_EXPECTED_VULNERABILITY_DESCRIPTION) + ) def test_sink_with_result_of_user_defined_nested(self): vulnerabilities = self.run_analysis('examples/nested_functions_code/sink_with_result_of_user_defined_nested.py') @@ -161,7 +165,7 @@ def test_sink_with_result_of_user_defined_nested(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/nested_functions_code/sink_with_result_of_user_defined_nested.py - > User input at line 16, trigger word "form[": + > User input at line 16, source "form[": req_param = request.form['suggestion'] Reassigned in: File: examples/nested_functions_code/sink_with_result_of_user_defined_nested.py @@ -195,7 +199,7 @@ def test_sink_with_result_of_user_defined_nested(self): File: examples/nested_functions_code/sink_with_result_of_user_defined_nested.py > Line 17: result = ~call_1 File: examples/nested_functions_code/sink_with_result_of_user_defined_nested.py - > reaches line 18, trigger word "subprocess.call(": + > reaches line 18, sink "subprocess.call(": ~call_3 = ret_subprocess.call(result, shell=True) """ self.assertTrue(self.string_compare_alpha(vulnerability_description, EXPECTED_VULNERABILITY_DESCRIPTION)) @@ -206,7 +210,7 @@ def test_sink_with_blackbox_inner(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/nested_functions_code/sink_with_blackbox_inner.py - > User input at line 12, trigger word "form[": + > User input at line 12, source "form[": req_param = request.form['suggestion'] Reassigned in: File: examples/nested_functions_code/sink_with_blackbox_inner.py @@ -214,14 +218,14 @@ def test_sink_with_blackbox_inner(self): File: examples/nested_functions_code/sink_with_blackbox_inner.py > Line 14: ~call_2 = ret_scrypt.encypt(~call_3) File: examples/nested_functions_code/sink_with_blackbox_inner.py - > reaches line 14, trigger word "subprocess.call(": + > reaches line 14, sink "subprocess.call(": ~call_1 = ret_subprocess.call(~call_2, shell=True) This vulnerability is unknown due to: Label: ~call_2 = ret_scrypt.encypt(~call_3) """ OTHER_EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/nested_functions_code/sink_with_blackbox_inner.py - > User input at line 12, trigger word "form[": + > User input at line 12, source "form[": req_param = request.form['suggestion'] Reassigned in: File: examples/nested_functions_code/sink_with_blackbox_inner.py @@ -229,13 +233,14 @@ def test_sink_with_blackbox_inner(self): File: examples/nested_functions_code/sink_with_blackbox_inner.py > Line 14: ~call_2 = ret_scrypt.encypt(~call_3) File: examples/nested_functions_code/sink_with_blackbox_inner.py - > reaches line 14, trigger word "subprocess.call(": + > reaches line 14, sink "subprocess.call(": ~call_1 = ret_subprocess.call(~call_2, shell=True) This vulnerability is unknown due to: Label: ~call_3 = ret_scrypt.encypt(req_param) """ - self.assertTrue(self.string_compare_alpha(vulnerability_description, EXPECTED_VULNERABILITY_DESCRIPTION) - or - self.string_compare_alpha(vulnerability_description, OTHER_EXPECTED_VULNERABILITY_DESCRIPTION)) + self.assertTrue( + self.string_compare_alpha(vulnerability_description, EXPECTED_VULNERABILITY_DESCRIPTION) or + self.string_compare_alpha(vulnerability_description, OTHER_EXPECTED_VULNERABILITY_DESCRIPTION) + ) def test_sink_with_user_defined_inner(self): vulnerabilities = self.run_analysis('examples/nested_functions_code/sink_with_user_defined_inner.py') @@ -243,7 +248,7 @@ def test_sink_with_user_defined_inner(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/nested_functions_code/sink_with_user_defined_inner.py - > User input at line 16, trigger word "form[": + > User input at line 16, source "form[": req_param = request.form['suggestion'] Reassigned in: File: examples/nested_functions_code/sink_with_user_defined_inner.py @@ -275,7 +280,7 @@ def test_sink_with_user_defined_inner(self): File: examples/nested_functions_code/sink_with_user_defined_inner.py > Line 18: ~call_2 = ret_outer File: examples/nested_functions_code/sink_with_user_defined_inner.py - > reaches line 18, trigger word "subprocess.call(": + > reaches line 18, sink "subprocess.call(": ~call_1 = ret_subprocess.call(~call_2, shell=True) """ self.assertTrue(self.string_compare_alpha(vulnerability_description, EXPECTED_VULNERABILITY_DESCRIPTION)) diff --git a/tests/vulnerabilities/vulnerabilities_base_test_case.py b/tests/vulnerabilities/vulnerabilities_base_test_case.py new file mode 100644 index 0000000..dcf088a --- /dev/null +++ b/tests/vulnerabilities/vulnerabilities_base_test_case.py @@ -0,0 +1,10 @@ +from ..base_test_case import BaseTestCase + + +class VulnerabilitiesBaseTestCase(BaseTestCase): + + def string_compare_alpha(self, output, expected_string): + return ( + [char for char in output if char.isalpha()] == + [char for char in expected_string if char.isalpha()] + ) diff --git a/tests/vulnerabilities_test.py b/tests/vulnerabilities/vulnerabilities_test.py similarity index 87% rename from tests/vulnerabilities_test.py rename to tests/vulnerabilities/vulnerabilities_test.py index f3d7727..5e40a60 100644 --- a/tests/vulnerabilities_test.py +++ b/tests/vulnerabilities/vulnerabilities_test.py @@ -1,39 +1,31 @@ import os -from .base_test_case import BaseTestCase +from .vulnerabilities_base_test_case import VulnerabilitiesBaseTestCase -from pyt import ( - trigger_definitions_parser, - vulnerabilities -) -from pyt.argument_helpers import ( +from pyt.analysis.constraint_table import initialize_constraint_table +from pyt.analysis.fixed_point import analyse +from pyt.core.node_types import Node +from pyt.usage import ( default_blackbox_mapping_file, - default_trigger_word_file, + default_trigger_word_file +) +from pyt.vulnerabilities import ( + find_vulnerabilities, + trigger_definitions_parser, UImode, - VulnerabilityFiles + vulnerabilities ) -from pyt.constraint_table import initialize_constraint_table -from pyt.fixed_point import analyse -from pyt.framework_adaptor import FrameworkAdaptor -from pyt.framework_helper import ( +from pyt.web_frameworks import ( + FrameworkAdaptor, is_django_view_function, is_flask_route_function, is_function ) -from pyt.node_types import Node -from pyt.reaching_definitions_taint import ReachingDefinitionsTaintAnalysis - -class EngineTest(BaseTestCase): - def run_empty(self): - return - - def get_lattice_elements(self, cfg_nodes): - """Dummy analysis method""" - return cfg_nodes +class EngineTest(VulnerabilitiesBaseTestCase): def test_parse(self): - definitions = vulnerabilities.parse( + definitions = trigger_definitions_parser.parse( trigger_word_file=os.path.join( os.getcwd(), 'pyt', @@ -93,7 +85,11 @@ def test_find_triggers(self): XSS1 = cfg_list[1] trigger_words = [('get', [])] - l = vulnerabilities.find_triggers(XSS1.nodes, trigger_words) + l = vulnerabilities.find_triggers( + XSS1.nodes, + trigger_words, + nosec_lines=set() + ) self.assert_length(l, expected_length=1) def test_find_sanitiser_nodes(self): @@ -129,16 +125,13 @@ def run_analysis(self, path): FrameworkAdaptor(cfg_list, [], [], is_flask_route_function) initialize_constraint_table(cfg_list) - analyse(cfg_list, analysis_type=ReachingDefinitionsTaintAnalysis) + analyse(cfg_list) - return vulnerabilities.find_vulnerabilities( + return find_vulnerabilities( cfg_list, - ReachingDefinitionsTaintAnalysis, UImode.NORMAL, - VulnerabilityFiles( - default_blackbox_mapping_file, - default_trigger_word_file - ) + default_blackbox_mapping_file, + default_trigger_word_file ) def test_find_vulnerabilities_assign_other_var(self): @@ -159,7 +152,7 @@ def test_XSS_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/XSS.py - > User input at line 6, trigger word "request.args.get(": + > User input at line 6, source "request.args.get(": ~call_1 = ret_request.args.get('param', 'not set') Reassigned in: File: examples/vulnerable_code/XSS.py @@ -171,7 +164,7 @@ def test_XSS_result(self): File: examples/vulnerable_code/XSS.py > Line 10: ret_XSS1 = resp File: examples/vulnerable_code/XSS.py - > reaches line 9, trigger word "replace(": + > reaches line 9, sink "replace(": ~call_4 = ret_html.replace('{{ param }}', param) """ @@ -183,13 +176,13 @@ def test_command_injection_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/command_injection.py - > User input at line 15, trigger word "form[": + > User input at line 15, source "form[": param = request.form['suggestion'] Reassigned in: File: examples/vulnerable_code/command_injection.py > Line 16: command = 'echo ' + param + ' >> ' + 'menu.txt' File: examples/vulnerable_code/command_injection.py - > reaches line 18, trigger word "subprocess.call(": + > reaches line 18, sink "subprocess.call(": ~call_1 = ret_subprocess.call(command, shell=True) """ @@ -201,7 +194,7 @@ def test_path_traversal_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/path_traversal.py - > User input at line 15, trigger word "request.args.get(": + > User input at line 15, source "request.args.get(": ~call_1 = ret_request.args.get('image_name') Reassigned in: File: examples/vulnerable_code/path_traversal.py @@ -227,7 +220,7 @@ def test_path_traversal_result(self): File: examples/vulnerable_code/path_traversal.py > Line 19: foo = ~call_2 File: examples/vulnerable_code/path_traversal.py - > reaches line 20, trigger word "send_file(": + > reaches line 20, sink "send_file(": ~call_4 = ret_send_file(foo) """ @@ -239,7 +232,7 @@ def test_ensure_saved_scope(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/ensure_saved_scope.py - > User input at line 15, trigger word "request.args.get(": + > User input at line 15, source "request.args.get(": ~call_1 = ret_request.args.get('image_name') Reassigned in: File: examples/vulnerable_code/ensure_saved_scope.py @@ -265,7 +258,7 @@ def test_ensure_saved_scope(self): File: examples/vulnerable_code/ensure_saved_scope.py > Line 19: foo = ~call_2 File: examples/vulnerable_code/ensure_saved_scope.py - > reaches line 20, trigger word "send_file(": + > reaches line 20, sink "send_file(": ~call_4 = ret_send_file(image_name) """ @@ -277,7 +270,7 @@ def test_path_traversal_sanitised_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/path_traversal_sanitised.py - > User input at line 8, trigger word "request.args.get(": + > User input at line 8, source "request.args.get(": ~call_1 = ret_request.args.get('image_name') Reassigned in: File: examples/vulnerable_code/path_traversal_sanitised.py @@ -291,7 +284,7 @@ def test_path_traversal_sanitised_result(self): File: examples/vulnerable_code/path_traversal_sanitised.py > Line 12: ret_cat_picture = ~call_3 File: examples/vulnerable_code/path_traversal_sanitised.py - > reaches line 12, trigger word "send_file(": + > reaches line 12, sink "send_file(": ~call_3 = ret_send_file(~call_4) This vulnerability is sanitised by: Label: ~call_2 = ret_image_name.replace('..', '') """ @@ -304,7 +297,7 @@ def test_path_traversal_sanitised_2_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/path_traversal_sanitised_2.py - > User input at line 8, trigger word "request.args.get(": + > User input at line 8, source "request.args.get(": ~call_1 = ret_request.args.get('image_name') Reassigned in: File: examples/vulnerable_code/path_traversal_sanitised_2.py @@ -314,7 +307,7 @@ def test_path_traversal_sanitised_2_result(self): File: examples/vulnerable_code/path_traversal_sanitised_2.py > Line 12: ret_cat_picture = ~call_2 File: examples/vulnerable_code/path_traversal_sanitised_2.py - > reaches line 12, trigger word "send_file(": + > reaches line 12, sink "send_file(": ~call_2 = ret_send_file(~call_3) This vulnerability is potentially sanitised by: Label: if '..' in image_name: """ @@ -327,7 +320,7 @@ def test_sql_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/sql/sqli.py - > User input at line 26, trigger word "request.args.get(": + > User input at line 26, source "request.args.get(": ~call_1 = ret_request.args.get('param', 'not set') Reassigned in: File: examples/vulnerable_code/sql/sqli.py @@ -335,7 +328,7 @@ def test_sql_result(self): File: examples/vulnerable_code/sql/sqli.py > Line 27: result = ~call_2 File: examples/vulnerable_code/sql/sqli.py - > reaches line 27, trigger word "execute(": + > reaches line 27, sink "execute(": ~call_2 = ret_db.engine.execute(param) """ @@ -347,7 +340,7 @@ def test_XSS_form_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/XSS_form.py - > User input at line 14, trigger word "form[": + > User input at line 14, source "form[": data = request.form['my_text'] Reassigned in: File: examples/vulnerable_code/XSS_form.py @@ -357,7 +350,7 @@ def test_XSS_form_result(self): File: examples/vulnerable_code/XSS_form.py > Line 17: ret_example2_action = resp File: examples/vulnerable_code/XSS_form.py - > reaches line 15, trigger word "replace(": + > reaches line 15, sink "replace(": ~call_2 = ret_html1.replace('{{ data }}', data) """ @@ -369,7 +362,7 @@ def test_XSS_url_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/XSS_url.py - > User input at line 4, trigger word "Framework function URL parameter": + > User input at line 4, source "Framework function URL parameter": url Reassigned in: File: examples/vulnerable_code/XSS_url.py @@ -381,7 +374,7 @@ def test_XSS_url_result(self): File: examples/vulnerable_code/XSS_url.py > Line 10: ret_XSS1 = resp File: examples/vulnerable_code/XSS_url.py - > reaches line 9, trigger word "replace(": + > reaches line 9, sink "replace(": ~call_3 = ret_html.replace('{{ param }}', param) """ @@ -397,7 +390,7 @@ def test_XSS_reassign_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/XSS_reassign.py - > User input at line 6, trigger word "request.args.get(": + > User input at line 6, source "request.args.get(": ~call_1 = ret_request.args.get('param', 'not set') Reassigned in: File: examples/vulnerable_code/XSS_reassign.py @@ -411,7 +404,7 @@ def test_XSS_reassign_result(self): File: examples/vulnerable_code/XSS_reassign.py > Line 12: ret_XSS1 = resp File: examples/vulnerable_code/XSS_reassign.py - > reaches line 11, trigger word "replace(": + > reaches line 11, sink "replace(": ~call_4 = ret_html.replace('{{ param }}', param) """ @@ -423,7 +416,7 @@ def test_XSS_sanitised_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/XSS_sanitised.py - > User input at line 7, trigger word "request.args.get(": + > User input at line 7, source "request.args.get(": ~call_1 = ret_request.args.get('param', 'not set') Reassigned in: File: examples/vulnerable_code/XSS_sanitised.py @@ -439,7 +432,7 @@ def test_XSS_sanitised_result(self): File: examples/vulnerable_code/XSS_sanitised.py > Line 13: ret_XSS1 = resp File: examples/vulnerable_code/XSS_sanitised.py - > reaches line 12, trigger word "replace(": + > reaches line 12, sink "replace(": ~call_5 = ret_html.replace('{{ param }}', param) This vulnerability is sanitised by: Label: ~call_2 = ret_Markup.escape(param) """ @@ -456,7 +449,7 @@ def test_XSS_variable_assign_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/XSS_variable_assign.py - > User input at line 6, trigger word "request.args.get(": + > User input at line 6, source "request.args.get(": ~call_1 = ret_request.args.get('param', 'not set') Reassigned in: File: examples/vulnerable_code/XSS_variable_assign.py @@ -470,7 +463,7 @@ def test_XSS_variable_assign_result(self): File: examples/vulnerable_code/XSS_variable_assign.py > Line 12: ret_XSS1 = resp File: examples/vulnerable_code/XSS_variable_assign.py - > reaches line 11, trigger word "replace(": + > reaches line 11, sink "replace(": ~call_4 = ret_html.replace('{{ param }}', other_var) """ @@ -482,7 +475,7 @@ def test_XSS_variable_multiple_assign_result(self): vulnerability_description = str(vulnerabilities[0]) EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/XSS_variable_multiple_assign.py - > User input at line 6, trigger word "request.args.get(": + > User input at line 6, source "request.args.get(": ~call_1 = ret_request.args.get('param', 'not set') Reassigned in: File: examples/vulnerable_code/XSS_variable_multiple_assign.py @@ -500,17 +493,14 @@ def test_XSS_variable_multiple_assign_result(self): File: examples/vulnerable_code/XSS_variable_multiple_assign.py > Line 17: ret_XSS1 = resp File: examples/vulnerable_code/XSS_variable_multiple_assign.py - > reaches line 15, trigger word "replace(": + > reaches line 15, sink "replace(": ~call_4 = ret_html.replace('{{ param }}', another_one) """ self.assertTrue(self.string_compare_alpha(vulnerability_description, EXPECTED_VULNERABILITY_DESCRIPTION)) -class EngineDjangoTest(BaseTestCase): - def run_empty(self): - return - +class EngineDjangoTest(VulnerabilitiesBaseTestCase): def run_analysis(self, path): self.cfg_create_from_file(path) cfg_list = [self.cfg] @@ -518,7 +508,7 @@ def run_analysis(self, path): FrameworkAdaptor(cfg_list, [], [], is_django_view_function) initialize_constraint_table(cfg_list) - analyse(cfg_list, analysis_type=ReachingDefinitionsTaintAnalysis) + analyse(cfg_list) trigger_word_file = os.path.join( 'pyt', @@ -526,14 +516,11 @@ def run_analysis(self, path): 'django_trigger_words.pyt' ) - return vulnerabilities.find_vulnerabilities( + return find_vulnerabilities( cfg_list, - ReachingDefinitionsTaintAnalysis, UImode.NORMAL, - VulnerabilityFiles( - default_blackbox_mapping_file, - trigger_word_file - ) + default_blackbox_mapping_file, + trigger_word_file ) def test_django_view_param(self): @@ -543,22 +530,19 @@ def test_django_view_param(self): EXPECTED_VULNERABILITY_DESCRIPTION = """ File: examples/vulnerable_code/django_XSS.py - > User input at line 4, trigger word "Framework function URL parameter": + > User input at line 4, source "Framework function URL parameter": param Reassigned in: File: examples/vulnerable_code/django_XSS.py > Line 5: ret_xss1 = ~call_1 File: examples/vulnerable_code/django_XSS.py - > reaches line 5, trigger word "render(": + > reaches line 5, sink "render(": ~call_1 = ret_render(request, 'templates/xss.html', 'param'param) """ self.assertTrue(self.string_compare_alpha(vulnerability_description, EXPECTED_VULNERABILITY_DESCRIPTION)) -class EngineEveryTest(BaseTestCase): - def run_empty(self): - return - +class EngineEveryTest(VulnerabilitiesBaseTestCase): def run_analysis(self, path): self.cfg_create_from_file(path) cfg_list = [self.cfg] @@ -566,7 +550,7 @@ def run_analysis(self, path): FrameworkAdaptor(cfg_list, [], [], is_function) initialize_constraint_table(cfg_list) - analyse(cfg_list, analysis_type=ReachingDefinitionsTaintAnalysis) + analyse(cfg_list) trigger_word_file = os.path.join( 'pyt', @@ -574,14 +558,11 @@ def run_analysis(self, path): 'all_trigger_words.pyt' ) - return vulnerabilities.find_vulnerabilities( + return find_vulnerabilities( cfg_list, - ReachingDefinitionsTaintAnalysis, UImode.NORMAL, - VulnerabilityFiles( - default_blackbox_mapping_file, - trigger_word_file - ) + default_blackbox_mapping_file, + trigger_word_file ) def test_self_is_not_tainted(self): diff --git a/tests/web_frameworks/__init__.py b/tests/web_frameworks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/framework_helper_test.py b/tests/web_frameworks/framework_helper_test.py similarity index 89% rename from tests/framework_helper_test.py rename to tests/web_frameworks/framework_helper_test.py index 9086e55..5ee8d21 100644 --- a/tests/framework_helper_test.py +++ b/tests/web_frameworks/framework_helper_test.py @@ -1,10 +1,11 @@ -from .base_test_case import BaseTestCase -from pyt.framework_adaptor import _get_func_nodes -from pyt.framework_helper import ( +from ..base_test_case import BaseTestCase + +from pyt.web_frameworks import ( is_django_view_function, is_flask_route_function, is_function, is_function_without_leading_, + _get_func_nodes ) @@ -12,7 +13,6 @@ class FrameworkEngineTest(BaseTestCase): def test_find_flask_functions(self): self.cfg_create_from_file('examples/example_inputs/django_flask_and_normal_functions.py') - cfg_list = [self.cfg] funcs = _get_func_nodes() i = 0 @@ -23,11 +23,9 @@ def test_find_flask_functions(self): # So it is supposed to be 1, because foo is not an app.route self.assertEqual(i, 1) - def test_find_every_function_without_leading_underscore(self): self.cfg_create_from_file('examples/example_inputs/django_flask_and_normal_functions.py') - cfg_list = [self.cfg] funcs = _get_func_nodes() i = 0 @@ -40,7 +38,6 @@ def test_find_every_function_without_leading_underscore(self): def test_find_every_function(self): self.cfg_create_from_file('examples/example_inputs/django_flask_and_normal_functions.py') - cfg_list = [self.cfg] funcs = _get_func_nodes() i = 0 @@ -53,7 +50,6 @@ def test_find_every_function(self): def test_find_django_functions(self): self.cfg_create_from_file('examples/example_inputs/django_flask_and_normal_functions.py') - cfg_list = [self.cfg] funcs = _get_func_nodes() i = 0 @@ -67,7 +63,6 @@ def test_find_django_functions(self): def test_find_django_views(self): self.cfg_create_from_file('examples/example_inputs/django_views.py') - cfg_list = [self.cfg] funcs = _get_func_nodes() i = 0 diff --git a/tox.ini b/tox.ini index 7039314..f85b156 100644 --- a/tox.ini +++ b/tox.ini @@ -7,5 +7,6 @@ deps = -rrequirements-dev.txt commands = coverage erase coverage run tests - coverage report --show-missing --fail-under 89 + coverage report --include=tests/* --fail-under 100 + coverage report --include=pyt/* --fail-under 88 pre-commit run