use super::fixpoint::Context as GeneralFPContext;
use super::graph::*;
use super::interprocedural_fixpoint_generic::*;
use crate::intermediate_representation::*;
use petgraph::graph::EdgeIndex;
use petgraph::graph::NodeIndex;
use std::marker::PhantomData;
pub trait Context<'a> {
type Value: PartialEq + Eq + Clone;
fn get_graph(&self) -> &Graph<'a>;
fn merge(&self, value1: &Self::Value, value2: &Self::Value) -> Self::Value;
fn update_def(&self, value: &Self::Value, def: &Term<Def>) -> Option<Self::Value>;
fn update_jump(
&self,
value: &Self::Value,
jump: &Term<Jmp>,
untaken_conditional: Option<&Term<Jmp>>,
target: &Term<Blk>,
) -> Option<Self::Value>;
fn update_call(
&self,
value: &Self::Value,
call: &Term<Jmp>,
target: &Node,
calling_convention: &Option<String>,
) -> Option<Self::Value>;
fn update_return(
&self,
value: Option<&Self::Value>,
value_before_call: Option<&Self::Value>,
call_term: &Term<Jmp>,
return_term: &Term<Jmp>,
calling_convention: &Option<String>,
) -> Option<Self::Value>;
fn update_call_stub(&self, value: &Self::Value, call: &Term<Jmp>) -> Option<Self::Value>;
fn specialize_conditional(
&self,
value: &Self::Value,
condition: &Expression,
block_before_condition: &Term<Blk>,
is_true: bool,
) -> Option<Self::Value>;
}
pub struct GeneralizedContext<'a, T: Context<'a>> {
context: T,
_phantom_graph_reference: PhantomData<Graph<'a>>,
}
impl<'a, T: Context<'a>> GeneralizedContext<'a, T> {
pub fn new(context: T) -> Self {
GeneralizedContext {
context,
_phantom_graph_reference: PhantomData,
}
}
pub fn get_context(&self) -> &T {
&self.context
}
}
impl<'a, T: Context<'a>> GeneralFPContext for GeneralizedContext<'a, T> {
type EdgeLabel = Edge<'a>;
type NodeLabel = Node<'a>;
type NodeValue = NodeValue<T::Value>;
fn get_graph(&self) -> &Graph<'a> {
self.context.get_graph()
}
fn merge(&self, val1: &Self::NodeValue, val2: &Self::NodeValue) -> Self::NodeValue {
use NodeValue::*;
match (val1, val2) {
(Value(value1), Value(value2)) => Value(self.context.merge(value1, value2)),
(
CallFlowCombinator {
call_stub: call1,
interprocedural_flow: return1,
},
CallFlowCombinator {
call_stub: call2,
interprocedural_flow: return2,
},
) => CallFlowCombinator {
call_stub: merge_option(call1, call2, |v1, v2| self.context.merge(v1, v2)),
interprocedural_flow: merge_option(return1, return2, |v1, v2| {
self.context.merge(v1, v2)
}),
},
_ => panic!("Malformed CFG in fixpoint computation"),
}
}
fn update_edge(
&self,
node_value: &Self::NodeValue,
edge: EdgeIndex,
) -> Option<Self::NodeValue> {
let graph = self.context.get_graph();
let (start_node, end_node) = graph.edge_endpoints(edge).unwrap();
match graph.edge_weight(edge).unwrap() {
Edge::Block => {
let block_term = graph.node_weight(start_node).unwrap().get_block();
let value = node_value.unwrap_value();
let defs = &block_term.term.defs;
let end_val = defs.iter().try_fold(value.clone(), |accum, def| {
self.context.update_def(&accum, def)
});
end_val.map(NodeValue::Value)
}
Edge::CallCombine(_) => Some(Self::NodeValue::Value(node_value.unwrap_value().clone())),
Edge::Call(call) => self
.context
.update_call(
node_value.unwrap_value(),
call,
&graph[end_node],
&graph[end_node].get_sub().term.calling_convention,
)
.map(NodeValue::Value),
Edge::CrCallStub => Some(NodeValue::CallFlowCombinator {
call_stub: Some(node_value.unwrap_value().clone()),
interprocedural_flow: None,
}),
Edge::CrReturnStub => Some(NodeValue::CallFlowCombinator {
call_stub: None,
interprocedural_flow: Some(node_value.unwrap_value().clone()),
}),
Edge::ReturnCombine(call_term) => match node_value {
NodeValue::Value(_) => panic!("Unexpected interprocedural fixpoint graph state"),
NodeValue::CallFlowCombinator {
call_stub,
interprocedural_flow,
} => {
let (return_from_block, return_from_sub) = match graph.node_weight(start_node) {
Some(Node::CallReturn {
call: _,
return_: (return_from_block, return_from_sub),
}) => (return_from_block, return_from_sub),
_ => panic!("Malformed Control flow graph"),
};
let return_from_jmp = &return_from_block.term.jmps[0];
self.context
.update_return(
interprocedural_flow.as_ref(),
call_stub.as_ref(),
call_term,
return_from_jmp,
&return_from_sub.term.calling_convention,
)
.map(NodeValue::Value)
}
},
Edge::ExternCallStub(call) => self
.context
.update_call_stub(node_value.unwrap_value(), call)
.map(NodeValue::Value),
Edge::Jump(jump, untaken_conditional) => {
let value_after_condition = if let Jmp::CBranch {
target: _,
condition,
} = &jump.term
{
let block = graph[start_node].get_block();
self.context.specialize_conditional(
node_value.unwrap_value(),
condition,
block,
true,
)
} else if let Some(untaken_conditional_jump) = untaken_conditional {
if let Jmp::CBranch {
target: _,
condition,
} = &untaken_conditional_jump.term
{
let block = graph[start_node].get_block();
self.context.specialize_conditional(
node_value.unwrap_value(),
condition,
block,
false,
)
} else {
panic!("Malformed control flow graph");
}
} else {
Some(node_value.unwrap_value().clone())
};
if let Some(value) = value_after_condition {
self.context
.update_jump(
&value,
jump,
*untaken_conditional,
graph[end_node].get_block(),
)
.map(NodeValue::Value)
} else {
None
}
}
}
}
}
pub fn create_computation<'a, T: Context<'a>>(
problem: T,
default_value: Option<T::Value>,
) -> super::fixpoint::Computation<GeneralizedContext<'a, T>> {
let generalized_problem = GeneralizedContext::new(problem);
super::fixpoint::Computation::new(generalized_problem, default_value.map(NodeValue::Value))
}
pub fn create_bottom_up_worklist(graph: &Graph) -> Vec<NodeIndex> {
let mut graph = graph.clone();
graph.retain_edges(|frozen, edge| !matches!(frozen[edge], Edge::Call(..)));
petgraph::algo::kosaraju_scc(&graph)
.into_iter()
.flatten()
.collect()
}
pub fn create_top_down_worklist(graph: &Graph) -> Vec<NodeIndex> {
let mut graph = graph.clone();
graph.retain_edges(|frozen, edge| !matches!(frozen[edge], Edge::CrReturnStub));
petgraph::algo::kosaraju_scc(&graph)
.into_iter()
.flatten()
.collect()
}
pub fn create_computation_with_bottom_up_worklist_order<'a, T: Context<'a>>(
problem: T,
default_value: Option<T::Value>,
) -> super::fixpoint::Computation<GeneralizedContext<'a, T>> {
let priority_sorted_nodes: Vec<NodeIndex> = create_bottom_up_worklist(problem.get_graph());
let generalized_problem = GeneralizedContext::new(problem);
super::fixpoint::Computation::from_node_priority_list(
generalized_problem,
default_value.map(NodeValue::Value),
priority_sorted_nodes,
)
}
pub fn create_computation_with_top_down_worklist_order<'a, T: Context<'a>>(
problem: T,
default_value: Option<T::Value>,
) -> super::fixpoint::Computation<GeneralizedContext<'a, T>> {
let priority_sorted_nodes: Vec<NodeIndex> = create_top_down_worklist(problem.get_graph());
let generalized_problem = GeneralizedContext::new(problem);
super::fixpoint::Computation::from_node_priority_list(
generalized_problem,
default_value.map(NodeValue::Value),
priority_sorted_nodes,
)
}
#[cfg(test)]
mod tests {
use crate::{
analysis::{
expression_propagation::Context,
forward_interprocedural_fixpoint::{
create_computation_with_bottom_up_worklist_order,
create_computation_with_top_down_worklist_order,
},
},
expr,
intermediate_representation::*,
};
use std::collections::{BTreeMap, HashMap};
fn new_block(name: &str) -> Term<Blk> {
Term {
tid: Tid::new(name),
term: Blk {
defs: vec![],
jmps: vec![],
indirect_jmp_targets: Vec::new(),
},
}
}
fn mock_project() -> Project {
let mut callee_block = new_block("callee block");
callee_block.term.jmps.push(Term {
tid: Tid::new("ret"),
term: Jmp::Return(expr!("42:4")),
});
let called_function = Term {
tid: Tid::new("called_function"),
term: Sub {
name: "called_function".to_string(),
blocks: vec![callee_block],
calling_convention: Some("_stdcall".to_string()),
},
};
let mut caller_block_2 = new_block("caller_block_2");
let mut caller_block_1 = new_block("caller_block_1");
caller_block_1.term.jmps.push(Term {
tid: Tid::new("call"),
term: Jmp::Call {
target: called_function.tid.clone(),
return_: Some(caller_block_2.tid.clone()),
},
});
caller_block_2.term.jmps.push(Term {
tid: Tid::new("jmp"),
term: Jmp::Branch(caller_block_1.tid.clone()),
});
let caller_function = Term {
tid: Tid::new("caller_function"),
term: Sub {
name: "caller_function".to_string(),
blocks: vec![caller_block_1, caller_block_2],
calling_convention: Some("_stdcall".to_string()),
},
};
let mut project = Project::mock_x64();
project.program.term.subs = BTreeMap::from([
(caller_function.tid.clone(), caller_function.clone()),
(called_function.tid.clone(), called_function.clone()),
]);
project
}
#[test]
fn check_bottom_up_worklist() {
let project = mock_project();
let graph = crate::analysis::graph::get_program_cfg(&project.program);
let context = Context::new(&graph);
let comp = create_computation_with_bottom_up_worklist_order(context, Some(HashMap::new()));
for node in comp.get_worklist()[6..].iter() {
match graph[*node] {
crate::analysis::graph::Node::BlkStart(_, sub)
| crate::analysis::graph::Node::BlkEnd(_, sub) => {
assert_eq!(sub.tid, Tid::new("called_function"))
}
_ => panic!(),
}
}
}
#[test]
fn check_top_down_worklist() {
let project = mock_project();
let graph = crate::analysis::graph::get_program_cfg(&project.program);
let context = Context::new(&graph);
let comp = create_computation_with_top_down_worklist_order(context, Some(HashMap::new()));
for node in comp.get_worklist()[..2].iter() {
match graph[*node] {
crate::analysis::graph::Node::BlkStart(_, sub)
| crate::analysis::graph::Node::BlkEnd(_, sub) => {
assert_eq!(sub.tid, Tid::new("called_function"))
}
_ => panic!(),
}
}
}
}