use crate::abstract_domain::DataDomain;
use crate::abstract_domain::IntervalDomain;
use crate::abstract_domain::RegisterDomain;
use crate::abstract_domain::TryToInterval;
use crate::analysis::pointer_inference::PointerInference;
use crate::analysis::vsa_results::*;
use crate::intermediate_representation::*;
use crate::pipeline::AnalysisResults;
use crate::utils::log::CweWarning;
use crate::utils::log::LogMessage;
use crate::utils::symbol_utils::get_callsites;
use crate::utils::symbol_utils::get_symbol_map;
use crate::CweModule;
use serde::Deserialize;
use serde::Serialize;
pub static CWE_MODULE: CweModule = CweModule {
name: "CWE789",
version: "0.1",
run: check_cwe,
};
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
pub struct Config {
stack_threshold: u64,
heap_threshold: u64,
symbols: Vec<String>,
}
fn is_assign_on_sp(def: &Def, sp: &Variable) -> bool {
if let &Def::Assign { var, value: _ } = &def {
if var == sp {
return true;
}
}
false
}
fn exceeds_threshold_on_stack(interval: DataDomain<IntervalDomain>, threshold: u64) -> bool {
for rel_interval in interval.get_relative_values().values() {
if let Ok(offset) = rel_interval.try_to_interval() {
if let Ok(start) = offset.start.try_to_i128() {
if start < -i128::from(threshold) {
return true;
}
}
}
}
false
}
fn exceeds_threshold_on_call(interval: DataDomain<IntervalDomain>, threshold: u64) -> bool {
if let Some(interval) = interval.get_absolute_value() {
if let Ok(offset) = interval.try_to_interval() {
if let Ok(end) = offset.end.try_to_u128() {
if end > u128::from(threshold) {
return true;
}
}
}
}
false
}
fn multiply_args_for_calloc(
pir: &PointerInference,
jmp_tid: &Tid,
parms: Vec<&Arg>,
) -> Option<DataDomain<IntervalDomain>> {
if let (Some(nmeb), Some(size)) = (
pir.eval_parameter_arg_at_call(jmp_tid, parms[0]),
pir.eval_parameter_arg_at_call(jmp_tid, parms[1]),
) {
return Some(nmeb.bin_op(BinOpType::IntMult, &size));
}
None
}
fn generate_cwe_warning(allocation: &Tid, is_stack_allocation: bool) -> CweWarning {
CweWarning::new(
CWE_MODULE.name,
CWE_MODULE.version,
format!(
"(Large memory allocation) Potential{}memory exhaustion at 0x{}",
match is_stack_allocation {
true => " stack ",
false => " heap ",
},
allocation.address
),
)
.tids(vec![format!("{allocation}")])
.addresses(vec![allocation.address.clone()])
.symbols(vec![])
}
pub fn check_cwe(
analysis_results: &AnalysisResults,
cwe_params: &serde_json::Value,
) -> (Vec<LogMessage>, Vec<CweWarning>) {
let project = analysis_results.project;
let config: Config = serde_json::from_value(cwe_params.clone()).unwrap();
let mut cwe_warnings = Vec::new();
let pir = analysis_results.pointer_inference.unwrap();
let symbol_map = get_symbol_map(project, &config.symbols);
'functions: for sub in project.program.term.subs.values() {
for (_, jump, symbol) in get_callsites(sub, &symbol_map) {
if let Some(interval) = match symbol.name.as_str() {
"calloc" => multiply_args_for_calloc(
pir,
&jump.tid,
vec![&symbol.parameters[0], &symbol.parameters[1]],
),
"realloc" => pir.eval_parameter_arg_at_call(&jump.tid, &symbol.parameters[1]),
_ => pir.eval_parameter_arg_at_call(&jump.tid, &symbol.parameters[0]),
} {
if exceeds_threshold_on_call(interval, config.heap_threshold) {
cwe_warnings.push(generate_cwe_warning(&jump.tid, false));
}
}
}
for blk in &sub.term.blocks {
let assign_on_sp: Vec<&Term<Def>> = blk
.term
.defs
.iter()
.filter(|x| is_assign_on_sp(&x.term, &project.stack_pointer_register))
.collect();
for assign in assign_on_sp {
if let Some(interval) = pir.eval_value_at_def(&assign.tid) {
if exceeds_threshold_on_stack(interval, config.stack_threshold) {
cwe_warnings.push(generate_cwe_warning(&assign.tid, true));
continue 'functions;
}
}
}
}
}
cwe_warnings.dedup();
(Vec::new(), cwe_warnings)
}