1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
//! Helper functions for common tasks utilizing extern symbols,
//! e.g. searching for calls to a specific extern symbol.

use std::collections::{HashMap, HashSet};

use crate::intermediate_representation::*;

/// Find the extern symbol object for a symbol name and return the symbol tid and name.
pub fn find_symbol<'a>(prog: &'a Term<Program>, name: &str) -> Option<(&'a Tid, &'a str)> {
    let mut symbol: Option<(&'a Tid, &'a str)> = None;
    prog.term.extern_symbols.iter().find(|(_tid, sym)| {
        if name == sym.name {
            symbol = Some((&sym.tid, &sym.name));
            true
        } else {
            false
        }
    });

    symbol
}

/// Match direct calls' target tids in the program's subroutines
/// with the tids of the external symbols given to the function.
/// When a match was found, add a triple of (caller name, callsite tid, callee name)
/// to a vector. Lastly, return the vector with all callsites of all given external symbols.
pub fn get_calls_to_symbols<'a, 'b>(
    sub: &'a Term<Sub>,
    symbols: &'b HashMap<&'a Tid, &'a str>,
) -> Vec<(&'a str, &'a Tid, &'a str)> {
    let mut calls: Vec<(&'a str, &'a Tid, &'a str)> = Vec::new();
    for blk in sub.term.blocks.iter() {
        for jmp in blk.term.jmps.iter() {
            if let Jmp::Call { target: dst, .. } = &jmp.term {
                if symbols.contains_key(dst) {
                    calls.push((sub.term.name.as_str(), &jmp.tid, symbols.get(dst).unwrap()));
                }
            }
        }
    }
    calls
}

/// Get a map from TIDs to the corresponding extern symbol struct.
///
/// Only symbols with names contained in `symbols_to_find` are contained in the
/// map.
///
/// This is O(|symbols_to_find| x |extern_symbols|), prefer
/// [`get_symbol_map_fast`] if speed matters.
pub fn get_symbol_map<'a>(
    project: &'a Project,
    symbols_to_find: &[String],
) -> HashMap<Tid, &'a ExternSymbol> {
    let mut tid_map = HashMap::new();
    for symbol_name in symbols_to_find {
        if let Some((tid, symbol)) =
            project
                .program
                .term
                .extern_symbols
                .iter()
                .find_map(|(_tid, symbol)| {
                    if symbol.name == *symbol_name {
                        Some((symbol.tid.clone(), symbol))
                    } else {
                        None
                    }
                })
        {
            tid_map.insert(tid, symbol);
        }
    }
    tid_map
}

/// Get a map from TIDs to the corresponding extern symbol struct.
///
/// Only symbols with names contained in `symbols_to_find` are contained in the
/// map.
///
/// More efficient than [`get_symbol_map`], prefer this if `symbols_to_find` is
/// huge since this is O(|extern_symbols|) and not
/// O(|symbols_to_find|x|extern_symbols|).
pub fn get_symbol_map_fast<'a>(
    project: &'a Project,
    symbols_to_find: &HashSet<String>,
) -> HashMap<Tid, &'a ExternSymbol> {
    project
        .program
        .term
        .extern_symbols
        .iter()
        .filter_map(|(_tid, symbol)| {
            if symbols_to_find.contains(&symbol.name) {
                Some((symbol.tid.clone(), symbol))
            } else {
                None
            }
        })
        .collect()
}

/// Find calls to TIDs contained as keys in the given symbol map.
/// For each match return the block containing the call,
/// the jump term representing the call itself and the symbol corresponding to the TID from the symbol map.
pub fn get_callsites<'a>(
    sub: &'a Term<Sub>,
    symbol_map: &HashMap<Tid, &'a ExternSymbol>,
) -> Vec<(&'a Term<Blk>, &'a Term<Jmp>, &'a ExternSymbol)> {
    let mut callsites = Vec::new();
    for blk in sub.term.blocks.iter() {
        for jmp in blk.term.jmps.iter() {
            if let Jmp::Call { target: dst, .. } = &jmp.term {
                if let Some(symbol) = symbol_map.get(dst) {
                    callsites.push((blk, jmp, *symbol));
                }
            }
        }
    }
    callsites
}