use super::context::BoundsMetadata;
use super::Context;
use super::Data;
use crate::abstract_domain::*;
use crate::analysis::function_signature::FunctionSignature;
use crate::intermediate_representation::Project;
use crate::prelude::*;
use std::collections::BTreeMap;
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct State {
stack_id: AbstractIdentifier,
object_lower_bounds: DomainMap<AbstractIdentifier, BitvectorDomain, UnionMergeStrategy>,
object_upper_bounds: DomainMap<AbstractIdentifier, BitvectorDomain, UnionMergeStrategy>,
}
impl State {
pub fn new(function_tid: &Tid, function_sig: &FunctionSignature, project: &Project) -> State {
let stack_id =
AbstractIdentifier::from_var(function_tid.clone(), &project.stack_pointer_register);
let stack_upper_bound: i64 = match project.stack_pointer_register.name.as_str() {
"ESP" => 4,
"RSP" => 8,
_ => 0,
};
let stack_upper_bound = std::cmp::max(
stack_upper_bound,
function_sig.get_stack_params_total_size(&project.stack_pointer_register),
);
let object_lower_bounds = BTreeMap::from([(
stack_id.clone(),
BitvectorDomain::new_top(stack_id.bytesize()),
)]);
let object_upper_bounds = BTreeMap::from([(
stack_id.clone(),
Bitvector::from_i64(stack_upper_bound)
.into_resize_signed(stack_id.bytesize())
.into(),
)]);
State {
stack_id,
object_lower_bounds: object_lower_bounds.into(),
object_upper_bounds: object_upper_bounds.into(),
}
}
pub fn check_address_access(
&mut self,
address: &Data,
value_size: ByteSize,
context: &Context,
) -> Vec<String> {
let mut out_of_bounds_access_warnings = Vec::new();
for (id, offset) in address.get_relative_values() {
if !self.object_lower_bounds.contains_key(id) {
self.compute_bounds_of_id(id, context);
}
if let Ok((lower_offset, upper_offset)) = offset.try_to_offset_interval() {
if let Ok(lower_bound) = self.object_lower_bounds.get(id).unwrap().try_to_offset() {
if lower_bound > lower_offset {
out_of_bounds_access_warnings.push(format!("For the object ID {id} access to the offset {lower_offset} may be smaller than the lower object bound of {lower_bound}."));
if let (
Some(BoundsMetadata {
source: Some(source),
..
}),
_,
) = context.compute_bounds_of_id(id, &self.stack_id)
{
out_of_bounds_access_warnings.push(format!("The object bound is based on the possible source value {:#} for the object ID.", source.to_json_compact()));
let call_sequence_tids = collect_tids_for_cwe_warning(
source.get_if_unique_target().unwrap().0,
self,
context,
);
out_of_bounds_access_warnings
.push(format!("Relevant callgraph TIDs: [{call_sequence_tids}]"));
} else {
let mut callgraph_tids = format!("{}", self.stack_id.get_tid());
for call_tid in id.get_path_hints() {
callgraph_tids += &format!(", {call_tid}");
}
out_of_bounds_access_warnings
.push(format!("Relevant callgraph TIDs: [{callgraph_tids}]",));
}
self.object_lower_bounds
.insert(id.clone(), BitvectorDomain::new_top(address.bytesize()));
}
}
if let Ok(upper_bound) = self.object_upper_bounds.get(id).unwrap().try_to_offset() {
if upper_bound < upper_offset + (u64::from(value_size) as i64) {
out_of_bounds_access_warnings.push(format!("For the object ID {} access to the offset {} (size {}) may overflow the upper object bound of {}.",
id,
upper_offset,
u64::from(value_size),
upper_bound,
));
if let (
_,
Some(BoundsMetadata {
source: Some(source),
..
}),
) = context.compute_bounds_of_id(id, &self.stack_id)
{
out_of_bounds_access_warnings.push(format!("The object bound is based on the possible source value {:#} for the object ID.", source.to_json_compact()));
let call_sequence_tids = collect_tids_for_cwe_warning(
source.get_if_unique_target().unwrap().0,
self,
context,
);
out_of_bounds_access_warnings
.push(format!("Relevant callgraph TIDs: [{call_sequence_tids}]"));
} else {
let mut callgraph_tids = format!("{}", self.stack_id.get_tid());
for call_tid in id.get_path_hints() {
callgraph_tids += &format!(", {call_tid}");
}
out_of_bounds_access_warnings
.push(format!("Relevant callgraph TIDs: [{callgraph_tids}]",));
}
self.object_upper_bounds
.insert(id.clone(), BitvectorDomain::new_top(address.bytesize()));
}
}
}
}
out_of_bounds_access_warnings
}
fn compute_bounds_of_id(&mut self, object_id: &AbstractIdentifier, context: &Context) {
let (lower_bound, upper_bound) = context.compute_bounds_of_id(object_id, &self.stack_id);
let lower_bound = match lower_bound {
Some(bound_metadata) => Bitvector::from_i64(bound_metadata.resulting_bound)
.into_resize_signed(object_id.bytesize())
.into(),
None => BitvectorDomain::new_top(object_id.bytesize()),
};
let upper_bound = match upper_bound {
Some(bound_metadata) => Bitvector::from_i64(bound_metadata.resulting_bound)
.into_resize_signed(object_id.bytesize())
.into(),
None => BitvectorDomain::new_top(object_id.bytesize()),
};
self.object_lower_bounds
.insert(object_id.clone(), lower_bound);
self.object_upper_bounds
.insert(object_id.clone(), upper_bound);
}
}
impl AbstractDomain for State {
fn merge(&self, other: &State) -> State {
State {
stack_id: self.stack_id.clone(),
object_lower_bounds: self.object_lower_bounds.merge(&other.object_lower_bounds),
object_upper_bounds: self.object_upper_bounds.merge(&other.object_upper_bounds),
}
}
fn is_top(&self) -> bool {
false
}
}
impl State {
#[allow(dead_code)]
pub fn to_json_compact(&self) -> serde_json::Value {
use serde_json::*;
let mut state_map = Map::new();
state_map.insert(
"stack_id".to_string(),
Value::String(self.stack_id.to_string()),
);
let lower_bounds: Vec<_> = self
.object_lower_bounds
.iter()
.map(|(id, bound)| Value::String(format!("{id}: {bound}")))
.collect();
state_map.insert("lower_bounds".to_string(), Value::Array(lower_bounds));
let upper_bounds: Vec<_> = self
.object_upper_bounds
.iter()
.map(|(id, bound)| Value::String(format!("{id}: {bound}")))
.collect();
state_map.insert("upper_bounds".to_string(), Value::Array(upper_bounds));
Value::Object(state_map)
}
}
fn collect_tids_for_cwe_warning(
id: &AbstractIdentifier,
state: &State,
context: &Context,
) -> String {
use crate::analysis::callgraph::find_call_sequences_to_target;
let caller_tid = if context.project.program.term.subs.contains_key(id.get_tid()) {
id.get_tid().clone()
} else {
let root_call_tid = if let Some(root_call) = id.get_path_hints().last() {
root_call
} else {
id.get_tid()
};
context
.project
.program
.term
.find_sub_containing_jump(root_call_tid)
.expect("Caller corresponding to call does not exist.")
};
let mut tids = Vec::new();
tids.push(caller_tid.clone());
tids.extend(id.get_path_hints().iter().cloned());
if caller_tid != *state.stack_id.get_tid() {
let call_sequence_tids = find_call_sequences_to_target(
&context.callgraph,
&caller_tid,
state.stack_id.get_tid(),
);
tids.extend(call_sequence_tids);
}
tids.iter()
.map(|tid| format!("{tid}"))
.reduce(|accum, elem| format!("{accum}, {elem}"))
.unwrap()
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::{intermediate_representation::*, variable};
#[test]
fn test_new() {
let context = Context::mock_x64();
let state = State::new(
&Tid::new("func"),
&FunctionSignature::mock_x64(),
context.project,
);
let stack_id = AbstractIdentifier::from_var(Tid::new("func"), &variable!("RSP:8"));
assert_eq!(state.stack_id, stack_id);
assert_eq!(state.object_lower_bounds.len(), 1);
assert_eq!(state.object_upper_bounds.len(), 1);
assert_eq!(
*state.object_lower_bounds.get(&stack_id).unwrap(),
BitvectorDomain::new_top(ByteSize::new(8))
);
assert_eq!(
*state.object_upper_bounds.get(&stack_id).unwrap(),
Bitvector::from_i64(8).into()
);
}
#[test]
fn test_check_address_access() {
let context = Context::mock_x64();
let mut state = State::new(
&Tid::new("func"),
&FunctionSignature::mock_x64(),
context.project,
);
let stack_id = AbstractIdentifier::from_var(Tid::new("func"), &variable!("RSP:8"));
let address = Data::from_target(stack_id.clone(), Bitvector::from_i64(-12).into());
assert!(state
.check_address_access(&address, ByteSize::new(8), &context)
.is_empty());
let address = Data::from_target(stack_id.clone(), Bitvector::from_i64(4).into());
assert_eq!(
state
.check_address_access(&address, ByteSize::new(8), &context)
.len(),
2
);
let address = Data::from_target(stack_id, Bitvector::from_i64(8).into());
assert!(state
.check_address_access(&address, ByteSize::new(8), &context)
.is_empty());
}
#[test]
fn test_compute_bounds_of_id() {
let mut context = Context::mock_x64();
context
.malloc_tid_to_object_size_map
.insert(Tid::new("malloc_call"), Data::from(Bitvector::from_i64(42)));
context
.call_to_caller_fn_map
.insert(Tid::new("malloc_call"), Tid::new("main"));
let mut state = State::new(
&Tid::new("func"),
&FunctionSignature::mock_x64(),
context.project,
);
state.compute_bounds_of_id(&AbstractIdentifier::mock("malloc_call", "RAX", 8), &context);
assert_eq!(state.object_lower_bounds.len(), 2);
assert_eq!(
state.object_lower_bounds[&AbstractIdentifier::mock("malloc_call", "RAX", 8)],
Bitvector::from_i64(0).into()
);
assert_eq!(
state.object_upper_bounds[&AbstractIdentifier::mock("malloc_call", "RAX", 8)],
Bitvector::from_i64(42).into()
);
}
}